[Format] Convert all Python code w/o CI (#6448)
authorJared Roesch <jroesch@octoml.ai>
Fri, 11 Sep 2020 13:17:24 +0000 (06:17 -0700)
committerGitHub <noreply@github.com>
Fri, 11 Sep 2020 13:17:24 +0000 (22:17 +0900)
* Add black setup

* Tweak pyproject.toml

* Fix syntax issues

* Fix

* Tweak

* Black all Python code

1013 files changed:
apps/android_camera/models/prepare_model.py
apps/android_rpc/tests/android_rpc_test.py
apps/benchmark/arm_cpu_imagenet_bench.py
apps/benchmark/gpu_imagenet_bench.py
apps/benchmark/mobile_gpu_imagenet_bench.py
apps/benchmark/util.py
apps/bundle_deploy/build_model.py
apps/dso_plugin_module/test_plugin_module.py
apps/extension/python/tvm_ext/__init__.py
apps/extension/tests/test_ext.py
apps/howto_deploy/prepare_test_libs.py
apps/howto_deploy/python_deploy.py
apps/ios_rpc/init_proj.py
apps/ios_rpc/tests/ios_rpc_mobilenet.py
apps/ios_rpc/tests/ios_rpc_test.py
apps/lldb/tvm.py
apps/sgx/read_results.py
apps/sgx/src/build_model.py
apps/tf_tvmdsoop/tests/test_tfop_module.py
apps/topi_recipe/broadcast/test_broadcast_map.py
apps/topi_recipe/conv/depthwise_conv2d_test.py
apps/topi_recipe/conv/test_conv2d_hwcn_map.py
apps/topi_recipe/conv/test_conv_int8_arm.py
apps/topi_recipe/conv/test_conv_int8_intel.py
apps/topi_recipe/gemm/android_gemm_square.py
apps/topi_recipe/gemm/cuda_gemm_square.py
apps/topi_recipe/gemm/gemm_int8.py
apps/topi_recipe/reduce/test_reduce_map.py
apps/topi_recipe/rnn/lstm.py
apps/topi_recipe/rnn/matexp.py
apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
conda/render_cuda.py
conftest.py
docs/conf.py
golang/sample/deploy.py
golang/sample/gen_mobilenet_lib.py
jvm/core/src/test/scripts/test_add_cpu.py
jvm/core/src/test/scripts/test_add_gpu.py
jvm/core/src/test/scripts/test_graph_runtime.py
jvm/core/src/test/scripts/test_rpc_proxy_server.py
nnvm/amalgamation/amalgamation.py
nnvm/amalgamation/generate.py
pyproject.toml [new file with mode: 0644]
python/setup.py
python/tvm/__init__.py
python/tvm/_ffi/_ctypes/ndarray.py
python/tvm/_ffi/_ctypes/object.py
python/tvm/_ffi/_ctypes/packed_func.py
python/tvm/_ffi/_ctypes/types.py
python/tvm/_ffi/_pyversion.py
python/tvm/_ffi/base.py
python/tvm/_ffi/libinfo.py
python/tvm/_ffi/registry.py
python/tvm/_ffi/runtime_ctypes.py
python/tvm/arith/analyzer.py
python/tvm/arith/int_set.py
python/tvm/arith/int_solver.py
python/tvm/arith/pattern.py
python/tvm/auto_scheduler/__init__.py
python/tvm/auto_scheduler/auto_schedule.py
python/tvm/auto_scheduler/compute_dag.py
python/tvm/auto_scheduler/cost_model/cost_model.py
python/tvm/auto_scheduler/cost_model/xgb_model.py
python/tvm/auto_scheduler/feature.py
python/tvm/auto_scheduler/loop_state.py
python/tvm/auto_scheduler/measure.py
python/tvm/auto_scheduler/measure_record.py
python/tvm/auto_scheduler/search_policy.py
python/tvm/auto_scheduler/utils.py
python/tvm/auto_scheduler/workload_registry.py
python/tvm/autotvm/__init__.py
python/tvm/autotvm/database.py
python/tvm/autotvm/env.py
python/tvm/autotvm/feature.py
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py
python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py
python/tvm/autotvm/graph_tuner/pbqp_tuner.py
python/tvm/autotvm/graph_tuner/utils/__init__.py
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
python/tvm/autotvm/graph_tuner/utils/utils.py
python/tvm/autotvm/measure/__init__.py
python/tvm/autotvm/measure/executor.py
python/tvm/autotvm/measure/local_executor.py
python/tvm/autotvm/measure/measure.py
python/tvm/autotvm/measure/measure_methods.py
python/tvm/autotvm/record.py
python/tvm/autotvm/task/__init__.py
python/tvm/autotvm/task/code_hash.py
python/tvm/autotvm/task/dispatcher.py
python/tvm/autotvm/task/relay_integration.py
python/tvm/autotvm/task/space.py
python/tvm/autotvm/task/task.py
python/tvm/autotvm/task/topi_integration.py
python/tvm/autotvm/tophub.py
python/tvm/autotvm/tuner/callback.py
python/tvm/autotvm/tuner/ga_tuner.py
python/tvm/autotvm/tuner/index_based_tuner.py
python/tvm/autotvm/tuner/metric.py
python/tvm/autotvm/tuner/model_based_tuner.py
python/tvm/autotvm/tuner/sa_model_optimizer.py
python/tvm/autotvm/tuner/tuner.py
python/tvm/autotvm/tuner/xgboost_cost_model.py
python/tvm/autotvm/tuner/xgboost_tuner.py
python/tvm/autotvm/util.py
python/tvm/contrib/binutil.py
python/tvm/contrib/cblas.py
python/tvm/contrib/cc.py
python/tvm/contrib/clang.py
python/tvm/contrib/coreml_runtime.py
python/tvm/contrib/cublas.py
python/tvm/contrib/cublaslt.py
python/tvm/contrib/cudnn.py
python/tvm/contrib/debugger/debug_result.py
python/tvm/contrib/debugger/debug_runtime.py
python/tvm/contrib/dlpack.py
python/tvm/contrib/download.py
python/tvm/contrib/emcc.py
python/tvm/contrib/graph_runtime.py
python/tvm/contrib/hexagon.py
python/tvm/contrib/miopen.py
python/tvm/contrib/mkl.py
python/tvm/contrib/mkldnn.py
python/tvm/contrib/mps.py
python/tvm/contrib/mxnet.py
python/tvm/contrib/ndk.py
python/tvm/contrib/nnpack.py
python/tvm/contrib/nvcc.py
python/tvm/contrib/peak.py
python/tvm/contrib/pickle_memoize.py
python/tvm/contrib/random.py
python/tvm/contrib/rocblas.py
python/tvm/contrib/rocm.py
python/tvm/contrib/rpc.py
python/tvm/contrib/sdaccel.py
python/tvm/contrib/sparse.py
python/tvm/contrib/spirv.py
python/tvm/contrib/tar.py
python/tvm/contrib/target/coreml.py
python/tvm/contrib/target/onnx.py
python/tvm/contrib/tedd.py
python/tvm/contrib/tf_op/module.py
python/tvm/contrib/tflite_runtime.py
python/tvm/contrib/util.py
python/tvm/contrib/xcode.py
python/tvm/driver/build_module.py
python/tvm/driver/tvmc/common.py
python/tvm/driver/tvmc/main.py
python/tvm/error.py
python/tvm/exec/autotvm_log_editor.py
python/tvm/exec/measure_peak.py
python/tvm/exec/query_rpc_tracker.py
python/tvm/exec/rpc_proxy.py
python/tvm/exec/rpc_server.py
python/tvm/exec/rpc_tracker.py
python/tvm/hybrid/parser.py
python/tvm/hybrid/registry.py
python/tvm/hybrid/special_stmt.py
python/tvm/hybrid/ty.py
python/tvm/ir/_ffi_transform_api.py
python/tvm/ir/adt.py
python/tvm/ir/attrs.py
python/tvm/ir/base.py
python/tvm/ir/container.py
python/tvm/ir/expr.py
python/tvm/ir/function.py
python/tvm/ir/json_compact.py
python/tvm/ir/module.py
python/tvm/ir/op.py
python/tvm/ir/tensor_type.py
python/tvm/ir/transform.py
python/tvm/ir/type.py
python/tvm/ir/type_relation.py
python/tvm/micro/base.py
python/tvm/micro/device/arm/stm32f746xx.py
python/tvm/micro/device/base.py
python/tvm/micro/device/host.py
python/tvm/micro/device/riscv_spike.py
python/tvm/micro/func_registry.py
python/tvm/parser/__init__.py
python/tvm/relay/analysis/analysis.py
python/tvm/relay/analysis/annotated_regions.py
python/tvm/relay/analysis/feature.py
python/tvm/relay/analysis/sparse_dense.py
python/tvm/relay/backend/compile_engine.py
python/tvm/relay/backend/graph_runtime_factory.py
python/tvm/relay/backend/interpreter.py
python/tvm/relay/backend/vm.py
python/tvm/relay/base.py
python/tvm/relay/build_module.py
python/tvm/relay/data_dep_optimization/__init__.py
python/tvm/relay/data_dep_optimization/bsr_dense.py
python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py
python/tvm/relay/data_dep_optimization/utils.py
python/tvm/relay/dataflow_pattern/__init__.py
python/tvm/relay/debug.py
python/tvm/relay/expr.py
python/tvm/relay/expr_functor.py
python/tvm/relay/frontend/caffe.py
python/tvm/relay/frontend/caffe2.py
python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/coreml.py
python/tvm/relay/frontend/darknet.py
python/tvm/relay/frontend/keras.py
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/mxnet_qnn_op_utils.py
python/tvm/relay/frontend/nnvm_common.py
python/tvm/relay/frontend/onnx.py
python/tvm/relay/frontend/pytorch.py
python/tvm/relay/frontend/qnn_torch.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/frontend/tensorflow_parser.py
python/tvm/relay/frontend/tflite.py
python/tvm/relay/frontend/tflite_flexbuffer.py
python/tvm/relay/function.py
python/tvm/relay/loops.py
python/tvm/relay/op/__init__.py
python/tvm/relay/op/_algorithm.py
python/tvm/relay/op/_reduce.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_tensor_grad.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/algorithm.py
python/tvm/relay/op/annotation/annotation.py
python/tvm/relay/op/contrib/arm_compute_lib.py
python/tvm/relay/op/contrib/coreml.py
python/tvm/relay/op/contrib/dnnl.py
python/tvm/relay/op/contrib/ethosn.py
python/tvm/relay/op/contrib/register.py
python/tvm/relay/op/dyn/__init__.py
python/tvm/relay/op/dyn/_algorithm.py
python/tvm/relay/op/dyn/_tensor.py
python/tvm/relay/op/dyn/_transform.py
python/tvm/relay/op/dyn/image/_image.py
python/tvm/relay/op/dyn/nn/_nn.py
python/tvm/relay/op/image/_image.py
python/tvm/relay/op/image/image.py
python/tvm/relay/op/memory/memory.py
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/op.py
python/tvm/relay/op/op_attrs.py
python/tvm/relay/op/reduce.py
python/tvm/relay/op/strategy/arm_cpu.py
python/tvm/relay/op/strategy/bifrost.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/op/strategy/generic.py
python/tvm/relay/op/strategy/hls.py
python/tvm/relay/op/strategy/intel_graphics.py
python/tvm/relay/op/strategy/mali.py
python/tvm/relay/op/strategy/rocm.py
python/tvm/relay/op/strategy/x86.py
python/tvm/relay/op/tensor.py
python/tvm/relay/op/transform.py
python/tvm/relay/op/vision/_rcnn.py
python/tvm/relay/op/vision/_vision.py
python/tvm/relay/op/vision/_yolo.py
python/tvm/relay/op/vision/multibox.py
python/tvm/relay/op/vision/nms.py
python/tvm/relay/op/vision/rcnn.py
python/tvm/relay/op/vision/yolo.py
python/tvm/relay/param_dict.py
python/tvm/relay/prelude.py
python/tvm/relay/qnn/op/layout_conversions.py
python/tvm/relay/qnn/op/legalizations.py
python/tvm/relay/qnn/op/op.py
python/tvm/relay/qnn/op/qnn.py
python/tvm/relay/qnn/transform.py
python/tvm/relay/quantize/__init__.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/_calibrate.py
python/tvm/relay/quantize/_partition.py
python/tvm/relay/quantize/_partition_conversions.py
python/tvm/relay/quantize/_quantize.py
python/tvm/relay/quantize/kl_divergence.py
python/tvm/relay/quantize/quantize.py
python/tvm/relay/scope_builder.py
python/tvm/relay/testing/__init__.py
python/tvm/relay/testing/darknet.py
python/tvm/relay/testing/dcgan.py
python/tvm/relay/testing/densenet.py
python/tvm/relay/testing/dqn.py
python/tvm/relay/testing/inception_v3.py
python/tvm/relay/testing/init.py
python/tvm/relay/testing/layers.py
python/tvm/relay/testing/lstm.py
python/tvm/relay/testing/mlp.py
python/tvm/relay/testing/mobilenet.py
python/tvm/relay/testing/nat.py
python/tvm/relay/testing/py_converter.py
python/tvm/relay/testing/resnet.py
python/tvm/relay/testing/resnet_3d.py
python/tvm/relay/testing/squeezenet.py
python/tvm/relay/testing/synthetic.py
python/tvm/relay/testing/temp_op_attr.py
python/tvm/relay/testing/tf.py
python/tvm/relay/testing/vgg.py
python/tvm/relay/testing/yolo_detection.py
python/tvm/relay/transform/memory_alloc.py
python/tvm/relay/transform/memory_plan.py
python/tvm/relay/transform/transform.py
python/tvm/relay/ty.py
python/tvm/relay/type_functor.py
python/tvm/rpc/base.py
python/tvm/rpc/client.py
python/tvm/rpc/minrpc.py
python/tvm/rpc/proxy.py
python/tvm/rpc/server.py
python/tvm/rpc/tornado_util.py
python/tvm/rpc/tracker.py
python/tvm/runtime/_ffi_node_api.py
python/tvm/runtime/container.py
python/tvm/runtime/module.py
python/tvm/runtime/ndarray.py
python/tvm/runtime/object.py
python/tvm/runtime/object_generic.py
python/tvm/runtime/packed_func.py
python/tvm/runtime/vm.py
python/tvm/target/arm_isa.py
python/tvm/target/codegen.py
python/tvm/target/datatype.py
python/tvm/target/generic_func.py
python/tvm/target/tag.py
python/tvm/target/target.py
python/tvm/te/hybrid/__init__.py
python/tvm/te/hybrid/calls.py
python/tvm/te/hybrid/module.py
python/tvm/te/hybrid/parser.py
python/tvm/te/hybrid/preprocessor.py
python/tvm/te/hybrid/runtime.py
python/tvm/te/hybrid/util.py
python/tvm/te/operation.py
python/tvm/te/schedule.py
python/tvm/te/tag.py
python/tvm/te/tensor.py
python/tvm/te/tensor_intrin.py
python/tvm/testing.py
python/tvm/tir/buffer.py
python/tvm/tir/data_layout.py
python/tvm/tir/expr.py
python/tvm/tir/function.py
python/tvm/tir/generic.py
python/tvm/tir/ir_builder.py
python/tvm/tir/op.py
python/tvm/tir/stmt.py
python/tvm/tir/stmt_functor.py
python/tvm/tir/transform/function_pass.py
python/tvm/tir/transform/transform.py
python/tvm/topi/__init__.py
python/tvm/topi/argwhere.py
python/tvm/topi/arm_cpu/bitserial_conv2d.py
python/tvm/topi/arm_cpu/bitserial_dense.py
python/tvm/topi/arm_cpu/conv2d.py
python/tvm/topi/arm_cpu/conv2d_alter_op.py
python/tvm/topi/arm_cpu/conv2d_gemm.py
python/tvm/topi/arm_cpu/conv2d_int8.py
python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
python/tvm/topi/arm_cpu/conv2d_transpose.py
python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct.py
python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py
python/tvm/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py
python/tvm/topi/arm_cpu/depthwise_conv2d.py
python/tvm/topi/arm_cpu/injective.py
python/tvm/topi/arm_cpu/tensor_intrin.py
python/tvm/topi/bifrost/conv2d.py
python/tvm/topi/bifrost/dense.py
python/tvm/topi/bifrost/depthwise_conv2d.py
python/tvm/topi/bifrost/gemm.py
python/tvm/topi/bifrost/transforms.py
python/tvm/topi/broadcast.py
python/tvm/topi/cpp/__init__.py
python/tvm/topi/cuda/batch_matmul.py
python/tvm/topi/cuda/conv1d.py
python/tvm/topi/cuda/conv1d_transpose_ncw.py
python/tvm/topi/cuda/conv2d.py
python/tvm/topi/cuda/conv2d_alter_op.py
python/tvm/topi/cuda/conv2d_direct.py
python/tvm/topi/cuda/conv2d_hwcn.py
python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
python/tvm/topi/cuda/conv2d_int8.py
python/tvm/topi/cuda/conv2d_nhwc.py
python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
python/tvm/topi/cuda/conv2d_nhwc_winograd.py
python/tvm/topi/cuda/conv2d_transpose_nchw.py
python/tvm/topi/cuda/conv2d_winograd.py
python/tvm/topi/cuda/conv3d.py
python/tvm/topi/cuda/conv3d_alter_op.py
python/tvm/topi/cuda/conv3d_direct.py
python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
python/tvm/topi/cuda/conv3d_transpose_ncdhw.py
python/tvm/topi/cuda/conv3d_winograd.py
python/tvm/topi/cuda/correlation.py
python/tvm/topi/cuda/deformable_conv2d.py
python/tvm/topi/cuda/dense.py
python/tvm/topi/cuda/dense_tensorcore.py
python/tvm/topi/cuda/depthwise_conv2d.py
python/tvm/topi/cuda/group_conv2d_nchw.py
python/tvm/topi/cuda/injective.py
python/tvm/topi/cuda/nms.py
python/tvm/topi/cuda/nn.py
python/tvm/topi/cuda/pooling.py
python/tvm/topi/cuda/rcnn/proposal.py
python/tvm/topi/cuda/reduction.py
python/tvm/topi/cuda/softmax.py
python/tvm/topi/cuda/sort.py
python/tvm/topi/cuda/sparse.py
python/tvm/topi/cuda/ssd/multibox.py
python/tvm/topi/cuda/tensor_intrin.py
python/tvm/topi/cuda/vision.py
python/tvm/topi/generic/conv2d.py
python/tvm/topi/generic/extern.py
python/tvm/topi/generic/injective.py
python/tvm/topi/generic/nn.py
python/tvm/topi/generic/sort.py
python/tvm/topi/generic/vision.py
python/tvm/topi/generic_op_impl.py
python/tvm/topi/hls/injective.py
python/tvm/topi/hls/nn.py
python/tvm/topi/image/dilation2d.py
python/tvm/topi/image/grid_sample.py
python/tvm/topi/image/resize.py
python/tvm/topi/intel_graphics/conv2d.py
python/tvm/topi/intel_graphics/conv2d_alter_op.py
python/tvm/topi/intel_graphics/depthwise_conv2d.py
python/tvm/topi/mali/conv2d.py
python/tvm/topi/mali/dense.py
python/tvm/topi/mali/depthwise_conv2d.py
python/tvm/topi/math.py
python/tvm/topi/nn/batch_matmul.py
python/tvm/topi/nn/bitserial_conv2d.py
python/tvm/topi/nn/bitserial_dense.py
python/tvm/topi/nn/bitserial_util.py
python/tvm/topi/nn/bnn.py
python/tvm/topi/nn/conv1d.py
python/tvm/topi/nn/conv1d_transpose.py
python/tvm/topi/nn/conv2d.py
python/tvm/topi/nn/conv2d_transpose.py
python/tvm/topi/nn/conv3d.py
python/tvm/topi/nn/conv3d_transpose.py
python/tvm/topi/nn/correlation.py
python/tvm/topi/nn/deformable_conv2d.py
python/tvm/topi/nn/dense.py
python/tvm/topi/nn/depth_to_space.py
python/tvm/topi/nn/depthwise_conv2d.py
python/tvm/topi/nn/dilate.py
python/tvm/topi/nn/elemwise.py
python/tvm/topi/nn/fifo_buffer.py
python/tvm/topi/nn/flatten.py
python/tvm/topi/nn/local_response_norm.py
python/tvm/topi/nn/mapping.py
python/tvm/topi/nn/pad.py
python/tvm/topi/nn/pooling.py
python/tvm/topi/nn/softmax.py
python/tvm/topi/nn/space_to_depth.py
python/tvm/topi/nn/sparse.py
python/tvm/topi/nn/upsampling.py
python/tvm/topi/nn/util.py
python/tvm/topi/nn/winograd_util.py
python/tvm/topi/reduction.py
python/tvm/topi/rocm/conv2d.py
python/tvm/topi/rocm/dense.py
python/tvm/topi/rocm/nn.py
python/tvm/topi/scatter.py
python/tvm/topi/scatter_add.py
python/tvm/topi/sort.py
python/tvm/topi/sparse/csrmm.py
python/tvm/topi/sparse/csrmv.py
python/tvm/topi/sparse/dense.py
python/tvm/topi/tag.py
python/tvm/topi/tensor.py
python/tvm/topi/testing/__init__.py
python/tvm/topi/testing/adaptive_pool_python.py
python/tvm/topi/testing/batch_matmul.py
python/tvm/topi/testing/bilinear_resize_python.py
python/tvm/topi/testing/common.py
python/tvm/topi/testing/conv1d_ncw_python.py
python/tvm/topi/testing/conv1d_transpose_ncw_python.py
python/tvm/topi/testing/conv2d_hwcn_python.py
python/tvm/topi/testing/conv2d_nchw_python.py
python/tvm/topi/testing/conv2d_nhwc_python.py
python/tvm/topi/testing/conv2d_transpose_python.py
python/tvm/topi/testing/conv3d_ncdhw_python.py
python/tvm/topi/testing/conv3d_ndhwc_python.py
python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py
python/tvm/topi/testing/correlation_nchw_python.py
python/tvm/topi/testing/crop_and_resize_python.py
python/tvm/topi/testing/deformable_conv2d_nchw_python.py
python/tvm/topi/testing/depth_to_space.py
python/tvm/topi/testing/depthwise_conv2d_python.py
python/tvm/topi/testing/dilate_python.py
python/tvm/topi/testing/gather_nd_python.py
python/tvm/topi/testing/gather_python.py
python/tvm/topi/testing/grid_sample_python.py
python/tvm/topi/testing/l2_normalize_python.py
python/tvm/topi/testing/lrn_python.py
python/tvm/topi/testing/matrix_set_diag.py
python/tvm/topi/testing/one_hot.py
python/tvm/topi/testing/pool1d_python.py
python/tvm/topi/testing/pool_grad_python.py
python/tvm/topi/testing/reorg_python.py
python/tvm/topi/testing/roi_align_python.py
python/tvm/topi/testing/roi_pool_python.py
python/tvm/topi/testing/sequence_mask_python.py
python/tvm/topi/testing/slice_axis_python.py
python/tvm/topi/testing/softmax_python.py
python/tvm/topi/testing/space_to_depth.py
python/tvm/topi/testing/strided_slice_python.py
python/tvm/topi/testing/trilinear_resize3d_python.py
python/tvm/topi/testing/upsampling_python.py
python/tvm/topi/transform.py
python/tvm/topi/util.py
python/tvm/topi/vision/nms.py
python/tvm/topi/vision/rcnn/proposal.py
python/tvm/topi/vision/rcnn/roi_align.py
python/tvm/topi/vision/rcnn/roi_pool.py
python/tvm/topi/vision/reorg.py
python/tvm/topi/vision/ssd/multibox.py
python/tvm/topi/x86/batch_matmul.py
python/tvm/topi/x86/binarize_pack.py
python/tvm/topi/x86/binary_dense.py
python/tvm/topi/x86/bitserial_conv2d.py
python/tvm/topi/x86/bitserial_dense.py
python/tvm/topi/x86/conv1d.py
python/tvm/topi/x86/conv2d.py
python/tvm/topi/x86/conv2d_alter_op.py
python/tvm/topi/x86/conv2d_avx_1x1.py
python/tvm/topi/x86/conv2d_avx_common.py
python/tvm/topi/x86/conv2d_int8.py
python/tvm/topi/x86/conv2d_transpose.py
python/tvm/topi/x86/conv3d.py
python/tvm/topi/x86/conv3d_transpose.py
python/tvm/topi/x86/dense.py
python/tvm/topi/x86/depthwise_conv2d.py
python/tvm/topi/x86/injective.py
python/tvm/topi/x86/nn.py
python/tvm/topi/x86/pooling.py
python/tvm/topi/x86/reduction.py
python/tvm/topi/x86/roi_align.py
python/tvm/topi/x86/sparse.py
python/tvm/topi/x86/tensor_intrin.py
python/tvm/topi/x86/util.py
rust/tvm-graph-rt/tests/build_model.py
rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py
rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py
rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py
rust/tvm/examples/resnet/src/build_resnet.py
rust/tvm/tests/basics/src/tvm_add.py
tests/lint/filter_untracked.py
tests/lint/pylintrc
tests/micro/test_runtime_micro_on_arm.py
tests/python/contrib/test_arm_compute_lib/infrastructure.py
tests/python/contrib/test_arm_compute_lib/test_conv2d.py
tests/python/contrib/test_arm_compute_lib/test_dense.py
tests/python/contrib/test_arm_compute_lib/test_network.py
tests/python/contrib/test_arm_compute_lib/test_pooling.py
tests/python/contrib/test_arm_compute_lib/test_reshape.py
tests/python/contrib/test_arm_compute_lib/test_runtime.py
tests/python/contrib/test_binutil.py
tests/python/contrib/test_cblas.py
tests/python/contrib/test_coreml_codegen.py
tests/python/contrib/test_coreml_runtime.py
tests/python/contrib/test_cublas.py
tests/python/contrib/test_cudnn.py
tests/python/contrib/test_dlpack.py
tests/python/contrib/test_edgetpu_runtime.py
tests/python/contrib/test_ethosn/__init__.py
tests/python/contrib/test_ethosn/infrastructure.py
tests/python/contrib/test_ethosn/test_concatenate.py
tests/python/contrib/test_ethosn/test_conv2d.py
tests/python/contrib/test_ethosn/test_split.py
tests/python/contrib/test_ethosn/test_topologies.py
tests/python/contrib/test_gemm_acc16.py
tests/python/contrib/test_gemm_acc32_vnni.py
tests/python/contrib/test_miopen.py
tests/python/contrib/test_mps.py
tests/python/contrib/test_mxnet_bridge.py
tests/python/contrib/test_nnpack.py
tests/python/contrib/test_onnx.py
tests/python/contrib/test_onnx_model.py
tests/python/contrib/test_random.py
tests/python/contrib/test_rocblas.py
tests/python/contrib/test_rpc_proxy.py
tests/python/contrib/test_rpc_tracker.py
tests/python/contrib/test_sort.py
tests/python/contrib/test_sparse.py
tests/python/contrib/test_tedd.py
tests/python/contrib/test_tflite_runtime.py
tests/python/contrib/test_util.py
tests/python/frontend/caffe/test_forward.py
tests/python/frontend/caffe2/model_zoo/__init__.py
tests/python/frontend/caffe2/model_zoo/squeezenet.py
tests/python/frontend/caffe2/test_forward.py
tests/python/frontend/caffe2/test_graph.py
tests/python/frontend/coreml/model_zoo/__init__.py
tests/python/frontend/coreml/test_forward.py
tests/python/frontend/darknet/test_forward.py
tests/python/frontend/keras/test_forward.py
tests/python/frontend/mxnet/model_zoo/__init__.py
tests/python/frontend/mxnet/model_zoo/dcgan.py
tests/python/frontend/mxnet/model_zoo/dqn.py
tests/python/frontend/mxnet/model_zoo/inception_v3.py
tests/python/frontend/mxnet/model_zoo/mlp.py
tests/python/frontend/mxnet/model_zoo/resnet.py
tests/python/frontend/mxnet/model_zoo/squeezenet.py
tests/python/frontend/mxnet/model_zoo/vgg.py
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/mxnet/test_graph.py
tests/python/frontend/mxnet/test_qnn_ops_utils.py
tests/python/frontend/onnx/test_forward.py
tests/python/frontend/pytorch/qnn_test.py
tests/python/frontend/pytorch/test_forward.py
tests/python/frontend/pytorch/test_lstm.py
tests/python/frontend/tensorflow/test_bn_dynamic.py
tests/python/frontend/tensorflow/test_control_flow.py
tests/python/frontend/tensorflow/test_debugging.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/frontend/tensorflow/test_no_op.py
tests/python/frontend/test_common.py
tests/python/frontend/tflite/test_forward.py
tests/python/integration/test_dot.py
tests/python/integration/test_ewise.py
tests/python/integration/test_ewise_fpga.py
tests/python/integration/test_gemm.py
tests/python/integration/test_reduce.py
tests/python/integration/test_scan.py
tests/python/integration/test_tuning.py
tests/python/integration/test_winograd_nnpack.py
tests/python/nightly/quantization/test_quantization_accuracy.py
tests/python/relay/benchmarking/benchmark_vm.py
tests/python/relay/dyn/test_dynamic_op_level10.py
tests/python/relay/dyn/test_dynamic_op_level2.py
tests/python/relay/dyn/test_dynamic_op_level3.py
tests/python/relay/dyn/test_dynamic_op_level4.py
tests/python/relay/dyn/test_dynamic_op_level5.py
tests/python/relay/dyn/test_dynamic_op_level6.py
tests/python/relay/test_adt.py
tests/python/relay/test_analysis_basic_block_normal_form.py
tests/python/relay/test_analysis_extract_fused_functions.py
tests/python/relay/test_analysis_feature.py
tests/python/relay/test_analysis_get_calibration_data.py
tests/python/relay/test_annotated_regions.py
tests/python/relay/test_any.py
tests/python/relay/test_autotvm_task_extraction.py
tests/python/relay/test_backend_compile_engine.py
tests/python/relay/test_backend_graph_runtime.py
tests/python/relay/test_backend_interpreter.py
tests/python/relay/test_call_graph.py
tests/python/relay/test_change_batch.py
tests/python/relay/test_cmp_op.py
tests/python/relay/test_cpp_build_module.py
tests/python/relay/test_dataflow_pattern.py
tests/python/relay/test_debug.py
tests/python/relay/test_error_reporting.py
tests/python/relay/test_expr_functor.py
tests/python/relay/test_external_codegen.py
tests/python/relay/test_ir_bind.py
tests/python/relay/test_ir_module.py
tests/python/relay/test_ir_nodes.py
tests/python/relay/test_ir_op.py
tests/python/relay/test_ir_parser.py
tests/python/relay/test_ir_structural_equal_hash.py
tests/python/relay/test_ir_text_printer.py
tests/python/relay/test_ir_well_formed.py
tests/python/relay/test_json_compact.py
tests/python/relay/test_json_runtime.py
tests/python/relay/test_memory_passes.py
tests/python/relay/test_op_fast_math.py
tests/python/relay/test_op_grad_level1.py
tests/python/relay/test_op_grad_level10.py
tests/python/relay/test_op_grad_level2.py
tests/python/relay/test_op_grad_level3.py
tests/python/relay/test_op_grad_level4.py
tests/python/relay/test_op_level1.py
tests/python/relay/test_op_level10.py
tests/python/relay/test_op_level2.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level4.py
tests/python/relay/test_op_level5.py
tests/python/relay/test_op_level6.py
tests/python/relay/test_op_qnn_add.py
tests/python/relay/test_op_qnn_concatenate.py
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_op_qnn_dense.py
tests/python/relay/test_op_qnn_dequantize.py
tests/python/relay/test_op_qnn_mul.py
tests/python/relay/test_op_qnn_quantize.py
tests/python/relay/test_op_qnn_requantize.py
tests/python/relay/test_op_qnn_subtract.py
tests/python/relay/test_param_dict.py
tests/python/relay/test_pass_alter_op_layout.py
tests/python/relay/test_pass_annotate_target.py
tests/python/relay/test_pass_annotation.py
tests/python/relay/test_pass_auto_quantize.py
tests/python/relay/test_pass_canonicalize_cast.py
tests/python/relay/test_pass_check_kind.py
tests/python/relay/test_pass_combine_parallel_batch_matmul.py
tests/python/relay/test_pass_combine_parallel_conv2d.py
tests/python/relay/test_pass_combine_parallel_dense.py
tests/python/relay/test_pass_convert_op_layout.py
tests/python/relay/test_pass_dead_code_elimination.py
tests/python/relay/test_pass_defunctionalization.py
tests/python/relay/test_pass_dynamic_to_static.py
tests/python/relay/test_pass_eliminate_common_subexpr.py
tests/python/relay/test_pass_eta_expand.py
tests/python/relay/test_pass_fast_math.py
tests/python/relay/test_pass_fold_constant.py
tests/python/relay/test_pass_fold_scale_axis.py
tests/python/relay/test_pass_fuse_ops.py
tests/python/relay/test_pass_gradient.py
tests/python/relay/test_pass_inline.py
tests/python/relay/test_pass_lambda_lift.py
tests/python/relay/test_pass_lazy_gradient_init.py
tests/python/relay/test_pass_legalize.py
tests/python/relay/test_pass_mac_count.py
tests/python/relay/test_pass_manager.py
tests/python/relay/test_pass_merge_compiler_regions.py
tests/python/relay/test_pass_merge_composite.py
tests/python/relay/test_pass_partial_eval.py
tests/python/relay/test_pass_partition_graph.py
tests/python/relay/test_pass_qnn_legalize.py
tests/python/relay/test_pass_remove_unused_functions.py
tests/python/relay/test_pass_simplify_expr.py
tests/python/relay/test_pass_simplify_inference.py
tests/python/relay/test_pass_to_a_normal_form.py
tests/python/relay/test_pass_to_basic_block_normal_form.py
tests/python/relay/test_pass_to_cps.py
tests/python/relay/test_pass_to_graph_normal_form.py
tests/python/relay/test_pass_unmatched_cases.py
tests/python/relay/test_pass_vars.py
tests/python/relay/test_py_converter.py
tests/python/relay/test_simplify_fc_transpose.py
tests/python/relay/test_sparse_dense_convert.py
tests/python/relay/test_type_functor.py
tests/python/relay/test_type_infer.py
tests/python/relay/test_type_solver.py
tests/python/relay/test_typecall.py
tests/python/relay/test_vm.py
tests/python/relay/test_vm_serialization.py
tests/python/topi/python/common.py
tests/python/topi/python/test_fifo_buffer.py
tests/python/topi/python/test_topi_basic.py
tests/python/topi/python/test_topi_batch_matmul.py
tests/python/topi/python/test_topi_bitserial_conv2d.py
tests/python/topi/python/test_topi_bitserial_conv2d_rasp.py
tests/python/topi/python/test_topi_bitserial_dense.py
tests/python/topi/python/test_topi_bnn.py
tests/python/topi/python/test_topi_broadcast.py
tests/python/topi/python/test_topi_clip.py
tests/python/topi/python/test_topi_conv1d.py
tests/python/topi/python/test_topi_conv1d_transpose_ncw.py
tests/python/topi/python/test_topi_conv2d_NCHWc.py
tests/python/topi/python/test_topi_conv2d_hwcn.py
tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
tests/python/topi/python/test_topi_conv2d_int8.py
tests/python/topi/python/test_topi_conv2d_nchw.py
tests/python/topi/python/test_topi_conv2d_nhwc.py
tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py
tests/python/topi/python/test_topi_conv2d_nhwc_tensorcore.py
tests/python/topi/python/test_topi_conv2d_nhwc_winograd.py
tests/python/topi/python/test_topi_conv2d_transpose_nchw.py
tests/python/topi/python/test_topi_conv2d_winograd.py
tests/python/topi/python/test_topi_conv3d_ncdhw.py
tests/python/topi/python/test_topi_conv3d_ndhwc.py
tests/python/topi/python/test_topi_conv3d_ndhwc_tensorcore.py
tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
tests/python/topi/python/test_topi_conv3d_winograd.py
tests/python/topi/python/test_topi_correlation.py
tests/python/topi/python/test_topi_deformable_conv2d.py
tests/python/topi/python/test_topi_dense.py
tests/python/topi/python/test_topi_dense_tensorcore.py
tests/python/topi/python/test_topi_depth_to_space.py
tests/python/topi/python/test_topi_depthwise_conv2d.py
tests/python/topi/python/test_topi_depthwise_conv2d_back_input.py
tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py
tests/python/topi/python/test_topi_dilate.py
tests/python/topi/python/test_topi_group_conv2d.py
tests/python/topi/python/test_topi_group_conv2d_NCHWc_int8.py
tests/python/topi/python/test_topi_image.py
tests/python/topi/python/test_topi_lrn.py
tests/python/topi/python/test_topi_math.py
tests/python/topi/python/test_topi_matmul.py
tests/python/topi/python/test_topi_reduce.py
tests/python/topi/python/test_topi_relu.py
tests/python/topi/python/test_topi_reorg.py
tests/python/topi/python/test_topi_softmax.py
tests/python/topi/python/test_topi_sort.py
tests/python/topi/python/test_topi_space_to_depth.py
tests/python/topi/python/test_topi_sparse.py
tests/python/topi/python/test_topi_tensor.py
tests/python/topi/python/test_topi_transform.py
tests/python/topi/python/test_topi_upsampling.py
tests/python/topi/python/test_topi_util.py
tests/python/topi/python/test_topi_vision.py
tests/python/unittest/test_arith_canonical_simplify.py
tests/python/unittest/test_arith_const_int_bound.py
tests/python/unittest/test_arith_deduce_bound.py
tests/python/unittest/test_arith_detect_clip_bound.py
tests/python/unittest/test_arith_detect_linear_equation.py
tests/python/unittest/test_arith_domain_touched.py
tests/python/unittest/test_arith_intset.py
tests/python/unittest/test_arith_modular_set.py
tests/python/unittest/test_arith_rewrite_simplify.py
tests/python/unittest/test_arith_solve_linear_equations.py
tests/python/unittest/test_arith_solve_linear_inequality.py
tests/python/unittest/test_auto_scheduler_common.py
tests/python/unittest/test_auto_scheduler_compute_dag.py
tests/python/unittest/test_auto_scheduler_cost_model.py
tests/python/unittest/test_auto_scheduler_evolutionary_search.py
tests/python/unittest/test_auto_scheduler_feature.py
tests/python/unittest/test_auto_scheduler_layout_rewrite.py
tests/python/unittest/test_auto_scheduler_loop_state.py
tests/python/unittest/test_auto_scheduler_measure.py
tests/python/unittest/test_auto_scheduler_search_policy.py
tests/python/unittest/test_auto_scheduler_sketch_generation.py
tests/python/unittest/test_autotvm_common.py
tests/python/unittest/test_autotvm_database.py
tests/python/unittest/test_autotvm_dispatch_context.py
tests/python/unittest/test_autotvm_executor.py
tests/python/unittest/test_autotvm_feature.py
tests/python/unittest/test_autotvm_flop_calculator.py
tests/python/unittest/test_autotvm_graph_tuner_core.py
tests/python/unittest/test_autotvm_graph_tuner_utils.py
tests/python/unittest/test_autotvm_index_tuner.py
tests/python/unittest/test_autotvm_measure.py
tests/python/unittest/test_autotvm_record.py
tests/python/unittest/test_autotvm_space.py
tests/python/unittest/test_autotvm_xgboost_model.py
tests/python/unittest/test_filter_untracked.py
tests/python/unittest/test_format_si_prefix.py
tests/python/unittest/test_hybrid_error_report.py
tests/python/unittest/test_hybrid_roundtrip.py
tests/python/unittest/test_ir_attrs.py
tests/python/unittest/test_ir_container.py
tests/python/unittest/test_ir_type.py
tests/python/unittest/test_node_reflection.py
tests/python/unittest/test_runtime_container.py
tests/python/unittest/test_runtime_error.py
tests/python/unittest/test_runtime_extension.py
tests/python/unittest/test_runtime_graph.py
tests/python/unittest/test_runtime_graph_debug.py
tests/python/unittest/test_runtime_heterogeneous.py
tests/python/unittest/test_runtime_measure.py
tests/python/unittest/test_runtime_micro.py
tests/python/unittest/test_runtime_module_based_interface.py
tests/python/unittest/test_runtime_module_export.py
tests/python/unittest/test_runtime_module_load.py
tests/python/unittest/test_runtime_ndarray.py
tests/python/unittest/test_runtime_packed_func.py
tests/python/unittest/test_runtime_rpc.py
tests/python/unittest/test_runtime_vm_profiler.py
tests/python/unittest/test_target_codegen_arm.py
tests/python/unittest/test_target_codegen_blob.py
tests/python/unittest/test_target_codegen_bool.py
tests/python/unittest/test_target_codegen_c_host.py
tests/python/unittest/test_target_codegen_cross_llvm.py
tests/python/unittest/test_target_codegen_cuda.py
tests/python/unittest/test_target_codegen_device.py
tests/python/unittest/test_target_codegen_extern.py
tests/python/unittest/test_target_codegen_hexagon.py
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_target_codegen_opencl.py
tests/python/unittest/test_target_codegen_rocm.py
tests/python/unittest/test_target_codegen_static_init.py
tests/python/unittest/test_target_codegen_vm_basic.py
tests/python/unittest/test_target_codegen_vulkan.py
tests/python/unittest/test_target_codegen_x86.py
tests/python/unittest/test_target_custom_datatypes.py
tests/python/unittest/test_target_target.py
tests/python/unittest/test_te_autodiff.py
tests/python/unittest/test_te_build_lower.py
tests/python/unittest/test_te_group.py
tests/python/unittest/test_te_hybrid_script.py
tests/python/unittest/test_te_schedule.py
tests/python/unittest/test_te_schedule_bound_inference.py
tests/python/unittest/test_te_schedule_bound_inference_tiling.py
tests/python/unittest/test_te_schedule_graph.py
tests/python/unittest/test_te_schedule_lstm.py
tests/python/unittest/test_te_schedule_ops.py
tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
tests/python/unittest/test_te_schedule_tensor_core.py
tests/python/unittest/test_te_schedule_tensorize.py
tests/python/unittest/test_te_tag.py
tests/python/unittest/test_te_tensor.py
tests/python/unittest/test_te_tensor_overload.py
tests/python/unittest/test_te_verify_compute.py
tests/python/unittest/test_testing.py
tests/python/unittest/test_tir_analysis_expr_deep_equal.py
tests/python/unittest/test_tir_analysis_usedef.py
tests/python/unittest/test_tir_analysis_verify_gpu_code.py
tests/python/unittest/test_tir_analysis_verify_memory.py
tests/python/unittest/test_tir_analysis_verify_ssa.py
tests/python/unittest/test_tir_buffer.py
tests/python/unittest/test_tir_constructor.py
tests/python/unittest/test_tir_data_layout.py
tests/python/unittest/test_tir_intrin.py
tests/python/unittest/test_tir_ir_builder.py
tests/python/unittest/test_tir_nodes.py
tests/python/unittest/test_tir_ops.py
tests/python/unittest/test_tir_stmt_functor_ir_transform.py
tests/python/unittest/test_tir_structural_equal_hash.py
tests/python/unittest/test_tir_transform_bf16_legalize.py
tests/python/unittest/test_tir_transform_combine_context_call.py
tests/python/unittest/test_tir_transform_coproc_sync.py
tests/python/unittest/test_tir_transform_decorate_device_scope.py
tests/python/unittest/test_tir_transform_hoist_if.py
tests/python/unittest/test_tir_transform_inject_copy_intrin.py
tests/python/unittest/test_tir_transform_inject_double_buffer.py
tests/python/unittest/test_tir_transform_inject_virtual_thread.py
tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
tests/python/unittest/test_tir_transform_lift_attr_scope.py
tests/python/unittest/test_tir_transform_loop_partition.py
tests/python/unittest/test_tir_transform_lower_intrin.py
tests/python/unittest/test_tir_transform_lower_warp_memory.py
tests/python/unittest/test_tir_transform_make_packed_api.py
tests/python/unittest/test_tir_transform_narrow_datatype.py
tests/python/unittest/test_tir_transform_prim_func_pass.py
tests/python/unittest/test_tir_transform_remove_no_op.py
tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py
tests/python/unittest/test_tir_transform_simplify.py
tests/python/unittest/test_tir_transform_storage_flatten.py
tests/python/unittest/test_tir_transform_storage_rewrite.py
tests/python/unittest/test_tir_transform_thread_sync.py
tests/python/unittest/test_tir_transform_unroll_loop.py
tests/python/unittest/test_tir_transform_vectorize.py
tutorials/autotvm/tune_conv2d_cuda.py
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py
tutorials/autotvm/tune_simple_template.py
tutorials/dev/low_level_custom_pass.py
tutorials/dev/use_pass_infra.py
tutorials/frontend/build_gcn.py
tutorials/frontend/deploy_model_on_android.py
tutorials/frontend/deploy_model_on_rasp.py
tutorials/frontend/deploy_prequantized.py
tutorials/frontend/deploy_prequantized_tflite.py
tutorials/frontend/deploy_quantized.py
tutorials/frontend/deploy_sparse.py
tutorials/frontend/deploy_ssd_gluoncv.py
tutorials/frontend/from_caffe2.py
tutorials/frontend/from_coreml.py
tutorials/frontend/from_darknet.py
tutorials/frontend/from_keras.py
tutorials/frontend/from_mxnet.py
tutorials/frontend/from_onnx.py
tutorials/frontend/from_pytorch.py
tutorials/frontend/from_tensorflow.py
tutorials/frontend/from_tflite.py
tutorials/frontend/using_external_lib.py
tutorials/get_started/cross_compilation_and_rpc.py
tutorials/get_started/relay_quick_start.py
tutorials/get_started/tensor_expr_get_started.py
tutorials/language/extern_op.py
tutorials/language/intrin_math.py
tutorials/language/reduction.py
tutorials/language/scan.py
tutorials/language/schedule_primitives.py
tutorials/language/tedd.py
tutorials/language/tensorize.py
tutorials/language/tuple_inputs.py
tutorials/micro/micro_tflite.py
tutorials/optimize/opt_conv_cuda.py
tutorials/optimize/opt_conv_tensorcore.py
tutorials/optimize/opt_gemm.py
tutorials/optimize/opt_matmul_auto_tensorcore.py
tutorials/topi/intro_topi.py
version.py
vta/python/vta/bitstream.py
vta/python/vta/build_module.py
vta/python/vta/environment.py
vta/python/vta/exec/rpc_server.py
vta/python/vta/intrin.py
vta/python/vta/libinfo.py
vta/python/vta/program_bitstream.py
vta/python/vta/rpc_client.py
vta/python/vta/testing/__init__.py
vta/python/vta/testing/simulator.py
vta/python/vta/testing/util.py
vta/python/vta/top/bitpack.py
vta/python/vta/top/graphpack.py
vta/python/vta/top/op.py
vta/python/vta/top/util.py
vta/python/vta/top/vta_conv2d.py
vta/python/vta/top/vta_conv2d_transpose.py
vta/python/vta/top/vta_dense.py
vta/python/vta/top/vta_group_conv2d.py
vta/python/vta/transform.py
vta/scripts/tune_conv2d.py
vta/scripts/tune_conv2d_transpose.py
vta/scripts/tune_dense.py
vta/scripts/tune_group_conv2d.py
vta/scripts/tune_resnet.py
vta/tests/python/de10nano/test_program_rpc.py
vta/tests/python/integration/test_benchmark_gemm.py
vta/tests/python/integration/test_benchmark_topi_conv2d.py
vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
vta/tests/python/integration/test_benchmark_topi_dense.py
vta/tests/python/integration/test_benchmark_topi_group_conv2d.py
vta/tests/python/pynq/test_program_rpc.py
vta/tests/python/unittest/test_environment.py
vta/tests/python/unittest/test_vta_insn.py
vta/tutorials/autotvm/tune_relay_vta.py
vta/tutorials/frontend/deploy_classification.py
vta/tutorials/frontend/legacy/deploy_detection.py
vta/tutorials/matrix_multiply.py
vta/tutorials/optimize/convolution_opt.py
vta/tutorials/optimize/matrix_multiply_opt.py
vta/tutorials/vta_get_started.py
web/tests/python/prepare_test_libs.py
web/tests/python/webgpu_rpc_test.py
web/tests/python/websock_rpc_test.py

index ab1ed78..19be368 100644 (file)
@@ -28,13 +28,14 @@ import tvm.relay as relay
 from tvm.contrib import util, ndk, graph_runtime as runtime
 from tvm.contrib.download import download_testdata, download
 
-target = 'llvm -mtriple=arm64-linux-android'
+target = "llvm -mtriple=arm64-linux-android"
 target_host = None
 
+
 def del_dir(target: Union[Path, str], only_if_empty: bool = False):
     target = Path(target).expanduser()
     assert target.is_dir()
-    for p in sorted(target.glob('**/*'), reverse=True):
+    for p in sorted(target.glob("**/*"), reverse=True):
         if not p.exists():
             continue
         p.chmod(0o666)
@@ -42,38 +43,47 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False):
             p.rmdir()
         else:
             if only_if_empty:
-                raise RuntimeError(f'{p.parent} is not empty!')
+                raise RuntimeError(f"{p.parent} is not empty!")
             p.unlink()
     target.rmdir()
 
+
 def get_model(model_name, batch_size=1):
-    if model_name == 'resnet18_v1':
+    if model_name == "resnet18_v1":
         import mxnet as mx
         from mxnet import gluon
         from mxnet.gluon.model_zoo import vision
+
         gluon_model = vision.get_model(model_name, pretrained=True)
         img_size = 224
         data_shape = (batch_size, 3, img_size, img_size)
         net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
         return (net, params)
-    elif model_name == 'mobilenet_v2':
+    elif model_name == "mobilenet_v2":
         import keras
         from keras.applications.mobilenet_v2 import MobileNetV2
+
         keras.backend.clear_session()  # Destroys the current TF graph and creates a new one.
-        weights_url = ''.join(['https://github.com/JonathanCMitchell/',
-                               'mobilenet_v2_keras/releases/download/v1.1/',
-                               'mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5'])
-        weights_file = 'mobilenet_v2_weights.h5'
-        weights_path = download_testdata(weights_url, weights_file, module='keras')
-        keras_mobilenet_v2 = MobileNetV2(alpha=0.5, include_top=True, weights=None,
-                                        input_shape=(224, 224, 3), classes=1000)
+        weights_url = "".join(
+            [
+                "https://github.com/JonathanCMitchell/",
+                "mobilenet_v2_keras/releases/download/v1.1/",
+                "mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5",
+            ]
+        )
+        weights_file = "mobilenet_v2_weights.h5"
+        weights_path = download_testdata(weights_url, weights_file, module="keras")
+        keras_mobilenet_v2 = MobileNetV2(
+            alpha=0.5, include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
+        )
         keras_mobilenet_v2.load_weights(weights_path)
-        
+
         img_size = 224
         data_shape = (batch_size, 3, img_size, img_size)
-        mod, params = relay.frontend.from_keras(keras_mobilenet_v2,  {'input_1': data_shape})
+        mod, params = relay.frontend.from_keras(keras_mobilenet_v2, {"input_1": data_shape})
         return (mod, params)
 
+
 def main(model_str, output_path):
     if output_path.exists():
         del_dir(output_path)
@@ -90,34 +100,40 @@ def main(model_str, output_path):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, params = relay.build(net, target, target_host=target_host, params=params)
     print("dumping lib...")
-    lib.export_library(output_path_str + '/' + 'deploy_lib_cpu.so', ndk.create_shared)
+    lib.export_library(output_path_str + "/" + "deploy_lib_cpu.so", ndk.create_shared)
     print("dumping graph...")
-    with open(output_path_str + '/' + 'deploy_graph.json', 'w') as f:
+    with open(output_path_str + "/" + "deploy_graph.json", "w") as f:
         f.write(graph)
     print("dumping params...")
-    with open(output_path_str + '/' + 'deploy_param.params', 'wb') as f:
+    with open(output_path_str + "/" + "deploy_param.params", "wb") as f:
         f.write(relay.save_param_dict(params))
     print("dumping labels...")
-    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-        '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-        '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-        'imagenet1000_clsid_to_human.txt'])
-    synset_path = output_path_str + '/image_net_labels'
-    download(synset_url, output_path_str + '/image_net_labels')
+    synset_url = "".join(
+        [
+            "https://gist.githubusercontent.com/zhreshold/",
+            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+            "imagenet1000_clsid_to_human.txt",
+        ]
+    )
+    synset_path = output_path_str + "/image_net_labels"
+    download(synset_url, output_path_str + "/image_net_labels")
     with open(synset_path) as fi:
         synset = eval(fi.read())
-        with open(output_path_str + '/image_net_labels.json', "w") as fo:
+        with open(output_path_str + "/image_net_labels.json", "w") as fo:
             json.dump(synset, fo, indent=4)
     os.remove(synset_path)
 
-if __name__ == '__main__':
-    if environ.get('TVM_NDK_CC') is None:
+
+if __name__ == "__main__":
+    if environ.get("TVM_NDK_CC") is None:
         raise RuntimeError("Require environment variable TVM_NDK_CC")
-    models_path = Path().absolute().parent.joinpath('app/src/main/assets/models/')
+    models_path = Path().absolute().parent.joinpath("app/src/main/assets/models/")
     if not models_path.exists():
         models_path.mkdir()
-    models = {'mobilenet_v2': models_path.joinpath('mobilenet_v2'),
-              'resnet18_v1': models_path.joinpath('resnet18_v1')
-            }
+    models = {
+        "mobilenet_v2": models_path.joinpath("mobilenet_v2"),
+        "resnet18_v1": models_path.joinpath("resnet18_v1"),
+    }
     for model, output_path in models.items():
         main(model, output_path)
index 754d092..2827c14 100644 (file)
@@ -43,18 +43,18 @@ test_opencl = False
 # whether enable to execute test on Vulkan target
 test_vulkan = False
 
+
 def test_rpc_module():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     a_np = np.random.uniform(size=1024).astype(A.dtype)
     temp = util.tempdir()
 
     # Establish remote connection with target hardware
     tracker = rpc.connect_tracker(tracker_host, tracker_port)
-    remote = tracker.request(key, priority=0,
-                             session_timeout=60)
+    remote = tracker.request(key, priority=0, session_timeout=60)
 
     # Compile the Graph for CPU target
     s = te.create_schedule(B.op)
@@ -67,7 +67,7 @@ def test_rpc_module():
     f.export_library(path_dso_cpu, ndk.create_shared)
 
     # Execute the portable graph on cpu target
-    print('Run CPU test ...')
+    print("Run CPU test ...")
     ctx = remote.cpu(0)
     remote.upload(path_dso_cpu)
     f2 = remote.load_module("cpu_lib.so")
@@ -75,7 +75,7 @@ def test_rpc_module():
     b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
     time_f = f2.time_evaluator(f2.entry_name, ctx, number=10)
     cost = time_f(a, b).mean
-    print('%g secs/op\n' % cost)
+    print("%g secs/op\n" % cost)
     np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
     # Compile the Graph for OpenCL target
@@ -90,7 +90,7 @@ def test_rpc_module():
         path_dso_cl = temp.relpath("dev_lib_cl.so")
         f.export_library(path_dso_cl, ndk.create_shared)
 
-        print('Run GPU(OpenCL Flavor) test ...')
+        print("Run GPU(OpenCL Flavor) test ...")
         ctx = remote.cl(0)
         remote.upload(path_dso_cl)
         f1 = remote.load_module("dev_lib_cl.so")
@@ -98,7 +98,7 @@ def test_rpc_module():
         b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
         time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
         cost = time_f(a, b).mean
-        print('%g secs/op\n' % cost)
+        print("%g secs/op\n" % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
     # Compile the Graph for Vulkan target
@@ -113,7 +113,7 @@ def test_rpc_module():
         path_dso_vulkan = temp.relpath("dev_lib_vulkan.so")
         f.export_library(path_dso_vulkan, ndk.create_shared)
 
-        print('Run GPU(Vulkan Flavor) test ...')
+        print("Run GPU(Vulkan Flavor) test ...")
         ctx = remote.vulkan(0)
         remote.upload(path_dso_vulkan)
         f1 = remote.load_module("dev_lib_vulkan.so")
@@ -121,7 +121,7 @@ def test_rpc_module():
         b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
         time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
         cost = time_f(a, b).mean
-        print('%g secs/op\n' % cost)
+        print("%g secs/op\n" % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
 
index f319d5a..a4a88d8 100644 (file)
@@ -40,12 +40,12 @@ def evaluate_network(network, target, target_host, repeat):
 
     print_progress("%-20s building..." % network)
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(
-            net, target=target, target_host=target_host, params=params)
+        graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)
 
     tmp = tempdir()
-    if 'android' in str(target):
+    if "android" in str(target):
         from tvm.contrib import ndk
+
         filename = "%s.so" % network
         lib.export_library(tmp.relpath(filename), ndk.create_shared)
     else:
@@ -60,38 +60,55 @@ def evaluate_network(network, target, target_host, repeat):
     rlib = remote.load_module(filename)
     module = runtime.create(graph, rlib, ctx)
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-    module.set_input('data', data_tvm)
+    module.set_input("data", data_tvm)
     module.set_input(**params)
 
     # evaluate
     print_progress("%-20s evaluating..." % network)
     ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=repeat)
     prof_res = np.array(ftimer().results) * 1000  # multiply 1000 for converting to millisecond
-    print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
+    print(
+        "%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--network", type=str, choices=
-                        ['resnet-18', 'resnet-34', 'resnet-50',
-                         'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
-                         'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'],
-                        help='The name of neural network')
-    parser.add_argument("--model", type=str, choices=
-                        ['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
-                         'pixel2', 'rasp3b', 'pynq'], default='rk3399',
-                        help="The model of the test device. If your device is not listed in "
-                             "the choices list, pick the most similar one as argument.")
-    parser.add_argument("--host", type=str, default='localhost')
+    parser.add_argument(
+        "--network",
+        type=str,
+        choices=[
+            "resnet-18",
+            "resnet-34",
+            "resnet-50",
+            "vgg-16",
+            "vgg-19",
+            "densenet-121",
+            "inception_v3",
+            "mobilenet",
+            "squeezenet_v1.0",
+            "squeezenet_v1.1",
+        ],
+        help="The name of neural network",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        choices=["rk3399", "mate10", "mate10pro", "p20", "p20pro", "pixel2", "rasp3b", "pynq"],
+        default="rk3399",
+        help="The model of the test device. If your device is not listed in "
+        "the choices list, pick the most similar one as argument.",
+    )
+    parser.add_argument("--host", type=str, default="localhost")
     parser.add_argument("--port", type=int, default=9190)
     parser.add_argument("--rpc-key", type=str, required=True)
     parser.add_argument("--repeat", type=int, default=10)
     args = parser.parse_args()
 
-    dtype = 'float32'
+    dtype = "float32"
 
     if args.network is None:
-        networks = ['squeezenet_v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
+        networks = ["squeezenet_v1.1", "mobilenet", "resnet-18", "vgg-16"]
     else:
         networks = [args.network]
 
@@ -103,4 +120,3 @@ if __name__ == "__main__":
     print("--------------------------------------------------")
     for network in networks:
         evaluate_network(network, target, target_host, args.repeat)
-
index a65a9e8..a1c0cc6 100644 (file)
@@ -40,45 +40,71 @@ def benchmark(network, target):
     ctx = tvm.context(str(target), 0)
     module = runtime.create(graph, lib, ctx)
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-    module.set_input('data', data_tvm)
+    module.set_input("data", data_tvm)
     module.set_input(**params)
 
     # evaluate
     ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
     prof_res = np.array(ftimer().results) * 1000  # multiply 1000 for converting to millisecond
-    print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
+    print(
+        "%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--network", type=str, choices=
-                        ['resnet-18', 'resnet-34', 'resnet-50',
-                         'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
-                         'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'],
-                        help='The name of neural network')
-    parser.add_argument("--device", type=str,
-                        choices=['amd_apu'], default='amd_apu',
-                        help="The name of the test device. If your device is not listed in "
-                             "the choices list, pick the most similar one as argument.")
-    parser.add_argument("--model", type=str,
-                        choices=['1080ti', 'titanx', 'tx2', 'gfx900', 'v1000'], default='1080ti',
-                        help="The model of the test device. If your device is not listed in "
-                             "the choices list, pick the most similar one as argument.")
+    parser.add_argument(
+        "--network",
+        type=str,
+        choices=[
+            "resnet-18",
+            "resnet-34",
+            "resnet-50",
+            "vgg-16",
+            "vgg-19",
+            "densenet-121",
+            "inception_v3",
+            "mobilenet",
+            "squeezenet_v1.0",
+            "squeezenet_v1.1",
+        ],
+        help="The name of neural network",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        choices=["amd_apu"],
+        default="amd_apu",
+        help="The name of the test device. If your device is not listed in "
+        "the choices list, pick the most similar one as argument.",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        choices=["1080ti", "titanx", "tx2", "gfx900", "v1000"],
+        default="1080ti",
+        help="The model of the test device. If your device is not listed in "
+        "the choices list, pick the most similar one as argument.",
+    )
     parser.add_argument("--repeat", type=int, default=600)
-    parser.add_argument("--target", type=str,
-                        choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal', 'vulkan'], default='cuda',
-                        help="The tvm compilation target")
+    parser.add_argument(
+        "--target",
+        type=str,
+        choices=["cuda", "opencl", "rocm", "nvptx", "metal", "vulkan"],
+        default="cuda",
+        help="The tvm compilation target",
+    )
     parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.")
     args = parser.parse_args()
 
-    dtype = 'float32'
+    dtype = "float32"
 
     if args.network is None:
-        networks = ['resnet-50', 'mobilenet', 'vgg-19', 'inception_v3']
+        networks = ["resnet-50", "mobilenet", "vgg-19", "inception_v3"]
     else:
         networks = [args.network]
 
-    target = tvm.target.Target('%s -device=%s -model=%s' % (args.target, args.device, args.model))
+    target = tvm.target.Target("%s -device=%s -model=%s" % (args.target, args.device, args.model))
 
     print("--------------------------------------------------")
     print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
@@ -89,7 +115,9 @@ if __name__ == "__main__":
         else:
             threads = list()
             for n in range(args.thread):
-                thread = threading.Thread(target=benchmark, args=([network, target]), name="thread%d" % n)
+                thread = threading.Thread(
+                    target=benchmark, args=([network, target]), name="thread%d" % n
+                )
                 threads.append(thread)
 
             for thread in threads:
index 83127ff..fa1af54 100644 (file)
@@ -29,6 +29,7 @@ from tvm import relay
 
 from util import get_network, print_progress
 
+
 def evaluate_network(network, target, target_host, dtype, repeat):
     # connect to remote device
     tracker = tvm.rpc.connect_tracker(args.host, args.port)
@@ -39,12 +40,12 @@ def evaluate_network(network, target, target_host, dtype, repeat):
 
     print_progress("%-20s building..." % network)
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(
-            net, target=target, target_host=target_host, params=params)
+        graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)
 
     tmp = tempdir()
-    if 'android' in str(target) or 'android' in str(target_host):
+    if "android" in str(target) or "android" in str(target_host):
         from tvm.contrib import ndk
+
         filename = "%s.so" % network
         lib.export_library(tmp.relpath(filename), ndk.create_shared)
     else:
@@ -59,36 +60,54 @@ def evaluate_network(network, target, target_host, dtype, repeat):
     rlib = remote.load_module(filename)
     module = runtime.create(graph, rlib, ctx)
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-    module.set_input('data', data_tvm)
+    module.set_input("data", data_tvm)
     module.set_input(**params)
 
     # evaluate
     print_progress("%-20s evaluating..." % network)
     ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=repeat)
     prof_res = np.array(ftimer().results) * 1000  # multiply 1000 for converting to millisecond
-    print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
+    print(
+        "%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--network", type=str, choices=
-                        ['resnet-18', 'resnet-34', 'resnet-50',
-                         'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
-                         'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'],
-                        help='The name of neural network')
-    parser.add_argument("--model", type=str, choices=
-                        ['rk3399'], default='rk3399',
-                        help="The model of the test device. If your device is not listed in "
-                             "the choices list, pick the most similar one as argument.")
-    parser.add_argument("--host", type=str, default='localhost')
+    parser.add_argument(
+        "--network",
+        type=str,
+        choices=[
+            "resnet-18",
+            "resnet-34",
+            "resnet-50",
+            "vgg-16",
+            "vgg-19",
+            "densenet-121",
+            "inception_v3",
+            "mobilenet",
+            "squeezenet_v1.0",
+            "squeezenet_v1.1",
+        ],
+        help="The name of neural network",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        choices=["rk3399"],
+        default="rk3399",
+        help="The model of the test device. If your device is not listed in "
+        "the choices list, pick the most similar one as argument.",
+    )
+    parser.add_argument("--host", type=str, default="localhost")
     parser.add_argument("--port", type=int, default=9190)
     parser.add_argument("--rpc-key", type=str, required=True)
     parser.add_argument("--repeat", type=int, default=30)
-    parser.add_argument("--dtype", type=str, default='float32')
+    parser.add_argument("--dtype", type=str, default="float32")
     args = parser.parse_args()
 
     if args.network is None:
-        networks = ['squeezenet_v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
+        networks = ["squeezenet_v1.1", "mobilenet", "resnet-18", "vgg-16"]
     else:
         networks = [args.network]
 
index 86d139f..01f0a11 100644 (file)
@@ -20,7 +20,8 @@ import sys
 from tvm import relay
 from tvm.relay import testing
 
-def get_network(name, batch_size, dtype='float32'):
+
+def get_network(name, batch_size, dtype="float32"):
     """Get the symbol definition and random weight of a network
 
     Parameters
@@ -46,36 +47,48 @@ def get_network(name, batch_size, dtype='float32'):
     input_shape = (batch_size, 3, 224, 224)
     output_shape = (batch_size, 1000)
 
-    if name == 'mobilenet':
+    if name == "mobilenet":
         net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'inception_v3':
+    elif name == "inception_v3":
         input_shape = (batch_size, 3, 299, 299)
         net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
     elif "resnet" in name:
-        n_layer = int(name.split('-')[1])
-        net, params = testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        net, params = testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "vgg" in name:
-        n_layer = int(name.split('-')[1])
-        net, params = testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        net, params = testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "densenet" in name:
-        n_layer = int(name.split('-')[1])
-        net, params = testing.densenet.get_workload(densenet_size=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        net, params = testing.densenet.get_workload(
+            densenet_size=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "squeezenet" in name:
         version = name.split("_v")[1]
-        net, params = testing.squeezenet.get_workload(batch_size=batch_size, version=version, dtype=dtype)
-    elif name == 'mxnet':
+        net, params = testing.squeezenet.get_workload(
+            batch_size=batch_size, version=version, dtype=dtype
+        )
+    elif name == "mxnet":
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
-        block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+
+        block = get_model("resnet18_v1", pretrained=True)
+        net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
         net = net["main"]
-        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
         net = tvm.IRModule.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
     return net, params, input_shape, output_shape
 
+
 def print_progress(msg):
     """print progress message
 
index e99623f..623d246 100644 (file)
@@ -25,65 +25,87 @@ import logging
 import json
 
 RUNTIMES = {
-    'c': '{name}_c.{ext}',
-    'c++': '{name}_cpp.{ext}',
+    "c": "{name}_c.{ext}",
+    "c++": "{name}_cpp.{ext}",
 }
 
+
 def build_module(opts):
     dshape = (1, 3, 224, 224)
     from mxnet.gluon.model_zoo.vision import get_model
-    block = get_model('mobilenet0.25', pretrained=True)
-    shape_dict = {'data': dshape}
+
+    block = get_model("mobilenet0.25", pretrained=True)
+    shape_dict = {"data": dshape}
     mod, params = relay.frontend.from_mxnet(block, shape_dict)
     func = mod["main"]
-    func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
+    func = relay.Function(
+        func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs
+    )
 
     for runtime_name, file_format_str in RUNTIMES.items():
-        with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}):
+        with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
             graph, lib, params = relay.build(
-                func, f'llvm --runtime={runtime_name} --system-lib', params=params)
+                func, f"llvm --runtime={runtime_name} --system-lib", params=params
+            )
 
         build_dir = os.path.abspath(opts.out_dir)
         if not os.path.isdir(build_dir):
             os.makedirs(build_dir)
 
-        lib.save(os.path.join(build_dir, file_format_str.format(name='model', ext='o')))
-        with open(os.path.join(build_dir, file_format_str.format(name='graph', ext='json')), 'w') as f_graph_json:
+        lib.save(os.path.join(build_dir, file_format_str.format(name="model", ext="o")))
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="graph", ext="json")), "w"
+        ) as f_graph_json:
             f_graph_json.write(graph)
-        with open(os.path.join(build_dir, file_format_str.format(name='params', ext='bin')), 'wb') as f_params:
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="params", ext="bin")), "wb"
+        ) as f_params:
             f_params.write(relay.save_param_dict(params))
 
+
 def build_test_module(opts):
     import numpy as np
 
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(1, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(1, 5))
     z = relay.add(x, y)
     func = relay.Function([x, y], z)
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(1, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(1, 5).astype("float32")
     params = {"y": y_data}
 
     for runtime_name, file_format_str in RUNTIMES.items():
-        with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}):
+        with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
             graph, lib, lowered_params = relay.build(
-                tvm.IRModule.from_expr(func), f"llvm --runtime={runtime_name} --system-lib", params=params)
+                tvm.IRModule.from_expr(func),
+                f"llvm --runtime={runtime_name} --system-lib",
+                params=params,
+            )
 
         build_dir = os.path.abspath(opts.out_dir)
         if not os.path.isdir(build_dir):
             os.makedirs(build_dir)
 
-        lib.save(os.path.join(build_dir, file_format_str.format(name='test_model', ext='o')))
-        with open(os.path.join(build_dir, file_format_str.format(name='test_graph', ext='json')), 'w') as f_graph_json:
+        lib.save(os.path.join(build_dir, file_format_str.format(name="test_model", ext="o")))
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="test_graph", ext="json")), "w"
+        ) as f_graph_json:
             f_graph_json.write(graph)
-        with open(os.path.join(build_dir, file_format_str.format(name='test_params', ext='bin')), 'wb') as f_params:
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="test_params", ext="bin")), "wb"
+        ) as f_params:
             f_params.write(relay.save_param_dict(lowered_params))
-        with open(os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb") as fp:
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb"
+        ) as fp:
             fp.write(x_data.astype(np.float32).tobytes())
         x_output = x_data + y_data
-        with open(os.path.join(build_dir, file_format_str.format(name="test_output", ext="bin")), "wb") as fp:
+        with open(
+            os.path.join(build_dir, file_format_str.format(name="test_output", ext="bin")), "wb"
+        ) as fp:
             fp.write(x_output.astype(np.float32).tobytes())
 
+
 def build_inputs(opts):
     from tvm.contrib import download
     from PIL import Image
@@ -92,29 +114,30 @@ def build_inputs(opts):
     build_dir = os.path.abspath(opts.out_dir)
 
     # Download test image
-    image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
+    image_url = "https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg"
     image_fn = os.path.join(build_dir, "cat.png")
     download.download(image_url, image_fn)
     image = Image.open(image_fn).resize((224, 224))
 
     def transform_image(image):
-        image = np.array(image) - np.array([123., 117., 104.])
+        image = np.array(image) - np.array([123.0, 117.0, 104.0])
         image /= np.array([58.395, 57.12, 57.375])
         image = image.transpose((2, 0, 1))
         image = image[np.newaxis, :]
         return image
 
     x = transform_image(image)
-    print('x', x.shape)
+    print("x", x.shape)
     with open(os.path.join(build_dir, "cat.bin"), "wb") as fp:
         fp.write(x.astype(np.float32).tobytes())
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
 
     parser = argparse.ArgumentParser()
-    parser.add_argument('-o', '--out-dir', default='.')
-    parser.add_argument('-t', '--test', action='store_true')
+    parser.add_argument("-o", "--out-dir", default=".")
+    parser.add_argument("-t", "--test", action="store_true")
     opts = parser.parse_args()
 
     if opts.test:
index 0704dd0..21e6661 100644 (file)
@@ -18,6 +18,7 @@ import tvm
 from tvm import te
 import os
 
+
 def test_plugin_module():
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     mod = tvm.runtime.load_module(os.path.join(curr_path, "lib", "plugin_module.so"))
@@ -36,7 +37,7 @@ def test_plugin_module():
         assert mod["AddOne"](10) == 11
         assert mod["SubOne"](10) == 9
         # advanced usecase: return a module
-        mymod = mod["CreateMyModule"](10);
+        mymod = mod["CreateMyModule"](10)
         fadd = mymod["add"]
         assert fadd(10) == 20
         assert mymod["mul"](10) == 100
index 1df304a..0315a8f 100644 (file)
 from __future__ import absolute_import
 import os
 import ctypes
+
 # Import TVM first to get library symbols
 import tvm
 from tvm import te
 
+
 def load_lib():
     """Load library, the functions will be registered into TVM"""
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     # load in as global so the global extern symbol is visible to other dll.
-    lib = ctypes.CDLL(
-        os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL)
+    lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL)
     return lib
 
+
 _LIB = load_lib()
 
 # Expose two functions into python
@@ -39,9 +41,11 @@ sym_add = tvm.get_global_func("tvm_ext.sym_add")
 ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
 ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
 
+
 @tvm.register_object("tvm_ext.IntVector")
 class IntVec(tvm.Object):
     """Example for using extension class in c++ """
+
     @property
     def _tvm_handle(self):
         return self.handle.value
@@ -54,6 +58,7 @@ nd_create = tvm.get_global_func("tvm_ext.nd_create")
 nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
 nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
 
+
 @tvm.register_object("tvm_ext.NDSubClass")
 class NDSubClass(tvm.nd.NDArrayBase):
     """Example for subclassing TVM's NDArray infrastructure.
index defac94..c73e820 100644 (file)
@@ -21,17 +21,21 @@ import tvm.testing
 from tvm import te
 import numpy as np
 
+
 def test_bind_add():
     def add(a, b):
         return a + b
+
     f = tvm_ext.bind_add(add, 1)
-    assert f(2)  == 3
+    assert f(2) == 3
+
 
 def test_ext_dev():
     n = 10
-    A = te.placeholder((n,), name='A')
-    B = te.compute((n,), lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute((n,), lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
+
     def check_llvm():
         if not tvm.testing.device_enabled("llvm"):
             return
@@ -42,39 +46,41 @@ def test_ext_dev():
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         f(a, b)
         tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
+
     check_llvm()
 
 
 def test_sym_add():
-    a = te.var('a')
-    b = te.var('b')
+    a = te.var("a")
+    b = te.var("b")
     c = tvm_ext.sym_add(a, b)
     assert c.a == a and c.b == b
 
 
 def test_ext_vec():
     ivec = tvm_ext.ivec_create(1, 2, 3)
-    assert(isinstance(ivec, tvm_ext.IntVec))
+    assert isinstance(ivec, tvm_ext.IntVec)
     assert ivec[0] == 1
     assert ivec[1] == 2
 
     def ivec_cb(v2):
-        assert(isinstance(v2, tvm_ext.IntVec))
+        assert isinstance(v2, tvm_ext.IntVec)
         assert v2[2] == 3
 
     tvm.runtime.convert(ivec_cb)(ivec)
 
 
 def test_extract_ext():
-    fdict = tvm._ffi.registry.extract_ext_funcs(
-        tvm_ext._LIB.TVMExtDeclare)
+    fdict = tvm._ffi.registry.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
     assert fdict["mul"](3, 4) == 12
 
 
 def test_extern_call():
     n = 10
-    A = te.placeholder((n,), name='A')
-    B = te.compute((n,), lambda *i: tvm.tir.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(
+        (n,), lambda *i: tvm.tir.call_extern("float32", "TVMTestAddOne", A(*i)), name="B"
+    )
     s = te.create_schedule(B.op)
 
     def check_llvm():
@@ -87,6 +93,7 @@ def test_extern_call():
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         f(a, b)
         tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
+
     check_llvm()
 
 
@@ -97,11 +104,11 @@ def test_nd_subclass():
     c = a + b
     d = a + a
     e = b + b
-    assert(a.additional_info == 3)
-    assert(b.additional_info == 5)
-    assert(c.additional_info == 8)
-    assert(d.additional_info == 6)
-    assert(e.additional_info == 10)
+    assert a.additional_info == 3
+    assert b.additional_info == 5
+    assert c.additional_info == 8
+    assert d.additional_info == 6
+    assert e.additional_info == 10
 
 
 if __name__ == "__main__":
index 88d9f8e..db3c944 100644 (file)
@@ -19,10 +19,11 @@ import tvm
 from tvm import te
 import os
 
+
 def prepare_test_libs(base_path):
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
     # Compile library as dynamic library
     fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
@@ -34,6 +35,7 @@ def prepare_test_libs(base_path):
     syslib_path = os.path.join(base_path, "test_addone_sys.o")
     fadd_syslib.save(syslib_path)
 
+
 if __name__ == "__main__":
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     prepare_test_libs(os.path.join(curr_path, "./lib"))
index 2a44325..0a67803 100644 (file)
@@ -22,27 +22,28 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 def verify(mod, fname):
-  # Get the function from the module
-  f = mod.get_function(fname)
-  # Use tvm.nd.array to convert numpy ndarray to tvm
-  # NDArray type, so that function can be invoked normally
-  N = 10
-  x = tvm.nd.array(np.arange(N, dtype=np.float32))
-  y = tvm.nd.array(np.zeros(N, dtype=np.float32))
-  # Invoke the function
-  f(x, y)
-  np_x = x.asnumpy()
-  np_y = y.asnumpy()
-  # Verify correctness of function
-  assert(np.all([xi+1 == yi for xi, yi in zip(np_x, np_y)]))
-  print("Finish verification...")
+    # Get the function from the module
+    f = mod.get_function(fname)
+    # Use tvm.nd.array to convert numpy ndarray to tvm
+    # NDArray type, so that function can be invoked normally
+    N = 10
+    x = tvm.nd.array(np.arange(N, dtype=np.float32))
+    y = tvm.nd.array(np.zeros(N, dtype=np.float32))
+    # Invoke the function
+    f(x, y)
+    np_x = x.asnumpy()
+    np_y = y.asnumpy()
+    # Verify correctness of function
+    assert np.all([xi + 1 == yi for xi, yi in zip(np_x, np_y)])
+    print("Finish verification...")
 
 
 if __name__ == "__main__":
-  # The normal dynamic loading method for deployment
-  mod_dylib = tvm.runtime.load_module("lib/test_addone_dll.so")
-  print("Verify dynamic loading from test_addone_dll.so")
-  verify(mod_dylib, "addone")
-  # There might be methods to use the system lib way in
-  # python, but dynamic loading is good enough for now.
+    # The normal dynamic loading method for deployment
+    mod_dylib = tvm.runtime.load_module("lib/test_addone_dll.so")
+    print("Verify dynamic loading from test_addone_dll.so")
+    verify(mod_dylib, "addone")
+    # There might be methods to use the system lib way in
+    # python, but dynamic loading is good enough for now.
index b0b2c77..deee86c 100644 (file)
@@ -18,21 +18,35 @@ import argparse
 import re
 
 default_team_id = "3FR42MXLK9"
-default_bundle_identifier = 'org.apache.tvmrpc'
+default_bundle_identifier = "org.apache.tvmrpc"
 
-parser = argparse.ArgumentParser(description='Update tvmrpc.xcodeproj\
- developer information')
-parser.add_argument('--team_id', type=str, required=True,
-                    help='Apple Developer Team ID.\n\
+parser = argparse.ArgumentParser(
+    description="Update tvmrpc.xcodeproj\
+ developer information"
+)
+parser.add_argument(
+    "--team_id",
+    type=str,
+    required=True,
+    help="Apple Developer Team ID.\n\
                     Can be found here:\n\
                     \n\
                     https://developer.apple.com/account/#/membership\n\
-                    (example: {})'.format(default_team_id))
+                    (example: {})".format(
+        default_team_id
+    ),
+)
 
-parser.add_argument('--bundle_identifier', type=str, required=False,
-                    default=default_bundle_identifier,
-                    help='The new bundle identifier\n\
-                    (example: {})'.format(default_bundle_identifier))
+parser.add_argument(
+    "--bundle_identifier",
+    type=str,
+    required=False,
+    default=default_bundle_identifier,
+    help="The new bundle identifier\n\
+                    (example: {})".format(
+        default_bundle_identifier
+    ),
+)
 
 args = parser.parse_args()
 team_id = args.team_id
index 1ce3651..642c7da 100644 (file)
@@ -48,8 +48,8 @@ proxy_port = 9090
 key = "iphone"
 
 # Change target configuration, this is setting for iphone6s
-#arch = "x86_64"
-#sdk = "iphonesimulator"
+# arch = "x86_64"
+# sdk = "iphonesimulator"
 arch = "arm64"
 sdk = "iphoneos"
 target_host = "llvm -mtriple=%s-apple-darwin" % arch
@@ -59,25 +59,30 @@ target_host = "llvm -mtriple=%s-apple-darwin" % arch
 def compile_metal(src):
     return xcode.compile_metal(src, sdk=sdk)
 
+
 def prepare_input():
-    img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-    img_name = 'cat.png'
-    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                          '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                          '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                          'imagenet1000_clsid_to_human.txt'])
-    synset_name = 'imagenet1000_clsid_to_human.txt'
-    img_path = download_testdata(img_url, 'cat.png', module='data')
-    synset_path = download_testdata(synset_url, synset_name, module='data')
+    img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+    img_name = "cat.png"
+    synset_url = "".join(
+        [
+            "https://gist.githubusercontent.com/zhreshold/",
+            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+            "imagenet1000_clsid_to_human.txt",
+        ]
+    )
+    synset_name = "imagenet1000_clsid_to_human.txt"
+    img_path = download_testdata(img_url, "cat.png", module="data")
+    synset_path = download_testdata(synset_url, synset_name, module="data")
     with open(synset_path) as f:
         synset = eval(f.read())
         image = Image.open(img_path).resize((224, 224))
 
-    image = np.array(image) - np.array([123., 117., 104.])
+    image = np.array(image) - np.array([123.0, 117.0, 104.0])
     image /= np.array([58.395, 57.12, 57.375])
     image = image.transpose((2, 0, 1))
     image = image[np.newaxis, :]
-    return image.astype('float32'), synset
+    return image.astype("float32"), synset
 
 
 def get_model(model_name, data_shape):
@@ -85,7 +90,9 @@ def get_model(model_name, data_shape):
     mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
     # we want a probability so add a softmax operator
     func = mod["main"]
-    func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
+    func = relay.Function(
+        func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs
+    )
 
     return func, params
 
@@ -93,19 +100,19 @@ def get_model(model_name, data_shape):
 def test_mobilenet():
     temp = util.tempdir()
     image, synset = prepare_input()
-    model, params = get_model('mobilenetv2_1.0', image.shape)
+    model, params = get_model("mobilenetv2_1.0", image.shape)
 
     def run(mod, target):
         with relay.build_config(opt_level=3):
-            graph, lib, _params = relay.build(mod, target=target,
-                                             target_host=target_host, params=params)
+            graph, lib, _params = relay.build(
+                mod, target=target, target_host=target_host, params=params
+            )
         path_dso = temp.relpath("deploy.dylib")
         lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
         xcode.codesign(path_dso)
 
         # Start RPC test server that contains the compiled library.
-        xcode.popen_test_rpc(proxy_host, proxy_port, key,
-                             destination=destination, libs=[path_dso])
+        xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination, libs=[path_dso])
 
         # connect to the proxy
         remote = rpc.connect(proxy_host, proxy_port, key=key)
@@ -117,12 +124,12 @@ def test_mobilenet():
         lib = remote.load_module("deploy.dylib")
         m = graph_runtime.create(graph, lib, ctx)
 
-        m.set_input('data', tvm.nd.array(image, ctx))
+        m.set_input("data", tvm.nd.array(image, ctx))
         m.set_input(**_params)
         m.run()
         tvm_output = m.get_output(0)
         top1 = np.argmax(tvm_output.asnumpy()[0])
-        print('TVM prediction top-1:', top1, synset[top1])
+        print("TVM prediction top-1:", top1, synset[top1])
 
         # evaluate
         ftimer = m.module.time_evaluator("run", ctx, number=3, repeat=10)
@@ -146,14 +153,16 @@ def test_mobilenet():
         mod = tvm.IRModule()
         mod["main"] = func
 
-        seq = tvm.transform.Sequential([
-            transform.SimplifyInference(),
-            transform.FoldConstant(),
-            transform.FoldScaleAxis(),
-            transform.AnnotateTarget(compiler),
-            transform.MergeCompilerRegions(),
-            transform.PartitionGraph()
-        ])
+        seq = tvm.transform.Sequential(
+            [
+                transform.SimplifyInference(),
+                transform.FoldConstant(),
+                transform.FoldScaleAxis(),
+                transform.AnnotateTarget(compiler),
+                transform.MergeCompilerRegions(),
+                transform.PartitionGraph(),
+            ]
+        )
 
         with relay.build_config(opt_level=3):
             mod = seq(mod)
@@ -167,5 +176,6 @@ def test_mobilenet():
     # CoreML
     run(annotate(model, "coremlcompiler"), target_host)
 
+
 if __name__ == "__main__":
     test_mobilenet()
index 181f3c0..620fe49 100644 (file)
@@ -53,11 +53,12 @@ target = "llvm -mtriple=%s-apple-darwin" % arch
 def compile_metal(src):
     return xcode.compile_metal(src, sdk=sdk)
 
+
 def test_rpc_module():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     temp = util.tempdir()
     s = te.create_schedule(B.op)
     xo, xi = s[B].split(B.op.axis[0], factor=64)
@@ -67,8 +68,7 @@ def test_rpc_module():
     # If we don't want to do metal and only use cpu, just set target to be target
     f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd")
     path_dso1 = temp.relpath("dev_lib.dylib")
-    f.export_library(path_dso1, xcode.create_dylib,
-                     arch=arch, sdk=sdk)
+    f.export_library(path_dso1, xcode.create_dylib, arch=arch, sdk=sdk)
     xcode.codesign(path_dso1)
 
     s = te.create_schedule(B.op)
@@ -78,14 +78,13 @@ def test_rpc_module():
     s[B].pragma(xi, "parallel_barrier_when_finish")
     f = tvm.build(s, [A, B], target, name="myadd_cpu")
     path_dso2 = temp.relpath("cpu_lib.dylib")
-    f.export_library(path_dso2, xcode.create_dylib,
-                     arch=arch, sdk=sdk)
+    f.export_library(path_dso2, xcode.create_dylib, arch=arch, sdk=sdk)
     xcode.codesign(path_dso2)
 
     # Start RPC test server that contains the compiled library.
-    server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
-                                  destination=destination,
-                                  libs=[path_dso1, path_dso2])
+    server = xcode.popen_test_rpc(
+        proxy_host, proxy_port, key, destination=destination, libs=[path_dso1, path_dso2]
+    )
 
     # connect to the proxy
     remote = rpc.connect(proxy_host, proxy_port, key=key)
@@ -96,7 +95,7 @@ def test_rpc_module():
     b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
     time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
     cost = time_f(a, b).mean
-    print('%g secs/op' % cost)
+    print("%g secs/op" % cost)
     np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
     # CPU
     ctx = remote.cpu(0)
@@ -106,7 +105,8 @@ def test_rpc_module():
     b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
     time_f = f2.time_evaluator(f1.entry_name, ctx, number=10)
     cost = time_f(a, b).mean
-    print('%g secs/op' % cost)
+    print("%g secs/op" % cost)
     np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
+
 test_rpc_module()
index fb5c4de..6ec2ae2 100644 (file)
@@ -117,11 +117,9 @@ def __lldb_init_module(debugger, _):
         "tvm::relay::alter_op_layout::TransformMemorizer",
         "tvm::relay::fold_scale_axis::Message",
         "tvm::relay::fold_scale_axis:BackwardTransformer",
-                ]:
+    ]:
         debugger.HandleCommand(
-            "type summary add -F tvm.NodeRef_SummaryProvider {node} -w tvm".format(
-                node=node
-            )
+            "type summary add -F tvm.NodeRef_SummaryProvider {node} -w tvm".format(node=node)
         )
     debugger.HandleCommand("command script add -f tvm.PrettyPrint pp")
     debugger.HandleCommand("type category enable tvm")
@@ -141,9 +139,7 @@ def _GetContext(debugger):
 
 def PrettyPrint(debugger, command, result, internal_dict):
     ctx = _GetContext(debugger)
-    rc = ctx.EvaluateExpression(
-        "tvm::PrettyPrint({command})".format(command=command)
-    )
+    rc = ctx.EvaluateExpression("tvm::PrettyPrint({command})".format(command=command))
     result.AppendMessage(str(rc))
 
 
@@ -160,9 +156,7 @@ def _EvalExpression(logger, ctx, expr, value_name):
     if err.Fail():
         _log(logger, "_EvalExpression failed: {err}".format(err=err))
         raise EvaluateError(err)
-    _log(
-        logger, "_EvalExpression success: {typename}".format(typename=rc.GetTypeName())
-    )
+    _log(logger, "_EvalExpression success: {typename}".format(typename=rc.GetTypeName()))
     return rc
 
 
@@ -172,9 +166,7 @@ def _EvalExpressionAsString(logger, ctx, expr):
 
 
 def _EvalAsNodeRef(logger, ctx, value):
-    return _EvalExpressionAsString(
-        logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name)
-    )
+    return _EvalExpressionAsString(logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name))
 
 
 def NodeRef_SummaryProvider(value, _):
index 2e41036..8f9034b 100644 (file)
@@ -20,9 +20,11 @@ import sys
 
 import numpy as np
 
+
 def float_bytes(l):
     for i in range(0, len(l), 4):
-        yield l[i:i + 4]
+        yield l[i : i + 4]
+
 
-floats = [struct.unpack('f', f)[0] for f in float_bytes(sys.stdin.buffer.read())]
+floats = [struct.unpack("f", f)[0] for f in float_bytes(sys.stdin.buffer.read())]
 print(np.array(floats))
index b988574..868d3bc 100755 (executable)
@@ -31,26 +31,26 @@ from tvm import te
 
 def main():
     dshape = (1, 28, 28)
-    net, params = relay.testing.mlp.get_workload(batch_size=dshape[0], dtype='float32')
+    net, params = relay.testing.mlp.get_workload(batch_size=dshape[0], dtype="float32")
 
     dshape = (1, 3, 224, 224)
     net, params = relay.testing.resnet.get_workload(
-        layers=18, batch_size=dshape[0], image_shape=dshape[1:])
+        layers=18, batch_size=dshape[0], image_shape=dshape[1:]
+    )
 
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(
-            net, 'llvm --system-lib', params=params)
+        graph, lib, params = relay.build(net, "llvm --system-lib", params=params)
 
     build_dir = osp.abspath(sys.argv[1])
     if not osp.isdir(build_dir):
         os.makedirs(build_dir, exist_ok=True)
 
-    lib.save(osp.join(build_dir, 'model.o'))
-    with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json:
+    lib.save(osp.join(build_dir, "model.o"))
+    with open(osp.join(build_dir, "graph.json"), "w") as f_graph_json:
         f_graph_json.write(graph)
-        with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params:
+        with open(osp.join(build_dir, "params.bin"), "wb") as f_params:
             f_params.write(relay.save_param_dict(params))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
index 1672b58..507bae2 100644 (file)
@@ -33,22 +33,21 @@ def test_use_tvmdso_op():
     def export_cpu_add_lib():
         """create cpu add op lib"""
         n = te.var("n")
-        ph_a = te.placeholder((n,), name='ph_a')
-        ph_b = te.placeholder((n,), name='ph_b')
-        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+        ph_a = te.placeholder((n,), name="ph_a")
+        ph_b = te.placeholder((n,), name="ph_b")
+        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name="ph_c")
         sched = te.create_schedule(ph_c.op)
         fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")
         lib_path = tempfile.mktemp("tvm_add_dll.so")
         fadd_dylib.export_library(lib_path)
         return lib_path
 
-
     def export_gpu_add_lib():
         """create gpu add op lib"""
         n = te.var("n")
-        ph_a = te.placeholder((n,), name='ph_a')
-        ph_b = te.placeholder((n,), name='ph_b')
-        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
+        ph_a = te.placeholder((n,), name="ph_a")
+        ph_b = te.placeholder((n,), name="ph_b")
+        ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name="ph_c")
         sched = te.create_schedule(ph_c.op)
         b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64)
         sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x"))
@@ -58,7 +57,6 @@ def test_use_tvmdso_op():
         fadd_dylib.export_library(lib_path)
         return lib_path
 
-
     def test_add(session, lib_path, tf_device):
         """test add lib with TensorFlow wrapper"""
         module = tf_op.OpModule(lib_path)
@@ -83,7 +81,6 @@ def test_use_tvmdso_op():
             output3 = session.run(add3(left, right), feed_dict)
             np.testing.assert_equal(output3, expect)
 
-
     def cpu_test(session):
         """test function for cpu"""
         cpu_lib = None
@@ -94,7 +91,6 @@ def test_use_tvmdso_op():
             if cpu_lib is not None:
                 os.remove(cpu_lib)
 
-
     def gpu_test(session):
         """test function for gpu"""
         gpu_lib = None
index 7303191..e7b5c3a 100644 (file)
@@ -50,8 +50,12 @@ def tvm_callback_cuda_postproc(code):
 
 def test_broadcast_to(in_shape, out_shape):
     global TASK
-    TASK = "bcast_to_i" + "_".join([str(ele) for ele in in_shape])\
-           + "o" + "_".join([str(ele) for ele in out_shape])
+    TASK = (
+        "bcast_to_i"
+        + "_".join([str(ele) for ele in in_shape])
+        + "o"
+        + "_".join([str(ele) for ele in out_shape])
+    )
     # Build the logic and compile the function
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.broadcast_to(A, out_shape)
@@ -70,9 +74,14 @@ def test_broadcast_to(in_shape, out_shape):
 
 def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
     global TASK
-    TASK = "bcast_binary_" + typ + "_lhs" +\
-           "_".join([str(ele) for ele in lhs_shape]) +\
-           "rhs" + "_".join([str(ele) for ele in rhs_shape])
+    TASK = (
+        "bcast_binary_"
+        + typ
+        + "_lhs"
+        + "_".join([str(ele) for ele in lhs_shape])
+        + "rhs"
+        + "_".join([str(ele) for ele in rhs_shape])
+    )
     A = te.placeholder(shape=lhs_shape, name="A")
     B = te.placeholder(shape=rhs_shape, name="B")
     if typ == "add":
@@ -117,8 +126,8 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
 
 if __name__ == "__main__":
     test_broadcast_to((1,), (10,))
-    test_broadcast_to((1, 1, 5, 4),  (3, 4, 4, 4, 5, 4))
-    test_broadcast_to((1, 128, 1, 32),  (64, 128, 64, 32))
+    test_broadcast_to((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
+    test_broadcast_to((1, 128, 1, 32), (64, 128, 64, 32))
     test_broadcast_binary_op((5, 2, 3), (2, 1), typ="add")
     test_broadcast_binary_op((5, 64, 128), (2, 5, 64, 1), typ="mul")
     test_broadcast_binary_op((2, 3, 1, 32), (64, 32), typ="div")
index c5f8b07..036f1a4 100644 (file)
@@ -23,20 +23,26 @@ from tvm.contrib import nvcc
 
 from tvm import topi
 from tvm.topi.util import get_const_tuple
-from tvm.topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
+from tvm.topi.cuda.depthwise_conv2d import (
+    schedule_depthwise_conv2d_nchw,
+    schedule_depthwise_conv2d_nhwc,
+)
 
 TASK = "depthwise_conv2d"
 USE_MANUAL_CODE = False
 
+
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     ptx = nvcc.compile_cuda(code, target="ptx")
     return ptx
 
+
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
 
+
 @tvm.register_func
 def tvm_callback_cuda_postproc(code):
     if not os.path.exists("perf"):
@@ -46,6 +52,7 @@ def tvm_callback_cuda_postproc(code):
         code = open("perf/%s_manual.cu" % TASK).read()
     return code
 
+
 def test_depthwise_conv2d_nchw():
     """You may test different settings."""
     batch = 1
@@ -61,14 +68,16 @@ def test_depthwise_conv2d_nchw():
     stride_h = 1
     stride_w = 1
 
-    padding = 'SAME' # or 'VALID'
+    padding = "SAME"  # or 'VALID'
 
     # Placeholder
-    Input = te.placeholder((batch, in_channel, in_height, in_width), name='Input')
-    Filter = te.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
+    Input = te.placeholder((batch, in_channel, in_height, in_width), name="Input")
+    Filter = te.placeholder(
+        (filter_channel, channel_multiplier, filter_height, filter_width), name="Filter"
+    )
     Stride = [stride_h, stride_w]
-    Scale = te.placeholder((in_channel * channel_multiplier,), name='Scale')
-    Shift = te.placeholder((in_channel * channel_multiplier,), name='Shift')
+    Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale")
+    Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift")
     # Declare
     DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
     ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
@@ -97,8 +106,12 @@ def test_depthwise_conv2d_nchw():
         scale_tvm = tvm.nd.array(scale_np, ctx)
         shift_tvm = tvm.nd.array(shift_np, ctx)
 
-        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
-        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
+        depthwise_conv2d_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx
+        )
+        scale_shift_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx
+        )
         relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
         # Measure time cost of kernel 1 (depthwise_conv2d)
         timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
@@ -114,27 +127,39 @@ def test_depthwise_conv2d_nchw():
         print("Stride = (%d, %d)" % (stride_h, stride_w))
         print("padding = %s\n" % padding)
         print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
-        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1 * 1e6))
+        print(
+            "average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us"
+            % (tcost_2 * 1e6)
+        )
+        print(
+            "average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us"
+            % (tcost_3 * 1e6)
+        )
         # correctness
-        depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
+        depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw(
+            input_np, filter_np, stride=[stride_h, stride_w], padding=padding
+        )
         scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
         for c in range(in_channel * channel_multiplier):
-            scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
+            scale_shift_scipy[:, c, :, :] = (
+                depthwise_conv2d_scipy[:, c, :, :] * scale_np[c] + shift_np[c]
+            )
         relu_scipy = np.maximum(scale_shift_scipy, 0)
-        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+        tvm.testing.assert_allclose(
+            depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5
+        )
         tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
         tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
         print("success")
 
-    for device in ['cuda', 'opencl', 'rocm']:
-        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
-            "auto_max_step": 128,
-            "explicit_unroll": device != "rocm"
-        }}):
+    for device in ["cuda", "opencl", "rocm"]:
+        with tvm.transform.PassContext(
+            config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "rocm"}}
+        ):
             check_device(device)
 
+
 def test_depthwise_conv2d_nhwc():
     """You may test different settings."""
     batch = 1
@@ -150,14 +175,16 @@ def test_depthwise_conv2d_nhwc():
     stride_h = 1
     stride_w = 1
 
-    padding = 'SAME' # or 'VALID'
+    padding = "SAME"  # or 'VALID'
 
     # Placeholder
-    Input = te.placeholder((batch, in_height, in_width, in_channel), name='Input')
-    Filter = te.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
+    Input = te.placeholder((batch, in_height, in_width, in_channel), name="Input")
+    Filter = te.placeholder(
+        (filter_height, filter_width, filter_channel, channel_multiplier), name="Filter"
+    )
     Stride = [stride_h, stride_w]
-    Scale = te.placeholder((in_channel * channel_multiplier,), name='Scale')
-    Shift = te.placeholder((in_channel * channel_multiplier,), name='Shift')
+    Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale")
+    Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift")
     # Declare
     DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
     ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
@@ -186,8 +213,12 @@ def test_depthwise_conv2d_nhwc():
         filter_tvm = tvm.nd.array(filter_np, ctx)
         scale_tvm = tvm.nd.array(scale_np, ctx)
         shift_tvm = tvm.nd.array(shift_np, ctx)
-        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
-        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
+        depthwise_conv2d_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx
+        )
+        scale_shift_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx
+        )
         relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
         # Measure time cost of kernel 1 (depthwise_conv2d)
         timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
@@ -203,27 +234,39 @@ def test_depthwise_conv2d_nhwc():
         print("Stride = (%d, %d)" % (stride_h, stride_w))
         print("padding = %s\n" % padding)
         print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
-        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1 * 1e6))
+        print(
+            "average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us"
+            % (tcost_2 * 1e6)
+        )
+        print(
+            "average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us"
+            % (tcost_3 * 1e6)
+        )
         # correctness
-        depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
+        depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc(
+            input_np, filter_np, stride=[stride_h, stride_w], padding=padding
+        )
         scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
         for c in range(in_channel * channel_multiplier):
-            scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
+            scale_shift_scipy[:, :, :, c] = (
+                depthwise_conv2d_scipy[:, :, :, c] * scale_np[c] + shift_np[c]
+            )
         relu_scipy = np.maximum(scale_shift_scipy, 0)
-        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+        tvm.testing.assert_allclose(
+            depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5
+        )
         tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
         tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
         print("success")
 
-    for device in ['cuda', 'opencl', 'rocm']:
-        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
-            "auto_max_step": 128,
-            "explicit_unroll": device != "cuda"
-        }}):
+    for device in ["cuda", "opencl", "rocm"]:
+        with tvm.transform.PassContext(
+            config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}}
+        ):
             check_device(device)
 
+
 if __name__ == "__main__":
     test_depthwise_conv2d_nchw()
     test_depthwise_conv2d_nhwc()
index 605044c..1d2032d 100644 (file)
@@ -27,15 +27,18 @@ from tvm.topi.util import get_const_tuple
 TASK = "conv2d_hwcn_map"
 USE_MANUAL_CODE = False
 
+
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     ptx = nvcc.compile_cuda(code, target="ptx")
     return ptx
 
+
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
 
+
 @tvm.register_func
 def tvm_callback_cuda_postproc(code):
     if not os.path.exists("perf"):
@@ -54,10 +57,10 @@ def test_conv2d_hwcn_map():
     num_filter = 128
     kernel = 3
     stride = 2
-    padding = 'SAME'
+    padding = "SAME"
 
-    A = te.placeholder((in_height, in_width, in_channel, batch), name='A')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
+    A = te.placeholder((in_height, in_width, in_channel, batch), name="A")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
     B = topi.nn.conv2d_hwcn(A, W, stride, padding)
     C = topi.nn.relu(B)
     s1 = topi.cuda.schedule_conv2d_hwcn([B])
@@ -78,10 +81,11 @@ def test_conv2d_hwcn_map():
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
 
-        with tvm.transform.PassContext(config={"tir.UrollLoop": {
-                "auto_unroll_max_step": 128,
-                "explicit_unroll": device == "rocm"
-        }}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.UrollLoop": {"auto_unroll_max_step": 128, "explicit_unroll": device == "rocm"}
+            }
+        ):
             func1 = tvm.build(s1, [A, W, B], device)
             func1(a, w, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
@@ -89,7 +93,7 @@ def test_conv2d_hwcn_map():
             func2(a, w, c)
             tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
-    for device in ['cuda', 'opencl', 'rocm']:
+    for device in ["cuda", "opencl", "rocm"]:
         check_device(device)
 
 
index ba8d0a4..289e69a 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
+# pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
 """ Conv Int8 functional and performance testing"""
 import sys
 import logging
@@ -24,80 +24,118 @@ from tvm import te
 from tvm import topi
 
 logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-LOGGER = logging.getLogger('test_conv_int8_intel')
+LOGGER = logging.getLogger("test_conv_int8_intel")
 LOGGER.disabled = False
 
 # All the WORKLOADS from Resnet except first layer
 # Workload is ['height', 'width', 'in_filter', 'out_filter',
 #              'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
-WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
-             (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
-             (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
-             (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
-             (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
-             (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
-             (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
-             (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
-             (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
-             (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
-             (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
-             (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
-             (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
-             (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
-             (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
-             (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
-             (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
-             (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
-             (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
-             (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
-             (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
-             (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
-            ]
-
-
-TARGET_NAME = 'llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
+WORKLOADS = [
+    (56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+    (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+    (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+    (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+    (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+    (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+    (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+    (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+    (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+    (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+    (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+    (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
+    (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
+    (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
+    (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
+    (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
+    (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
+    (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
+    (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
+    (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
+    (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
+    (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
+]
+
+
+TARGET_NAME = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
 NUM_VEC_LANES = 16
 CTX = tvm.context(TARGET_NAME, 0)
 
-def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
-              hstride, wstride, out_dtype):
+
+def get_shape(
+    im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, hstride, wstride, out_dtype
+):
     """
     Finds out the shape of all data structures
     """
-    data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
-
-    if out_dtype == 'int32' or out_dtype == 'uint32':
-        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
-                        NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
-    elif out_dtype == 'float32':
-        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
-                        NUM_VEC_LANES, NUM_VEC_LANES)
+    data_shape = (1, in_filter // NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
+
+    if out_dtype == "int32" or out_dtype == "uint32":
+        kernel_shape = (
+            out_filter // NUM_VEC_LANES,
+            in_filter // NUM_VEC_LANES,
+            k_h,
+            k_w,
+            NUM_VEC_LANES // 4,
+            NUM_VEC_LANES,
+            4,
+        )
+    elif out_dtype == "float32":
+        kernel_shape = (
+            out_filter // NUM_VEC_LANES,
+            in_filter // NUM_VEC_LANES,
+            k_h,
+            k_w,
+            NUM_VEC_LANES,
+            NUM_VEC_LANES,
+        )
     out_height = (im_height + 2 * hpad - k_h) // hstride + 1
     out_width = (im_width + 2 * wpad - k_w) // wstride + 1
-    o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
+    o_shape = (1, out_filter // NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
     return (data_shape, kernel_shape, o_shape)
 
 
-
-def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
-                  out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
+def run_inference(
+    data_dtype,
+    kernel_dtype,
+    out_dtype,
+    im_height,
+    im_width,
+    in_filter,
+    out_filter,
+    k_h,
+    k_w,
+    hpad,
+    wpad,
+    hstride,
+    wstride,
+):
     """
     Runs the inference and checks the functional correctness between
     compute and schedule outputs
     """
-    (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
-                                                    out_filter, k_h, k_w, hpad, wpad,
-                                                    hstride, wstride, out_dtype)
+    (data_shape, kernel_shape, o_shape) = get_shape(
+        im_height,
+        im_width,
+        in_filter,
+        out_filter,
+        k_h,
+        k_w,
+        hpad,
+        wpad,
+        hstride,
+        wstride,
+        out_dtype,
+    )
 
     # Create TVM placeholders
-    data = te.placeholder(data_shape, name='data', dtype=data_dtype)
-    kernel = te.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
+    data = te.placeholder(data_shape, name="data", dtype=data_dtype)
+    kernel = te.placeholder(kernel_shape, name="kernel", dtype=kernel_dtype)
 
     # Create the numpy arrays to be used for executing conv models
-    if data_dtype == 'float32':
+    if data_dtype == "float32":
         data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
         kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
     else:
@@ -109,19 +147,32 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
     c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
     c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
 
-
     with tvm.target.Target(TARGET_NAME):
         if out_dtype == "float32":
-            conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
-                                        padding=hpad, dilation=(1, 1),
-                                        layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
+            conv = topi.nn.conv2d_NCHWc(
+                data,
+                kernel,
+                stride=hstride,
+                padding=hpad,
+                dilation=(1, 1),
+                layout="NCHWc",
+                out_layout="NCHWc",
+                out_dtype=out_dtype,
+            )
         else:
-            conv = topi.nn.conv2d_NCHWc_int8(data, kernel, strides=hstride,
-                                             padding=hpad, dilation=(1, 1),
-                                             layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
+            conv = topi.nn.conv2d_NCHWc_int8(
+                data,
+                kernel,
+                strides=hstride,
+                padding=hpad,
+                dilation=(1, 1),
+                layout="NCHWc",
+                out_layout="NCHWc",
+                out_dtype=out_dtype,
+            )
         out = topi.nn.relu(conv)
         sch = te.create_schedule(out.op)
-        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
+        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name="out")
         func(data_array, kernel_array, c_orig)
         LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
 
@@ -130,11 +181,11 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
             sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
         else:
             sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out])
-        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
+        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name="conv")
         func(data_array, kernel_array, c_sch)
 
         # Functional check
-        if data_dtype == 'uint8':
+        if data_dtype == "uint8":
             np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
         else:
             assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
@@ -143,17 +194,30 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
         LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
         return evaluator(data_array, kernel_array, c_sch).mean
 
+
 if __name__ == "__main__":
     LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
     SPEEDUP_ARRAY = []
     for i, wkl in enumerate(WORKLOADS):
         for dtype in ["uint", "int"]:
-            fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
-            int8_time = run_inference('%s8' % dtype, '%s8' % dtype, '%s32' % dtype, *wkl)
+            fp32_time = run_inference("float32", "float32", "float32", *wkl)
+            int8_time = run_inference("%s8" % dtype, "%s8" % dtype, "%s32" % dtype, *wkl)
             kernel_h = wkl[4]
             kernel_w = wkl[5]
-            LOGGER.info("[%s] Workload#" % dtype + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
-                        + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
-
-            SPEEDUP_ARRAY.append(fp32_time/int8_time)
-    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))
+            LOGGER.info(
+                "[%s] Workload#" % dtype
+                + str(i)
+                + ", "
+                + str(kernel_h)
+                + "x"
+                + str(kernel_w)
+                + ", "
+                + str(fp32_time)
+                + ", "
+                + str(int8_time)
+                + ", "
+                + str(fp32_time / int8_time)
+            )
+
+            SPEEDUP_ARRAY.append(fp32_time / int8_time)
+    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY) / float(len(SPEEDUP_ARRAY))))
index 3edfc03..562812a 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
+# pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
 """ Conv Int8 functional and performance testing"""
 import sys
 import logging
@@ -24,81 +24,119 @@ from tvm import te
 from tvm import topi
 
 logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-LOGGER = logging.getLogger('test_conv_int8_intel')
+LOGGER = logging.getLogger("test_conv_int8_intel")
 LOGGER.disabled = False
 
 # All the WORKLOADS from Resnet except first layer
 # Workload is ['height', 'width', 'in_filter', 'out_filter',
 #              'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
-WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
-             (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
-             (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
-             (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
-             (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
-             (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
-             (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
-             (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
-             (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
-             (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
-             (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
-             (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
-             (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
-             (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
-             (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
-             (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
-             (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
-             (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
-             (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
-             (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
-             (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
-             (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
-             (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
-            ]
-
-
-TARGET_NAME = 'llvm -mcpu=skylake-avx512'
+WORKLOADS = [
+    (56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+    (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+    (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+    (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+    (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+    (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+    (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+    (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+    (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+    (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+    (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+    (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
+    (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
+    (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
+    (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
+    (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
+    (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
+    (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
+    (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
+    (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
+    (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
+    (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
+    (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
+]
+
+
+TARGET_NAME = "llvm -mcpu=skylake-avx512"
 NUM_VEC_LANES = 16
 CTX = tvm.context(TARGET_NAME, 0)
 
-def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
-              hstride, wstride, out_dtype):
+
+def get_shape(
+    im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, hstride, wstride, out_dtype
+):
     """
     Finds out the shape of all data structures
     """
     ## Find shapes
-    data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
-
-    if out_dtype == 'int32':
-        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
-                        NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
-    elif out_dtype == 'float32':
-        kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
-                        NUM_VEC_LANES, NUM_VEC_LANES)
+    data_shape = (1, in_filter // NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
+
+    if out_dtype == "int32":
+        kernel_shape = (
+            out_filter // NUM_VEC_LANES,
+            in_filter // NUM_VEC_LANES,
+            k_h,
+            k_w,
+            NUM_VEC_LANES // 4,
+            NUM_VEC_LANES,
+            4,
+        )
+    elif out_dtype == "float32":
+        kernel_shape = (
+            out_filter // NUM_VEC_LANES,
+            in_filter // NUM_VEC_LANES,
+            k_h,
+            k_w,
+            NUM_VEC_LANES,
+            NUM_VEC_LANES,
+        )
     out_height = (im_height + 2 * hpad - k_h) // hstride + 1
     out_width = (im_width + 2 * wpad - k_w) // wstride + 1
-    o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
+    o_shape = (1, out_filter // NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
     return (data_shape, kernel_shape, o_shape)
 
 
-
-def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
-                  out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
+def run_inference(
+    data_dtype,
+    kernel_dtype,
+    out_dtype,
+    im_height,
+    im_width,
+    in_filter,
+    out_filter,
+    k_h,
+    k_w,
+    hpad,
+    wpad,
+    hstride,
+    wstride,
+):
     """
     Runs the inference and checks the functional correctness between
     compute and schedule outputs
     """
-    (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
-                                                    out_filter, k_h, k_w, hpad, wpad,
-                                                    hstride, wstride, out_dtype)
+    (data_shape, kernel_shape, o_shape) = get_shape(
+        im_height,
+        im_width,
+        in_filter,
+        out_filter,
+        k_h,
+        k_w,
+        hpad,
+        wpad,
+        hstride,
+        wstride,
+        out_dtype,
+    )
 
     # Create TVM placeholders
-    data = te.placeholder(data_shape, name='data', dtype=data_dtype)
-    kernel = te.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
+    data = te.placeholder(data_shape, name="data", dtype=data_dtype)
+    kernel = te.placeholder(kernel_shape, name="kernel", dtype=kernel_dtype)
 
     # Create the numpy arrays to be used for executing conv models
-    if data_dtype == 'float32':
+    if data_dtype == "float32":
         data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
         kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
     else:
@@ -110,24 +148,30 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
     c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
     c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
 
-
     with tvm.target.Target(TARGET_NAME):
-        conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
-                                    padding=hpad, dilation=(1, 1),
-                                    layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
+        conv = topi.nn.conv2d_NCHWc(
+            data,
+            kernel,
+            stride=hstride,
+            padding=hpad,
+            dilation=(1, 1),
+            layout="NCHWc",
+            out_layout="NCHWc",
+            out_dtype=out_dtype,
+        )
         out = topi.nn.relu(conv)
         sch = te.create_schedule(out.op)
-        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
+        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name="out")
         func(data_array, kernel_array, c_orig)
         LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
 
         # Generate and run the optimized schedule
         sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
-        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
+        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name="conv")
         func(data_array, kernel_array, c_sch)
 
         # Functional check
-        if data_dtype == 'uint8':
+        if data_dtype == "uint8":
             np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
         else:
             assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
@@ -136,16 +180,29 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f
         LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
         return evaluator(data_array, kernel_array, c_sch).mean
 
+
 if __name__ == "__main__":
     LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
     SPEEDUP_ARRAY = []
     for i, wkl in enumerate(WORKLOADS):
-        fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
-        int8_time = run_inference('uint8', 'int8', 'int32', *wkl)
+        fp32_time = run_inference("float32", "float32", "float32", *wkl)
+        int8_time = run_inference("uint8", "int8", "int32", *wkl)
         kernel_h = wkl[4]
         kernel_w = wkl[5]
-        LOGGER.info("Workload#" + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
-                    + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
-
-        SPEEDUP_ARRAY.append(fp32_time/int8_time)
-    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))
+        LOGGER.info(
+            "Workload#"
+            + str(i)
+            + ", "
+            + str(kernel_h)
+            + "x"
+            + str(kernel_w)
+            + ", "
+            + str(fp32_time)
+            + ", "
+            + str(int8_time)
+            + ", "
+            + str(fp32_time / int8_time)
+        )
+
+        SPEEDUP_ARRAY.append(fp32_time / int8_time)
+    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY) / float(len(SPEEDUP_ARRAY))))
index 1443aea..5228188 100644 (file)
@@ -32,10 +32,14 @@ key = "android"
 arch = "arm64"
 target = "llvm -mtriple=%s-linux-android" % arch
 
+
 def ngflops(N):
-    return 2.0 * float(N * N * N) / (10**9)
+    return 2.0 * float(N * N * N) / (10 ** 9)
+
+
+dtype = "float32"
+
 
-dtype = 'float32'
 def evaluate(func, ctx, N, times):
     a_np = np.random.uniform(size=(N, N)).astype(dtype)
     b_np = np.random.uniform(size=(N, N)).astype(dtype)
@@ -46,24 +50,23 @@ def evaluate(func, ctx, N, times):
     time_f = func.time_evaluator(func.entry_name, ctx, number=times)
     cost = time_f(a, b, c).mean
     gf = ngflops(N) / cost
-    print('%g secs/op, %g GFLOPS' % (cost, gf))
+    print("%g secs/op, %g GFLOPS" % (cost, gf))
     np.testing.assert_almost_equal(c.asnumpy(), a_np.dot(b_np), decimal=2)
 
+
 def test_gemm_gpu(N, times, bn, num_block, num_thread):
-    assert(bn <= N)
-    assert(num_thread * num_thread * 16 <= N)
-    assert(num_block * num_block * 2 <= N)
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='Btmp')
-    k = te.reduce_axis((0, N), name='k')
+    assert bn <= N
+    assert num_thread * num_thread * 16 <= N
+    assert num_block * num_block * 2 <= N
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="Btmp")
+    k = te.reduce_axis((0, N), name="k")
 
-    packedB = te.compute((N, N / bn, bn),
-              lambda x, y, z: B[x, y * bn + z], name = 'B')
+    packedB = te.compute((N, N / bn, bn), lambda x, y, z: B[x, y * bn + z], name="B")
 
     C = te.compute(
-        (N, N),
-        lambda ii, jj: te.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k),
-        name='C')
+        (N, N), lambda ii, jj: te.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k), name="C"
+    )
 
     s = te.create_schedule(C.op)
     CC = s.cache_write(C, "local")
@@ -130,5 +133,6 @@ def test_gemm_gpu(N, times, bn, num_block, num_thread):
 
     evaluate(f, ctx, N, times)
 
+
 if __name__ == "__main__":
     test_gemm_gpu(1024, times=5, bn=8, num_block=2, num_thread=8)
index b35cd60..25d14f9 100644 (file)
@@ -22,18 +22,21 @@ from tvm.contrib import nvcc
 from tvm.contrib import spirv
 import numpy as np
 
-TASK="gemm"
+TASK = "gemm"
 USE_MANUAL_CODE = False
 
+
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
-    ptx =  nvcc.compile_cuda(code, target="ptx")
+    ptx = nvcc.compile_cuda(code, target="ptx")
     return ptx
 
+
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
 
+
 @tvm.register_func
 def tvm_callback_cuda_postproc(code):
     if not os.path.exists("perf"):
@@ -47,16 +50,13 @@ def tvm_callback_cuda_postproc(code):
 def test_gemm():
     # graph
     nn = 2048
-    n = te.var('n')
+    n = te.var("n")
     n = tvm.runtime.convert(nn)
     m, l = n, n
-    A = te.placeholder((l, n), name='A')
-    B = te.placeholder((l, m), name='B')
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute(
-        (m, n),
-        lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k),
-        name='C')
+    A = te.placeholder((l, n), name="A")
+    B = te.placeholder((l, m), name="B")
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C")
 
     # schedule
     s = te.create_schedule(C.op)
@@ -135,8 +135,7 @@ def test_gemm():
         c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
         for i in range(2):
             f(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
+        tvm.testing.assert_allclose(c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
 
         num_flops = 2 * nn * nn * nn
         num_runs = 10
@@ -146,11 +145,11 @@ def test_gemm():
         print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
 
     for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
-        with tvm.transform.PassContext(config={"tir.UnrollLoop": {
-            "auto_max_step": 128,
-            "explicit_unroll": device != "cuda"
-        }}):
+        with tvm.transform.PassContext(
+            config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}}
+        ):
             check_device(device)
 
+
 if __name__ == "__main__":
     test_gemm()
index 9362d71..0d0941d 100644 (file)
@@ -26,54 +26,65 @@ from tvm.topi.cuda.tensor_intrin import dp4a
 DO_TUNING = True
 PRETUNED_INDEX = 75333
 
-intrin_dp4a = dp4a('local', 'local', 'local')
+intrin_dp4a = dp4a("local", "local", "local")
+
 
 @autotvm.template
 def gemm_int8(n, m, l):
-    A = te.placeholder((n, l), name='A', dtype='int8')
-    B = te.placeholder((m, l), name='B', dtype='int8')
-
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute((n, m), lambda i, j: te.sum(A[i, k].astype('int32') * B[j, k].astype(
-        'int32'), axis=k), name='C')
+    A = te.placeholder((n, l), name="A", dtype="int8")
+    B = te.placeholder((m, l), name="B", dtype="int8")
+
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute(
+        (n, m),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
 
     cfg = autotvm.get_config()
     s = te.create_schedule(C.op)
     y, x = C.op.axis
 
-    AA = s.cache_read(A, 'shared', [C])
-    BB = s.cache_read(B, 'shared', [C])
-    AL = s.cache_read(AA, 'local', [C])
-    BL = s.cache_read(BB, 'local', [C])
-    CC = s.cache_write(C, 'local')
+    AA = s.cache_read(A, "shared", [C])
+    BB = s.cache_read(B, "shared", [C])
+    AL = s.cache_read(AA, "local", [C])
+    BL = s.cache_read(BB, "local", [C])
+    CC = s.cache_write(C, "local")
 
     k = CC.op.reduce_axis[0]
 
-    cfg.define_split('tile_k', cfg.axis(k), num_outputs=3,
-                     filter=lambda entity: entity.size[2] == 4 and \
-                     entity.size[0] * 2 >= entity.size[1])
+    cfg.define_split(
+        "tile_k",
+        cfg.axis(k),
+        num_outputs=3,
+        filter=lambda entity: entity.size[2] == 4 and entity.size[0] * 2 >= entity.size[1],
+    )
 
-    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
+    ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
 
     s[CC].tensorize(ki, intrin_dp4a)
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
 
     def block_size_filter(entity):
-        return entity.size[0] * 2 >= entity.size[1] * 2 and \
-                entity.size[1] <= 16 and entity.size[3] <= 4
-    cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter)
-    cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter)
-    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y)
-    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x)
+        return (
+            entity.size[0] * 2 >= entity.size[1] * 2
+            and entity.size[1] <= 16
+            and entity.size[3] <= 4
+        )
+
+    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4, filter=block_size_filter)
+    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4, filter=block_size_filter)
+    by, tyz, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, txz, tx, xi = cfg["tile_x"].apply(s, C, x)
 
     s[C].bind(by, block_y)
     s[C].bind(bx, block_x)
-    s[C].bind(tyz, te.thread_axis('vthread'))
-    s[C].bind(txz, te.thread_axis('vthread'))
+    s[C].bind(tyz, te.thread_axis("vthread"))
+    s[C].bind(txz, te.thread_axis("vthread"))
     s[C].bind(ty, thread_y)
     s[C].bind(tx, thread_x)
     s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
@@ -90,51 +101,53 @@ def gemm_int8(n, m, l):
         s[stage].vectorize(xi)
         s[stage].double_buffer()
 
-    cfg.define_knob('storage_align', [16, 48])
+    cfg.define_knob("storage_align", [16, 48])
     for stage in [AA, BB]:
-        s[stage].storage_align(s[stage].op.axis[0],
-                               cfg['storage_align'].val, 0)
+        s[stage].storage_align(s[stage].op.axis[0], cfg["storage_align"].val, 0)
         s[stage].compute_at(s[CC], ko)
 
         fused = s[stage].fuse(*s[stage].op.axis)
-        ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2])
-        tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2])
+        ty, tx = s[stage].split(fused, nparts=cfg["tile_y"].size[2])
+        tx, xi = s[stage].split(tx, nparts=cfg["tile_x"].size[2])
         _, xi = s[stage].split(xi, factor=16)
 
         s[stage].bind(ty, thread_y)
         s[stage].bind(tx, thread_x)
         s[stage].vectorize(xi)
 
-    cfg.define_knob('auto_unroll_max_step', [512, 1500])
-    s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[C].pragma(by, 'unroll_explicit', False)
+    cfg.define_knob("auto_unroll_max_step", [512, 1500])
+    s[C].pragma(by, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[C].pragma(by, "unroll_explicit", False)
 
-    cfg.add_flop(n*m*l*2)
+    cfg.add_flop(n * m * l * 2)
     return s, [A, B, C]
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     N = 2048
     n = m = l = N
 
     logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
-    task = autotvm.task.create(gemm_int8, args=(n, m, l), target='cuda')
+    task = autotvm.task.create(gemm_int8, args=(n, m, l), target="cuda")
     print(task.config_space)
 
     measure_option = autotvm.measure_option(
         builder=autotvm.LocalBuilder(),
-        runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
+        runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4),
     )
 
-    log_name = 'gemm_int8.log'
+    log_name = "gemm_int8.log"
     if DO_TUNING:
         tuner = autotvm.tuner.XGBTuner(task)
-        tuner.tune(n_trial=1000, measure_option=measure_option,
-                   callbacks=[autotvm.callback.log_to_file(log_name)])
+        tuner.tune(
+            n_trial=1000,
+            measure_option=measure_option,
+            callbacks=[autotvm.callback.log_to_file(log_name)],
+        )
 
         dispatch_context = autotvm.apply_history_best(log_name)
         best_config = dispatch_context.query(task.target, task.workload)
-        print('\nBest config:')
+        print("\nBest config:")
         print(best_config)
     else:
         config = task.config_space.get(PRETUNED_INDEX)
@@ -143,31 +156,27 @@ if __name__ == '__main__':
         print(config)
 
     with dispatch_context:
-        with tvm.target.Target('cuda'):
+        with tvm.target.Target("cuda"):
             s, arg_bufs = gemm_int8(n, m, l)
-            f = tvm.build(s, arg_bufs, 'cuda', name='gemm_int8')
+            f = tvm.build(s, arg_bufs, "cuda", name="gemm_int8")
 
-    ctx = tvm.context('cuda', 0)
+    ctx = tvm.context("cuda", 0)
 
-    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype='int8')
-    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype='int8')
+    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype="int8")
+    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype="int8")
 
     a = tvm.nd.array(a_np, ctx)
     b = tvm.nd.array(b_np, ctx)
-    c = tvm.nd.array(np.zeros((n, m), dtype='int32'), ctx)
+    c = tvm.nd.array(np.zeros((n, m), dtype="int32"), ctx)
     f(a, b, c)
 
     tvm.testing.assert_allclose(
-        c.asnumpy(),
-        np.dot(
-            a_np.astype('int32'),
-            b_np.T.astype('int32')),
-        rtol=1e-5)
+        c.asnumpy(), np.dot(a_np.astype("int32"), b_np.T.astype("int32")), rtol=1e-5
+    )
 
     num_ops = 2 * l * m * n
     num_runs = 1000
     timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
     t = timer_f(a, b, c).mean
     GOPS = num_ops / (t * 1e3) / 1e6
-    print("average time cost of %d runs = %g ms, %g GOPS." %
-          (num_runs, t * 1e3, GOPS))
+    print("average time cost of %d runs = %g ms, %g GOPS." % (num_runs, t * 1e3, GOPS))
index b6d0602..00b1ca2 100644 (file)
@@ -47,20 +47,24 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
     # Build the logic and compile the function
     A = te.placeholder(shape=in_shape, name="A")
     if type == "sum":
-        TASK = "sum_map_id%d" %test_id
+        TASK = "sum_map_id%d" % test_id
         B = topi.sum(A, axis=axis, keepdims=keepdims)
     elif type == "max":
-        TASK = "max_map_id%d" %test_id
+        TASK = "max_map_id%d" % test_id
         B = topi.max(A, axis=axis, keepdims=keepdims)
     elif type == "min":
-        TASK = "min_map_id%d" %test_id
+        TASK = "min_map_id%d" % test_id
         B = topi.min(A, axis=axis, keepdims=keepdims)
     else:
         raise NotImplementedError
     s = topi.cuda.schedule_reduce(B)
-    with tvm.transform.PassContext(config={"tir.UnrollLoop": {
-        "auto_max_step": 16,
-    }}):
+    with tvm.transform.PassContext(
+        config={
+            "tir.UnrollLoop": {
+                "auto_max_step": 16,
+            }
+        }
+    ):
         fcuda = tvm.build(s, [A, B], "cuda", name="sum")
 
     # Test
@@ -81,24 +85,11 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
         fcuda(data_tvm, out_tvm)
     tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, rtol=4e-4, atol=4e-4)
 
+
 if __name__ == "__main__":
-    test_reduce_map(in_shape=(128, 24, 128, 24),
-                    axis=(1, 2, 3),
-                    keepdims=True,
-                    type="sum",
-                    test_id=0)
-    test_reduce_map(in_shape=(128, 24 * 128 * 24),
-                    axis=(1,),
-                    keepdims=False,
-                    type="max",
-                    test_id=1)
-    test_reduce_map(in_shape=(32, 128, 24),
-                    axis=None,
-                    keepdims=True,
-                    type="sum",
-                    test_id=2)
-    test_reduce_map(in_shape=(128, 24, 128, 24),
-                    axis=(0, 2),
-                    keepdims=False,
-                    type="min",
-                    test_id=3)
+    test_reduce_map(
+        in_shape=(128, 24, 128, 24), axis=(1, 2, 3), keepdims=True, type="sum", test_id=0
+    )
+    test_reduce_map(in_shape=(128, 24 * 128 * 24), axis=(1,), keepdims=False, type="max", test_id=1)
+    test_reduce_map(in_shape=(32, 128, 24), axis=None, keepdims=True, type="sum", test_id=2)
+    test_reduce_map(in_shape=(128, 24, 128, 24), axis=(0, 2), keepdims=False, type="min", test_id=3)
index be46d89..701797e 100644 (file)
@@ -22,7 +22,7 @@ from tvm.contrib import nvcc
 import numpy as np
 
 # Quick knobs
-TASK="lstm"
+TASK = "lstm"
 USE_MANUAL_CODE = False
 PERSIST_KERNEL = True
 DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
@@ -33,7 +33,7 @@ UNROLL_WLOAD = True
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     """Use nvcc compiler for better perf."""
-    ptx =  nvcc.compile_cuda(code, target="ptx")
+    ptx = nvcc.compile_cuda(code, target="ptx")
     return ptx
 
 
@@ -59,7 +59,7 @@ def lstm():
     num_thread_x = 16 * 3 // 2
     num_sm = 24
     n_num_step = 128
-    num_step = te.var('num_step')
+    num_step = te.var("num_step")
     num_hidden = 1152 // 2
     batch_size = 1
     # Global transition matrix
@@ -70,30 +70,35 @@ def lstm():
     # h: output hidden state, c: cell state.
     s_state_h = te.placeholder((num_step, batch_size, num_hidden))
     s_state_c = te.placeholder((num_step, batch_size, num_hidden))
-    s_init_c = te.compute((1, batch_size, num_hidden),
-                           lambda *i: 0.0, name="init_c")
-    s_init_h = te.compute((1, batch_size, num_hidden),
-                           lambda *i: 0.0, name="init_h")
+    s_init_c = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c")
+    s_init_h = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h")
     # LSTM transition
     k = te.reduce_axis((0, num_hidden), name="ki2h")
     s_h2h = te.compute(
         (num_step, batch_size, 4, num_hidden),
         lambda t, i, x, j: te.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
-        name="s_h2h")
+        name="s_h2h",
+    )
     # Gate rules
-    gates = te.compute(Xi2h.shape, lambda *i:
-                        Xi2h(*i) + s_h2h(*i), name="gates")
+    gates = te.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates")
     gshape = (num_step, batch_size, num_hidden)
     in_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 0, j]), name="in_gate")
-    in_transform = te.compute(gshape, lambda t, i, j: te.tanh(gates[t, i, 1, j]), name="in_transform")
-    forget_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 2, j]), name="forget_gate")
+    in_transform = te.compute(
+        gshape, lambda t, i, j: te.tanh(gates[t, i, 1, j]), name="in_transform"
+    )
+    forget_gate = te.compute(
+        gshape, lambda t, i, j: te.sigmoid(gates[t, i, 2, j]), name="forget_gate"
+    )
     out_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, i, 3, j]), name="out_gate")
-    next_c = te.compute(gshape,
-                         lambda t, i, j:
-                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
-                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
-    next_h = te.compute(gshape,
-                         lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h")
+    next_c = te.compute(
+        gshape,
+        lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j]
+        + in_gate[t, i, j] * in_transform[t, i, j],
+        name="next_c",
+    )
+    next_h = te.compute(
+        gshape, lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h"
+    )
     update_c = te.compute(gshape, lambda *i: next_c(*i), name="update_c")
     update_h = te.compute(gshape, lambda *i: next_h(*i), name="update_h")
     # schedule
@@ -102,7 +107,8 @@ def lstm():
         [update_h, update_c],
         [s_state_h, s_state_c],
         inputs=[Xi2h],
-        name="lstm_scan")
+        name="lstm_scan",
+    )
     # schedule
     s = te.create_schedule(scan_h.op)
     # Inline gate computations
@@ -164,18 +170,13 @@ def lstm():
     # verify we can lower correctly
     def check_device(target):
         num_step = n_num_step
-        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
-                          target)
+        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
         ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
         # launch the kernel.
-        scan_h_np = np.zeros(
-            (num_step, batch_size, num_hidden)).astype("float32")
-        scan_c_np = np.zeros(
-            (num_step, batch_size, num_hidden)).astype("float32")
-        Xi2h_np = np.random.normal(
-            size=(num_step, batch_size, 4, num_hidden)).astype("float32")
-        Wh2h_np = np.random.normal(
-            size=(4, num_hidden, num_hidden)).astype("float32")
+        scan_h_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
+        scan_c_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
+        Xi2h_np = np.random.normal(size=(num_step, batch_size, 4, num_hidden)).astype("float32")
+        Wh2h_np = np.random.normal(size=(4, num_hidden, num_hidden)).astype("float32")
         scan_h_a = tvm.nd.array(scan_h_np, ctx)
         scan_c_a = tvm.nd.array(scan_c_np, ctx)
         Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
@@ -188,13 +189,16 @@ def lstm():
         print("Time cost=%g" % eval_result.mean)
 
     # set unroll_explicit for more readable code.
-    with tvm.transform.PassContext(config={
-        "tir.UnrollLoop": {
-            "auto_max_step": 128,
-        },
-        "tir.detect_global_barrier": DETECT_GLOBAL_BARRIER
-    }):
+    with tvm.transform.PassContext(
+        config={
+            "tir.UnrollLoop": {
+                "auto_max_step": 128,
+            },
+            "tir.detect_global_barrier": DETECT_GLOBAL_BARRIER,
+        }
+    ):
         check_device("cuda")
 
+
 if __name__ == "__main__":
     lstm()
index 444e27f..e2cea9b 100644 (file)
@@ -32,22 +32,25 @@ from tvm.contrib import nvcc
 import numpy as np
 
 # Quick knobs
-TASK="matexp"
+TASK = "matexp"
 USE_MANUAL_CODE = False
 PERSIST_KERNEL = True
 DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
 SKIP_CHECK = False
 
+
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     """Use nvcc compiler for better perf."""
-    ptx =  nvcc.compile_cuda(code, target="ptx")
+    ptx = nvcc.compile_cuda(code, target="ptx")
     return ptx
 
+
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
 
+
 @tvm.register_func
 def tvm_callback_cuda_postproc(code):
     if not os.path.exists("perf"):
@@ -57,6 +60,7 @@ def tvm_callback_cuda_postproc(code):
         code = open("perf/%s_manual.cu" % TASK).read()
     return code
 
+
 def rnn_matexp():
     n_num_step = 128
     n_num_hidden = 1152
@@ -71,14 +75,14 @@ def rnn_matexp():
     num_sm = 24
 
     Whh = te.placeholder((num_hidden, num_hidden), name="Whh")
-    s_init = te.compute((1, batch_size, num_hidden),
-                         lambda _, i, j: 1.0, name="init")
+    s_init = te.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init")
     s_state = te.placeholder((num_step, batch_size, num_hidden))
     kh = te.reduce_axis((0, num_hidden), name="kh")
     s_update = te.compute(
         (num_step, batch_size, num_hidden),
-        lambda t, i, j: te.sum(s_state[t-1, i, kh] * Whh[kh, j], axis=kh),
-        name="update")
+        lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh),
+        name="update",
+    )
     s_scan = tvm.te.scan(s_init, s_update, s_state)
     # schedule
     s = te.create_schedule(s_scan.op)
@@ -127,20 +131,21 @@ def rnn_matexp():
     s[SS].bind(tx, thread_x)
 
     def check_device(target):
-        with tvm.transform.PassContext(config={
-            "tir.UnrollLoop": {
-                "auto_max_step": 128,
-            },
-            "tir.detect_global_barrier": detect_global_barrier
-        }):
+        with tvm.transform.PassContext(
+            config={
+                "tir.UnrollLoop": {
+                    "auto_max_step": 128,
+                },
+                "tir.detect_global_barrier": detect_global_barrier,
+            }
+        ):
             f = tvm.build(s, [s_scan, Whh], target)
         ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
         # launch the kernel.
-        res_np = np.zeros(
-            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
+        res_np = np.zeros((n_num_step, n_batch_size, n_num_hidden)).astype("float32")
         Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
         Whh_np[:] = 2.0 / n_num_hidden
-        Whh_np[:, n_num_hidden//2:] = 0
+        Whh_np[:, n_num_hidden // 2 :] = 0
 
         res_a = tvm.nd.array(res_np, ctx)
         Whh_a = tvm.nd.array(Whh_np, ctx)
@@ -160,12 +165,14 @@ def rnn_matexp():
             Whh_np = Whh_np.astype("float64")
             for t in range(1, n_num_step):
                 res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
-            for i  in range(n_num_step):
+            for i in range(n_num_step):
                 for j in range(n_num_hidden):
-                    if abs(res_cmp[i,0,j] - res_gpu[i,0,j]) > 1e-5:
-                        print("%d, %d: %g vs %g" % (i,j, res_cmp[i,0,j], res_gpu[i,0,j]))
+                    if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
+                        print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
             tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
+
     check_device("cuda")
 
+
 if __name__ == "__main__":
     rnn_matexp()
index 78f80fa..cfea02a 100644 (file)
@@ -44,30 +44,35 @@ def build_graph_lib(model_file, opt_level):
 
     # Compile the relay mod
     mod, params = _get_mod_and_params(model_file)
-    target = 'llvm -target=wasm32-unknown-unknown -mattr=+simd128 --system-lib'
+    target = "llvm -target=wasm32-unknown-unknown -mattr=+simd128 --system-lib"
     with tvm.transform.PassContext(opt_level=opt_level):
         graph_json, lib, params = relay.build(mod, target=target, params=params)
 
     # Save the model artifacts to obj_file
-    obj_file = os.path.join(out_dir, 'graph.o')
+    obj_file = os.path.join(out_dir, "graph.o")
     lib.save(obj_file)
     # Run llvm-ar to archive obj_file into lib_file
-    lib_file = os.path.join(out_dir, 'libgraph_wasm32.a')
-    cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), 'rcs', lib_file, obj_file]
+    lib_file = os.path.join(out_dir, "libgraph_wasm32.a")
+    cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file]
     subprocess.run(cmds)
 
-    with open(os.path.join(out_dir, 'graph.json'), 'w') as f_graph:
+    with open(os.path.join(out_dir, "graph.json"), "w") as f_graph:
         f_graph.write(graph_json)
 
-    with open(os.path.join(out_dir, 'graph.params'), 'wb') as f_params:
+    with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
         f_params.write(relay.save_param_dict(params))
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='ONNX model build example')
-    parser.add_argument('model_file', type=str, help='the path of onnx model file')
-    parser.add_argument('-O', '--opt-level', type=int, default=0,
-                        help='level of optimization. 0 is unoptimized and 3 is the highest level')
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="ONNX model build example")
+    parser.add_argument("model_file", type=str, help="the path of onnx model file")
+    parser.add_argument(
+        "-O",
+        "--opt-level",
+        type=int,
+        default=0,
+        help="level of optimization. 0 is unoptimized and 3 is the highest level",
+    )
     args = parser.parse_args()
 
     build_graph_lib(args.model_file, args.opt_level)
index 8057892..efd6169 100644 (file)
@@ -21,7 +21,7 @@ import subprocess
 
 from jinja2 import Template
 
-CUDA_VERSIONS = ['10.0', '9.0']
+CUDA_VERSIONS = ["10.0", "9.0"]
 
 
 # Make sure that the cudnn version you set here is available
@@ -29,8 +29,8 @@ CUDA_VERSIONS = ['10.0', '9.0']
 # and from conda.
 
 # These two must be in sync
-CUDNN_FULL_VERSION = '7.6.0.64'
-CUDNN_VERSION = '7.6.0'
+CUDNN_FULL_VERSION = "7.6.0.64"
+CUDNN_VERSION = "7.6.0"
 
 
 condadir = os.path.dirname(sys.argv[0])
@@ -38,22 +38,21 @@ condadir = os.path.abspath(condadir)
 srcdir = os.path.dirname(condadir)
 
 
-with open(os.path.join(condadir, 'Dockerfile.template')) as f:
+with open(os.path.join(condadir, "Dockerfile.template")) as f:
     docker_template = Template(f.read())
 
 
 def render_dockerfile(version):
-    txt = docker_template.render(cuda_version=version,
-                                 cudnn_short_version=CUDNN_VERSION,
-                                 cudnn_version=CUDNN_FULL_VERSION)
-    fname = os.path.join(condadir,
-                         '../docker/Dockerfile.conda_cuda' + version.replace('.', ''))
-    with open(fname, 'w') as f:
+    txt = docker_template.render(
+        cuda_version=version, cudnn_short_version=CUDNN_VERSION, cudnn_version=CUDNN_FULL_VERSION
+    )
+    fname = os.path.join(condadir, "../docker/Dockerfile.conda_cuda" + version.replace(".", ""))
+    with open(fname, "w") as f:
         f.write(txt)
     return fname
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     build_versions = CUDA_VERSIONS
     if len(sys.argv) > 1:
         build_versions = sys.argv[1:]
index edf1a73..133a832 100644 (file)
 import tvm.testing
 from pytest import ExitCode
 
+
 def pytest_configure(config):
     print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
     print("pytest marker:", config.option.markexpr)
 
+
 def pytest_sessionfinish(session, exitstatus):
     # Don't exit with an error if we select a subset of tests that doesn't
     # include anything
-    if session.config.option.markexpr != '':
+    if session.config.option.markexpr != "":
         if exitstatus == ExitCode.NO_TESTS_COLLECTED:
             session.exitstatus = ExitCode.OK
index c03f1b7..18d8960 100644 (file)
@@ -39,53 +39,54 @@ import sphinx_gallery
 # add these directories to sys.path here. If the directory is relative to the
 # documentation root, use os.path.abspath to make it absolute, like shown here.
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../python/'))
-sys.path.insert(0, os.path.join(curr_path, '../vta/python'))
+sys.path.insert(0, os.path.join(curr_path, "../python/"))
+sys.path.insert(0, os.path.join(curr_path, "../vta/python"))
 
 # -- General configuration ------------------------------------------------
 
 # General information about the project.
-project = u'tvm'
-author = u'Apache Software Foundation'
-copyright = u'2020, %s' % author
-github_doc_root = 'https://github.com/apache/incubator-tvm/tree/master/docs/'
+project = "tvm"
+author = "Apache Software Foundation"
+copyright = "2020, %s" % author
+github_doc_root = "https://github.com/apache/incubator-tvm/tree/master/docs/"
 
-os.environ['TVM_BUILD_DOC'] = '1'
+os.environ["TVM_BUILD_DOC"] = "1"
 # Version information.
 import tvm
 from tvm import topi
 from tvm import te
+
 version = tvm.__version__
 release = tvm.__version__
 
 # Add any Sphinx extension module names here, as strings. They can be
 # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones
 extensions = [
-    'sphinx.ext.autodoc',
-    'sphinx.ext.autosummary',
-    'sphinx.ext.intersphinx',
-    'sphinx.ext.napoleon',
-    'sphinx.ext.mathjax',
-    'sphinx_gallery.gen_gallery',
-    'autodocsumm'
+    "sphinx.ext.autodoc",
+    "sphinx.ext.autosummary",
+    "sphinx.ext.intersphinx",
+    "sphinx.ext.napoleon",
+    "sphinx.ext.mathjax",
+    "sphinx_gallery.gen_gallery",
+    "autodocsumm",
 ]
 
 # Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
 
 # The suffix(es) of source filenames.
 # You can specify multiple suffix as a list of string:
 # source_suffix = ['.rst', '.md']
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
 
 # The encoding of source files.
-#source_encoding = 'utf-8-sig'
+# source_encoding = 'utf-8-sig'
 
 # generate autosummary even if no references
 autosummary_generate = True
 
 # The master toctree document.
-master_doc = 'index'
+master_doc = "index"
 
 # The language for content autogenerated by Sphinx. Refer to documentation
 # for a list of supported languages.
@@ -96,37 +97,37 @@ language = None
 
 # There are two options for replacing |today|: either, you set today to some
 # non-false value, then it is used:
-#today = ''
+# today = ''
 # Else, today_fmt is used as the format for a strftime call.
-#today_fmt = '%B %d, %Y'
+# today_fmt = '%B %d, %Y'
 
 # List of patterns, relative to source directory, that match files and
 # directories to ignore when looking for source files.
-exclude_patterns = ['_build']
+exclude_patterns = ["_build"]
 
 # The reST default role (used for this markup: `text`) to use for all
 # documents.
-#default_role = None
+# default_role = None
 
 # If true, '()' will be appended to :func: etc. cross-reference text.
-#add_function_parentheses = True
+# add_function_parentheses = True
 
 # If true, the current module name will be prepended to all description
 # unit titles (such as .. function::).
-#add_module_names = True
+# add_module_names = True
 
 # If true, sectionauthor and moduleauthor directives will be shown in the
 # output. They are ignored by default.
-#show_authors = False
+# show_authors = False
 
 # The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
 
 # A list of ignored prefixes for module index sorting.
-#modindex_common_prefix = []
+# modindex_common_prefix = []
 
 # If true, keep warnings as "system message" paragraphs in the built documents.
-#keep_warnings = False
+# keep_warnings = False
 
 # If true, `todo` and `todoList` produce output, else they produce nothing.
 todo_include_todos = False
@@ -134,23 +135,24 @@ todo_include_todos = False
 # -- Options for HTML output ----------------------------------------------
 
 # The theme is set by the make target
-html_theme = os.environ.get('TVM_THEME', 'rtd')
+html_theme = os.environ.get("TVM_THEME", "rtd")
 
-on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
+on_rtd = os.environ.get("READTHEDOCS", None) == "True"
 # only import rtd theme and set it if want to build docs locally
-if not on_rtd and html_theme == 'rtd':
+if not on_rtd and html_theme == "rtd":
     import sphinx_rtd_theme
-    html_theme = 'sphinx_rtd_theme'
+
+    html_theme = "sphinx_rtd_theme"
     html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
 
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
 
 html_theme_options = {
-    'analytics_id': 'UA-75982049-2',
-    'logo_only': True,
+    "analytics_id": "UA-75982049-2",
+    "logo_only": True,
 }
 
 html_logo = "_static/img/tvm-logo-small.png"
@@ -159,25 +161,23 @@ html_favicon = "_static/img/tvm-logo-square.png"
 
 
 # Output file base name for HTML help builder.
-htmlhelp_basename = project + 'doc'
+htmlhelp_basename = project + "doc"
 
 # -- Options for LaTeX output ---------------------------------------------
-latex_elements = {
-}
+latex_elements = {}
 
 # Grouping the document tree into LaTeX files. List of tuples
 # (source start file, target name, title,
 #  author, documentclass [howto, manual, or own class]).
 latex_documents = [
-  (master_doc, '%s.tex' % project, project,
-   author, 'manual'),
+    (master_doc, "%s.tex" % project, project, author, "manual"),
 ]
 
 intersphinx_mapping = {
-    'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
-    'numpy': ('https://numpy.org/doc/stable', None),
-    'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
-    'matplotlib': ('https://matplotlib.org/', None),
+    "python": ("https://docs.python.org/{.major}".format(sys.version_info), None),
+    "numpy": ("https://numpy.org/doc/stable", None),
+    "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+    "matplotlib": ("https://matplotlib.org/", None),
 }
 
 from sphinx_gallery.sorting import ExplicitOrder
@@ -186,39 +186,42 @@ examples_dirs = ["../tutorials/", "../vta/tutorials/"]
 gallery_dirs = ["tutorials", "vta/tutorials"]
 
 subsection_order = ExplicitOrder(
-    ['../tutorials/get_started',
-     '../tutorials/frontend',
-     '../tutorials/language',
-     '../tutorials/optimize',
-     '../tutorials/autotvm',
-     '../tutorials/dev',
-     '../tutorials/topi',
-     '../tutorials/deployment',
-     '../tutorials/micro',
-     '../vta/tutorials/frontend',
-     '../vta/tutorials/optimize',
-     '../vta/tutorials/autotvm'])
+    [
+        "../tutorials/get_started",
+        "../tutorials/frontend",
+        "../tutorials/language",
+        "../tutorials/optimize",
+        "../tutorials/autotvm",
+        "../tutorials/dev",
+        "../tutorials/topi",
+        "../tutorials/deployment",
+        "../tutorials/micro",
+        "../vta/tutorials/frontend",
+        "../vta/tutorials/optimize",
+        "../vta/tutorials/autotvm",
+    ]
+)
 
 sphinx_gallery_conf = {
-    'backreferences_dir': 'gen_modules/backreferences',
-    'doc_module': ('tvm', 'numpy'),
-    'reference_url': {
-        'tvm': None,
-        'matplotlib': 'https://matplotlib.org/',
-        'numpy': 'https://numpy.org/doc/stable'
+    "backreferences_dir": "gen_modules/backreferences",
+    "doc_module": ("tvm", "numpy"),
+    "reference_url": {
+        "tvm": None,
+        "matplotlib": "https://matplotlib.org/",
+        "numpy": "https://numpy.org/doc/stable",
     },
-    'examples_dirs': examples_dirs,
-    'gallery_dirs': gallery_dirs,
-    'subsection_order': subsection_order,
-    'filename_pattern': os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", ".py"),
-    'find_mayavi_figures': False,
-    'download_all_examples': False,
+    "examples_dirs": examples_dirs,
+    "gallery_dirs": gallery_dirs,
+    "subsection_order": subsection_order,
+    "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", ".py"),
+    "find_mayavi_figures": False,
+    "download_all_examples": False,
     "min_reported_time": 60,
-    'expected_failing_examples': []
+    "expected_failing_examples": [],
 }
 
 autodoc_default_options = {
-    'member-order': 'bysource',
+    "member-order": "bysource",
 }
 
 # Maps the original namespace to list of potential modules
@@ -229,6 +232,7 @@ tvm_alias_check_map = {
     "tvm.relay": ["tvm.ir", "tvm.tir"],
 }
 
+
 def update_alias_docstring(name, obj, lines):
     """Update the docstring of alias functions.
 
@@ -265,8 +269,7 @@ def update_alias_docstring(name, obj, lines):
 
         if hasattr(sys.modules[amod], target_name):
             obj_type = ":py:func" if callable(obj) else ":py:class"
-            lines.append(
-                ".. rubric:: Alias of %s:`%s.%s`" % (obj_type, amod, target_name))
+            lines.append(".. rubric:: Alias of %s:`%s.%s`" % (obj_type, amod, target_name))
 
 
 def process_docstring(app, what, name, obj, options, lines):
@@ -276,5 +279,5 @@ def process_docstring(app, what, name, obj, options, lines):
 
 
 def setup(app):
-    app.connect('autodoc-process-docstring', process_docstring)
-    app.add_css_file('css/tvm_theme.css')
+    app.connect("autodoc-process-docstring", process_docstring)
+    app.add_css_file("css/tvm_theme.css")
index d523b9c..a0553cf 100644 (file)
@@ -26,15 +26,15 @@ import numpy as np
 
 # Global declarations of environment.
 
-tgt_host="llvm"
-tgt="llvm"
+tgt_host = "llvm"
+tgt = "llvm"
 
 ######################################################################
 # Describe the Computation
 # ------------------------
 n = te.var("n")
-A = te.placeholder((n,), name='A')
-B = te.placeholder((n,), name='B')
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 
 ######################################################################
index d4dcf21..b82e0c4 100644 (file)
@@ -25,13 +25,14 @@ from tvm.contrib.download import download_testdata
 # ----------------------------------------------
 def extract(path):
     import tarfile
+
     if path.endswith("tgz") or path.endswith("gz"):
         dir_path = os.path.dirname(path)
         tar = tarfile.open(path)
         tar.extractall(path=dir_path)
         tar.close()
     else:
-        raise RuntimeError('Could not decompress the file: ' + path)
+        raise RuntimeError("Could not decompress the file: " + path)
 
 
 ###################################
@@ -39,7 +40,7 @@ def extract(path):
 # ---------------------------------
 
 model_url = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz"
-model_path = download_testdata(model_url, "mobilenet_v2_1.4_224.tgz", module=['tf', 'official'])
+model_path = download_testdata(model_url, "mobilenet_v2_1.4_224.tgz", module=["tf", "official"])
 model_dir = os.path.dirname(model_path)
 extract(model_path)
 
@@ -50,9 +51,11 @@ model_file = os.path.join(model_dir, "mobilenet_v2_1.4_224.tflite")
 tflite_model_buf = open(model_file, "rb").read()
 try:
     import tflite
+
     tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
 except AttributeError:
     import tflite.Model
+
     tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
 
 
@@ -66,30 +69,29 @@ input_shape = (1, 224, 224, 3)
 input_dtype = "float32"
 
 # parse TFLite model and convert into Relay computation graph
-mod, params = relay.frontend.from_tflite(tflite_model,
-                                         shape_dict={input_tensor: input_shape},
-                                         dtype_dict={input_tensor: input_dtype})
+mod, params = relay.frontend.from_tflite(
+    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
+)
 
 #############
 # Compilation
 # -----------
 
-target = 'llvm'
+target = "llvm"
 
 # Build with Relay
 with transform.PassContext(opt_level=3):
-    graph, lib, params = relay.build_module.build(
-        mod, target, params=params)
+    graph, lib, params = relay.build_module.build(mod, target, params=params)
 
 ###############################################
 # Save the graph, lib and parameters into files
 # ---------------------------------------------
 
 lib.export_library("./mobilenet.so")
-print('lib export succeefully')
+print("lib export succeefully")
 
 with open("./mobilenet.json", "w") as fo:
-   fo.write(graph)
+    fo.write(graph)
 
 with open("./mobilenet.params", "wb") as fo:
-   fo.write(relay.save_param_dict(params))
+    fo.write(relay.save_param_dict(params))
index bda66f8..40edd08 100644 (file)
@@ -20,20 +20,24 @@ import tvm
 from tvm import te
 from tvm.contrib import cc, util
 
+
 def test_add(target_dir):
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
     C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
     s = te.create_schedule(C.op)
     fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd")
 
     fadd.save(os.path.join(target_dir, "add_cpu.o"))
-    cc.create_shared(os.path.join(target_dir, "add_cpu.so"),
-            [os.path.join(target_dir, "add_cpu.o")])
+    cc.create_shared(
+        os.path.join(target_dir, "add_cpu.so"), [os.path.join(target_dir, "add_cpu.o")]
+    )
+
 
 if __name__ == "__main__":
     import sys
+
     if len(sys.argv) != 2:
         sys.exit(-1)
     test_add(sys.argv[1])
index d520054..7983930 100644 (file)
@@ -20,13 +20,14 @@ import tvm
 from tvm import te
 from tvm.contrib import cc, util
 
+
 def test_add(target_dir):
     if not tvm.runtime.enabled("cuda"):
         print("skip %s because cuda is not enabled..." % __file__)
         return
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
     C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 
     s = te.create_schedule(C.op)
@@ -38,11 +39,14 @@ def test_add(target_dir):
 
     fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
     fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
-    cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
-            [os.path.join(target_dir, "add_gpu.o")])
+    cc.create_shared(
+        os.path.join(target_dir, "add_gpu.so"), [os.path.join(target_dir, "add_gpu.o")]
+    )
+
 
 if __name__ == "__main__":
     import sys
+
     if len(sys.argv) != 2:
         sys.exit(-1)
     test_add(sys.argv[1])
index 63a76d1..07a19fe 100644 (file)
@@ -21,34 +21,37 @@ from tvm import te
 import json
 from tvm.contrib import graph_runtime
 
+
 def dump_graph_lib(target_dir):
     dim = 4
-    A = te.placeholder((dim,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((dim,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     sched = te.create_schedule(B.op)
 
     node0 = {"op": "null", "name": "x", "inputs": []}
-    node1 = {"op": "tvm_op", "name": "add",
-             "inputs": [[0, 0, 0]],
-             "attrs": {"func_name": "myadd",
-                       "flatten_data": "1",
-                       "num_inputs" : "1",
-                    "num_outputs" : "1"}}
+    node1 = {
+        "op": "tvm_op",
+        "name": "add",
+        "inputs": [[0, 0, 0]],
+        "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"},
+    }
     nodes = [node0, node1]
     arg_nodes = [0]
     node_row_ptr = [0, 1, 2]
     outputs = [[1, 0, 0]]
     shape = (4,)
     attrs = {
-        "shape" : ["list_shape", [shape, shape]],
-        "dltype" : ["list_str", ["float32", "float32"]],
-        "storage_id" : ["list_int", [0, 1]],
+        "shape": ["list_shape", [shape, shape]],
+        "dltype": ["list_str", ["float32", "float32"]],
+        "storage_id": ["list_int", [0, 1]],
+    }
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": outputs,
+        "attrs": attrs,
     }
-    graph = {"nodes": nodes,
-             "arg_nodes": arg_nodes,
-             "node_row_ptr": node_row_ptr,
-             "heads": outputs,
-             "attrs": attrs}
 
     graph = json.dumps(graph)
     mlib = tvm.build(sched, [A, B], "llvm", name="myadd")
@@ -57,8 +60,10 @@ def dump_graph_lib(target_dir):
     with open(os.path.join(target_dir, "graph_addone.json"), "w") as fo:
         fo.write(graph)
 
+
 if __name__ == "__main__":
     import sys
+
     if len(sys.argv) != 2:
         sys.exit(-1)
     dump_graph_lib(sys.argv[1])
index 68dd19e..d3e23e1 100644 (file)
 import time
 from tvm.rpc import proxy
 
+
 def start_proxy_server(port, timeout):
-    prox = proxy.Proxy("localhost", port=port, port_end=port+1)
+    prox = proxy.Proxy("localhost", port=port, port_end=port + 1)
     if timeout > 0:
         import time
+
         time.sleep(timeout)
         prox.terminate()
     else:
         prox.proc.join()
 
+
 if __name__ == "__main__":
     import sys
+
     if len(sys.argv) < 2:
         sys.exit(-1)
     port = int(sys.argv[1])
index bb2fdf4..96b6006 100644 (file)
@@ -19,80 +19,96 @@ import sys
 import os.path, re, StringIO
 
 blacklist = [
-    'Windows.h',
-    'mach/clock.h', 'mach/mach.h',
-    'malloc.h',
-    'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
-    'sys/stat.h', 'sys/types.h',
-    'omp.h', 'execinfo.h', 'packet/sse-inl.h'
-    ]
+    "Windows.h",
+    "mach/clock.h",
+    "mach/mach.h",
+    "malloc.h",
+    "glog/logging.h",
+    "io/azure_filesys.h",
+    "io/hdfs_filesys.h",
+    "io/s3_filesys.h",
+    "sys/stat.h",
+    "sys/types.h",
+    "omp.h",
+    "execinfo.h",
+    "packet/sse-inl.h",
+]
 
 
 def get_sources(def_file):
     sources = []
     files = []
     visited = set()
-    mxnet_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
+    mxnet_path = os.path.abspath(
+        os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
+    )
     for line in open(def_file):
-        files = files + line.strip().split(' ')
+        files = files + line.strip().split(" ")
 
     for f in files:
         f = f.strip()
-        if not f or f.endswith('.o:') or f == '\\': continue
+        if not f or f.endswith(".o:") or f == "\\":
+            continue
         fn = os.path.relpath(f)
         if os.path.abspath(f).startswith(mxnet_path) and fn not in visited:
             sources.append(fn)
             visited.add(fn)
     return sources
 
+
 sources = get_sources(sys.argv[1])
 
+
 def find_source(name, start):
     candidates = []
     for x in sources:
-        if x == name or x.endswith('/' + name): candidates.append(x)
-    if not candidates: return ''
-    if len(candidates) == 1: return candidates[0]
+        if x == name or x.endswith("/" + name):
+            candidates.append(x)
+    if not candidates:
+        return ""
+    if len(candidates) == 1:
+        return candidates[0]
     for x in candidates:
-        if x.split('/')[1] == start.split('/')[1]: return x
-    return ''
+        if x.split("/")[1] == start.split("/")[1]:
+            return x
+    return ""
 
 
-re1 = re.compile('<([./a-zA-Z0-9_-]*)>')
+re1 = re.compile("<([./a-zA-Z0-9_-]*)>")
 re2 = re.compile('"([./a-zA-Z0-9_-]*)"')
 
 sysheaders = []
 history = set([])
 out = StringIO.StringIO()
 
+
 def expand(x, pending):
-    if x in history and x not in ['mshadow/mshadow/expr_scalar-inl.h']: # MULTIPLE includes
+    if x in history and x not in ["mshadow/mshadow/expr_scalar-inl.h"]:  # MULTIPLE includes
         return
 
     if x in pending:
-        #print('loop found: %s in ' % x, pending)
+        # print('loop found: %s in ' % x, pending)
         return
 
     print("//===== EXPANDING: %s =====\n" % x, file=out)
     for line in open(x):
-        if line.find('#include') < 0:
+        if line.find("#include") < 0:
             out.write(line)
             continue
-        if line.strip().find('#include') > 0:
+        if line.strip().find("#include") > 0:
             print(line)
             continue
         m = re1.search(line)
-        if not m: m = re2.search(line)
         if not m:
-            print(line + ' not found')
+            m = re2.search(line)
+        if not m:
+            print(line + " not found")
             continue
-        h = m.groups()[0].strip('./')
+        h = m.groups()[0].strip("./")
         source = find_source(h, x)
         if not source:
-            if (h not in blacklist and
-                h not in sysheaders and
-                'mkl' not in h and
-                'nnpack' not in h): sysheaders.append(h)
+            if h not in blacklist and h not in sysheaders and "mkl" not in h and "nnpack" not in h:
+                sysheaders.append(h)
         else:
             expand(source, pending + [x])
     print("//===== EXPANDED: %s =====\n" % x, file=out)
@@ -101,17 +117,15 @@ def expand(x, pending):
 
 expand(sys.argv[2], [])
 
-f = open(sys.argv[3], 'wb')
-
+f = open(sys.argv[3], "wb")
 
 
 for k in sorted(sysheaders):
     print("#include <%s>" % k, file=f)
 
-print('', file=f)
+print("", file=f)
 print(out.getvalue(), file=f)
 
 for x in sources:
-    if x not in history and not x.endswith('.o'):
-        print('Not processed:', x)
-
+    if x not in history and not x.endswith(".o"):
+        print("Not processed:", x)
index bbd026c..cfe8ee4 100644 (file)
@@ -22,13 +22,12 @@ FOLDERS = ["core", "pass", "c_api"]
 fo = open(sys.argv[1], "w")
 
 
-
 for folder in FOLDERS:
     path = str(os.path.join("../src", folder))
     flst = os.listdir(path)
     for f in flst:
-       if f.endswith(".cc") == True:
-               fo.write('#include "' + str(os.path.join("src", folder, f)) + '"\n')
+        if f.endswith(".cc") == True:
+            fo.write('#include "' + str(os.path.join("src", folder, f)) + '"\n')
 
 
 fo.close()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644 (file)
index 0000000..6c8cfdc
--- /dev/null
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[tool.black]
+line-length = 100
+target-version = ['py36']
+include = '(\.pyi?$)'
+exclude = '''
+
+(
+  /(
+      \.github
+    | \.tvm
+    | \.tvm_test_data
+    | \.vscode
+    | \.venv
+    | 3rdparty\/
+    | build\/
+    | cmake\/
+    | conda\/
+    | docker\/
+    | docs\/
+    | golang\/
+    | include\/
+    | jvm\/
+    | licenses\/
+    | nnvm\/
+    | rust\/
+    | src\/
+    | vta\/
+    | web\/
+  )/|tests/lint/add_asf_header.py|tests/lint/check_file_type.py|python/tvm/topi/testing/pool3d_python.py|python/topi/python/test_topi_pooling.py
+)
+'''
index bef1f37..402d993 100644 (file)
@@ -41,12 +41,12 @@ def get_lib_path():
     """Get library path, name and version"""
     # We can not import `libinfo.py` in setup.py directly since __init__.py
     # Will be invoked which introduces dependences
-    libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
-    libinfo = {'__file__': libinfo_py}
-    exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
-    version = libinfo['__version__']
-    if not os.getenv('CONDA_BUILD'):
-        lib_path = libinfo['find_lib_path']()
+    libinfo_py = os.path.join(CURRENT_DIR, "./tvm/_ffi/libinfo.py")
+    libinfo = {"__file__": libinfo_py}
+    exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo)
+    version = libinfo["__version__"]
+    if not os.getenv("CONDA_BUILD"):
+        lib_path = libinfo["find_lib_path"]()
         libs = [lib_path[0]]
         if libs[0].find("runtime") == -1:
             for name in lib_path[1:]:
@@ -63,7 +63,7 @@ LIB_LIST, __version__ = get_lib_path()
 
 def config_cython():
     """Try to configure cython and return cython configuration"""
-    if os.name == 'nt':
+    if os.name == "nt":
         print("WARNING: Cython is not supported on Windows, will compile without cython module")
         return []
     sys_cflags = sysconfig.get_config_var("CFLAGS")
@@ -73,6 +73,7 @@ def config_cython():
         return []
     try:
         from Cython.Build import cythonize
+
         # from setuptools.extension import Extension
         if sys.version_info >= (3, 0):
             subdir = "_cy3"
@@ -80,26 +81,30 @@ def config_cython():
             subdir = "_cy2"
         ret = []
         path = "tvm/_ffi/_cython"
-        if os.name == 'nt':
-            library_dirs = ['tvm', '../build/Release', '../build']
-            libraries = ['libtvm']
+        if os.name == "nt":
+            library_dirs = ["tvm", "../build/Release", "../build"]
+            libraries = ["libtvm"]
         else:
             library_dirs = None
             libraries = None
         for fn in os.listdir(path):
             if not fn.endswith(".pyx"):
                 continue
-            ret.append(Extension(
-                "tvm._ffi.%s.%s" % (subdir, fn[:-4]),
-                ["tvm/_ffi/_cython/%s" % fn],
-                include_dirs=["../include/",
-                              "../3rdparty/dmlc-core/include",
-                              "../3rdparty/dlpack/include",
-                ],
-                extra_compile_args=["-std=c++14"],
-                library_dirs=library_dirs,
-                libraries=libraries,
-                language="c++"))
+            ret.append(
+                Extension(
+                    "tvm._ffi.%s.%s" % (subdir, fn[:-4]),
+                    ["tvm/_ffi/_cython/%s" % fn],
+                    include_dirs=[
+                        "../include/",
+                        "../3rdparty/dmlc-core/include",
+                        "../3rdparty/dlpack/include",
+                    ],
+                    extra_compile_args=["-std=c++14"],
+                    library_dirs=library_dirs,
+                    libraries=libraries,
+                    language="c++",
+                )
+            )
         return cythonize(ret, compiler_directives={"language_level": 3})
     except ImportError:
         print("WARNING: Cython is not installed, will compile without cython module")
@@ -116,7 +121,7 @@ class BinaryDistribution(Distribution):
 
 include_libs = False
 wheel_include_libs = False
-if not os.getenv('CONDA_BUILD'):
+if not os.getenv("CONDA_BUILD"):
     if "bdist_wheel" in sys.argv:
         wheel_include_libs = True
     else:
@@ -128,56 +133,49 @@ setup_kwargs = {}
 if wheel_include_libs:
     with open("MANIFEST.in", "w") as fo:
         for path in LIB_LIST:
-            shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm'))
+            shutil.copy(path, os.path.join(CURRENT_DIR, "tvm"))
             _, libname = os.path.split(path)
             fo.write("include tvm/%s\n" % libname)
-    setup_kwargs = {
-        "include_package_data": True
-    }
+    setup_kwargs = {"include_package_data": True}
 
 if include_libs:
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     for i, path in enumerate(LIB_LIST):
         LIB_LIST[i] = os.path.relpath(path, curr_path)
-    setup_kwargs = {
-        "include_package_data": True,
-        "data_files": [('tvm', LIB_LIST)]
-    }
+    setup_kwargs = {"include_package_data": True, "data_files": [("tvm", LIB_LIST)]}
 
 
 def get_package_data_files():
     # Relay standard libraries
-    return ['relay/std/prelude.rly', 'relay/std/core.rly']
-
-
-setup(name='tvm',
-      version=__version__,
-      description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems",
-      zip_safe=False,
-      entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]},
-      install_requires=[
-        'numpy',
-        'scipy',
-        'decorator',
-        'attrs',
-        'psutil',
-        'typed_ast',
-        ],
-      extras_require={'test': ['pillow<7',
-                               'matplotlib'],
-                      'extra_feature': ['tornado',
-                                        'psutil',
-                                        'xgboost>=1.1.0',
-                                        'mypy',
-                                        'orderedset']},
-
-      packages=find_packages(),
-      package_dir={'tvm': 'tvm'},
-      package_data={'tvm': get_package_data_files()},
-      distclass=BinaryDistribution,
-      url='https://github.com/apache/incubator-tvm',
-      ext_modules=config_cython(),
-      **setup_kwargs)
+    return ["relay/std/prelude.rly", "relay/std/core.rly"]
+
+
+setup(
+    name="tvm",
+    version=__version__,
+    description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems",
+    zip_safe=False,
+    entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]},
+    install_requires=[
+        "numpy",
+        "scipy",
+        "decorator",
+        "attrs",
+        "psutil",
+        "typed_ast",
+    ],
+    extras_require={
+        "test": ["pillow<7", "matplotlib"],
+        "extra_feature": ["tornado", "psutil", "xgboost>=1.1.0", "mypy", "orderedset"],
+    },
+    packages=find_packages(),
+    package_dir={"tvm": "tvm"},
+    package_data={"tvm": get_package_data_files()},
+    distclass=BinaryDistribution,
+    url="https://github.com/apache/incubator-tvm",
+    ext_modules=config_cython(),
+    **setup_kwargs,
+)
 
 
 if wheel_include_libs:
index e10f387..d3473c6 100644 (file)
@@ -76,7 +76,7 @@ def tvm_wrap_excepthook(exception_hook):
     def wrapper(exctype, value, trbk):
         """Clean subprocesses when TVM is interrupted."""
         exception_hook(exctype, value, trbk)
-        if hasattr(multiprocessing, 'active_children'):
+        if hasattr(multiprocessing, "active_children"):
             # pylint: disable=not-callable
             for p in multiprocessing.active_children():
                 p.terminate()
index 949cc8b..1aa6964 100644 (file)
@@ -23,12 +23,12 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_ha
 
 
 TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
-_c_str_dltensor = c_str('dltensor')
-_c_str_used_dltensor = c_str('used_dltensor')
+_c_str_dltensor = c_str("dltensor")
+_c_str_used_dltensor = c_str("used_dltensor")
 
 
 # used for PyCapsule manipulation
-if hasattr(ctypes, 'pythonapi'):
+if hasattr(ctypes, "pythonapi"):
     ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
     ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
     ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
@@ -57,11 +57,13 @@ def _dlpack_deleter(pycapsule):
         _LIB.TVMDLManagedTensorCallDeleter(ptr)
         ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
 
+
 _c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
 
 
 class NDArrayBase(object):
     """A simple Device/CPU Array object in runtime."""
+
     __slots__ = ["handle", "is_view"]
     # pylint: disable=no-member
     def __init__(self, handle, is_view=False):
@@ -120,8 +122,10 @@ def _make_array(handle, is_view, is_container):
     ret.is_view = is_view
     return ret
 
+
 _TVM_COMPATS = ()
 
+
 def _reg_extension(cls, fcreate):
     global _TVM_COMPATS
     _TVM_COMPATS += (cls,)
@@ -130,14 +134,18 @@ def _reg_extension(cls, fcreate):
         RETURN_SWITCH[cls._tvm_tcode] = fret
         C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
 
+
 _TVM_ND_CLS = {}
 
+
 def _register_ndarray(index, cls):
     global _TVM_ND_CLS
     _TVM_ND_CLS[index] = cls
 
+
 _CLASS_NDARRAY = None
 
+
 def _set_class_ndarray(cls):
     global _CLASS_NDARRAY
     _CLASS_NDARRAY = cls
index 359b018..d30026a 100644 (file)
@@ -30,6 +30,7 @@ OBJECT_TYPE = {}
 
 _CLASS_OBJECT = None
 
+
 def _set_class_object(object_class):
     global _CLASS_OBJECT
     _CLASS_OBJECT = object_class
@@ -60,16 +61,20 @@ def _return_object(x):
     obj.handle = handle
     return obj
 
+
 RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object
 C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func(
-    _return_object, ArgTypeCode.OBJECT_HANDLE)
+    _return_object, ArgTypeCode.OBJECT_HANDLE
+)
 
 C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
-    _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG)
+    _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG
+)
 
 
 class PyNativeObject:
     """Base class of all TVM objects that also subclass python's builtin types."""
+
     __slots__ = []
 
     def __init_tvm_object_by_constructor__(self, fconstructor, *args):
@@ -94,9 +99,9 @@ class PyNativeObject:
         self.__tvm_object__ = obj
 
 
-
 class ObjectBase(object):
     """Base object for all object types"""
+
     __slots__ = ["handle"]
 
     def __del__(self):
index 8a2f49a..acf9776 100644 (file)
@@ -37,11 +37,13 @@ ModuleHandle = ctypes.c_void_p
 ObjectHandle = ctypes.c_void_p
 TVMRetValueHandle = ctypes.c_void_p
 
+
 def _ctypes_free_resource(rhandle):
     """callback to free resources when it it not needed."""
     pyobj = ctypes.cast(rhandle, ctypes.py_object)
     ctypes.pythonapi.Py_DecRef(pyobj)
 
+
 # Global callback that is always alive
 TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
 ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
@@ -69,6 +71,7 @@ def convert_to_tvm_func(pyfunc):
         The converted tvm function.
     """
     local_pyfunc = pyfunc
+
     def cfun(args, type_codes, num_args, ret, _):
         """ ctypes function """
         num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
@@ -101,8 +104,7 @@ def convert_to_tvm_func(pyfunc):
     # TVM_FREE_PYOBJ will be called after it is no longer needed.
     pyobj = ctypes.py_object(f)
     ctypes.pythonapi.Py_IncRef(pyobj)
-    if _LIB.TVMFuncCreateFromCFunc(
-            f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
+    if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
         raise get_last_ffi_error()
     return _make_packed_func(handle, False)
 
@@ -121,8 +123,9 @@ def _make_tvm_args(args, temp_args):
             type_codes[i] = ArgTypeCode.NULL
         elif isinstance(arg, NDArrayBase):
             values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
-            type_codes[i] = (ArgTypeCode.NDARRAY_HANDLE
-                             if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE)
+            type_codes[i] = (
+                ArgTypeCode.NDARRAY_HANDLE if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE
+            )
         elif isinstance(arg, PyNativeObject):
             values[i].v_handle = arg.__tvm_object__.handle
             type_codes[i] = ArgTypeCode.OBJECT_HANDLE
@@ -150,8 +153,8 @@ def _make_tvm_args(args, temp_args):
 
             arr = TVMByteArray()
             arr.data = ctypes.cast(
-                (ctypes.c_byte * len(arg)).from_buffer(arg),
-                ctypes.POINTER(ctypes.c_byte))
+                (ctypes.c_byte * len(arg)).from_buffer(arg), ctypes.POINTER(ctypes.c_byte)
+            )
             arr.size = len(arg)
             values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
             temp_args.append(arr)
@@ -188,6 +191,7 @@ def _make_tvm_args(args, temp_args):
 
 class PackedFuncBase(object):
     """Function base."""
+
     __slots__ = ["handle", "is_global"]
     # pylint: disable=no-member
     def __init__(self, handle, is_global):
@@ -219,9 +223,17 @@ class PackedFuncBase(object):
         values, tcodes, num_args = _make_tvm_args(args, temp_args)
         ret_val = TVMValue()
         ret_tcode = ctypes.c_int()
-        if _LIB.TVMFuncCall(
-                self.handle, values, tcodes, ctypes.c_int(num_args),
-                ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
+        if (
+            _LIB.TVMFuncCall(
+                self.handle,
+                values,
+                tcodes,
+                ctypes.c_int(num_args),
+                ctypes.byref(ret_val),
+                ctypes.byref(ret_tcode),
+            )
+            != 0
+        ):
             raise get_last_ffi_error()
         _ = temp_args
         _ = args
@@ -234,9 +246,17 @@ def __init_handle_by_constructor__(fconstructor, args):
     values, tcodes, num_args = _make_tvm_args(args, temp_args)
     ret_val = TVMValue()
     ret_tcode = ctypes.c_int()
-    if _LIB.TVMFuncCall(
-            fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
-            ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
+    if (
+        _LIB.TVMFuncCall(
+            fconstructor.handle,
+            values,
+            tcodes,
+            ctypes.c_int(num_args),
+            ctypes.byref(ret_val),
+            ctypes.byref(ret_tcode),
+        )
+        != 0
+    ):
         raise get_last_ffi_error()
     _ = temp_args
     _ = args
@@ -273,15 +293,18 @@ def _get_global_func(name, allow_missing=False):
 
     raise ValueError("Cannot find global function %s" % name)
 
+
 # setup return handle for function type
 _object.__init_by_constructor__ = __init_handle_by_constructor__
 RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
 RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module
 RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
 C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func(
-    _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE)
+    _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE
+)
 C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func(
-    _return_module, ArgTypeCode.MODULE_HANDLE)
+    _return_module, ArgTypeCode.MODULE_HANDLE
+)
 C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
 C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
 
@@ -296,10 +319,12 @@ def _set_class_module(module_class):
     global _CLASS_MODULE
     _CLASS_MODULE = module_class
 
+
 def _set_class_packed_func(packed_func_class):
     global _CLASS_PACKED_FUNC
     _CLASS_PACKED_FUNC = packed_func_class
 
+
 def _set_class_object_generic(object_generic_class, func_convert_to_object):
     global _CLASS_OBJECT_GENERIC
     global _FUNC_CONVERT_TO_OBJECT
index d4e7b36..4b6d669 100644 (file)
@@ -21,12 +21,16 @@ import struct
 from ..base import py_str, check_call, _LIB
 from ..runtime_ctypes import TVMByteArray, ArgTypeCode, TVMContext
 
+
 class TVMValue(ctypes.Union):
     """TVMValue in C API"""
-    _fields_ = [("v_int64", ctypes.c_int64),
-                ("v_float64", ctypes.c_double),
-                ("v_handle", ctypes.c_void_p),
-                ("v_str", ctypes.c_char_p)]
+
+    _fields_ = [
+        ("v_int64", ctypes.c_int64),
+        ("v_float64", ctypes.c_double),
+        ("v_handle", ctypes.c_void_p),
+        ("v_str", ctypes.c_char_p),
+    ]
 
 
 TVMPackedCFunc = ctypes.CFUNCTYPE(
@@ -35,12 +39,11 @@ TVMPackedCFunc = ctypes.CFUNCTYPE(
     ctypes.POINTER(ctypes.c_int),
     ctypes.c_int,
     ctypes.c_void_p,
-    ctypes.c_void_p)
+    ctypes.c_void_p,
+)
 
 
-TVMCFuncFinalizer = ctypes.CFUNCTYPE(
-    None,
-    ctypes.c_void_p)
+TVMCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
 
 
 def _return_handle(x):
@@ -50,6 +53,7 @@ def _return_handle(x):
         handle = ctypes.c_void_p(handle)
     return handle
 
+
 def _return_bytes(x):
     """return bytes"""
     handle = x.v_handle
@@ -60,9 +64,10 @@ def _return_bytes(x):
     res = bytearray(size)
     rptr = (ctypes.c_byte * size).from_buffer(res)
     if not ctypes.memmove(rptr, arr.data, size):
-        raise RuntimeError('memmove failed')
+        raise RuntimeError("memmove failed")
     return res
 
+
 def _return_context(value):
     """return TVMContext"""
     # use bit unpacking from int64 view
@@ -77,8 +82,10 @@ def _wrap_arg_func(return_f, type_code):
         tcode = ctypes.c_int(type_code)
         check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode)))
         return return_f(x)
+
     return _wrap_func
 
+
 def _ctx_to_int64(ctx):
     """Pack context into int64 in native endian"""
     data = struct.pack("=ii", ctx.device_type, ctx.device_id)
@@ -92,7 +99,7 @@ RETURN_SWITCH = {
     ArgTypeCode.NULL: lambda x: None,
     ArgTypeCode.STR: lambda x: py_str(x.v_str),
     ArgTypeCode.BYTES: _return_bytes,
-    ArgTypeCode.TVM_CONTEXT: _return_context
+    ArgTypeCode.TVM_CONTEXT: _return_context,
 }
 
 C_TO_PY_ARG_SWITCH = {
@@ -102,5 +109,5 @@ C_TO_PY_ARG_SWITCH = {
     ArgTypeCode.NULL: lambda x: None,
     ArgTypeCode.STR: lambda x: py_str(x.v_str),
     ArgTypeCode.BYTES: _return_bytes,
-    ArgTypeCode.TVM_CONTEXT: _return_context
+    ArgTypeCode.TVM_CONTEXT: _return_context,
 }
index 9579acf..b661cfd 100644 (file)
@@ -18,9 +18,9 @@
 """
 import sys
 
-#----------------------------
+# ----------------------------
 # Python3 version.
-#----------------------------
+# ----------------------------
 if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 6):
     PY3STATEMENT = "The minimal Python requirement is Python 3.6"
     raise Exception(PY3STATEMENT)
index 2cca014..df220ae 100644 (file)
@@ -23,9 +23,9 @@ import ctypes
 import numpy as np
 from . import libinfo
 
-#----------------------------
+# ----------------------------
 # library loading
-#----------------------------
+# ----------------------------
 string_types = (str,)
 integer_types = (int, np.int32)
 numeric_types = integer_types + (float, np.float32)
@@ -33,15 +33,17 @@ numeric_types = integer_types + (float, np.float32)
 # this function is needed for python3
 # to convert ctypes.char_p .value back to python str
 if sys.platform == "win32":
+
     def _py_str(x):
         try:
-            return x.decode('utf-8')
+            return x.decode("utf-8")
         except UnicodeDecodeError:
-            encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
+            encoding = "cp" + str(ctypes.cdll.kernel32.GetACP())
         return x.decode(encoding)
+
     py_str = _py_str
 else:
-    py_str = lambda x: x.decode('utf-8')
+    py_str = lambda x: x.decode("utf-8")
 
 
 def _load_lib():
@@ -51,6 +53,7 @@ def _load_lib():
     lib.TVMGetLastError.restype = ctypes.c_char_p
     return lib, os.path.basename(lib_path[0])
 
+
 try:
     import readline  # pylint: disable=unused-import
 except ImportError:
@@ -67,9 +70,9 @@ _RUNTIME_ONLY = "runtime" in _LIB_NAME
 # The FFI mode of TVM
 _FFI_MODE = os.environ.get("TVM_FFI", "auto")
 
-#----------------------------
+# ----------------------------
 # helper function in ctypes.
-#----------------------------
+# ----------------------------
 def c_str(string):
     """Create ctypes char * from a python string
     Parameters
@@ -82,7 +85,7 @@ def c_str(string):
     str : c_char_p
         A char pointer that can be passed to C API
     """
-    return ctypes.c_char_p(string.encode('utf-8'))
+    return ctypes.c_char_p(string.encode("utf-8"))
 
 
 def c_array(ctype, values):
@@ -116,12 +119,13 @@ def decorate(func, fwrapped):
         The wrapped function
     """
     import decorator
+
     return decorator.decorate(func, fwrapped)
 
 
-#-----------------------------------------
+# -----------------------------------------
 # Base code for structured error handling.
-#-----------------------------------------
+# -----------------------------------------
 # Maps error type to its constructor
 ERROR_TYPE = {}
 
@@ -169,6 +173,7 @@ def register_error(func_name=None, cls=None):
         err_name = func_name if isinstance(func_name, str) else mycls.__name__
         ERROR_TYPE[err_name] = mycls
         return mycls
+
     if cls is None:
         return register
     return register(cls)
index a1483a1..b9fc8dc 100644 (file)
@@ -18,6 +18,7 @@
 import sys
 import os
 
+
 def split_env_var(env_var, split):
     """Splits environment variable string.
 
@@ -66,17 +67,17 @@ def find_lib_path(name=None, search_path=None, optional=False):
 
     dll_path = []
 
-    if os.environ.get('TVM_LIBRARY_PATH', None):
-        dll_path.append(os.environ['TVM_LIBRARY_PATH'])
+    if os.environ.get("TVM_LIBRARY_PATH", None):
+        dll_path.append(os.environ["TVM_LIBRARY_PATH"])
 
-    if sys.platform.startswith('linux'):
-        dll_path.extend(split_env_var('LD_LIBRARY_PATH', ':'))
-        dll_path.extend(split_env_var('PATH', ':'))
-    elif sys.platform.startswith('darwin'):
-        dll_path.extend(split_env_var('DYLD_LIBRARY_PATH', ':'))
-        dll_path.extend(split_env_var('PATH', ':'))
-    elif sys.platform.startswith('win32'):
-        dll_path.extend(split_env_var('PATH', ';'))
+    if sys.platform.startswith("linux"):
+        dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":"))
+        dll_path.extend(split_env_var("PATH", ":"))
+    elif sys.platform.startswith("darwin"):
+        dll_path.extend(split_env_var("DYLD_LIBRARY_PATH", ":"))
+        dll_path.extend(split_env_var("PATH", ":"))
+    elif sys.platform.startswith("win32"):
+        dll_path.extend(split_env_var("PATH", ";"))
 
     # Pip lib directory
     dll_path.append(os.path.join(ffi_dir, ".."))
@@ -107,17 +108,19 @@ def find_lib_path(name=None, search_path=None, optional=False):
             lib_dll_path = [os.path.join(p, name) for p in dll_path]
         runtime_dll_path = []
     else:
-        if sys.platform.startswith('win32'):
-            lib_dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path] +\
-                           [os.path.join(p, 'tvm.dll') for p in dll_path]
-            runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dll') for p in dll_path] +\
-                               [os.path.join(p, 'tvm_runtime.dll') for p in dll_path]
-        elif sys.platform.startswith('darwin'):
-            lib_dll_path = [os.path.join(p, 'libtvm.dylib') for p in dll_path]
-            runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dylib') for p in dll_path]
+        if sys.platform.startswith("win32"):
+            lib_dll_path = [os.path.join(p, "libtvm.dll") for p in dll_path] + [
+                os.path.join(p, "tvm.dll") for p in dll_path
+            ]
+            runtime_dll_path = [os.path.join(p, "libtvm_runtime.dll") for p in dll_path] + [
+                os.path.join(p, "tvm_runtime.dll") for p in dll_path
+            ]
+        elif sys.platform.startswith("darwin"):
+            lib_dll_path = [os.path.join(p, "libtvm.dylib") for p in dll_path]
+            runtime_dll_path = [os.path.join(p, "libtvm_runtime.dylib") for p in dll_path]
         else:
-            lib_dll_path = [os.path.join(p, 'libtvm.so') for p in dll_path]
-            runtime_dll_path = [os.path.join(p, 'libtvm_runtime.so') for p in dll_path]
+            lib_dll_path = [os.path.join(p, "libtvm.so") for p in dll_path]
+            runtime_dll_path = [os.path.join(p, "libtvm_runtime.so") for p in dll_path]
 
     if not use_runtime:
         # try to find lib_dll_path
@@ -129,9 +132,11 @@ def find_lib_path(name=None, search_path=None, optional=False):
         lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
 
     if not lib_found:
-        message = ('Cannot find the files.\n' +
-                   'List of candidates:\n' +
-                   str('\n'.join(lib_dll_path + runtime_dll_path)))
+        message = (
+            "Cannot find the files.\n"
+            + "List of candidates:\n"
+            + str("\n".join(lib_dll_path + runtime_dll_path))
+        )
         if not optional:
             raise RuntimeError(message)
         return None
@@ -163,8 +168,8 @@ def find_include_path(name=None, search_path=None, optional=False):
 
     header_path = []
 
-    if os.environ.get('TVM_INCLUDE_PATH', None):
-        header_path.append(os.environ['TVM_INCLUDE_PATH'])
+    if os.environ.get("TVM_INCLUDE_PATH", None):
+        header_path.append(os.environ["TVM_INCLUDE_PATH"])
 
     header_path.append(install_include_dir)
     header_path.append(source_dir)
@@ -186,23 +191,21 @@ def find_include_path(name=None, search_path=None, optional=False):
         dlpack_include_path = []
         dmlc_include_path = []
     else:
-        tvm_include_path = [os.path.join(p, 'include') for p in header_path]
-        dlpack_include_path = [os.path.join(p, 'dlpack/include') for p in
-                               header_path]
-        dmlc_include_path = [os.path.join(p, 'dmlc-core/include') for p in
-                             header_path]
+        tvm_include_path = [os.path.join(p, "include") for p in header_path]
+        dlpack_include_path = [os.path.join(p, "dlpack/include") for p in header_path]
+        dmlc_include_path = [os.path.join(p, "dmlc-core/include") for p in header_path]
 
         # try to find include path
         include_found = [p for p in tvm_include_path if os.path.exists(p) and os.path.isdir(p)]
-        include_found += [p for p in dlpack_include_path if os.path.exists(p)
-                          and os.path.isdir(p)]
-        include_found += [p for p in dmlc_include_path if os.path.exists(p)
-                          and os.path.isdir(p)]
+        include_found += [p for p in dlpack_include_path if os.path.exists(p) and os.path.isdir(p)]
+        include_found += [p for p in dmlc_include_path if os.path.exists(p) and os.path.isdir(p)]
 
     if not include_found:
-        message = ('Cannot find the files.\n' +
-                   'List of candidates:\n' +
-                   str('\n'.join(tvm_include_path + dlpack_include_path)))
+        message = (
+            "Cannot find the files.\n"
+            + "List of candidates:\n"
+            + str("\n".join(tvm_include_path + dlpack_include_path))
+        )
         if not optional:
             raise RuntimeError(message)
         return None
index 0942ccb..b42dada 100644 (file)
@@ -64,12 +64,10 @@ def register_object(type_key=None):
         else:
             tidx = ctypes.c_uint()
             if not _RUNTIME_ONLY:
-                check_call(_LIB.TVMObjectTypeKey2Index(
-                    c_str(object_name), ctypes.byref(tidx)))
+                check_call(_LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx)))
             else:
                 # directly skip unknown objects during runtime.
-                ret = _LIB.TVMObjectTypeKey2Index(
-                    c_str(object_name), ctypes.byref(tidx))
+                ret = _LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx))
                 if ret != 0:
                     return cls
             tindex = tidx.value
@@ -185,13 +183,14 @@ def register_func(func_name, f=None, override=False):
         raise ValueError("expect string function name")
 
     ioverride = ctypes.c_int(override)
+
     def register(myf):
         """internal register function"""
         if not isinstance(myf, PackedFuncBase):
             myf = convert_to_tvm_func(myf)
-        check_call(_LIB.TVMFuncRegisterGlobal(
-            c_str(func_name), myf.handle, ioverride))
+        check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride))
         return myf
+
     if f:
         return register(f)
     return register
@@ -227,8 +226,7 @@ def list_global_func_names():
     plist = ctypes.POINTER(ctypes.c_char_p)()
     size = ctypes.c_uint()
 
-    check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
-                                           ctypes.byref(plist)))
+    check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist)))
     fnames = []
     for i in range(size.value):
         fnames.append(py_str(plist[i]))
@@ -250,8 +248,10 @@ def extract_ext_funcs(finit):
         The extracted functions
     """
     fdict = {}
+
     def _list(name, func):
         fdict[name] = func
+
     myf = convert_to_tvm_func(_list)
     ret = finit(myf.handle)
     _ = myf
@@ -275,8 +275,7 @@ def _init_api(namespace, target_module_name=None):
     target_module_name : str
        The target module name if different from namespace
     """
-    target_module_name = (
-        target_module_name if target_module_name else namespace)
+    target_module_name = target_module_name if target_module_name else namespace
     if namespace.startswith("tvm."):
         _init_api_prefix(target_module_name, namespace[4:])
     else:
@@ -290,7 +289,7 @@ def _init_api_prefix(module_name, prefix):
         if not name.startswith(prefix):
             continue
 
-        fname = name[len(prefix)+1:]
+        fname = name[len(prefix) + 1 :]
         target_module = module
 
         if fname.find(".") != -1:
@@ -298,5 +297,5 @@ def _init_api_prefix(module_name, prefix):
         f = get_global_func(name)
         ff = _get_api(f)
         ff.__name__ = fname
-        ff.__doc__ = ("TVM PackedFunc %s. " % fname)
+        ff.__doc__ = "TVM PackedFunc %s. " % fname
         setattr(target_module, ff.__name__, ff)
index dcc9528..58be070 100644 (file)
@@ -23,8 +23,10 @@ from .base import _LIB, check_call
 
 tvm_shape_index_t = ctypes.c_int64
 
+
 class ArgTypeCode(object):
     """Type code used in API calls"""
+
     INT = 0
     UINT = 1
     FLOAT = 2
@@ -42,14 +44,16 @@ class ArgTypeCode(object):
     OBJECT_RVALUE_REF_ARG = 14
     EXT_BEGIN = 15
 
+
 class TVMByteArray(ctypes.Structure):
     """Temp data structure for byte array."""
-    _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
-                ("size", ctypes.c_size_t)]
+
+    _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)]
 
 
 class DataTypeCode(object):
     """DataType code in DLTensor."""
+
     INT = 0
     UINT = 1
     FLOAT = 2
@@ -59,16 +63,16 @@ class DataTypeCode(object):
 
 class DataType(ctypes.Structure):
     """TVM datatype structure"""
-    _fields_ = [("type_code", ctypes.c_uint8),
-                ("bits", ctypes.c_uint8),
-                ("lanes", ctypes.c_uint16)]
+
+    _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)]
     CODE2STR = {
-        DataTypeCode.INT : 'int',
-        DataTypeCode.UINT : 'uint',
-        DataTypeCode.FLOAT : 'float',
-        DataTypeCode.HANDLE : 'handle',
-        DataTypeCode.BFLOAT : 'bfloat'
+        DataTypeCode.INT: "int",
+        DataTypeCode.UINT: "uint",
+        DataTypeCode.FLOAT: "float",
+        DataTypeCode.HANDLE: "handle",
+        DataTypeCode.BFLOAT: "bfloat",
     }
+
     def __init__(self, type_str):
         super(DataType, self).__init__()
         if isinstance(type_str, np.dtype):
@@ -104,18 +108,18 @@ class DataType(ctypes.Structure):
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
-            low, high = head.find('['), head.find(']')
+
+            low, high = head.find("["), head.find("]")
             if not low or not high or low >= high:
                 raise ValueError("Badly formatted custom type string %s" % type_str)
-            type_name = head[low + 1:high]
+            type_name = head[low + 1 : high]
             self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name)
-            head = head[high+1:]
+            head = head[high + 1 :]
         else:
             raise ValueError("Do not know how to handle type %s" % type_str)
         bits = int(head) if head else bits
         self.bits = bits
 
-
     def __repr__(self):
         # pylint: disable=import-outside-toplevel
         if self.bits == 1 and self.lanes == 1:
@@ -124,64 +128,69 @@ class DataType(ctypes.Structure):
             type_name = DataType.CODE2STR[self.type_code]
         else:
             import tvm.runtime._ffi_api
-            type_name = "custom[%s]" % \
-                        tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
+
+            type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
         x = "%s%d" % (type_name, self.bits)
         if self.lanes != 1:
             x += "x%d" % self.lanes
         return x
 
     def __eq__(self, other):
-        return (self.bits == other.bits and
-                self.type_code == other.type_code and
-                self.lanes == other.lanes)
+        return (
+            self.bits == other.bits
+            and self.type_code == other.type_code
+            and self.lanes == other.lanes
+        )
 
     def __ne__(self, other):
         return not self.__eq__(other)
 
+
 RPC_SESS_MASK = 128
 
+
 class TVMContext(ctypes.Structure):
     """TVM context strucure."""
-    _fields_ = [("device_type", ctypes.c_int),
-                ("device_id", ctypes.c_int)]
+
+    _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
     MASK2STR = {
-        1 : 'cpu',
-        2 : 'gpu',
-        4 : 'opencl',
-        5 : 'aocl',
-        6 : 'sdaccel',
-        7 : 'vulkan',
-        8 : 'metal',
-        9 : 'vpi',
-        10: 'rocm',
-        12: 'ext_dev',
-        13: 'micro_dev',
-        14: 'hexagon',
-        15: 'webgpu'
+        1: "cpu",
+        2: "gpu",
+        4: "opencl",
+        5: "aocl",
+        6: "sdaccel",
+        7: "vulkan",
+        8: "metal",
+        9: "vpi",
+        10: "rocm",
+        12: "ext_dev",
+        13: "micro_dev",
+        14: "hexagon",
+        15: "webgpu",
     }
     STR2MASK = {
-        'llvm': 1,
-        'stackvm': 1,
-        'cpu': 1,
-        'c': 1,
-        'gpu': 2,
-        'cuda': 2,
-        'nvptx': 2,
-        'cl': 4,
-        'opencl': 4,
-        'aocl' : 5,
-        'aocl_sw_emu' : 5,
-        'sdaccel': 6,
-        'vulkan': 7,
-        'metal': 8,
-        'vpi': 9,
-        'rocm': 10,
-        'ext_dev': 12,
-        'micro_dev': 13,
-        'hexagon': 14,
-        'webgpu': 15,
+        "llvm": 1,
+        "stackvm": 1,
+        "cpu": 1,
+        "c": 1,
+        "gpu": 2,
+        "cuda": 2,
+        "nvptx": 2,
+        "cl": 4,
+        "opencl": 4,
+        "aocl": 5,
+        "aocl_sw_emu": 5,
+        "sdaccel": 6,
+        "vulkan": 7,
+        "metal": 8,
+        "vpi": 9,
+        "rocm": 10,
+        "ext_dev": 12,
+        "micro_dev": 13,
+        "hexagon": 14,
+        "webgpu": 15,
     }
+
     def __init__(self, device_type, device_id):
         super(TVMContext, self).__init__()
         self.device_type = device_type
@@ -191,32 +200,28 @@ class TVMContext(ctypes.Structure):
         """Internal helper function to invoke runtime.GetDeviceAttr"""
         # pylint: disable=import-outside-toplevel
         import tvm.runtime._ffi_api
-        return tvm.runtime._ffi_api.GetDeviceAttr(
-            device_type, device_id, attr_id)
+
+        return tvm.runtime._ffi_api.GetDeviceAttr(device_type, device_id, attr_id)
 
     @property
     def exist(self):
         """Whether this device exist."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 0) != 0
+        return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0
 
     @property
     def max_threads_per_block(self):
         """Maximum number of threads on each block."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 1)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 1)
 
     @property
     def warp_size(self):
         """Number of threads that executes in concurrent."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 2)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 2)
 
     @property
     def max_shared_memory_per_block(self):
         """Total amount of shared memory per block in bytes."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 3)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 3)
 
     @property
     def compute_version(self):
@@ -229,26 +234,22 @@ class TVMContext(ctypes.Structure):
         version : str
             The version string in `major.minor` format.
         """
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 4)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 4)
 
     @property
     def device_name(self):
         """Return the string name of device."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 5)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 5)
 
     @property
     def max_clock_rate(self):
         """Return the max clock frequency of device."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 6)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 6)
 
     @property
     def multi_processor_count(self):
         """Return the number of compute units of device."""
-        return self._GetDeviceAttr(
-            self.device_type, self.device_id, 7)
+        return self._GetDeviceAttr(self.device_type, self.device_id, 7)
 
     @property
     def max_thread_dimensions(self):
@@ -259,17 +260,18 @@ class TVMContext(ctypes.Structure):
         dims: List of int
             The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
         """
-        return json.loads(self._GetDeviceAttr(
-            self.device_type, self.device_id, 8))
+        return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8))
 
     def sync(self):
         """Synchronize until jobs finished at the context."""
         check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
 
     def __eq__(self, other):
-        return (isinstance(other, TVMContext) and
-                self.device_id == other.device_id and
-                self.device_type == other.device_type)
+        return (
+            isinstance(other, TVMContext)
+            and self.device_id == other.device_id
+            and self.device_type == other.device_type
+        )
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -281,21 +283,22 @@ class TVMContext(ctypes.Structure):
         if self.device_type >= RPC_SESS_MASK:
             tbl_id = self.device_type / RPC_SESS_MASK - 1
             dev_type = self.device_type % RPC_SESS_MASK
-            return "remote[%d]:%s(%d)" % (
-                tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
-        return "%s(%d)" % (
-            TVMContext.MASK2STR[self.device_type], self.device_id)
+            return "remote[%d]:%s(%d)" % (tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
+        return "%s(%d)" % (TVMContext.MASK2STR[self.device_type], self.device_id)
 
 
 class TVMArray(ctypes.Structure):
     """TVMValue in C API"""
-    _fields_ = [("data", ctypes.c_void_p),
-                ("ctx", TVMContext),
-                ("ndim", ctypes.c_int),
-                ("dtype", DataType),
-                ("shape", ctypes.POINTER(tvm_shape_index_t)),
-                ("strides", ctypes.POINTER(tvm_shape_index_t)),
-                ("byte_offset", ctypes.c_uint64)]
+
+    _fields_ = [
+        ("data", ctypes.c_void_p),
+        ("ctx", TVMContext),
+        ("ndim", ctypes.c_int),
+        ("dtype", DataType),
+        ("shape", ctypes.POINTER(tvm_shape_index_t)),
+        ("strides", ctypes.POINTER(tvm_shape_index_t)),
+        ("byte_offset", ctypes.c_uint64),
+    ]
 
 
 class ObjectRValueRef:
@@ -306,7 +309,9 @@ class ObjectRValueRef:
     obj : tvm.runtime.Object
         The object that this value refers to
     """
+
     __slots__ = ["obj"]
+
     def __init__(self, obj):
         self.obj = obj
 
index e841de9..c3b32b5 100644 (file)
@@ -23,9 +23,9 @@ from . import _ffi_api
 @tvm._ffi.register_object("arith.ModularSet")
 class ModularSet(Object):
     """Represent range of (coeff * x + base) for x in Z """
+
     def __init__(self, coeff, base):
-        self.__init_handle_by_constructor__(
-            _ffi_api.ModularSet, coeff, base)
+        self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base)
 
 
 @tvm._ffi.register_object("arith.ConstIntBound")
@@ -40,12 +40,12 @@ class ConstIntBound(Object):
     max_value : int
         The maximum value of the bound.
     """
+
     POS_INF = (1 << 63) - 1
     NEG_INF = -POS_INF
 
     def __init__(self, min_value, max_value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.ConstIntBound, min_value, max_value)
+        self.__init_handle_by_constructor__(_ffi_api.ConstIntBound, min_value, max_value)
 
 
 class ConstraintScope:
@@ -60,6 +60,7 @@ class ConstraintScope:
     ----
     Do not create object directly, use Analyzer.constraint_scope
     """
+
     def __init__(self, fenter):
         self._fenter = fenter
         self._fexit = None
@@ -77,6 +78,7 @@ class Analyzer:
     This is a stateful analyzer class that can
     be used to perform various symbolic integer analysis.
     """
+
     def __init__(self):
         _mod = _ffi_api.CreateAnalyzer()
         self._const_int_bound = _mod("const_int_bound")
@@ -225,8 +227,10 @@ class Analyzer:
           # constraint no longer in effect
           assert analyzer.modular_set(x).coeff != 3
         """
+
         def _fenter():
             return self._enter_constraint_context(constraint)
+
         return ConstraintScope(_fenter)
 
     def update(self, var, info, override=False):
@@ -246,5 +250,4 @@ class Analyzer:
         if isinstance(info, ConstIntBound):
             self._const_int_bound_update(var, info, override)
         else:
-            raise TypeError(
-                "Do not know how to handle type {}".format(type(info)))
+            raise TypeError("Do not know how to handle type {}".format(type(info)))
index 838e8e5..255dbfd 100644 (file)
@@ -22,6 +22,7 @@ from . import _ffi_api
 
 class IntSet(Object):
     """Represent a set of integer in one dimension."""
+
     def is_nothing(self):
         """Whether the set represent nothing"""
         return _ffi_api.IntSetIsNothing(self)
@@ -75,6 +76,6 @@ class IntervalSet(IntSet):
     max_value : PrimExpr
         The maximum value in the interval.
     """
+
     def __init__(self, min_value, max_value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.IntervalSet, min_value, max_value)
+        self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)
index 91fa459..6e8a010 100644 (file)
@@ -39,9 +39,9 @@ class IntGroupBounds(Object):
     upper : List[tvm.ir.PrimExpr]
         the upper bounds (include)
     """
+
     def __init__(self, coef, lower, equal, upper):
-        self.__init_handle_by_constructor__(
-            _ffi_api.IntGroupBounds, coef, lower, equal, upper)
+        self.__init_handle_by_constructor__(_ffi_api.IntGroupBounds, coef, lower, equal, upper)
 
     @staticmethod
     def from_range(rng):
@@ -61,7 +61,7 @@ class IntGroupBounds(Object):
 
     def find_best_range(self):
         """Return the best range from the grouped bounds.
-           None if (-inf, +inf).
+        None if (-inf, +inf).
         """
         return _ffi_api.IntGroupBounds_FindBestRange(self)
 
@@ -80,9 +80,9 @@ class IntConstraints(Object):
     relations : List[tvm.ir.PrimExpr]
         The relations between the variables (either equations or inequalities)
     """
+
     def __init__(self, variables, ranges, relations):
-        self.__init_handle_by_constructor__(
-            _ffi_api.IntConstraints, variables, ranges, relations)
+        self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations)
 
 
 @tvm._ffi.register_object("arith.IntConstraintsTransform")
@@ -112,9 +112,11 @@ class IntConstraintsTransform(Object):
         mapping from variables in the dst to the variables in the src,
         e.g., {m -> a, n -> -b}
     """
+
     def __init__(self, src, dst, src_to_dst, dst_to_src):
         self.__init_handle_by_constructor__(
-            _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)
+            _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src
+        )
 
 
 def solve_linear_equations(equations, variables=None, ranges=None):
@@ -168,8 +170,9 @@ def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_ran
         If deskew_range is set (=True), the result ranges will be deskewed to be started from zero.
         New variables are created accordingly therefore IntConstraintsTransform is returned.
     """
-    solver = _ffi_api.SolveInequalitiesDeskewRange \
-        if deskew_range else _ffi_api.SolveInequalitiesToRange
+    solver = (
+        _ffi_api.SolveInequalitiesDeskewRange if deskew_range else _ffi_api.SolveInequalitiesToRange
+    )
     if isinstance(equations, IntConstraints):
         assert variables is None
         assert ranges is None
index 2281088..53f8eb6 100644 (file)
@@ -41,7 +41,7 @@ def detect_linear_equation(expr, var_list):
 
 
 def detect_clip_bound(expr, var_list):
-    """ Detect if expression corresponds to clip bound of the vars
+    """Detect if expression corresponds to clip bound of the vars
 
     Parameters
     ----------
index 9ad526c..43e08a4 100644 (file)
@@ -26,13 +26,17 @@ from . import workload_registry
 from . import feature
 
 # Shortcut
-from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
-    auto_schedule
+from .auto_schedule import SearchTask, TuningOptions, HardwareParams, auto_schedule
 from .compute_dag import ComputeDAG
 from .cost_model import RandomModel, XGBModel
-from .measure import MeasureInput, MeasureResult, LocalBuilder, LocalRunner, RPCRunner, \
-    LocalRPCMeasureContext
-from .measure_record import RecordToFile, RecordReader, load_best, \
-    load_records, save_records
+from .measure import (
+    MeasureInput,
+    MeasureResult,
+    LocalBuilder,
+    LocalRunner,
+    RPCRunner,
+    LocalRPCMeasureContext,
+)
+from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
 from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
 from .workload_registry import register_workload, make_workload_key
index 2942025..af257f5 100644 (file)
@@ -37,7 +37,7 @@ from . import _ffi_api
 
 @tvm._ffi.register_object("auto_scheduler.HardwareParams")
 class HardwareParams(Object):
-    """ The parameters of target hardware used to guide the search policy
+    """The parameters of target hardware used to guide the search policy
 
     TODO(jcf94): This is considered to be merged with the new Target specification:
     https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
@@ -51,14 +51,16 @@ class HardwareParams(Object):
     cache_line_bytes : int
         The size of cache line in bytes.
     """
+
     def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes):
-        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
-                                            vector_unit_bytes, cache_line_bytes)
+        self.__init_handle_by_constructor__(
+            _ffi_api.HardwareParams, num_cores, vector_unit_bytes, cache_line_bytes
+        )
 
 
 @tvm._ffi.register_object("auto_scheduler.SearchTask")
 class SearchTask(Object):
-    """ The computation information and hardware parameters for a schedule search task.
+    """The computation information and hardware parameters for a schedule search task.
 
     Parameters
     ----------
@@ -73,16 +75,16 @@ class SearchTask(Object):
     hardware_params : Optional[HardwareParams]
         Hardware parameters used in this search task.
     """
-    def __init__(self, dag, workload_key, target, target_host=None,
-                 hardware_params=None):
-        self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
-                                            workload_key, target, target_host,
-                                            hardware_params)
+
+    def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
+        self.__init_handle_by_constructor__(
+            _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
+        )
 
 
 @tvm._ffi.register_object("auto_scheduler.TuningOptions")
 class TuningOptions(Object):
-    """ This controls the options of performance tuning.
+    """This controls the options of performance tuning.
 
     Parameters
     ----------
@@ -109,33 +111,53 @@ class TuningOptions(Object):
       Candidates:
         - auto_scheduler.RecordToFile
     """
-    def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64,
-                 verbose=1, builder='local', runner='local', measure_callbacks=None):
+
+    def __init__(
+        self,
+        num_measure_trials=0,
+        early_stopping=None,
+        num_measures_per_round=64,
+        verbose=1,
+        builder="local",
+        runner="local",
+        measure_callbacks=None,
+    ):
         if isinstance(builder, str):
-            if builder == 'local':
+            if builder == "local":
                 builder = LocalBuilder()
             else:
                 raise ValueError("Invalid builder: " + builder)
         elif not isinstance(builder, tvm.auto_scheduler.measure.ProgramBuilder):
-            raise ValueError("Invalid builder: " + builder +
-                             " . TuningOptions expects a ProgramBuilder or string.")
+            raise ValueError(
+                "Invalid builder: "
+                + builder
+                + " . TuningOptions expects a ProgramBuilder or string."
+            )
 
         if isinstance(runner, str):
-            if runner == 'local':
+            if runner == "local":
                 runner = LocalRunner()
             else:
                 raise ValueError("Invalid runner: " + runner)
         elif not isinstance(runner, tvm.auto_scheduler.measure.ProgramRunner):
-            raise ValueError("Invalid runner: " + runner +
-                             " . TuningOptions expects a ProgramRunner or string.")
+            raise ValueError(
+                "Invalid runner: " + runner + " . TuningOptions expects a ProgramRunner or string."
+            )
 
         self.__init_handle_by_constructor__(
-            _ffi_api.TuningOptions, num_measure_trials, early_stopping or -1,
-            num_measures_per_round, verbose, builder, runner, measure_callbacks)
+            _ffi_api.TuningOptions,
+            num_measure_trials,
+            early_stopping or -1,
+            num_measures_per_round,
+            verbose,
+            builder,
+            runner,
+            measure_callbacks,
+        )
 
 
 def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
-    """ Do auto scheduling for a computation declaration.
+    """Do auto scheduling for a computation declaration.
 
     Parameters
     ----------
@@ -152,8 +174,9 @@ def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
         A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
     """
     if not isinstance(task, SearchTask):
-        raise ValueError("Invalid task: " + task +
-                         " . `auto_scheduler.auto_schedule` expects a SearchTask.")
+        raise ValueError(
+            "Invalid task: " + task + " . `auto_scheduler.auto_schedule` expects a SearchTask."
+        )
 
     sch, tensors = _ffi_api.AutoSchedule(search_policy or EmptyPolicy(task), tuning_options)
     return sch, tensors
index 99ce1e7..68883a0 100755 (executable)
@@ -50,6 +50,7 @@ class ComputeDAG(Object):
     compute : Union[List[Tensor], str]
         `Tensor`s or workload key for a compute declaration.
     """
+
     def __init__(self, compute):
         if isinstance(compute, str):
             compute = workload_key_to_tensors(compute)
@@ -58,12 +59,13 @@ class ComputeDAG(Object):
                 if not isinstance(item, tvm.te.Tensor):
                     raise ValueError("The input of ComputeDAG should be a list of Tensor")
         else:
-            raise ValueError("Invalid compute: " + compute +
-                             " . ComputeDAG expects a string or list of Tensor")
+            raise ValueError(
+                "Invalid compute: " + compute + " . ComputeDAG expects a string or list of Tensor"
+            )
         self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
 
     def get_init_state(self):
-        """ Get the init state of this ComputeDAG.
+        """Get the init state of this ComputeDAG.
 
         Returns
         -------
@@ -145,19 +147,19 @@ class ComputeDAG(Object):
     def __hash__(self):
         # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
         # of ComputeDAG
-        str_key = ''
+        str_key = ""
         for op in self.ops:
             t = op.output(0)
             if isinstance(op, PlaceholderOp):
-                str_key += 'placeholder,'
-                str_key += str(get_const_tuple(t.shape)) + ','
-                str_key += t.dtype + ';'
+                str_key += "placeholder,"
+                str_key += str(get_const_tuple(t.shape)) + ","
+                str_key += t.dtype + ";"
             elif isinstance(op, ComputeOp):
-                str_key += str(t.op.body) + ','
-                str_key += str(get_const_tuple(t.shape)) + ','
-                str_key += t.dtype + ';'
+                str_key += str(t.op.body) + ","
+                str_key += str(get_const_tuple(t.shape)) + ","
+                str_key += t.dtype + ";"
             else:
                 raise ValueError("Invalid op: " + op)
 
-        str_key = str_key.encode(encoding='utf-8')
+        str_key = str_key.encode(encoding="utf-8")
         return hashlib.md5(str_key).hexdigest()
index 80e963f..17370d6 100644 (file)
@@ -28,9 +28,11 @@ from .. import _ffi_api
 class CostModel(Object):
     """The base class for cost model"""
 
+
 @tvm._ffi.register_object("auto_scheduler.RandomModel")
 class RandomModel(CostModel):
     """A model returns random estimation for all inputs"""
+
     def __init__(self):
         self.__init_handle_by_constructor__(_ffi_api.RandomModel)
 
@@ -85,6 +87,7 @@ def random_fill_float(size, return_ptr):
 @tvm._ffi.register_object("auto_scheduler.PythonBasedModel")
 class PythonBasedModel(CostModel):
     """Base class for cost models implemented in python"""
+
     def __init__(self):
         def update_func(inputs, results):
             self.update(inputs, results)
@@ -100,8 +103,9 @@ class PythonBasedModel(CostModel):
             array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape)
             array_wrapper[:] = ret
 
-        self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func,
-                                            predict_func, predict_stage_func)
+        self.__init_handle_by_constructor__(
+            _ffi_api.PythonBasedModel, update_func, predict_func, predict_stage_func
+        )
 
     def update(self, inputs, results):
         """Update the cost model according to new measurement results (training data).
index 6fd8d17..8704f2a 100644 (file)
@@ -32,10 +32,12 @@ from .cost_model import PythonBasedModel
 from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states
 from ..measure_record import RecordReader
 
-logger = logging.getLogger('auto_scheduler')
+logger = logging.getLogger("auto_scheduler")
+
 
 class XGBDMatrixContext:
     """A global context to hold additional attributes of xgb.DMatrix"""
+
     def __init__(self):
         self.context_dict = defaultdict(dict)
 
@@ -69,6 +71,7 @@ class XGBDMatrixContext:
         """
         self.context_dict[key][matrix.handle.value] = value
 
+
 dmatrix_context = XGBDMatrixContext()
 
 
@@ -91,19 +94,19 @@ class XGBModel(PythonBasedModel):
     It is called "pack-sum" because we combine several samples into a "pack" and sum up
     their predictions.
     """
+
     def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
         self.xgb_params = {
-            'max_depth': 10,
-            'gamma': 0.001,
-            'min_child_weight': 0,
-            'eta': 0.2,
+            "max_depth": 10,
+            "gamma": 0.001,
+            "min_child_weight": 0,
+            "eta": 0.2,
             # todo(merrymercy): automatically decrease learning rate when the loss is too large
-
-            'n_gpus': 0,
-            'nthread': multiprocessing.cpu_count() // 2,
-            'verbosity': 0,
-            'seed': seed or 43,
-            'disable_default_eval_metric': 1
+            "n_gpus": 0,
+            "nthread": multiprocessing.cpu_count() // 2,
+            "verbosity": 0,
+            "seed": seed or 43,
+            "disable_default_eval_metric": 1,
         }
         self.bst = None
         self.plan_size = 32
@@ -137,30 +140,38 @@ class XGBModel(PythonBasedModel):
 
         # extract feature
         n_cached = len(self.inputs_feature_cache)
-        features, normalized_throughputs, task_ids = \
-            get_per_store_features_from_measure_pairs(self.inputs, self.results,
-                                                      skip_first_n_feature_extraction=n_cached)
+        features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs(
+            self.inputs, self.results, skip_first_n_feature_extraction=n_cached
+        )
         if n_cached > 0:
             features = list(features)
             features[:n_cached] = self.inputs_feature_cache
             features = np.array(features, dtype=object)
         self.inputs_feature_cache = features
-        dtrain = pack_sum_xgbmatrix(features, normalized_throughputs,
-                                    task_ids, normalized_throughputs)
+        dtrain = pack_sum_xgbmatrix(
+            features, normalized_throughputs, task_ids, normalized_throughputs
+        )
 
         # train xgb model
-        self.bst = xgb.train(self.xgb_params, dtrain,
-                             num_boost_round=10000,
-                             obj=pack_sum_square_error,
-                             callbacks=[custom_callback(
-                                 stopping_rounds=50,
-                                 metric='tr-p-rmse',
-                                 fevals=[
-                                     pack_sum_rmse, pack_sum_average_peak_score(self.plan_size),
-                                 ],
-                                 evals=[(dtrain, 'tr')],
-                                 maximize=False,
-                                 verbose_eval=self.verbose_eval)])
+        self.bst = xgb.train(
+            self.xgb_params,
+            dtrain,
+            num_boost_round=10000,
+            obj=pack_sum_square_error,
+            callbacks=[
+                custom_callback(
+                    stopping_rounds=50,
+                    metric="tr-p-rmse",
+                    fevals=[
+                        pack_sum_rmse,
+                        pack_sum_average_peak_score(self.plan_size),
+                    ],
+                    evals=[(dtrain, "tr")],
+                    maximize=False,
+                    verbose_eval=self.verbose_eval,
+                )
+            ],
+        )
 
     def predict(self, task, states):
         """Predict the scores of states
@@ -188,7 +199,7 @@ class XGBModel(PythonBasedModel):
         # Predict 0 for invalid states that failed to be lowered.
         for idx, feature in enumerate(features):
             if feature.min() == feature.max() == 0:
-                ret[idx] = float('-inf')
+                ret[idx] = float("-inf")
 
         return ret
 
@@ -240,12 +251,18 @@ class XGBModel(PythonBasedModel):
                 breakdown = np.concatenate((breakdown, np.array(stage_score)))
         else:
             breakdown = np.concatenate(
-                (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), )))
+                (
+                    np.random.uniform(0, 1, (len(states),)),
+                    np.zeros(
+                        len(states),
+                    ),
+                )
+            )
 
         # Predict 0 for invalid states that failed to be lowered.
         for idx, feature in enumerate(features):
             if feature.min() == feature.max() == 0:
-                breakdown[idx] = float('-inf')
+                breakdown[idx] = float("-inf")
 
         return breakdown
 
@@ -366,8 +383,8 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None):
     ret = xgb.DMatrix(np.array(x_flatten), y_flatten)
     if weights is not None:
         ret.set_weight(weights_flatten)
-    dmatrix_context.set('pack_ids', ret, np.array(pack_ids))
-    dmatrix_context.set('group_sizes', ret, group_sizes)
+    dmatrix_context.set("pack_ids", ret, np.array(pack_ids))
+    dmatrix_context.set("group_sizes", ret, group_sizes)
     return ret
 
 
@@ -389,6 +406,7 @@ def predict_throughput_pack_sum(raw_preds, pack_ids):
     sum_pred = np.bincount(pack_ids, weights=raw_preds)
     return sum_pred
 
+
 def pack_sum_square_error(preds, dtrain):
     """Implement square error loss on pack-sum format as
      a custom objective function for xgboost.
@@ -420,6 +438,7 @@ def pack_sum_square_error(preds, dtrain):
 
     return gradient * weight, hessian * weight
 
+
 def pack_sum_rmse(raw_preds, labels):
     """Evaluate RMSE (rooted mean square error) in the pack-sum format
 
@@ -438,7 +457,8 @@ def pack_sum_rmse(raw_preds, labels):
     """
     pack_ids = dmatrix_context.get("pack_ids", labels)
     preds = predict_throughput_pack_sum(raw_preds, pack_ids)[pack_ids]
-    return 'p-rmse', np.sqrt(np.mean(np.square((preds - labels.get_label()))))
+    return "p-rmse", np.sqrt(np.mean(np.square((preds - labels.get_label()))))
+
 
 def pack_sum_average_peak_score(N):
     """Return the evaluation function for average-peak-score@N
@@ -469,18 +489,20 @@ def pack_sum_average_peak_score(N):
         score: float
         The name and score of this metric
         """
-        group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)])
+        group_sizes = dmatrix_context.get("group_sizes", labels, [len(preds)])
         pack_ids = dmatrix_context.get("pack_ids", labels)
 
         preds = predict_throughput_pack_sum(preds, pack_ids)
-        labels = (np.bincount(pack_ids, weights=labels.get_label())
-                  / np.unique(pack_ids, return_counts=True)[1])
+        labels = (
+            np.bincount(pack_ids, weights=labels.get_label())
+            / np.unique(pack_ids, return_counts=True)[1]
+        )
 
         scores = []
         offset = 0
         for size in group_sizes:
-            preds_group = preds[offset:offset + size]
-            labels_group = labels[offset:offset + size]
+            preds_group = preds[offset : offset + size]
+            labels_group = labels[offset : offset + size]
             offset += size
 
             trials = np.argsort(preds_group)[::-1][:N]
@@ -488,11 +510,20 @@ def pack_sum_average_peak_score(N):
             curve = max_curve(trial_scores) / np.max(labels_group)
             scores.append(np.mean(curve))
         return "a-peak@%d" % N, np.mean(scores)
+
     return feval
 
 
-def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
-                    maximize=False, verbose_eval=True, skip_every=2):
+def custom_callback(
+    stopping_rounds,
+    metric,
+    fevals,
+    evals=(),
+    log_file=None,
+    maximize=False,
+    verbose_eval=True,
+    skip_every=2,
+):
     """Callback function for xgboost to support multiple custom evaluation functions"""
     state = {}
     metric_shortname = metric.split("-")[1]
@@ -501,21 +532,21 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         """internal function"""
         bst = env.model
 
-        state['maximize_score'] = maximize
-        state['best_iteration'] = 0
+        state["maximize_score"] = maximize
+        state["best_iteration"] = 0
         if maximize:
-            state['best_score'] = float('-inf')
+            state["best_score"] = float("-inf")
         else:
-            state['best_score'] = float('inf')
+            state["best_score"] = float("inf")
 
         if bst is not None:
-            if bst.attr('best_score') is not None:
-                state['best_score'] = float(bst.attr('best_score'))
-                state['best_iteration'] = int(bst.attr('best_iteration'))
-                state['best_msg'] = bst.attr('best_msg')
+            if bst.attr("best_score") is not None:
+                state["best_score"] = float(bst.attr("best_score"))
+                state["best_iteration"] = int(bst.attr("best_iteration"))
+                state["best_msg"] = bst.attr("best_msg")
             else:
-                bst.set_attr(best_iteration=str(state['best_iteration']))
-                bst.set_attr(best_score=str(state['best_score']))
+                bst.set_attr(best_iteration=str(state["best_iteration"]))
+                bst.set_attr(best_score=str(state["best_score"]))
         else:
             assert env.cvfolds is not None
 
@@ -542,7 +573,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         else:
             for feval in fevals:
                 bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(':') for x in bst_eval.split()]
+                res = [x.split(":") for x in bst_eval.split()]
                 for kv in res[1:]:
                     res_dict[kv[0]] = [float(kv[1])]
 
@@ -557,14 +588,14 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
             infos = ["XGB iter: %3d" % i]
             for item in eval_res:
-                if 'null' in item[0]:
+                if "null" in item[0]:
                     continue
                 infos.append("%s: %.6f" % (item[0], item[1]))
 
             logger.debug("\t".join(infos))
             if log_file:
                 with open(log_file, "a") as fout:
-                    fout.write("\t".join(infos) + '\n')
+                    fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
@@ -574,24 +605,23 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
                 break
         assert score is not None
 
-        best_score = state['best_score']
-        best_iteration = state['best_iteration']
-        maximize_score = state['maximize_score']
-        if (maximize_score and score > best_score) or \
-                (not maximize_score and score < best_score):
-            msg = '[%d] %s' % (
-                env.iteration,
-                '\t'.join([_fmt_metric(x) for x in eval_res]))
-            state['best_msg'] = msg
-            state['best_score'] = score
-            state['best_iteration'] = env.iteration
+        best_score = state["best_score"]
+        best_iteration = state["best_iteration"]
+        maximize_score = state["maximize_score"]
+        if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
+            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
+            state["best_msg"] = msg
+            state["best_score"] = score
+            state["best_iteration"] = env.iteration
             # save the property to attributes, so they will occur in checkpoint.
             if env.model is not None:
-                env.model.set_attr(best_score=str(state['best_score']),
-                                   best_iteration=str(state['best_iteration']),
-                                   best_msg=state['best_msg'])
+                env.model.set_attr(
+                    best_score=str(state["best_score"]),
+                    best_iteration=str(state["best_iteration"]),
+                    best_msg=state["best_msg"],
+                )
         elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state['best_msg']
+            best_msg = state["best_msg"]
             if verbose_eval and env.rank == 0:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
             raise EarlyStopException(best_iteration)
index e531c3d..ef42dc6 100644 (file)
@@ -45,6 +45,7 @@ DEFAULT_FEATURE_VEC_LEN = 164
 SIZE_OF_INT32 = 4
 SIZE_OF_FLOAT32 = 4
 
+
 def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
     """Unpack the flatten feature (in byte array format) from c++
 
@@ -92,8 +93,8 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar
     n = struct.unpack_from("1i", byte_arr, offset=offset)[0]
     offset += SIZE_OF_INT32
 
-    sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset)
-    offset += SIZE_OF_INT32 * (n+2)
+    sizes = struct.unpack_from("%di" % (n + 2), byte_arr, offset=offset)
+    offset += SIZE_OF_INT32 * (n + 2)
 
     # unpack features
     features = []
@@ -117,8 +118,12 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar
 
             n_stmts = int(n_stmts[0] + 0.5)
             tmp_vec_len = (size - 1) // n_stmts
-            assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \
-                                           "Expected %d but got %d." % (vec_len, tmp_vec_len)
+            assert (
+                tmp_vec_len == vec_len
+            ), "The lenght of feature vector is wrong. " "Expected %d but got %d." % (
+                vec_len,
+                tmp_vec_len,
+            )
             assert tmp_vec_len * n_stmts == size - 1
             for _ in range(n_stmts):
                 x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset)
@@ -141,10 +146,9 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar
     return np.array(features, dtype=object), np.array(normalized_throughputs), np.array(task_ids)
 
 
-def get_per_store_features_from_file(filename: str,
-                                     max_lines: int,
-                                     max_n_bufs: Optional[int] = None) \
-        -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+def get_per_store_features_from_file(
+    filename: str, max_lines: int, max_n_bufs: Optional[int] = None
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
     """Get per-store features from a log file
 
     Parameters
@@ -166,15 +170,17 @@ def get_per_store_features_from_file(filename: str,
         Task ids
     """
     byte_arr = _ffi_api.GetPerStoreFeaturesFromFile(
-        filename, max_lines, max_n_bufs or DEFAULT_MAX_N_BUFS)
+        filename, max_lines, max_n_bufs or DEFAULT_MAX_N_BUFS
+    )
     return unpack_feature(byte_arr)
 
 
-def get_per_store_features_from_measure_pairs(inputs: List[MeasureInput],
-                                              results: List[MeasureResult],
-                                              skip_first_n_feature_extraction: int = 0,
-                                              max_n_bufs: Optional[int] = None) \
-        -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+def get_per_store_features_from_measure_pairs(
+    inputs: List[MeasureInput],
+    results: List[MeasureResult],
+    skip_first_n_feature_extraction: int = 0,
+    max_n_bufs: Optional[int] = None,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
     """Get per-store features from measurement input/result pairs
 
     Parameters
@@ -198,13 +204,14 @@ def get_per_store_features_from_measure_pairs(inputs: List[MeasureInput],
         Task ids
     """
     byte_arr = _ffi_api.GetPerStoreFeaturesFromMeasurePairs(
-        inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS)
+        inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS
+    )
     return unpack_feature(byte_arr)
 
 
-def get_per_store_features_from_states(states: List[Union[State, StateObject]],
-                                       task: "SearchTask",
-                                       max_n_bufs: Optional[int] = None) -> List[np.ndarray]:
+def get_per_store_features_from_states(
+    states: List[Union[State, StateObject]], task: "SearchTask", max_n_bufs: Optional[int] = None
+) -> List[np.ndarray]:
     """Get per-store features from measurement input/result pairs
 
     Parameters
@@ -230,7 +237,8 @@ def get_per_store_features_from_states(states: List[Union[State, StateObject]],
     elif isinstance(states[0], StateObject):
         state_objects = states
     byte_arr = _ffi_api.GetPerStoreFeaturesFromStates(
-        state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS)
+        state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS
+    )
     return unpack_feature(byte_arr)[0]
 
 
index e26c20f..897a682 100644 (file)
@@ -60,15 +60,13 @@ class Stage(Object):
 
     # Static trans table for compute_at location
     # This is used to transform the compute_at location to C++ enum
-    COMPUTE_AT_TRANS_TABLE = {
-        "root": 0,
-        "inlined": 1,
-        "iter": 2
-    }
+    COMPUTE_AT_TRANS_TABLE = {"root": 0, "inlined": 1, "iter": 2}
+
 
 @tvm._ffi.register_object("auto_scheduler.State")
 class StateObject(Object):
     """ The internal State object """
+
     def __eq__(self, other):
         return _ffi_api.StateEqual(self, other)
 
@@ -106,14 +104,14 @@ class State:
         "threadIdx.y": 8,
         "blockIdx.z": 9,
         "threadIdx.z": 10,
-        "tensorize": 11
+        "tensorize": 11,
     }
 
     def __init__(self, state_object, dag):
         self.state_object = state_object
         self.compute_dag = dag
 
-        self.stage_id_map = {}    # A dict maps operation to stage id
+        self.stage_id_map = {}  # A dict maps operation to stage id
         self._update_stage_id_map()
 
     @property
@@ -172,9 +170,12 @@ class State:
         if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
             raise ValueError("Invalid thread_name: ", thread_name)
 
-        self.state_object, res = _ffi_api.StateBind(self.state_object,
-                                                    self._resolve_stage_id(stage), iterator,
-                                                    State.ANNOTATION_TRANS_TABLE[thread_name])
+        self.state_object, res = _ffi_api.StateBind(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            iterator,
+            State.ANNOTATION_TRANS_TABLE[thread_name],
+        )
         return res
 
     def parallel(self, stage, iterator):
@@ -194,8 +195,9 @@ class State:
         res_it : Iterator
             The paralleled Iterator.
         """
-        self.state_object, res = _ffi_api.StateParallel(self.state_object,
-                                                        self._resolve_stage_id(stage), iterator)
+        self.state_object, res = _ffi_api.StateParallel(
+            self.state_object, self._resolve_stage_id(stage), iterator
+        )
         return res
 
     def unroll(self, stage, iterator, max_unroll=None):
@@ -217,9 +219,12 @@ class State:
         res_it : Iterator
             The unrolled Iterator.
         """
-        self.state_object, res = _ffi_api.StateUnroll(self.state_object,
-                                                      self._resolve_stage_id(stage), iterator,
-                                                      max_unroll if max_unroll else -1)
+        self.state_object, res = _ffi_api.StateUnroll(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            iterator,
+            max_unroll if max_unroll else -1,
+        )
         return res
 
     def vectorize(self, stage, iterator):
@@ -239,8 +244,9 @@ class State:
         res_it : Iterator
             The vectorized Iterator.
         """
-        self.state_object, res = _ffi_api.StateVectorize(self.state_object,
-                                                         self._resolve_stage_id(stage), iterator)
+        self.state_object, res = _ffi_api.StateVectorize(
+            self.state_object, self._resolve_stage_id(stage), iterator
+        )
         return res
 
     def fuse(self, stage, iters):
@@ -265,8 +271,9 @@ class State:
         If the iterators to be fused have stages attached at them(by compute_at), the fused
         result will become the new attach point.
         """
-        self.state_object, res = _ffi_api.StateFuse(self.state_object,
-                                                    self._resolve_stage_id(stage), iters)
+        self.state_object, res = _ffi_api.StateFuse(
+            self.state_object, self._resolve_stage_id(stage), iters
+        )
         return res
 
     def pragma(self, stage, iterator, pragma_type):
@@ -283,8 +290,9 @@ class State:
         pragma_type : str
             The pragma string.
         """
-        self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
-                                                 iterator, pragma_type)
+        self.state_object = _ffi_api.StatePragma(
+            self.state_object, self._resolve_stage_id(stage), iterator, pragma_type
+        )
 
     def reorder(self, stage, order):
         """Schedule primitive corresponding to `te.Stage.reorder`.
@@ -298,8 +306,9 @@ class State:
         order : List[Iterator]
             Iterators in the expected order.
         """
-        self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
-                                                  order)
+        self.state_object = _ffi_api.StateReorder(
+            self.state_object, self._resolve_stage_id(stage), order
+        )
 
     def split(self, stage, iterator, lengths, inner_to_outer=True):
         """Schedule primitive corresponding to `te.Stage.split`.
@@ -330,9 +339,9 @@ class State:
         If we do split on an iterator which has stages attached at it(by compute_at), the inner
         most iterator of split results will become the new attach point.
         """
-        self.state_object, res = _ffi_api.StateSplit(self.state_object,
-                                                     self._resolve_stage_id(stage),
-                                                     iterator, lengths, inner_to_outer)
+        self.state_object, res = _ffi_api.StateSplit(
+            self.state_object, self._resolve_stage_id(stage), iterator, lengths, inner_to_outer
+        )
         return res
 
     def follow_split(self, stage, iterator, src_step_id, n_split):
@@ -366,15 +375,13 @@ class State:
             The splitted new Iterators.
         """
 
-        self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
-                                                           self._resolve_stage_id(stage),
-                                                           iterator,
-                                                           src_step_id, n_split)
+        self.state_object, res = _ffi_api.StateFollowSplit(
+            self.state_object, self._resolve_stage_id(stage), iterator, src_step_id, n_split
+        )
         return res
 
-    def follow_fused_split(self, stage, iterator, src_step_ids, level,
-                           factor_or_nparts):
-        """ Schedule primitive extends to split step.
+    def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts):
+        """Schedule primitive extends to split step.
 
         This step is used to split an iterator by the same factors
         as the given list of SplitSteps and FuseSteps.
@@ -415,11 +422,14 @@ class State:
             The splitted new Iterators.
         """
 
-        self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
-                                                                self._resolve_stage_id(stage),
-                                                                iterator,
-                                                                src_step_ids, level,
-                                                                factor_or_nparts)
+        self.state_object, res = _ffi_api.StateFollowFusedSplit(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            iterator,
+            src_step_ids,
+            level,
+            factor_or_nparts,
+        )
         return res
 
     def storage_align(self, stage, iterator, factor, offset):
@@ -438,9 +448,9 @@ class State:
         offset : int
             The offset in the alignment specification.
         """
-        self.state_object = _ffi_api.StateStorageAlign(self.state_object,
-                                                       self._resolve_stage_id(stage), iterator,
-                                                       factor, offset)
+        self.state_object = _ffi_api.StateStorageAlign(
+            self.state_object, self._resolve_stage_id(stage), iterator, factor, offset
+        )
 
     def compute_at(self, stage, target_stage, target_iter):
         """Schedule primitive corresponding to `te.Stage.compute_at`.
@@ -464,10 +474,12 @@ class State:
         as bound for the newly created iterators.
         Call ComputeDAG::InferBound on the returned state to get the complete bound information.
         """
-        self.state_object = _ffi_api.StateComputeAt(self.state_object,
-                                                    self._resolve_stage_id(stage),
-                                                    self._resolve_stage_id(target_stage),
-                                                    target_iter)
+        self.state_object = _ffi_api.StateComputeAt(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            self._resolve_stage_id(target_stage),
+            target_iter,
+        )
 
     def compute_inline(self, stage):
         """Schedule primitive corresponding to `te.Stage.compute_inline`, see also the `te.Stage`
@@ -479,8 +491,9 @@ class State:
             The Stage to be marked compute inlined, which can be specified by the integer index,
             Operation, or output tensor of the stage.
         """
-        self.state_object = _ffi_api.StateComputeInline(self.state_object,
-                                                        self._resolve_stage_id(stage))
+        self.state_object = _ffi_api.StateComputeInline(
+            self.state_object, self._resolve_stage_id(stage)
+        )
 
     def compute_root(self, stage):
         """Schedule primitive corresponding to `te.Stage.compute_root`.
@@ -499,8 +512,9 @@ class State:
         as bound for the newly created iterators.
         Call ComputeDAG::InferBound on the returned state to get the complete bound information.
         """
-        self.state_object = _ffi_api.StateComputeRoot(self.state_object,
-                                                      self._resolve_stage_id(stage))
+        self.state_object = _ffi_api.StateComputeRoot(
+            self.state_object, self._resolve_stage_id(stage)
+        )
 
     def cache_read(self, stage, scope_name, reader_stages):
         """Schedule primitive corresponding to `te.Schedule.cache_read`.
@@ -528,10 +542,13 @@ class State:
         target stage).
         """
         reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
-        self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
-                                                                  self._resolve_stage_id(stage),
-                                                                  scope_name, reader_stage_ids,
-                                                                  self.compute_dag)
+        self.state_object, new_stage_id = _ffi_api.StateCacheRead(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            scope_name,
+            reader_stage_ids,
+            self.compute_dag,
+        )
         # Add a new stage will change all ops behind the added stage. But we still want to keep the
         # original ops map, apply stage id offset to stage_id_map to make them work.
         self._apply_stage_id_offset(int(new_stage_id))
@@ -561,9 +578,9 @@ class State:
         target stage).
         This step will cache write all output tensors of the target stage.
         """
-        self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object,
-                                                                   self._resolve_stage_id(stage),
-                                                                   scope_name, self.compute_dag)
+        self.state_object, new_stage_id = _ffi_api.StateCacheWrite(
+            self.state_object, self._resolve_stage_id(stage), scope_name, self.compute_dag
+        )
         # Add a new stage will change all ops behind the added stage. But we still want to keep the
         # original ops map, apply stage id offset to stage_id_map to make them work.
         self._apply_stage_id_offset(int(new_stage_id))
@@ -594,10 +611,13 @@ class State:
         Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
         target stage).
         """
-        self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
-                                                                self._resolve_stage_id(stage),
-                                                                iterator, factor_iter_id,
-                                                                self.compute_dag)
+        self.state_object, new_stage_id = _ffi_api.StateRfactor(
+            self.state_object,
+            self._resolve_stage_id(stage),
+            iterator,
+            factor_iter_id,
+            self.compute_dag,
+        )
         # Add a new stage will change all ops behind the added stage. But we still want to keep the
         # original ops map, apply stage id offset to stage_id_map to make them work.
         self._apply_stage_id_offset(int(new_stage_id))
@@ -617,8 +637,9 @@ class State:
             return self.stage_id_map[stage_id.op]
         if isinstance(stage_id, int):
             return stage_id
-        raise ValueError("Invalid stage: " + stage_id +
-                         " . Expect to be a int, Operation or Tensor")
+        raise ValueError(
+            "Invalid stage: " + stage_id + " . Expect to be a int, Operation or Tensor"
+        )
 
     def _update_stage_id_map(self):
         for index, stage in enumerate(self.stages):
@@ -634,8 +655,7 @@ class State:
             key = key.op
         if isinstance(key, Operation):
             return self.stages[self.stage_id_map[key]]
-        raise ValueError("Invalid item: " + key +
-                         " . Expect to be a Operation or Tensor")
+        raise ValueError("Invalid item: " + key + " . Expect to be a Operation or Tensor")
 
     def __str__(self):
         return str(self.state_object)
index 8c6476c..fd05b2d 100644 (file)
@@ -49,8 +49,13 @@ from tvm.contrib import tar, ndk
 
 from . import _ffi_api
 from .loop_state import StateObject
-from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \
-    check_remote
+from .utils import (
+    get_const_tuple,
+    NoDaemonPool,
+    call_func_with_timeout,
+    request_remote,
+    check_remote,
+)
 
 # The maximum length of error message
 MAX_ERROR_MSG_LEN = 512
@@ -60,6 +65,7 @@ MAX_ERROR_MSG_LEN = 512
 GLOBAL_BUILD_ARGUMENTS = None
 GLOBAL_RUN_ARGUMENTS = None
 
+
 @tvm._ffi.register_object("auto_scheduler.MeasureCallback")
 class MeasureCallback(Object):
     """ The base class of measurement callback functions. """
@@ -67,7 +73,7 @@ class MeasureCallback(Object):
 
 @tvm._ffi.register_object("auto_scheduler.MeasureInput")
 class MeasureInput(Object):
-    """ Store the input of a measurement.
+    """Store the input of a measurement.
 
     Parameters
     ----------
@@ -76,6 +82,7 @@ class MeasureInput(Object):
     state : Union[State, StateObject]
         The State to be measured.
     """
+
     def __init__(self, task, state):
         state = state if isinstance(state, StateObject) else state.state_object
         self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)
@@ -83,7 +90,7 @@ class MeasureInput(Object):
 
 @tvm._ffi.register_object("auto_scheduler.BuildResult")
 class BuildResult(Object):
-    """ Store the result of a build.
+    """Store the result of a build.
 
     Parameters
     ----------
@@ -98,17 +105,19 @@ class BuildResult(Object):
     time_cost : float
         The time cost of build.
     """
+
     def __init__(self, filename, args, error_no, error_msg, time_cost):
         filename = filename if filename else ""
         error_msg = error_msg if error_msg else ""
 
         self.__init_handle_by_constructor__(
-            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost
+        )
 
 
 @tvm._ffi.register_object("auto_scheduler.MeasureResult")
 class MeasureResult(Object):
-    """ Store the results of a measurement.
+    """Store the results of a measurement.
 
     Parameters
     ----------
@@ -123,12 +132,13 @@ class MeasureResult(Object):
     timestamp : float
         The time stamps of this measurement.
     """
+
     def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
         error_msg = error_msg if error_msg else ""
 
         self.__init_handle_by_constructor__(
-            _ffi_api.MeasureResult, costs, error_no,
-            error_msg, all_cost, timestamp)
+            _ffi_api.MeasureResult, costs, error_no, error_msg, all_cost, timestamp
+        )
 
 
 @tvm._ffi.register_object("auto_scheduler.ProgramBuilder")
@@ -136,7 +146,7 @@ class ProgramBuilder(Object):
     """ The base class of ProgramBuilders. """
 
     def build(self, measure_inputs, verbose=1):
-        """ Build programs and return results.
+        """Build programs and return results.
 
         Parameters
         ----------
@@ -157,7 +167,7 @@ class ProgramRunner(Object):
     """ The base class of ProgramRunners. """
 
     def run(self, measure_inputs, build_results, verbose=1):
-        """ Run measurement and return results.
+        """Run measurement and return results.
 
         Parameters
         ----------
@@ -177,7 +187,7 @@ class ProgramRunner(Object):
 
 @tvm._ffi.register_object("auto_scheduler.LocalBuilder")
 class LocalBuilder(ProgramBuilder):
-    """ LocalBuilder use local CPU cores to build programs in parallel.
+    """LocalBuilder use local CPU cores to build programs in parallel.
 
     Parameters
     ----------
@@ -190,17 +200,13 @@ class LocalBuilder(ProgramBuilder):
         The name of registered build function.
     """
 
-    def __init__(self,
-                 timeout=15,
-                 n_parallel=multiprocessing.cpu_count(),
-                 build_func='default'):
-        self.__init_handle_by_constructor__(
-            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+    def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_func="default"):
+        self.__init_handle_by_constructor__(_ffi_api.LocalBuilder, timeout, n_parallel, build_func)
 
 
 @tvm._ffi.register_object("auto_scheduler.LocalRunner")
 class LocalRunner(ProgramRunner):
-    """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+    """LocalRunner that uses local CPU/GPU to measures the time cost of programs.
 
     Parameters
     ----------
@@ -233,21 +239,29 @@ class LocalRunner(ProgramRunner):
         This is only has effect on CPU task.
     """
 
-    def __init__(self,
-                 timeout=10,
-                 number=3,
-                 repeat=1,
-                 min_repeat_ms=0,
-                 cooldown_interval=0.0,
-                 enable_cpu_cache_flush=False):
+    def __init__(
+        self,
+        timeout=10,
+        number=3,
+        repeat=1,
+        min_repeat_ms=0,
+        cooldown_interval=0.0,
+        enable_cpu_cache_flush=False,
+    ):
         self.__init_handle_by_constructor__(
-            _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval,
-            enable_cpu_cache_flush)
+            _ffi_api.LocalRunner,
+            timeout,
+            number,
+            repeat,
+            min_repeat_ms,
+            cooldown_interval,
+            enable_cpu_cache_flush,
+        )
 
 
 @tvm._ffi.register_object("auto_scheduler.RPCRunner")
 class RPCRunner(ProgramRunner):
-    """ RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
+    """RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
     Or sometime we may need to use RPC even in local running to insulate the thread environment.
     (e.g. running CUDA programs)
 
@@ -292,24 +306,48 @@ class RPCRunner(ProgramRunner):
         This is only has effect on CPU task.
     """
 
-    def __init__(self, key, host, port,
-                 priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
-                 min_repeat_ms=0, cooldown_interval=0.0, enable_cpu_cache_flush=False):
+    def __init__(
+        self,
+        key,
+        host,
+        port,
+        priority=1,
+        n_parallel=1,
+        timeout=10,
+        number=3,
+        repeat=1,
+        min_repeat_ms=0,
+        cooldown_interval=0.0,
+        enable_cpu_cache_flush=False,
+    ):
         self.__init_handle_by_constructor__(
-            _ffi_api.RPCRunner, key, host, port, priority, n_parallel, timeout,
-            number, repeat, min_repeat_ms, cooldown_interval, enable_cpu_cache_flush)
+            _ffi_api.RPCRunner,
+            key,
+            host,
+            port,
+            priority,
+            n_parallel,
+            timeout,
+            number,
+            repeat,
+            min_repeat_ms,
+            cooldown_interval,
+            enable_cpu_cache_flush,
+        )
 
         if check_remote(key, host, port, priority, timeout):
             print("Get devices for measurement successfully!")
         else:
-            raise RuntimeError("Cannot get remote devices from the tracker. "
-                               "Please check the status of tracker by "
-                               "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
-                               "and make sure you have free devices on the queue status.")
+            raise RuntimeError(
+                "Cannot get remote devices from the tracker. "
+                "Please check the status of tracker by "
+                "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
+                "and make sure you have free devices on the queue status."
+            )
 
 
 class LocalRPCMeasureContext:
-    """ A context wrapper for running RPCRunner locally.
+    """A context wrapper for running RPCRunner locally.
     This will launch a local RPC Tracker and local RPC Server.
 
     Parameters
@@ -347,21 +385,46 @@ class LocalRPCMeasureContext:
         This is only has effect on CPU task.
     """
 
-    def __init__(self, priority=1, n_parallel=1, timeout=10, number=3, repeat=1,
-                 min_repeat_ms=0, cooldown_interval=0.0, enable_cpu_cache_flush=False):
+    def __init__(
+        self,
+        priority=1,
+        n_parallel=1,
+        timeout=10,
+        number=3,
+        repeat=1,
+        min_repeat_ms=0,
+        cooldown_interval=0.0,
+        enable_cpu_cache_flush=False,
+    ):
         ctx = tvm.context("cuda", 0)
         if ctx.exist:
-            cuda_arch = "sm_" + "".join(ctx.compute_version.split('.'))
+            cuda_arch = "sm_" + "".join(ctx.compute_version.split("."))
             set_cuda_target_arch(cuda_arch)
-        host = '0.0.0.0'
+        host = "0.0.0.0"
         self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
-        device_key = '$local$device$%d' % self.tracker.port
-        self.server = Server(host, port=self.tracker.port, port_end=10000,
-                             key=device_key, use_popen=True, silent=True,
-                             tracker_addr=(self.tracker.host, self.tracker.port))
-        self.runner = RPCRunner(device_key, host, self.tracker.port, priority,
-                                n_parallel, timeout, number, repeat,
-                                min_repeat_ms, cooldown_interval, enable_cpu_cache_flush)
+        device_key = "$local$device$%d" % self.tracker.port
+        self.server = Server(
+            host,
+            port=self.tracker.port,
+            port_end=10000,
+            key=device_key,
+            use_popen=True,
+            silent=True,
+            tracker_addr=(self.tracker.host, self.tracker.port),
+        )
+        self.runner = RPCRunner(
+            device_key,
+            host,
+            self.tracker.port,
+            priority,
+            n_parallel,
+            timeout,
+            number,
+            repeat,
+            min_repeat_ms,
+            cooldown_interval,
+            enable_cpu_cache_flush,
+        )
         # Wait for the processes to start
         time.sleep(0.5)
 
@@ -373,24 +436,26 @@ class LocalRPCMeasureContext:
 
 class MeasureErrorNo(object):
     """ Error type for MeasureResult. """
-    NO_ERROR = 0              # No error
-    INSTANTIATION_ERROR = 1   # Errors happen when apply transform steps from init state
-    COMPILE_HOST = 2          # Errors happen when compiling code on host (e.g., tvm.build)
-    COMPILE_DEVICE = 3        # Errors happen when compiling code on device
-                              # (e.g. OpenCL JIT on the device)
-    RUNTIME_DEVICE = 4        # Errors happen when run program on device
-    WRONG_ANSWER = 5          # Answer is wrong when compared to a reference output
-    BUILD_TIMEOUT = 6         # Timeout during compilation
-    RUN_TIMEOUT = 7           # Timeout during run
-    UNKNOWN_ERROR = 8         # Unknown error
+
+    NO_ERROR = 0  # No error
+    INSTANTIATION_ERROR = 1  # Errors happen when apply transform steps from init state
+    COMPILE_HOST = 2  # Errors happen when compiling code on host (e.g., tvm.build)
+    COMPILE_DEVICE = 3  # Errors happen when compiling code on device
+    # (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4  # Errors happen when run program on device
+    WRONG_ANSWER = 5  # Answer is wrong when compared to a reference output
+    BUILD_TIMEOUT = 6  # Timeout during compilation
+    RUN_TIMEOUT = 7  # Timeout during run
+    UNKNOWN_ERROR = 8  # Unknown error
 
 
 def make_error_msg():
     """ Get the error message from traceback. """
     error_msg = str(traceback.format_exc())
     if len(error_msg) > MAX_ERROR_MSG_LEN:
-        error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \
-            "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:]
+        error_msg = (
+            error_msg[: MAX_ERROR_MSG_LEN // 2] + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN // 2 :]
+        )
     return error_msg
 
 
@@ -417,9 +482,9 @@ def local_build_worker(index):
     measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS
     assert isinstance(build_func, str)
 
-    if build_func == 'default':
+    if build_func == "default":
         build_func = tar.tar
-    elif build_func == 'ndk':
+    elif build_func == "ndk":
         build_func = ndk.create_shared
     else:
         raise ValueError("Invalid build_func" + build_func)
@@ -434,8 +499,7 @@ def local_build_worker(index):
         args = []
 
         try:
-            sch, args = task.compute_dag.apply_steps_from_state(
-                inp.state, layout_rewrite=True)
+            sch, args = task.compute_dag.apply_steps_from_state(inp.state, layout_rewrite=True)
         # pylint: disable=broad-except
         except Exception:
             error_no = MeasureErrorNo.INSTANTIATION_ERROR
@@ -443,14 +507,14 @@ def local_build_worker(index):
 
         if error_no == 0:
             dirname = tempfile.mkdtemp()
-            filename = os.path.join(
-                dirname, "tmp_func." + build_func.output_format)
+            filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
 
             try:
                 # TODO(merrymercy): Port the unroll pass.
                 with transform.PassContext():
                     func = build_module.build(
-                        sch, args, target=task.target, target_host=task.target_host)
+                        sch, args, target=task.target, target_host=task.target_host
+                    )
                 func.export_library(filename, build_func)
             # pylint: disable=broad-except
             except Exception:
@@ -476,7 +540,7 @@ def local_build_worker(index):
 
 
 @tvm._ffi.register_func("auto_scheduler.local_builder.build")
-def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1):
+def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbose=1):
     """
     Build function of LocalBuilder to build the MeasureInputs to runnable modules.
 
@@ -519,9 +583,17 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo
 
 
 @tvm._ffi.register_func("auto_scheduler.local_runner.run")
-def local_run(inputs, build_results,
-              timeout=10, number=3, repeat=1, min_repeat_ms=0, cooldown_interval=0,
-              enable_cpu_cache_flush=False, verbose=1):
+def local_run(
+    inputs,
+    build_results,
+    timeout=10,
+    number=3,
+    repeat=1,
+    min_repeat_ms=0,
+    cooldown_interval=0,
+    enable_cpu_cache_flush=False,
+    verbose=1,
+):
     """
     Run function of LocalRunner to test the performance of the input BuildResults.
 
@@ -580,10 +652,15 @@ def local_run(inputs, build_results,
             # under the std::function. We could lift the restriction later once we fold
             # the PackedFunc as an object. Currently, we pass function name to work
             # around it.
-            f_prepare = 'cache_flush_cpu_non_first_arg' if enable_cpu_cache_flush else ''
+            f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
             time_f = func.time_evaluator(
-                func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms,
-                f_preproc=f_prepare)
+                func.entry_name,
+                ctx,
+                number=number,
+                repeat=repeat,
+                min_repeat_ms=min_repeat_ms,
+                f_preproc=f_prepare,
+            )
         # pylint: disable=broad-except
         except Exception:
             costs = (max_float,)
@@ -592,8 +669,9 @@ def local_run(inputs, build_results,
 
         if error_no == 0:
             try:
-                args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
-                        build_res.args]
+                args = [
+                    ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args
+                ]
                 random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
                 assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
                 for arg in args:
@@ -618,20 +696,28 @@ def local_run(inputs, build_results,
         return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
 
     measure_results = []
-    assert len(inputs) == len(build_results), \
-        "Measure input size should be equal to build results"
+    assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
     for inp, build_res in zip(inputs, build_results):
         if build_res.error_no != 0:
-            res = (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
-                time.time()
+            res = (
+                (max_float,),
+                build_res.error_no,
+                build_res.error_msg,
+                build_res.time_cost,
+                time.time(),
+            )
         else:
-            res = call_func_with_timeout(
-                timeout, timed_func, args=(inp, build_res))
+            res = call_func_with_timeout(timeout, timed_func, args=(inp, build_res))
             if isinstance(res, TimeoutError):
                 if verbose >= 1:
                     print("*T", end="")  # Run timeout
-                res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \
-                    build_res.time_cost + timeout, time.time()
+                res = (
+                    (max_float,),
+                    MeasureErrorNo.RUN_TIMEOUT,
+                    None,
+                    build_res.time_cost + timeout,
+                    time.time(),
+                )
         measure_results.append(MeasureResult(*res))
 
     if verbose >= 1:
@@ -641,7 +727,7 @@ def local_run(inputs, build_results,
 
 
 def rpc_run_worker(index):
-    """ Function to be ran in the RPCRunner thread pool.
+    """Function to be ran in the RPCRunner thread pool.
 
     Parameters
     ----------
@@ -654,17 +740,34 @@ def rpc_run_worker(index):
         The measure result of this Runner thread.
     """
     global GLOBAL_RUN_ARGUMENTS
-    inputs, build_results, key, host, port, priority, timeout, number, \
-        repeat, min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, \
-        verbose = GLOBAL_RUN_ARGUMENTS
+    (
+        inputs,
+        build_results,
+        key,
+        host,
+        port,
+        priority,
+        timeout,
+        number,
+        repeat,
+        min_repeat_ms,
+        cooldown_interval,
+        enable_cpu_cache_flush,
+        verbose,
+    ) = GLOBAL_RUN_ARGUMENTS
 
     max_float = 1e10  # We use 1e10 instead of sys.float_info.max for better readability in log
     inp = inputs[index]
     build_res = build_results[index]
 
     if build_res.error_no != MeasureErrorNo.NO_ERROR:
-        return (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
-            time.time()
+        return (
+            (max_float,),
+            build_res.error_no,
+            build_res.error_msg,
+            build_res.time_cost,
+            time.time(),
+        )
 
     def timed_func():
         tic = time.time()
@@ -681,10 +784,15 @@ def rpc_run_worker(index):
             # under the std::function. We could lift the restriction later once we fold
             # the PackedFunc as an object. Currently, we pass function name to work
             # around it.
-            f_prepare = 'cache_flush_cpu_non_first_arg' if enable_cpu_cache_flush else ''
+            f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
             time_f = func.time_evaluator(
-                func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms,
-                f_preproc=f_prepare)
+                func.entry_name,
+                ctx,
+                number=number,
+                repeat=repeat,
+                min_repeat_ms=min_repeat_ms,
+                f_preproc=f_prepare,
+            )
         # pylint: disable=broad-except
         except Exception:
             costs = (max_float,)
@@ -693,13 +801,16 @@ def rpc_run_worker(index):
 
         if error_no == 0:
             try:
-                args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
-                        build_res.args]
+                args = [
+                    ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args
+                ]
                 try:
                     random_fill = remote.get_function("tvm.contrib.random.random_fill")
                 except AttributeError:
-                    raise AttributeError("Please make sure USE_RANDOM is ON in the config.cmake "
-                                         "on the remote devices")
+                    raise AttributeError(
+                        "Please make sure USE_RANDOM is ON in the config.cmake "
+                        "on the remote devices"
+                    )
                 for arg in args:
                     random_fill(arg)
                 ctx.sync()
@@ -707,8 +818,8 @@ def rpc_run_worker(index):
                 costs = time_f(*args).results
                 # clean up remote files
                 remote.remove(build_res.filename)
-                remote.remove(os.path.splitext(build_res.filename)[0] + '.so')
-                remote.remove('')
+                remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
+                remote.remove("")
             # pylint: disable=broad-except
             except Exception:
                 costs = (max_float,)
@@ -732,16 +843,34 @@ def rpc_run_worker(index):
     if isinstance(res, TimeoutError):
         if verbose >= 1:
             print("*T", end="")  # Run timeout
-        res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \
-            timeout, time.time()
+        res = (
+            (max_float,),
+            MeasureErrorNo.RUN_TIMEOUT,
+            None,
+            build_res.time_cost + timeout,
+            time.time(),
+        )
     return res
 
 
 @tvm._ffi.register_func("auto_scheduler.rpc_runner.run")
-def rpc_runner_run(inputs, build_results, key, host, port,
-                   priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=0,
-                   cooldown_interval=0.0, enable_cpu_cache_flush=False, verbose=1):
-    """ Run function of RPCRunner to test the performance of the input BuildResults.
+def rpc_runner_run(
+    inputs,
+    build_results,
+    key,
+    host,
+    port,
+    priority=1,
+    n_parallel=1,
+    timeout=10,
+    number=3,
+    repeat=1,
+    min_repeat_ms=0,
+    cooldown_interval=0.0,
+    enable_cpu_cache_flush=False,
+    verbose=1,
+):
+    """Run function of RPCRunner to test the performance of the input BuildResults.
 
     Parameters
     ----------
@@ -795,12 +924,23 @@ def rpc_runner_run(inputs, build_results, key, host, port,
         The measure results of these MeasureInputs.
     """
     global GLOBAL_RUN_ARGUMENTS
-    GLOBAL_RUN_ARGUMENTS = (inputs, build_results, key, host, port, priority, timeout, number,
-                            repeat, min_repeat_ms, cooldown_interval, enable_cpu_cache_flush,
-                            verbose)
-
-    assert len(inputs) == len(build_results), \
-        "Measure input size should be equal to build results"
+    GLOBAL_RUN_ARGUMENTS = (
+        inputs,
+        build_results,
+        key,
+        host,
+        port,
+        priority,
+        timeout,
+        number,
+        repeat,
+        min_repeat_ms,
+        cooldown_interval,
+        enable_cpu_cache_flush,
+        verbose,
+    )
+
+    assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
     pool = NoDaemonPool(n_parallel)
     tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
     pool.terminate()
index dd40f21..b9633d5 100644 (file)
@@ -35,6 +35,7 @@ class RecordToFile(MeasureCallback):
     filename : str
         File name for this callback to write log to.
     """
+
     def __init__(self, filename="auto_scheduler_tuning.json"):
         self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename)
 
@@ -49,11 +50,12 @@ class RecordReader(Object):
     filename : str = "auto_scheduler_tuning.json"
         File name for this reader to load log from.
     """
+
     def __init__(self, filename="auto_scheduler_tuning.json"):
         self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)
 
     def read_lines(self, max_lines=None, skip_lines=0):
-        """ Read multiple lines from the log file.
+        """Read multiple lines from the log file.
 
         Parameters
         ----------
@@ -69,8 +71,9 @@ class RecordReader(Object):
         results : List[MeasureResult]
             The MeasureResults loaded from the log file.
         """
-        inputs, results = _ffi_api.RecordReaderReadLines(self, max_lines if max_lines else -1,
-                                                         skip_lines)
+        inputs, results = _ffi_api.RecordReaderReadLines(
+            self, max_lines if max_lines else -1, skip_lines
+        )
         return inputs, results
 
     def __iter__(self):
@@ -112,8 +115,9 @@ def save_records(filename, inputs, results):
     """
     _ffi_api.SaveRecords(filename, inputs, results)
 
+
 def load_best(filename, workload_key=None, target=None):
-    """ Return the best measurement pair form a log file. This may return none results if
+    """Return the best measurement pair form a log file. This may return none results if
     there is no legal measure pair with the specified workload_key/target found from the log file.
 
     Parameters
index e2bfca3..15d84dc 100644 (file)
@@ -42,7 +42,7 @@ class SearchCallback(Object):
 
 @tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates")
 class PreloadMeasuredStates(SearchCallback):
-    """ A SearchCallback to load measured states from the log file for a search policy.
+    """A SearchCallback to load measured states from the log file for a search policy.
 
     This can resume the state of the search policy:
         - Making sure an already measured state in former searches will never be measured again.
@@ -54,6 +54,7 @@ class PreloadMeasuredStates(SearchCallback):
     filename : str
         The name of the record file.
     """
+
     def __init__(self, filename="auto_scheduler_tuning.json"):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
@@ -65,7 +66,7 @@ class SearchPolicy(Object):
 
 @tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
 class EmptyPolicy(SearchPolicy):
-    """ This is an example empty search policy which will always generate
+    """This is an example empty search policy which will always generate
     the init state of ComputeDAG.
 
     Parameters
@@ -75,13 +76,14 @@ class EmptyPolicy(SearchPolicy):
     init_search_callbacks : Optional[List[SearchCallback]]
         Callback functions called before the search process.
     """
+
     def __init__(self, task, init_search_callbacks=None):
         self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)
 
 
 @tvm._ffi.register_object("auto_scheduler.SketchPolicy")
 class SketchPolicy(SearchPolicy):
-    """  The search policy that searches in a hierarchical search space defined by sketches.
+    """The search policy that searches in a hierarchical search space defined by sketches.
     The policy randomly samples programs from the space defined by sketches and use evolutionary
     search to fine-tune them.
 
@@ -111,25 +113,28 @@ class SketchPolicy(SearchPolicy):
     DEFAULT_PARAMS = {
         "eps_greedy": 0.05,
         "retry_search_one_round_on_empty": 10,
-
-        'evolutionary_search_population': 2048,
-        'evolutionary_search_num_iters': 10,
-        'evolutionary_search_mutation_prob': 0.85,
+        "evolutionary_search_population": 2048,
+        "evolutionary_search_num_iters": 10,
+        "evolutionary_search_mutation_prob": 0.85,
         "evolutionary_search_use_measured_ratio": 0.2,
-
-        'cpu_multi_level_tiling_structure': 'SSRSRS',
-        'gpu_multi_level_tiling_structure': 'SSSRRSRS',
+        "cpu_multi_level_tiling_structure": "SSRSRS",
+        "gpu_multi_level_tiling_structure": "SSSRRSRS",
         # Notice: the default thread bind policy of GPU assumes the tiling structure to have at
         # least 3 spatial tiling levels in outermost
-
-        'max_innermost_split_factor': 16,
-        'max_vectorize_size': 16,
-
-        'disable_change_compute_location': 0,
+        "max_innermost_split_factor": 16,
+        "max_vectorize_size": 16,
+        "disable_change_compute_location": 0,
     }
 
-    def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
-                 init_search_callbacks=None):
+    def __init__(
+        self,
+        task,
+        schedule_cost_model=RandomModel(),
+        params=None,
+        seed=None,
+        verbose=1,
+        init_search_callbacks=None,
+    ):
         if params is None:
             params = SketchPolicy.DEFAULT_PARAMS
         else:
@@ -138,11 +143,17 @@ class SketchPolicy(SearchPolicy):
                     params[key] = value
 
         self.__init_handle_by_constructor__(
-            _ffi_api.SketchPolicy, task, schedule_cost_model, params,
-            seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)
+            _ffi_api.SketchPolicy,
+            task,
+            schedule_cost_model,
+            params,
+            seed or random.randint(1, 1 << 30),
+            verbose,
+            init_search_callbacks,
+        )
 
     def generate_sketches(self, print_for_debug=False):
-        """ Generate the sketches.
+        """Generate the sketches.
         This python interface is mainly used for debugging and testing.
         The actual search is all done in c++.
 
index f5b53fb..bbc2d77 100644 (file)
@@ -50,7 +50,7 @@ def get_func_name(func):
     name: str
         The function name.
     """
-    return func.func_name if hasattr(func, 'func_name') else func.__qualname__
+    return func.func_name if hasattr(func, "func_name") else func.__qualname__
 
 
 def get_const_int(exp):
@@ -92,7 +92,6 @@ def get_const_tuple(in_tuple):
     return tuple(get_const_int(x) for x in in_tuple)
 
 
-
 def list_to_tuple(x):
     """ Convert a list to a tuple recursively. """
     assert isinstance(x, list)
@@ -107,7 +106,7 @@ def serialize_args(args):
     ret = []
     for t in args:
         if isinstance(t, Tensor):
-            t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
+            t = ("TENSOR", get_const_tuple(t.shape), t.dtype)
         elif isinstance(t, list):
             t = list_to_tuple(t)
 
@@ -121,7 +120,7 @@ def deserialize_args(args):
     """The inverse function of :code:`serialize_args`"""
     ret = []
     for t in args:
-        if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
+        if isinstance(t, (tuple, list)) and t[0] == "TENSOR":
             ret.append(placeholder(shape=t[1], dtype=t[2]))
         else:
             ret.append(t)
@@ -147,7 +146,7 @@ class NoDaemonPool(multiprocessing.pool.Pool):
     This allows us to start new processings inside the worker function"""
 
     def __init__(self, *args, **kwargs):
-        kwargs['context'] = NoDaemonContext()
+        kwargs["context"] = NoDaemonContext()
         super().__init__(*args, **kwargs)
 
     def __reduce__(self):
@@ -170,6 +169,7 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM):
 
 def call_func_with_timeout(timeout, func, args=(), kwargs=None):
     """Call a function with timeout"""
+
     def func_wrapper(que):
         if kwargs:
             que.put(func(*args, **kwargs))
@@ -199,7 +199,7 @@ def call_func_with_timeout(timeout, func, args=(), kwargs=None):
 
 
 def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
-    """ Request a remote session.
+    """Request a remote session.
 
     Parameters
     ----------
@@ -222,12 +222,11 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
         The connected remote RPCSession.
     """
     # connect to the tracker
-    host = host or os.environ['TVM_TRACKER_HOST']
-    port = port or int(os.environ['TVM_TRACKER_PORT'])
+    host = host or os.environ["TVM_TRACKER_HOST"]
+    port = port or int(os.environ["TVM_TRACKER_PORT"])
 
     tracker = rpc.connect_tracker(host, port)
-    remote = tracker.request(device_key, priority=priority,
-                             session_timeout=timeout)
+    remote = tracker.request(device_key, priority=priority, session_timeout=timeout)
     return remote
 
 
@@ -259,7 +258,9 @@ def check_remote(device_key, host=None, port=None, priority=100, timeout=10):
     def _check():
         request_remote(device_key, host, port, priority)
 
-    t = threading.Thread(target=_check, )
+    t = threading.Thread(
+        target=_check,
+    )
     t.start()
     t.join(timeout)
     return not t.is_alive()
index 045720a..6c3b4d1 100644 (file)
@@ -39,7 +39,7 @@ WORKLOAD_FUNC_REGISTRY = {}
 
 
 def register_workload(func_name, f=None, override=False):
-    """ Register a function that generates a certain workload.
+    """Register a function that generates a certain workload.
 
     The input function should take hashable and jsonable arguments
     (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor.
@@ -74,16 +74,17 @@ def register_workload(func_name, f=None, override=False):
     def register(myf):
         """internal register function"""
         if func_name in WORKLOAD_FUNC_REGISTRY and not override:
-            raise RuntimeError('%s has been registered already' % func_name)
+            raise RuntimeError("%s has been registered already" % func_name)
         WORKLOAD_FUNC_REGISTRY[func_name] = myf
         return myf
+
     if f:
         return register(f)
     return register
 
 
 def make_workload_key(func, args):
-    """ Make a workload key by function and arguments.
+    """Make a workload key by function and arguments.
 
     Parameters
     ----------
@@ -105,12 +106,17 @@ def make_workload_key(func, args):
     elif isinstance(func, str):
         func_name = func
     else:
-        raise ValueError("Invalid function: " + str(func) +
-                         " . `make_workload_key` expects a callable function or its function name")
+        raise ValueError(
+            "Invalid function: "
+            + str(func)
+            + " . `make_workload_key` expects a callable function or its function name"
+        )
 
     if not func_name in WORKLOAD_FUNC_REGISTRY:
-        raise ValueError("%s is not registered. "  % func,
-                         "Please register it with @auto_scheduler.register_workload")
+        raise ValueError(
+            "%s is not registered. " % func,
+            "Please register it with @auto_scheduler.register_workload",
+        )
 
     args = serialize_args(args)
 
@@ -118,7 +124,7 @@ def make_workload_key(func, args):
 
 
 def decode_workload_key_to_func_args(workload_key):
-    """ Decode a workload key to the registerd function name and its corresponding args.
+    """Decode a workload key to the registerd function name and its corresponding args.
 
     Parameters
     ----------
@@ -136,14 +142,16 @@ def decode_workload_key_to_func_args(workload_key):
 
     workload = json.loads(workload_key)
     if not workload[0] in WORKLOAD_FUNC_REGISTRY:
-        raise ValueError("%s is not registered. " % workload[0] +
-                         "Please register it with @auto_scheduler.register_workload")
+        raise ValueError(
+            "%s is not registered. " % workload[0]
+            + "Please register it with @auto_scheduler.register_workload"
+        )
     return workload[0], deserialize_args(workload[1:])
 
 
 @tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors")
 def workload_key_to_tensors(workload_key):
-    """ Get the input/output tensors from the workload key.
+    """Get the input/output tensors from the workload key.
 
     This method is usually used to create a ComputeDAG by workload key.
 
@@ -166,7 +174,7 @@ def workload_key_to_tensors(workload_key):
 
 
 def save_workload_func_registry(filename):
-    """ Dump workload function registry to a pickle binary file.
+    """Dump workload function registry to a pickle binary file.
 
     Parameters
     ----------
@@ -175,11 +183,11 @@ def save_workload_func_registry(filename):
     """
     global WORKLOAD_FUNC_REGISTRY
 
-    pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb'))
+    pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, "wb"))
 
 
 def load_workload_func_registry(filename):
-    """ Load workload function registry from a pickle binary file.
+    """Load workload function registry from a pickle binary file.
 
     Parameters
     ----------
@@ -188,4 +196,4 @@ def load_workload_func_registry(filename):
     """
     global WORKLOAD_FUNC_REGISTRY
 
-    WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb'))
+    WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, "rb"))
index 6b5fafc..7eb1c8b 100644 (file)
@@ -38,11 +38,27 @@ from . import env
 from . import tophub
 
 # some shortcuts
-from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
-    LocalBuilder, LocalRunner, RPCRunner
+from .measure import (
+    measure_option,
+    MeasureInput,
+    MeasureResult,
+    MeasureErrorNo,
+    LocalBuilder,
+    LocalRunner,
+    RPCRunner,
+)
 from .tuner import callback
-from .task import get_config, create, ConfigSpace, ConfigEntity, \
-    register_topi_compute, register_topi_schedule, template, \
-    DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
-    ApplyGraphBest as apply_graph_best
+from .task import (
+    get_config,
+    create,
+    ConfigSpace,
+    ConfigEntity,
+    register_topi_compute,
+    register_topi_schedule,
+    template,
+    DispatchContext,
+    FallbackContext,
+    ApplyHistoryBest as apply_history_best,
+    ApplyGraphBest as apply_graph_best,
+)
 from .env import GLOBAL_SCOPE
index 963f7e5..6bb02e8 100644 (file)
@@ -28,6 +28,7 @@ class Database(object):
     """
     Base class for a record database object.
     """
+
     def load(self, inp, get_all=False):
         """
         Load a result based on an input's string key
@@ -92,13 +93,15 @@ def filter_inputs(db, measure_inputs, retry=False):
             partial_results.append(res)
     return partial_results, unsaved
 
+
 class RedisDatabase(Database):
     """
     Redis version of record database
     """
+
     REDIS_PROD = 15
     REDIS_LOCA = 14
-    REDIS_TEST = 13        # for unit test
+    REDIS_TEST = 13  # for unit test
     REDIS_NIGHT_TEMP = 12  # for nightly report (will be flushed after every workload)
 
     MAGIC_SPLIT = "$"
@@ -108,9 +111,9 @@ class RedisDatabase(Database):
         import redis
 
         if db_index == RedisDatabase.REDIS_TEST:
-            host = 'localhost'
+            host = "localhost"
         else:
-            host = os.environ.get('TVM_FLEET_HOST')
+            host = os.environ.get("TVM_FLEET_HOST")
         self.db = redis.StrictRedis(host=host, port=6379, db=db_index)
         self.db_index = db_index
 
@@ -134,12 +137,12 @@ class RedisDatabase(Database):
     def save(self, inp, res, extend=False):
         current = self.get(measure_str_key(inp))
         if not extend or current is None:
-            self.set(measure_str_key(inp),
-                     RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
+            self.set(measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
         else:
             current = current.split(RedisDatabase.MAGIC_SPLIT)
-            self.set(measure_str_key(inp),
-                     RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)]))
+            self.set(
+                measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)])
+            )
 
     def filter(self, func):
         """
@@ -168,7 +171,7 @@ class RedisDatabase(Database):
             try:
                 records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
                 records = [rec for rec in records if rec is not None]
-            except TypeError: # got a badly formatted/old format record
+            except TypeError:  # got a badly formatted/old format record
                 continue
 
             if not records:
@@ -184,6 +187,7 @@ class RedisDatabase(Database):
     def flush(self):
         self.db.flushdb()
 
+
 class DummyDatabase(RedisDatabase):
     """
     A database based on python dictionary for testing.
index 18674d4..ddd510c 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 """Global configuration/variable scope for autotvm"""
 
+
 class AutotvmGlobalScope(object):
     current = None
 
@@ -27,4 +28,5 @@ class AutotvmGlobalScope(object):
         self.in_tuning = False
         self.silent = False
 
+
 GLOBAL_SCOPE = AutotvmGlobalScope()
index 8df6e53..dff0f09 100644 (file)
@@ -35,9 +35,7 @@ from tvm.te import schedule
 from tvm.driver import build_module
 
 
-def ana_lower(sch, args,
-              binds=None,
-              simple_mode=True):
+def ana_lower(sch, args, binds=None, simple_mode=True):
     """Do lower while keeping all axes in IR
     i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
     """
@@ -56,16 +54,20 @@ def ana_lower(sch, args,
 
 try:
     _get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
-        "autotvm.feature.GetCurveSampleFeatureFlatten")
-    _get_itervar_feature = tvm._ffi.get_global_func(
-        "autotvm.feature.GetItervarFeature")
+        "autotvm.feature.GetCurveSampleFeatureFlatten"
+    )
+    _get_itervar_feature = tvm._ffi.get_global_func("autotvm.feature.GetItervarFeature")
     _get_itervar_feature_flatten = tvm._ffi.get_global_func(
-        "autotvm.feature.GetItervarFeatureFlatten")
+        "autotvm.feature.GetItervarFeatureFlatten"
+    )
 except ValueError as e:
+
     def raise_error(*args, **kwargs):  # pylint: disable=unused-argument
         raise RuntimeError("Cannot load autotvm c++ API")
-    _get_buffer_curve_sample_flatten = _get_itervar_feature = _get_itervar_feature_flatten = \
-        raise_error
+
+    _get_buffer_curve_sample_flatten = (
+        _get_itervar_feature
+    ) = _get_itervar_feature_flatten = raise_error
 
 
 def get_itervar_feature(sch, args, take_log=False):
@@ -136,12 +138,12 @@ def get_itervar_feature_flatten(sch, args, take_log=True):
     """
     stmt = ana_lower(sch, args, simple_mode=True)
     feas = _get_itervar_feature_flatten(stmt, take_log)
-    feas = struct.unpack('%df' % (len(feas)//4), feas)
+    feas = struct.unpack("%df" % (len(feas) // 4), feas)
     return feas
 
 
 def get_flatten_name(fea):
-    """ Get names of feature after flatten.
+    """Get names of feature after flatten.
 
     Parameters
     ----------
@@ -154,8 +156,8 @@ def get_flatten_name(fea):
     """
 
     feature_name = {
-        "_attr_": ["length", "nest_level", "topdown", "bottomup"] +
-                  ["ann_%d" % i for i in range(20)],
+        "_attr_": ["length", "nest_level", "topdown", "bottomup"]
+        + ["ann_%d" % i for i in range(20)],
         "_arith_": ["add", "mul", "div"],
         "buf_touch": ["stride", "mod", "count", "reuse", "T_count", "T_reuse"],
     }
@@ -163,6 +165,7 @@ def get_flatten_name(fea):
     if isinstance(fea, str):
         # pylint: disable=import-outside-toplevel
         from .record import decode
+
         # flatten line to feature
         line = fea
         ret = decode(line)
@@ -186,8 +189,7 @@ def get_flatten_name(fea):
                 name_list = feature_name["buf_touch"]
 
             for i in range(len((pair[1:]))):
-                names.append(
-                    ".".join(["f%d" % ct, var_name, key, name_list[i]]))
+                names.append(".".join(["f%d" % ct, var_name, key, name_list[i]]))
                 ct += 1
     return names
 
@@ -211,5 +213,5 @@ def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
     """
     stmt = ana_lower(sch, args, simple_mode=True)
     feas = _get_buffer_curve_sample_flatten(stmt, sample_n, False)
-    feas = struct.unpack('%df' % (len(feas)//4), feas)
+    feas = struct.unpack("%df" % (len(feas) // 4), feas)
     return feas
index 5a1ef16..40945ed 100644 (file)
@@ -30,8 +30,14 @@ from tvm.autotvm.record import encode, load_from_file
 from tvm.autotvm.measure import MeasureResult, MeasureInput
 
 from ...target import Target
-from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \
-    bind_inputs, expr2graph
+from .utils import (
+    is_boundary_node,
+    get_in_nodes,
+    get_out_nodes,
+    has_multiple_inputs,
+    bind_inputs,
+    expr2graph,
+)
 from ._base import INVALID_LAYOUT_TIME
 
 from ._base import OPT_OUT_OP
@@ -64,10 +70,20 @@ class BaseGraphTuner(object):
     graph should be provided through tensor-level tuning.
     """
 
-    def __init__(self, graph, input_shapes, records, target_ops,
-                 target, max_sch_num=20, dtype="float32", verbose=True,
-                 log_file="graph_tuner.log", log_level=logging.DEBUG,
-                 name="graph_tuner"):
+    def __init__(
+        self,
+        graph,
+        input_shapes,
+        records,
+        target_ops,
+        target,
+        max_sch_num=20,
+        dtype="float32",
+        verbose=True,
+        log_file="graph_tuner.log",
+        log_level=logging.DEBUG,
+        name="graph_tuner",
+    ):
         """Create a GlobalTuner instance. Local schedule searching for all nodes with
         target_op in the input graph and layout transformation benchmark need to be
         executed before initialization.
@@ -123,14 +139,13 @@ class BaseGraphTuner(object):
         self._logger = logging.getLogger(name + "_logger")
         need_file_handler = need_console_handler = True
         for handler in self._logger.handlers:
-            if handler.__class__.__name__ == 'FileHandler':
+            if handler.__class__.__name__ == "FileHandler":
                 need_file_handler = False
-            if handler.__class__.__name__ == 'StreamHandler':
+            if handler.__class__.__name__ == "StreamHandler":
                 need_console_handler = False
         self._log_level = log_level
         self._log_file = log_file
-        self._formatter = logging.Formatter(
-            '%(asctime)s %(levelname)s %(message)s')
+        self._formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
         self._logger.setLevel(log_level)
         if need_file_handler:
             file_handler = logging.FileHandler(log_file)
@@ -155,11 +170,12 @@ class BaseGraphTuner(object):
             raise RuntimeError("Unsupported graph type: %s" % str(type(graph)))
 
         self._graph = graph
-        self._in_nodes_dict = get_in_nodes(
-            self._node_list, self._target_ops, input_shapes.keys())
+        self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
         if len(self._in_nodes_dict) == 0:
-            raise RuntimeError("Could not find any input nodes with whose "
-                               "operator is one of %s" % self._target_ops)
+            raise RuntimeError(
+                "Could not find any input nodes with whose "
+                "operator is one of %s" % self._target_ops
+            )
         self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
         self._fetch_cfg()
         self._opt_out_op = OPT_OUT_OP
@@ -185,11 +201,12 @@ class BaseGraphTuner(object):
                         input_workload = input_node["workloads"][0]
                         first_tensor = input_workload[1]
                         dtype = first_tensor[-1]
-                        new_shape = tuple(
-                            [val.value for val in node_entry["types"][0].shape])
-                        actual_workload = (input_workload[0],) + \
-                                          (("TENSOR", new_shape, dtype),) + \
-                            input_workload[2:]
+                        new_shape = tuple([val.value for val in node_entry["types"][0].shape])
+                        actual_workload = (
+                            (input_workload[0],)
+                            + (("TENSOR", new_shape, dtype),)
+                            + input_workload[2:]
+                        )
                         node_entry["workloads"].append(actual_workload)
                         if "record_candidates" not in node_entry:
                             node_entry["record_candidates"] = input_node["record_candidates"]
@@ -238,8 +255,9 @@ class BaseGraphTuner(object):
                             layout_tracking_dict[layouts] = record
                     else:
                         layout_tracking_dict[layouts] = record
-            sorted_records = sorted(layout_tracking_dict.values(),
-                                    key=lambda item: item[1].costs[0])
+            sorted_records = sorted(
+                layout_tracking_dict.values(), key=lambda item: item[1].costs[0]
+            )
             for i in range(min(self._max_sch_num, len(sorted_records))):
                 record_candidates.append(sorted_records[i])
             node_entry["record_candidates"] = record_candidates
@@ -273,8 +291,7 @@ class BaseGraphTuner(object):
 
                 if node_entry["op"] in self._target_ops:
                     o_idx = key
-                    o_infer_layout_func = get_infer_layout(
-                        node_entry["topi_op"][0])
+                    o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
                     o_wkl = node_entry["workloads"][0]
                     i_topi_op = in_node_entry["topi_op"][0]
                     i_wkl = in_node_entry["workloads"][0]
@@ -288,11 +305,9 @@ class BaseGraphTuner(object):
                     o_idx = target_input_idx
                     if i <= target_input_pos:
                         continue
-                    o_infer_layout_func = get_infer_layout(
-                        node_entry["topi_op"][0])
+                    o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
                     o_wkl = node_entry["workloads"][target_input_pos]
-                    i_infer_layout_func = get_infer_layout(
-                        node_entry["topi_op"][i])
+                    i_infer_layout_func = get_infer_layout(node_entry["topi_op"][i])
                     i_wkl = node_entry["workloads"][i]
 
                 if (i_idx, o_idx) in pair_tracker:
@@ -303,51 +318,63 @@ class BaseGraphTuner(object):
                     for n, o_record in enumerate(node_entry["record_candidates"]):
                         i_cfg, o_cfg = i_record[0].config, o_record[0].config
                         with self._target:
-                            i_input_info, i_output_info = i_infer_layout_func(
-                                i_wkl, i_cfg)
-                            o_input_info, o_output_info = o_infer_layout_func(
-                                o_wkl, o_cfg)
-                        if len(i_input_info) > 1 or len(i_output_info) > 1 or \
-                                len(o_input_info) > 1 or len(o_output_info) > 1:
-                            raise RuntimeError("Graph tuner only supports target operator "
-                                               "with single input and single output. "
-                                               "Please check target_ops argument.")
+                            i_input_info, i_output_info = i_infer_layout_func(i_wkl, i_cfg)
+                            o_input_info, o_output_info = o_infer_layout_func(o_wkl, o_cfg)
+                        if (
+                            len(i_input_info) > 1
+                            or len(i_output_info) > 1
+                            or len(o_input_info) > 1
+                            or len(o_output_info) > 1
+                        ):
+                            raise RuntimeError(
+                                "Graph tuner only supports target operator "
+                                "with single input and single output. "
+                                "Please check target_ops argument."
+                            )
 
                         in_shape, in_layout = i_output_info[0]
                         if node_entry["op"] in self._target_ops:
                             _, out_layout = o_input_info[0]
                         else:
                             _, out_layout = o_output_info[0]
-                        data_placeholder = te.placeholder(in_shape, name="data",
-                                                          dtype=self._dtype)
+                        data_placeholder = te.placeholder(in_shape, name="data", dtype=self._dtype)
                         args = [data_placeholder, in_layout, out_layout]
                         callback(i_idx, o_idx, m, n, args)
 
-    def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx,
-                                to_sch_idx, args):
+    def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx, to_sch_idx, args):
         """Create dictionary containing matrix format of layout transformation
         between nodes."""
         in_layout, out_layout = args[1], args[2]
-        ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
+        ltf_workload = autotvm.task.args_to_workload(args, "layout_transform")
         idx_pair_key = (from_node_idx, to_node_idx)
 
         if in_layout == out_layout:
             layout_transform_time = 0
         else:
-            layout_transform_time = \
-                self._layout_transform_perf_records[ltf_workload][1].costs[0]
+            layout_transform_time = self._layout_transform_perf_records[ltf_workload][1].costs[0]
 
         if idx_pair_key not in self._layout_transform_interlayer_cost:
             self._layout_transform_interlayer_cost[idx_pair_key] = []
         if len(self._layout_transform_interlayer_cost[idx_pair_key]) <= from_sch_idx:
             self._layout_transform_interlayer_cost[idx_pair_key].append([])
-        self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx]\
-            .append(layout_transform_time)
-
-    def benchmark_layout_transform(self, min_exec_num=100, timeout=10,
-                                   use_rpc=False, device_key=None, host="localhost",
-                                   port=9190, n_parallel=1, build_func='default',
-                                   layout_records=None, target_host=None, infer_layout=False):
+        self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx].append(
+            layout_transform_time
+        )
+
+    def benchmark_layout_transform(
+        self,
+        min_exec_num=100,
+        timeout=10,
+        use_rpc=False,
+        device_key=None,
+        host="localhost",
+        port=9190,
+        n_parallel=1,
+        build_func="default",
+        layout_records=None,
+        target_host=None,
+        infer_layout=False,
+    ):
         """Benchmark all possible layout transformation in the graph,
         given a set of schedule candidates for each workload of target operator.
 
@@ -413,14 +440,12 @@ class BaseGraphTuner(object):
         """
         self._logger.info("Start to benchmark layout transformation...")
         if layout_records is None and infer_layout:
-            raise RuntimeError(
-                "Requires some records to infer layout transformation time.")
+            raise RuntimeError("Requires some records to infer layout transformation time.")
 
         if isinstance(layout_records, str):
             layout_records = load_from_file(layout_records)
             if not layout_records and infer_layout:
-                raise RuntimeError(
-                    "Records must be non-empty to infer layout transformation time.")
+                raise RuntimeError("Records must be non-empty to infer layout transformation time.")
 
         if isinstance(layout_records, str):
             layout_records = load_from_file(layout_records)
@@ -437,8 +462,7 @@ class BaseGraphTuner(object):
 
         args_list = []
 
-        def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx,
-                                 to_sch_idx, args):
+        def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx, to_sch_idx, args):
             """Callback function to fetch layout transform args"""
             _, in_layout, out_layout = args
             if in_layout != out_layout:
@@ -448,27 +472,31 @@ class BaseGraphTuner(object):
 
         def _log_to_list(record_list):
             """Callback to log result to a list."""
+
             def _callback(_, inputs, results):
                 """Callback implementation"""
                 record_list.append((inputs[0], results[0]))
+
             return _callback
 
-        builder = autotvm.LocalBuilder(
-            n_parallel=n_parallel, build_func=build_func)
-        runner = autotvm.LocalRunner(
-            number=min_exec_num, repeat=1, timeout=timeout)
+        builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func)
+        runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout)
         if use_rpc:
             if device_key is None:
-                raise RuntimeError(
-                    "device_key need to be set to use rpc tracker mode.")
-            runner = autotvm.measure.RPCRunner(device_key, host, port, n_parallel=n_parallel,
-                                               number=min_exec_num, repeat=1,
-                                               timeout=timeout)
+                raise RuntimeError("device_key need to be set to use rpc tracker mode.")
+            runner = autotvm.measure.RPCRunner(
+                device_key,
+                host,
+                port,
+                n_parallel=n_parallel,
+                number=min_exec_num,
+                repeat=1,
+                timeout=timeout,
+            )
         measure_option = autotvm.measure_option(builder=builder, runner=runner)
         for args in args_list:
             data, in_layout, out_layout = args
-            ltf_workload = autotvm.task.args_to_workload(
-                args, 'layout_transform')
+            ltf_workload = autotvm.task.args_to_workload(args, "layout_transform")
             if ltf_workload in self._layout_transform_perf_records:
                 continue
 
@@ -489,23 +517,21 @@ class BaseGraphTuner(object):
                 else:
                     inferred_time = flops * avg_time
 
-                record_input = MeasureInput(
-                    target=self._target, task=None, config=None)
-                record_output = MeasureResult(costs=(inferred_time,), error_no=0,
-                                              all_cost=-1, timestamp=-1)
-                self._layout_transform_perf_records[ltf_workload] = (
-                    record_input, record_output)
+                record_input = MeasureInput(target=self._target, task=None, config=None)
+                record_output = MeasureResult(
+                    costs=(inferred_time,), error_no=0, all_cost=-1, timestamp=-1
+                )
+                self._layout_transform_perf_records[ltf_workload] = (record_input, record_output)
                 continue
 
             records = []
-            task = autotvm.task.create("layout_transform", args=args, target=self._target,
-                                       target_host=target_host)
+            task = autotvm.task.create(
+                "layout_transform", args=args, target=self._target, target_host=target_host
+            )
             tuner = autotvm.tuner.GridSearchTuner(task)
-            tuner.tune(n_trial=1, measure_option=measure_option,
-                       callbacks=[_log_to_list(records)])
+            tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[_log_to_list(records)])
             if not isinstance(records[0][1].costs[0], float):
-                records[0] = (records[0][0], records[0]
-                              [1]._replace(costs=(INVALID_LAYOUT_TIME,)))
+                records[0] = (records[0][0], records[0][1]._replace(costs=(INVALID_LAYOUT_TIME,)))
             self._layout_transform_perf_records[ltf_workload] = records[0]
 
         self._iterate_layout_transform(self._create_matrix_callback)
@@ -537,8 +563,7 @@ class BaseGraphTuner(object):
             node_entry = self._node_list[index]
             if node_entry["op"] not in self._target_ops:
                 continue
-            ret.append(node_entry["record_candidates"]
-                       [self._optimal_record_dict[index]])
+            ret.append(node_entry["record_candidates"][self._optimal_record_dict[index]])
         return ret
 
     def write_opt_sch2record_file(self, record_file="graph_opt_schedule.log"):
index fc3b4dc..2d75602 100644 (file)
@@ -29,10 +29,21 @@ class DPStage(object):
 
     In most cases, instance of this class should be created through DPTuner.
     """
-    def __init__(self, idx, input_shapes, node_list,
-                 counted_nodes_set, layout_transform_interlayer_cost,
-                 stage_dict, in_nodes_dict, out_nodes_dict,
-                 dep_dict, target_ops, dtype="float32"):
+
+    def __init__(
+        self,
+        idx,
+        input_shapes,
+        node_list,
+        counted_nodes_set,
+        layout_transform_interlayer_cost,
+        stage_dict,
+        in_nodes_dict,
+        out_nodes_dict,
+        dep_dict,
+        target_ops,
+        dtype="float32",
+    ):
         """Initialize a stage and create all states.
 
         Parameters
@@ -105,8 +116,7 @@ class DPStage(object):
         input_idx = self._global_in_nodes_dict[self._idx][0]
         input_node_entry = self._global_node_list[input_idx]
         if is_boundary_node(input_node_entry, self._global_input_names):
-            self._full_states = np.array([record[1].costs[0]
-                                          for record in self._record_list])
+            self._full_states = np.array([record[1].costs[0] for record in self._record_list])
             self._states = self._full_states
         else:
             input_stage = self._global_stage_dict[input_idx]
@@ -118,9 +128,13 @@ class DPStage(object):
             num_input_schedules = len(input_record_list)
             num_input_states = input_flatten_states.shape[0]
 
-            full_states_shape = tuple([num_schedules, num_input_schedules] +
-                                      [len(self._global_node_list[dep_idx]["record_candidates"])
-                                       for dep_idx in input_dep])
+            full_states_shape = tuple(
+                [num_schedules, num_input_schedules]
+                + [
+                    len(self._global_node_list[dep_idx]["record_candidates"])
+                    for dep_idx in input_dep
+                ]
+            )
             self._full_states = np.zeros(full_states_shape).flatten().astype("float32")
             self._full_states_idx = [self._idx, input_idx] + input_dep
             dep_multiplier = 1
@@ -132,15 +146,16 @@ class DPStage(object):
                 current_sch_time = float(self._record_list[i][1].costs[0])
                 for j in range(num_input_states):
                     input_sch_idx = j // dep_multiplier
-                    layout_transform_time = \
-                        self._global_layout_transform_interlayer_cost \
-                            [(input_idx, self._idx)][input_sch_idx][i]
+                    layout_transform_time = self._global_layout_transform_interlayer_cost[
+                        (input_idx, self._idx)
+                    ][input_sch_idx][i]
 
                     if input_node_time_counted:
                         total_time = current_sch_time + layout_transform_time
                     else:
-                        total_time = \
+                        total_time = (
                             current_sch_time + layout_transform_time + input_flatten_states[j]
+                        )
                     current_state_idx = i * num_input_states + j
                     self._full_states[current_state_idx] = total_time
 
@@ -156,7 +171,9 @@ class DPStage(object):
                 self._dep = list(input_dep)
             else:
                 self._states = self._full_states
-                self._dep = [input_idx,] + input_dep
+                self._dep = [
+                    input_idx,
+                ] + input_dep
 
         # Update global dependency dictionary.
         # This is to monitor the dependency states to decide
@@ -202,9 +219,9 @@ class DPStage(object):
                 input_index_list.append(input_idx)
 
         # Generate new states
-        states_list, aligned_node_list = DPStage.align_states(input_index_list,
-                                                              self._global_stage_dict,
-                                                              self._global_node_list)
+        states_list, aligned_node_list = DPStage.align_states(
+            input_index_list, self._global_stage_dict, self._global_node_list
+        )
         target_node_idx, target_major_axis, target_multiplier, target_states = states_list[0]
         aligned_shape = target_states.shape
         self._full_states = np.zeros(aligned_shape).astype("float32").flatten()
@@ -215,8 +232,9 @@ class DPStage(object):
         src_states_list = [states_list[i][3].flatten() for i in range(1, len(states_list))]
 
         for i in range(num_states):
-            target_sch_idx = (i % (target_multiplier *
-                                   aligned_shape[target_major_axis])) // target_multiplier
+            target_sch_idx = (
+                i % (target_multiplier * aligned_shape[target_major_axis])
+            ) // target_multiplier
             if node_time_counted[0]:
                 new_state = 0
             else:
@@ -225,11 +243,12 @@ class DPStage(object):
             for j in range(1, len(states_list)):
                 src_states = src_states_list[j - 1]
                 src_node_idx, src_major_axis, src_multiplier, _ = states_list[j]
-                src_sch_idx = (i % (src_multiplier *
-                                    aligned_shape[src_major_axis])) // src_multiplier
-                layout_transform_time = \
-                    self._global_layout_transform_interlayer_cost\
-                        [(src_node_idx, target_node_idx)][src_sch_idx][target_sch_idx]
+                src_sch_idx = (
+                    i % (src_multiplier * aligned_shape[src_major_axis])
+                ) // src_multiplier
+                layout_transform_time = self._global_layout_transform_interlayer_cost[
+                    (src_node_idx, target_node_idx)
+                ][src_sch_idx][target_sch_idx]
 
                 if node_time_counted[j]:
                     new_state += layout_transform_time
@@ -256,7 +275,7 @@ class DPStage(object):
         for i, dep in enumerate(reduced_states_dep_list):
             if dep not in self._global_dep_dict or len(self._global_dep_dict[dep]) == 1:
                 self._global_dep_dict.pop(dep, None)
-                reduced_states = np.amin(reduced_states, axis=i+1-shift)
+                reduced_states = np.amin(reduced_states, axis=i + 1 - shift)
                 shift += 1
             else:
                 self._dep.append(dep)
index b9d40c8..97253e4 100644 (file)
@@ -29,6 +29,7 @@ if sys.version_info[0] == 3:
 else:
     import Queue as queue
 
+
 class DPTuner(BaseGraphTuner):
     """Tuner which uses dynamic programming to solve MDP problem.
 
@@ -37,9 +38,9 @@ class DPTuner(BaseGraphTuner):
     models, such as networks with many element-wise sum operators. In this case, switch
     to heuristic algorithm such as PBQP tuner.
     """
+
     def __init__(self, *args, **kwargs):
-        """Create a dynamic programming tuner.
-        """
+        """Create a dynamic programming tuner."""
         super(DPTuner, self).__init__(*args, **kwargs)
         self._num_states = self._max_num_states = None
         self._stage_dict = {}
@@ -55,7 +56,7 @@ class DPTuner(BaseGraphTuner):
             "dep_dict": self._dep_dict,
             "node_list": self._node_list,
             "input_shapes": self._input_shapes,
-            "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
+            "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost,
         }
 
     def _check_num_states(self, num_states):
@@ -63,24 +64,23 @@ class DPTuner(BaseGraphTuner):
         self._num_states += num_states
         if self._max_num_states is not None:
             if self._num_states > self._max_num_states:
-                raise RuntimeError("Too many states detected while running dynamic "
-                                   "programming: got %d states but upper limit is %d." %
-                                   (self._num_states, self._max_num_states))
+                raise RuntimeError(
+                    "Too many states detected while running dynamic "
+                    "programming: got %d states but upper limit is %d."
+                    % (self._num_states, self._max_num_states)
+                )
 
     def _forward(self):
-        """Forward pass in DP to generate states for all stages.
-        """
+        """Forward pass in DP to generate states for all stages."""
         self._logger.info("Start forward pass...")
         for node_idx in sorted(self._in_nodes_dict.keys()):
-            stage = DPStage(idx=node_idx, target_ops=self._target_ops,
-                            **self._global_data_dict)
+            stage = DPStage(idx=node_idx, target_ops=self._target_ops, **self._global_data_dict)
             self._check_num_states(stage.full_states.size)
             self._stage_dict[node_idx] = stage
         self._logger.info("Finished forward pass.")
 
     def _backward(self):
-        """Backward pass in DP to generate optimal solution.
-        """
+        """Backward pass in DP to generate optimal solution."""
         self._logger.info("Start backward pass...")
         input_names = self._input_shapes.keys()
         optimal_record_dict = {}
@@ -92,17 +92,20 @@ class DPTuner(BaseGraphTuner):
 
         # Restrict number of output nodes to avoid numpy reshape error
         if len(output_idx_list) > MAX_OUTPUT_NODES:
-            msg = "The number of outputs in graph is larger than upper " \
-                  "limit: %s vs %s. Usually this is caused by too many " \
-                  "LAYOUT_FIXED_OP in graph. Switch to greedily select schedule." \
-                  "No action required at this moment. We will continuously improve graph tuner" \
-                  % (len(output_idx_list), MAX_OUTPUT_NODES)
+            msg = (
+                "The number of outputs in graph is larger than upper "
+                "limit: %s vs %s. Usually this is caused by too many "
+                "LAYOUT_FIXED_OP in graph. Switch to greedily select schedule."
+                "No action required at this moment. We will continuously improve graph tuner"
+                % (len(output_idx_list), MAX_OUTPUT_NODES)
+            )
             self._logger.warning(msg)
-            self._optimal_record_dict = {key : 0 for key in self._in_nodes_dict}
+            self._optimal_record_dict = {key: 0 for key in self._in_nodes_dict}
             return
 
-        states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
-                                                              self._node_list)
+        states_list, aligned_node_list = DPStage.align_states(
+            output_idx_list, self._stage_dict, self._node_list
+        )
         num_states = states_list[0][3].size
         self._check_num_states(num_states * len(output_idx_list))
         aligned_node_shape = states_list[0][3].shape
@@ -120,16 +123,18 @@ class DPTuner(BaseGraphTuner):
                 min_pos = i
         for i, states in enumerate(states_list):
             current_major_axis = states[1]
-            current_sch_idx = (min_pos % (states[2] *
-                                          aligned_node_shape[current_major_axis])) // states[2]
+            current_sch_idx = (
+                min_pos % (states[2] * aligned_node_shape[current_major_axis])
+            ) // states[2]
             optimal_record_dict[aligned_node_list[i]] = current_sch_idx
         # Pick optimal schedule for dependencies of output nodes
         for i in range(len(states_list), len(aligned_node_list)):
             multiplier = 1
             for j in range(i + 1, len(aligned_node_list)):
                 multiplier *= aligned_node_shape[j]
-            optimal_record_dict[aligned_node_list[i]] = \
+            optimal_record_dict[aligned_node_list[i]] = (
                 min_pos // multiplier % aligned_node_shape[i]
+            )
 
         # Backward pass to get optimal schedules for other nodes
         bfs_q = queue.Queue()
@@ -193,8 +198,7 @@ class DPTuner(BaseGraphTuner):
         self._logger.info("Finished backward pass...")
 
     def run(self, **kwargs):
-        """Run dynamic programming solver.
-        """
+        """Run dynamic programming solver."""
         max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
         self._num_states = 0
         self._max_num_states = max_num_states
index d58694c..59f4ef0 100644 (file)
@@ -31,9 +31,9 @@ class PBQPTuner(BaseGraphTuner):
     Nearly optimal register allocation with pbqp.JMLC 2006.
     LNCS, vol.4228,pp. 346-361, 2016
     """
+
     def __init__(self, *args, **kwargs):
-        """Create a partitioned boolean quadratic programming tuner.
-        """
+        """Create a partitioned boolean quadratic programming tuner."""
         super(PBQPTuner, self).__init__(*args, **kwargs)
 
         # Remove input and ruled_out nodes
@@ -46,8 +46,9 @@ class PBQPTuner(BaseGraphTuner):
 
         self._adj_dict = {}
         for node_idx in self._in_nodes_dict:
-            self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + \
-                                       list(self._out_nodes_dict[node_idx])
+            self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + list(
+                self._out_nodes_dict[node_idx]
+            )
 
         self._record_cost_dict = {}
         for key in self._in_nodes_dict:
@@ -71,13 +72,11 @@ class PBQPTuner(BaseGraphTuner):
         self._is_optimal = True
 
     def _get_degree(self, node_idx):
-        """Get node degree.
-        """
+        """Get node degree."""
         return len(self._adj_dict[node_idx])
 
     def _reorder_adj_nodes(self, node_idx):
-        """Update buckets list with current adjacency list.
-        """
+        """Update buckets list with current adjacency list."""
         for adj_node in self._adj_dict[node_idx]:
             current_degree = self._get_degree(adj_node)
             prev_degree = self._node_degree_dict[adj_node]
@@ -87,36 +86,31 @@ class PBQPTuner(BaseGraphTuner):
                 self._node_degree_dict[adj_node] = current_degree
 
     def _remove_node(self, node_idx):
-        """Remove node from graph. Update adjacency list accordingly.
-        """
+        """Remove node from graph. Update adjacency list accordingly."""
         node_degree = self._get_degree(node_idx)
         self._buckets[node_degree].remove(node_idx)
         for adj_node in self._adj_dict[node_idx]:
             self._adj_dict[adj_node].remove(node_idx)
 
     def _insert_edge(self, node_x, node_y, adj_cost_matrix):
-        """Insert an edge between two nodes.
-        """
+        """Insert an edge between two nodes."""
         self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix
         self._layout_transform_interlayer_cost[(node_y, node_x)] = []
         for i in range(len(adj_cost_matrix[0])):
             self._layout_transform_interlayer_cost[(node_y, node_x)].append([])
             for cost_vec in adj_cost_matrix:
-                self._layout_transform_interlayer_cost[(node_y, node_x)][i] \
-                    .append(cost_vec[i])
+                self._layout_transform_interlayer_cost[(node_y, node_x)][i].append(cost_vec[i])
 
         self._adj_dict[node_x].append(node_y)
         self._adj_dict[node_y].append(node_x)
 
     def _backward_insert_node(self, node_idx):
-        """Reinsert node in backward pass.
-        """
+        """Reinsert node in backward pass."""
         for adj_node in self._adj_dict[node_idx]:
             self._adj_dict[adj_node].append(node_idx)
 
     def _RI_reduction(self, node_idx):
-        """Reduce nodes with degree 1.
-        """
+        """Reduce nodes with degree 1."""
         adj_node = self._adj_dict[node_idx][0]
         ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)]
         for i, cost_vec in enumerate(ltf_matrix):
@@ -129,8 +123,7 @@ class PBQPTuner(BaseGraphTuner):
         self._stack.append(node_idx)
 
     def _RII_reduction(self, node_idx):
-        """Reduce nodes with degree 2.
-        """
+        """Reduce nodes with degree 2."""
         adj_node_x, adj_node_y = self._adj_dict[node_idx]
         ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)]
         ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)]
@@ -139,8 +132,10 @@ class PBQPTuner(BaseGraphTuner):
             for j, cost_vec_y in enumerate(ltf_matrix_y):
                 min_cost = INVALID_LAYOUT_TIME
                 for k in range(len(self._record_cost_dict[node_idx])):
-                    min_cost = min(min_cost, cost_vec_x[k] + cost_vec_y[k]
-                                   + self._record_cost_dict[node_idx][k])
+                    min_cost = min(
+                        min_cost,
+                        cost_vec_x[k] + cost_vec_y[k] + self._record_cost_dict[node_idx][k],
+                    )
                 delta_matrix[i].append(min_cost)
 
         if adj_node_x == adj_node_y:
@@ -149,10 +144,8 @@ class PBQPTuner(BaseGraphTuner):
         elif adj_node_x in self._adj_dict[adj_node_y]:
             for i, _ in enumerate(delta_matrix):
                 for j, delta in enumerate(delta_matrix[i]):
-                    self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] \
-                        += delta
-                    self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] \
-                        += delta
+                    self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] += delta
+                    self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] += delta
         else:
             self._insert_edge(adj_node_x, adj_node_y, delta_matrix)
 
@@ -161,8 +154,7 @@ class PBQPTuner(BaseGraphTuner):
         self._stack.append(node_idx)
 
     def _RN_reduction(self, node_idx):
-        """Reduce nodes with degree greater than 2.
-        """
+        """Reduce nodes with degree greater than 2."""
         min_cost = INVALID_LAYOUT_TIME
         record_idx = -1
 
@@ -179,8 +171,9 @@ class PBQPTuner(BaseGraphTuner):
                 record_idx = i
 
         if record_idx < 0:
-            raise RuntimeError("Can't find a soltuion for node %d when "
-                               "applying RN reduction" % node_idx)
+            raise RuntimeError(
+                "Can't find a soltuion for node %d when " "applying RN reduction" % node_idx
+            )
         self._optimal_record_dict[node_idx] = record_idx
         self._is_optimal = False
 
@@ -194,8 +187,7 @@ class PBQPTuner(BaseGraphTuner):
         self._stack.append(node_idx)
 
     def _forward(self):
-        """Forward pass in PBQP to reduce nodes.
-        """
+        """Forward pass in PBQP to reduce nodes."""
         while True:
             if self._buckets[1]:
                 node_idx = self._buckets[1][0]
@@ -216,8 +208,7 @@ class PBQPTuner(BaseGraphTuner):
                 break
 
     def _backward(self):
-        """Backward pass in PBQP to generate optimal solution.
-        """
+        """Backward pass in PBQP to generate optimal solution."""
         # Solve nodes left in the forward graph
         for node_idx in self._buckets[0]:
             record_costs = self._record_cost_dict[node_idx]
@@ -232,15 +223,14 @@ class PBQPTuner(BaseGraphTuner):
                 for adj_node in self._adj_dict[node_idx]:
                     adj_optimal_idx = self._optimal_record_dict[adj_node]
                     for i, _ in enumerate(record_costs):
-                        record_costs[i] += \
-                            self._layout_transform_interlayer_cost \
-                                [(node_idx, adj_node)][i][adj_optimal_idx]
+                        record_costs[i] += self._layout_transform_interlayer_cost[
+                            (node_idx, adj_node)
+                        ][i][adj_optimal_idx]
                 min_cost = min(record_costs)
                 self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
 
     def run(self, **kwargs):
-        """Run partitioned boolean quadratic programming tuner.
-        """
+        """Run partitioned boolean quadratic programming tuner."""
         self._logger.info("Start to run PBQP algorithm...")
         # Define virtual record lists and layout transformaton matrices
         # for multi-input nodes.
@@ -266,16 +256,18 @@ class PBQPTuner(BaseGraphTuner):
                 for j in range(len(record_candidates)):
                     temp[(target_input_idx, key)].append([])
                     for k in range(len(record_candidates)):
-                        temp[(target_input_idx, key)][j].append(0 if j == k
-                                                                else INVALID_LAYOUT_TIME)
+                        temp[(target_input_idx, key)][j].append(
+                            0 if j == k else INVALID_LAYOUT_TIME
+                        )
 
                 for j in range(target_input_pos + 1, len(val)):
                     input_idx = val[j]
                     input_node = self._node_list[input_idx]
                     if is_boundary_node(input_node, input_names):
                         continue
-                    temp[(input_idx, key)] = \
-                        self._layout_transform_interlayer_cost[(input_idx, target_input_idx)]
+                    temp[(input_idx, key)] = self._layout_transform_interlayer_cost[
+                        (input_idx, target_input_idx)
+                    ]
         self._layout_transform_interlayer_cost.update(temp)
 
         # Create reverse layout transformation matrices
index 53659a9..21a16b8 100644 (file)
@@ -21,6 +21,5 @@ from __future__ import absolute_import
 from . import traverse_graph
 from . import utils
 
-from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
-    get_out_nodes
+from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, get_out_nodes
 from .utils import has_multiple_inputs, is_boundary_node, bind_inputs
index b85c562..6e29474 100644 (file)
@@ -29,6 +29,7 @@ from tvm.autotvm.task import TaskExtractEnv
 from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
 from .._base import OPT_OUT_OP
 
+
 def expr2graph(expr, target_ops, node_dict, node_list):
     """Convert relay expr to graph data structure
     and fetch workloads of target operators.
@@ -62,9 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
         for node_entry in node_list:
             if node_entry["op"] in target_ops:
                 task_name, args = env.task_collection[task_pos]
-                task = autotvm.task.create(task_name, args,
-                                           target="llvm",
-                                           target_host=None)
+                task = autotvm.task.create(task_name, args, target="llvm", target_host=None)
                 node_entry["workloads"] = [task.workload]
                 node_entry["topi_op"] = [task_name]
                 task_pos += 1
@@ -79,14 +78,13 @@ def _infer_type(node):
 
 
 def _expr2graph_impl(expr, target_ops, node_dict, node_list):
-    """Implementation to convert relay expr to graph data structure
-    """
+    """Implementation to convert relay expr to graph data structure"""
+
     def _traverse_expr(node):
         if node in node_dict:
             return
         node_index = len(node_list)
-        node_entry = {"node": node, "inputs": [], "types": [],
-                      "op": None, "name": None}
+        node_entry = {"node": node, "inputs": [], "types": [], "op": None, "name": None}
 
         if isinstance(node, Call):
             op = node.op
@@ -105,8 +103,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
                 for tupe_type in out_type.fields:
                     node_entry["types"].append(tupe_type)
             else:
-                raise RuntimeError("Unsupported output type %s in operator %s"
-                                   % (type(out_type), op.name))
+                raise RuntimeError(
+                    "Unsupported output type %s in operator %s" % (type(out_type), op.name)
+                )
 
             # Utilize tracing target to fetch workload with topo-order.
             # Since we only need workload, dummy target can be used to
@@ -117,21 +116,21 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
                     input_node_entry = node_list[input_idx[0]]
                     input_type = input_node_entry["types"][input_idx[1]]
                     if not isinstance(input_node_entry["node"], (Var, Constant, Call)):
-                        raise RuntimeError("Graph tuner can only tune target "
-                                           "operators with input node of type "
-                                           "relay.expr.Var/Constant/Call. Now "
-                                           "find a target op %s with input type %s"
-                                           % (op, str(type(input_node_entry["node"]))))
+                        raise RuntimeError(
+                            "Graph tuner can only tune target "
+                            "operators with input node of type "
+                            "relay.expr.Var/Constant/Call. Now "
+                            "find a target op %s with input type %s"
+                            % (op, str(type(input_node_entry["node"])))
+                        )
                     free_var = relay.Var("var_%d" % i, input_type)
                     params.append(free_var)
                 call = relay.Call(node.op, params, node.attrs)
                 mod = tvm.IRModule.from_expr(relay.Function(params, call))
                 relay.backend.compile_engine.get().clear()
-                build_thread = threading.Thread(target=relay.build,
-                                                args=(mod,
-                                                      "llvm -device=tracing",
-                                                      None,
-                                                      None))
+                build_thread = threading.Thread(
+                    target=relay.build, args=(mod, "llvm -device=tracing", None, None)
+                )
                 build_thread.start()
                 build_thread.join()
         elif isinstance(node, Var):
@@ -160,8 +159,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
         elif isinstance(node, tvm.ir.Op):
             return
         else:
-            raise RuntimeError("Not supported relay node type in graph tuning: %s"
-                               % str(type(node)))
+            raise RuntimeError(
+                "Not supported relay node type in graph tuning: %s" % str(type(node))
+            )
         node_dict[node] = node_index
         node_list.append(node_entry)
 
@@ -205,13 +205,11 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
     node_direct_ancestor = []
     for item_idx in node["inputs"]:
         item = node_list[item_idx[0]]
-        is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
-                input_names, OPT_OUT_OP)
+        is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names, OPT_OUT_OP)
         if item["op"] in target_ops or is_multiple_inputs:
             node_direct_ancestor.append(item_idx[0])
         else:
-            tmp = get_direct_ancestor(node_list, visited_dict, target_ops,
-                                      item_idx[0], input_names)
+            tmp = get_direct_ancestor(node_list, visited_dict, target_ops, item_idx[0], input_names)
             for tmp_item in tmp:
                 node_direct_ancestor.append(tmp_item)
     visited_dict[node_idx] = node_direct_ancestor
@@ -247,8 +245,7 @@ def get_in_nodes(node_list, target_ops, input_names):
         get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
     for key, val in visited_dict.items():
         node = node_list[key]
-        is_multiple_inputs = has_multiple_inputs(node_list, key, \
-                input_names, OPT_OUT_OP)
+        is_multiple_inputs = has_multiple_inputs(node_list, key, input_names, OPT_OUT_OP)
         if node["op"] in target_ops or is_multiple_inputs:
             in_node_dict[key] = val
 
@@ -264,8 +261,7 @@ def get_in_nodes(node_list, target_ops, input_names):
             if node["op"] not in target_ops:
                 for input_idx in val:
                     in_node = node_list[input_idx]
-                    if not is_boundary_node(in_node, input_names) and \
-                            input_idx in in_node_dict:
+                    if not is_boundary_node(in_node, input_names) and input_idx in in_node_dict:
                         is_boundary = False
                     else:
                         val.remove(input_idx)
@@ -278,7 +274,6 @@ def get_in_nodes(node_list, target_ops, input_names):
         else:
             has_reduced_node = False
 
-
     # Remove empty nodes to ignore pre-computed sub-graph
     has_empty_node = True
     while has_empty_node:
index 70e95c9..54e0d1c 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import relay
 from tvm.relay import transform
 
+
 def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
     """Check whether a node has multiple input nodes
     except variable nodes.
@@ -46,15 +47,13 @@ def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
         in_idx = in_idx[0]
         in_node = node_list[in_idx]
         # Exclude parameter nodes
-        if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
+        if in_node["op"] is not None and in_node["op"].name in opt_out_op:
             increase = False
             for t_idx in in_node["inputs"]:
-                increase = has_multiple_inputs(node_list, t_idx[0], \
-                        input_names, opt_out_op)
+                increase = has_multiple_inputs(node_list, t_idx[0], input_names, opt_out_op)
             if increase:
                 num_inputs += 1
-        elif in_node["op"] is not None or \
-                ("name" in in_node and in_node["name"] in input_names):
+        elif in_node["op"] is not None or ("name" in in_node and in_node["name"] in input_names):
             num_inputs += 1
     return num_inputs > 1
 
@@ -78,13 +77,23 @@ def is_boundary_node(node_entry, input_names):
         whether node is a boundary node.
     """
     # Operators dependent on original layouts.
-    _LAYOUT_FIXED_OP = [relay.op.get(name) for name in (
-        "nn.batch_flatten", "transpose", "reshape", "vision.multibox_prior",
-        "vision.multibox_transform_loc", "where", "vision.non_max_suppression",
-        "strided_slice")]
-
-    out = node_entry["op"] in _LAYOUT_FIXED_OP or \
-          ("name" in node_entry and node_entry["name"] in input_names)
+    _LAYOUT_FIXED_OP = [
+        relay.op.get(name)
+        for name in (
+            "nn.batch_flatten",
+            "transpose",
+            "reshape",
+            "vision.multibox_prior",
+            "vision.multibox_transform_loc",
+            "where",
+            "vision.non_max_suppression",
+            "strided_slice",
+        )
+    ]
+
+    out = node_entry["op"] in _LAYOUT_FIXED_OP or (
+        "name" in node_entry and node_entry["name"] in input_names
+    )
     return out
 
 
@@ -128,12 +137,13 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
     if input_shapes is None:
         return expr
     if isinstance(input_dtypes, str):
-        input_dtypes = {key : input_dtypes for key in input_shapes.keys()}
+        input_dtypes = {key: input_dtypes for key in input_shapes.keys()}
 
     updated_input_dict = {}
     for input_name in input_shapes.keys():
-        updated_input = relay.var(input_name, shape=input_shapes[input_name],
-                                  dtype=input_dtypes[input_name])
+        updated_input = relay.var(
+            input_name, shape=input_shapes[input_name], dtype=input_dtypes[input_name]
+        )
         updated_input_dict[input_name] = updated_input
 
     rebind_dict = {}
index 2c7cca0..0c32ae0 100644 (file)
 # under the License.
 """Distributed executor infrastructure to scale up the tuning"""
 
-from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
-    create_measure_batch
+from .measure import (
+    MeasureInput,
+    MeasureResult,
+    MeasureErrorNo,
+    measure_option,
+    create_measure_batch,
+)
 from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
 from .executor import Executor
 from .local_executor import LocalExecutor
index bcfdf39..f8eca72 100644 (file)
 # under the License.
 """ Abstraction for asynchronous job execution """
 
+
 class Executor(object):
     """
     Base abstract executor interface for asynchronous job submission.
     Allows submit asynchronous jobs and returns the Future object.
     """
+
     # timeout for jobs that may hang
     DEFAULT_TIMEOUT = 120
 
@@ -56,6 +58,7 @@ class Future(object):
     Future objects store the state of tasks--can be polled for
     result or a blocking call to retrieve the result can be used.
     """
+
     def done(self):
         """
         Return True if job was successfully cancelled or finished running.
@@ -83,6 +86,7 @@ class Future(object):
         """
         raise NotImplementedError()
 
+
 class FutureError(RuntimeError):
     """Base error class of all future events"""
 
index a0a826a..5dd5cba 100644 (file)
@@ -19,6 +19,7 @@
 import signal
 
 from multiprocessing import Process, Queue
+
 try:
     from queue import Empty
 except ImportError:
@@ -45,6 +46,7 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM):
         except psutil.NoSuchProcess:
             return
 
+
 def _execute_func(func, queue, args, kwargs):
     """execute function and return the result or exception to a queue"""
     try:
@@ -79,6 +81,7 @@ class LocalFuture(executor.Future):
     queue: multiprocessing.Queue
         queue for receiving the result of this task
     """
+
     def __init__(self, process, queue):
         self._done = False
         self._process = process
@@ -110,6 +113,7 @@ class LocalFutureNoFork(executor.Future):
     This is a none-fork version of LocalFuture.
     Use this for the runtime that does not support fork (like cudnn)
     """
+
     def __init__(self, result):
         self._result = result
 
@@ -132,21 +136,22 @@ class LocalExecutor(executor.Executor):
         (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime
         before submitting jobs.
     """
+
     def __init__(self, timeout=None, do_fork=True):
         self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
         self.do_fork = do_fork
 
         if self.do_fork:
             if not psutil:
-                raise RuntimeError("Python package psutil is missing. "
-                                   "please try `pip install psutil`")
+                raise RuntimeError(
+                    "Python package psutil is missing. " "please try `pip install psutil`"
+                )
 
     def submit(self, func, *args, **kwargs):
         if not self.do_fork:
             return LocalFutureNoFork(func(*args, **kwargs))
 
         queue = Queue(2)  # Size of 2 to avoid a race condition with size 1.
-        process = Process(target=call_with_timeout,
-                          args=(queue, self.timeout, func, args, kwargs))
+        process = Process(target=call_with_timeout, args=(queue, self.timeout, func, args, kwargs))
         process.start()
         return LocalFuture(process, queue)
index d77e737..8438b80 100644 (file)
@@ -19,6 +19,7 @@
 import multiprocessing
 from collections import namedtuple
 
+
 class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
     """
     Stores all the necessary inputs for a measurement.
@@ -54,15 +55,16 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
 
 class MeasureErrorNo(object):
     """Error type for MeasureResult"""
-    NO_ERROR = 0              # no error
-    INSTANTIATION_ERROR = 1   # actively detected error in instantiating a template with a config
-    COMPILE_HOST = 2          # error when compiling code on host (e.g. tvm.build)
-    COMPILE_DEVICE = 3        # error when compiling code on device (e.g. OpenCL JIT on the device)
-    RUNTIME_DEVICE = 4        # error when run program on device
-    WRONG_ANSWER = 5          # answer is wrong when compared to a golden output
-    BUILD_TIMEOUT = 6         # timeout during compilation
-    RUN_TIMEOUT = 7           # timeout during run
-    UNKNOWN_ERROR = 8         # unknown error
+
+    NO_ERROR = 0  # no error
+    INSTANTIATION_ERROR = 1  # actively detected error in instantiating a template with a config
+    COMPILE_HOST = 2  # error when compiling code on host (e.g. tvm.build)
+    COMPILE_DEVICE = 3  # error when compiling code on device (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4  # error when run program on device
+    WRONG_ANSWER = 5  # answer is wrong when compared to a golden output
+    BUILD_TIMEOUT = 6  # timeout during compilation
+    RUN_TIMEOUT = 7  # timeout during run
+    UNKNOWN_ERROR = 8  # unknown error
 
 
 class Builder(object):
@@ -76,6 +78,7 @@ class Builder(object):
         The number of tasks submitted in parallel
         By default it will use all cpu cores
     """
+
     def __init__(self, timeout=10, n_parallel=None):
         self.timeout = timeout
         self.n_parallel = n_parallel or multiprocessing.cpu_count()
@@ -123,6 +126,7 @@ class Runner(object):
         The number of tasks submitted in parallel
         By default it will use all cpu cores
     """
+
     def __init__(self, timeout=5, n_parallel=None):
         self.timeout = timeout
         self.n_parallel = n_parallel or multiprocessing.cpu_count()
@@ -212,20 +216,20 @@ def measure_option(builder, runner):
     from .measure_methods import LocalBuilder, LocalRunner
 
     if isinstance(builder, str):
-        if builder == 'local':
+        if builder == "local":
             builder = LocalBuilder()
         else:
             raise ValueError("Invalid builder: " + builder)
 
     if isinstance(runner, str):
-        if runner == 'local':
+        if runner == "local":
             runner = LocalRunner()
         else:
             raise ValueError("Invalid runner: " + runner)
 
     opt = {
-        'builder': builder,
-        'runner': runner,
+        "builder": builder,
+        "runner": runner,
     }
 
     return opt
@@ -247,8 +251,8 @@ def create_measure_batch(task, option):
     measure_batch: callable
         a callback function to measure a batch of configs
     """
-    builder = option['builder']
-    runner = option['runner']
+    builder = option["builder"]
+    runner = option["runner"]
 
     attach_objects = runner.set_task(task)
 
index 9c22b64..7032db6 100644 (file)
@@ -48,10 +48,10 @@ from ..task.space import InstantiationError
 from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
 from .local_executor import LocalExecutor
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 
-class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
+class BuildResult(namedtuple("BuildResult", ("filename", "arg_info", "error", "time_cost"))):
     """
     Stores all the necessary inputs for a measurement.
 
@@ -83,13 +83,13 @@ class LocalBuilder(Builder):
         If is callable, use it as custom build function, expect lib_format field.
     """
 
-    def __init__(self, timeout=10, n_parallel=None, build_func='default'):
+    def __init__(self, timeout=10, n_parallel=None, build_func="default"):
         super(LocalBuilder, self).__init__(timeout, n_parallel)
 
         if isinstance(build_func, str):
-            if build_func == 'default':
+            if build_func == "default":
                 build_func = tar.tar
-            elif build_func == 'ndk':
+            elif build_func == "ndk":
                 build_func = ndk.create_shared
             else:
                 raise ValueError("Invalid build_func" + build_func)
@@ -105,11 +105,8 @@ class LocalBuilder(Builder):
 
         for i in range(0, len(measure_inputs), self.n_parallel):
             futures = []
-            for inp in measure_inputs[i:i + self.n_parallel]:
-                ret = self.executor.submit(self.build_func,
-                                           inp,
-                                           self.tmp_dir,
-                                           **self.build_kwargs)
+            for inp in measure_inputs[i : i + self.n_parallel]:
+                ret = self.executor.submit(self.build_func, inp, self.tmp_dir, **self.build_kwargs)
                 futures.append(ret)
 
             for future in futures:
@@ -117,28 +114,46 @@ class LocalBuilder(Builder):
 
                 if isinstance(res, Exception):
                     # timeout or fleet error, return MeasureResult directly
-                    results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
-                                                 self.timeout, time.time()))
+                    results.append(
+                        MeasureResult(
+                            (res,), MeasureErrorNo.BUILD_TIMEOUT, self.timeout, time.time()
+                        )
+                    )
                 elif res.error is not None:
                     # instantiation error
                     if isinstance(res.error, InstantiationError):
-                        results.append(MeasureResult((res.error,),
-                                                     MeasureErrorNo.INSTANTIATION_ERROR,
-                                                     res.time_cost, time.time()))
+                        results.append(
+                            MeasureResult(
+                                (res.error,),
+                                MeasureErrorNo.INSTANTIATION_ERROR,
+                                res.time_cost,
+                                time.time(),
+                            )
+                        )
                     else:
                         if "InstantiationError" in str(res.error):
                             msg = str(res.error)
                             try:
-                                msg = msg.split('\n')[-2].split(": ")[1]
+                                msg = msg.split("\n")[-2].split(": ")[1]
                             except Exception:  # pylint: disable=broad-except
                                 pass
-                            results.append(MeasureResult((InstantiationError(msg),),
-                                                         MeasureErrorNo.INSTANTIATION_ERROR,
-                                                         res.time_cost, time.time()))
+                            results.append(
+                                MeasureResult(
+                                    (InstantiationError(msg),),
+                                    MeasureErrorNo.INSTANTIATION_ERROR,
+                                    res.time_cost,
+                                    time.time(),
+                                )
+                            )
                         else:  # tvm error
-                            results.append(MeasureResult((res.error,),
-                                                         MeasureErrorNo.COMPILE_HOST,
-                                                         res.time_cost, time.time()))
+                            results.append(
+                                MeasureResult(
+                                    (res.error,),
+                                    MeasureErrorNo.COMPILE_HOST,
+                                    res.time_cost,
+                                    time.time(),
+                                )
+                            )
                 else:
                     # return BuildResult
                     results.append(res)
@@ -192,11 +207,21 @@ class RPCRunner(Runner):
         This is only has effect on CPU task.
     """
 
-    def __init__(self,
-                 key, host, port, priority=1,
-                 timeout=10, n_parallel=None,
-                 number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
-                 check_correctness=False, enable_cpu_cache_flush=False):
+    def __init__(
+        self,
+        key,
+        host,
+        port,
+        priority=1,
+        timeout=10,
+        n_parallel=None,
+        number=4,
+        repeat=3,
+        min_repeat_ms=0,
+        cooldown_interval=0.1,
+        check_correctness=False,
+        enable_cpu_cache_flush=False,
+    ):
         super(RPCRunner, self).__init__(timeout, n_parallel)
 
         self.key = key
@@ -223,18 +248,21 @@ class RPCRunner(Runner):
         if check_remote(task.target, self.key, self.host, self.port):
             logger.info("Get devices for measurement successfully!")
         else:
-            raise RuntimeError("Cannot get remote devices from the tracker. "
-                               "Please check the status of tracker by "
-                               "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
-                               "and make sure you have free devices on the queue status.")
+            raise RuntimeError(
+                "Cannot get remote devices from the tracker. "
+                "Please check the status of tracker by "
+                "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
+                "and make sure you have free devices on the queue status."
+            )
 
         if self.check_correctness:
             # use llvm cpu to generate a reference input/output
             # this option works for tuning topi, but might not work for you custom op
             with Target("llvm"):
                 s, arg_bufs = task.instantiate(task.config_space.get(0))
-            self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
-                              for x in arg_bufs]
+            self.ref_input = [
+                np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) for x in arg_bufs
+            ]
             func = build(s, arg_bufs, "llvm")
             tvm_buf = [nd.array(x) for x in self.ref_input]
             func(*tvm_buf)
@@ -242,55 +270,62 @@ class RPCRunner(Runner):
 
     def get_build_kwargs(self):
         kwargs = {}
-        if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \
-           'rocm' in self.task.target.keys or 'vulkan' in self.task.target.keys:
+        if (
+            "cuda" in self.task.target.keys
+            or "opencl" in self.task.target.keys
+            or "rocm" in self.task.target.keys
+            or "vulkan" in self.task.target.keys
+        ):
             remote = request_remote(self.key, self.host, self.port)
             ctx = remote.context(str(self.task.target), 0)
             max_dims = ctx.max_thread_dimensions
-            kwargs['check_gpu'] = {
-                'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
-                'max_threads_per_block': ctx.max_threads_per_block,
-                'max_thread_x': max_dims[0],
-                'max_thread_y': max_dims[1],
-                'max_thread_z': max_dims[2],
+            kwargs["check_gpu"] = {
+                "max_shared_memory_per_block": ctx.max_shared_memory_per_block,
+                "max_threads_per_block": ctx.max_threads_per_block,
+                "max_thread_x": max_dims[0],
+                "max_thread_y": max_dims[1],
+                "max_thread_z": max_dims[2],
             }
 
-            if 'cuda' in self.task.target.keys:
-                kwargs["cuda_arch"] = "sm_" + \
-                    "".join(ctx.compute_version.split('.'))
-        if self.task.target.device_name == 'micro_dev':
-            kwargs.setdefault('build_option', {})[
-                'tir.disable_vectorize'] = True
+            if "cuda" in self.task.target.keys:
+                kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split("."))
+        if self.task.target.device_name == "micro_dev":
+            kwargs.setdefault("build_option", {})["tir.disable_vectorize"] = True
 
         return kwargs
 
     def run(self, measure_inputs, build_results):
         results = []
-        remote_args = (self.key, self.host, self.port,
-                       self.priority, self.timeout)
+        remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
 
         for i in range(0, len(measure_inputs), self.n_parallel):
             futures = []
-            for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
-                                              build_results[i:i+self.n_parallel]):
-                ret = self.executor.submit(run_through_rpc,
-                                           measure_inp,
-                                           build_res,
-                                           self.number,
-                                           self.repeat,
-                                           self.min_repeat_ms,
-                                           self.cooldown_interval,
-                                           remote_args,
-                                           self.ref_input,
-                                           self.ref_output,
-                                           self.enable_cpu_cache_flush)
+            for measure_inp, build_res in zip(
+                measure_inputs[i : i + self.n_parallel], build_results[i : i + self.n_parallel]
+            ):
+                ret = self.executor.submit(
+                    run_through_rpc,
+                    measure_inp,
+                    build_res,
+                    self.number,
+                    self.repeat,
+                    self.min_repeat_ms,
+                    self.cooldown_interval,
+                    remote_args,
+                    self.ref_input,
+                    self.ref_output,
+                    self.enable_cpu_cache_flush,
+                )
                 futures.append(ret)
 
             for future in futures:
                 res = future.get()
-                if isinstance(res, Exception):   # executor error or timeout
-                    results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
-                                                 self.timeout, time.time()))
+                if isinstance(res, Exception):  # executor error or timeout
+                    results.append(
+                        MeasureResult(
+                            (str(res),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time()
+                        )
+                    )
                 else:
                     results.append(res)
 
@@ -338,17 +373,30 @@ class LocalRunner(RPCRunner):
     for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
     """
 
-    def __init__(self,
-                 timeout=10,
-                 number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
-                 check_correctness=False, enable_cpu_cache_flush=False):
-        super(LocalRunner, self).__init__('', None, None, 0,
-                                          timeout=timeout, n_parallel=1,
-                                          number=number, repeat=repeat,
-                                          min_repeat_ms=min_repeat_ms,
-                                          cooldown_interval=cooldown_interval,
-                                          check_correctness=check_correctness,
-                                          enable_cpu_cache_flush=enable_cpu_cache_flush)
+    def __init__(
+        self,
+        timeout=10,
+        number=4,
+        repeat=3,
+        min_repeat_ms=0,
+        cooldown_interval=0.1,
+        check_correctness=False,
+        enable_cpu_cache_flush=False,
+    ):
+        super(LocalRunner, self).__init__(
+            "",
+            None,
+            None,
+            0,
+            timeout=timeout,
+            n_parallel=1,
+            number=number,
+            repeat=repeat,
+            min_repeat_ms=min_repeat_ms,
+            cooldown_interval=cooldown_interval,
+            check_correctness=check_correctness,
+            enable_cpu_cache_flush=enable_cpu_cache_flush,
+        )
         self.tracker = None
         self.server = None
 
@@ -358,12 +406,17 @@ class LocalRunner(RPCRunner):
         from ...rpc.server import Server
 
         self.task = task
-        tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True)
-        device_key = '$local$device$%d' % tracker.port
-        server = Server('0.0.0.0', port=9000, port_end=10000,
-                        key=device_key,
-                        use_popen=True, silent=True,
-                        tracker_addr=(tracker.host, tracker.port))
+        tracker = Tracker("0.0.0.0", port=9000, port_end=10000, silent=True)
+        device_key = "$local$device$%d" % tracker.port
+        server = Server(
+            "0.0.0.0",
+            port=9000,
+            port_end=10000,
+            key=device_key,
+            use_popen=True,
+            silent=True,
+            tracker_addr=(tracker.host, tracker.port),
+        )
         self.key = device_key
         self.host = tracker.host
         self.port = tracker.port
@@ -389,10 +442,13 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
             set_cuda_target_arch(cuda_arch)
 
         # if target is vta, we need to use vta build
-        if hasattr(measure_input.target, 'device_name') and \
-                measure_input.target.device_name == 'vta':
+        if (
+            hasattr(measure_input.target, "device_name")
+            and measure_input.target.device_name == "vta"
+        ):
             # pylint: disable=import-outside-toplevel
             import vta
+
             func = vta.build(s, args, target_host=task.target_host)
         else:
             with tvm.ir.transform.PassContext(config=opts):
@@ -400,7 +456,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
     return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
 
 
-class _WrappedBuildFunc():
+class _WrappedBuildFunc:
     """
     Wrap build_func to a function that can be used in measure.
 
@@ -420,8 +476,7 @@ class _WrappedBuildFunc():
 
     def __init__(self, build_func):
         if not hasattr(build_func, "output_format"):
-            raise AttributeError(
-                "Expect build_func to have the attribute output_format.")
+            raise AttributeError("Expect build_func to have the attribute output_format.")
         self.build_func = build_func
 
     def __call__(self, measure_input, tmp_dir, **kwargs):
@@ -438,8 +493,9 @@ class _WrappedBuildFunc():
         """
         tic = time.time()
         try:
-            filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
-                getrandbits(64), self.build_func.output_format))
+            filename = os.path.join(
+                tmp_dir, "tmp_func_%0x.%s" % (getrandbits(64), self.build_func.output_format)
+            )
             # TODO(tvm-team) consider linline _build_func_common
             func, arg_info = _build_func_common(measure_input, **kwargs)
             func.export_library(filename, self.build_func)
@@ -448,10 +504,18 @@ class _WrappedBuildFunc():
         return BuildResult(filename, arg_info, None, time.time() - tic)
 
 
-def run_through_rpc(measure_input, build_result,
-                    number, repeat, min_repeat_ms, cooldown_interval,
-                    remote_args, ref_input=None, ref_output=None,
-                    enable_cpu_cache_flush=False):
+def run_through_rpc(
+    measure_input,
+    build_result,
+    number,
+    repeat,
+    min_repeat_ms,
+    cooldown_interval,
+    remote_args,
+    ref_input=None,
+    ref_output=None,
+    enable_cpu_cache_flush=False,
+):
     """Run a generated library through rpc
 
     Parameters
@@ -500,10 +564,13 @@ def run_through_rpc(measure_input, build_result,
         # upload built module
         remote = request_remote(*remote_args)
         # Program the FPGA every single time when targeting VTA
-        if hasattr(measure_input.target, 'device_name') and \
-                measure_input.target.device_name == 'vta':
+        if (
+            hasattr(measure_input.target, "device_name")
+            and measure_input.target.device_name == "vta"
+        ):
             # pylint: disable=import-outside-toplevel
             from vta import program_fpga, reconfig_runtime
+
             program_fpga(remote, None)
             reconfig_runtime(remote)
         remote.upload(build_result.filename)
@@ -515,10 +582,15 @@ def run_through_rpc(measure_input, build_result,
         # under the std::function. We could lift the restriction later once we fold
         # the PackedFunc as an object. Currently, we pass function name to work
         # around it.
-        f_prepare = 'cache_flush_cpu_non_first_arg' if enable_cpu_cache_flush else ''
+        f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
         time_f = func.time_evaluator(
-            func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms,
-            f_preproc=f_prepare)
+            func.entry_name,
+            ctx,
+            number=number,
+            repeat=repeat,
+            min_repeat_ms=min_repeat_ms,
+            f_preproc=f_prepare,
+        )
 
         # set input
         if ref_input:
@@ -527,8 +599,9 @@ def run_through_rpc(measure_input, build_result,
             try:
                 random_fill = remote.get_function("tvm.contrib.random.random_fill")
             except AttributeError:
-                raise AttributeError("Please make sure USE_RANDOM is ON in the config.cmake "
-                                     "on the remote devices")
+                raise AttributeError(
+                    "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
+                )
             args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
             for arg in args:
                 random_fill(arg)
@@ -538,8 +611,8 @@ def run_through_rpc(measure_input, build_result,
 
         # clean up remote files
         remote.remove(build_result.filename)
-        remote.remove(os.path.splitext(build_result.filename)[0] + '.so')
-        remote.remove('')
+        remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
+        remote.remove("")
 
         if len(costs) > 2:  # remove largest and smallest value to reduce variance
             costs = list(costs)
@@ -555,9 +628,9 @@ def run_through_rpc(measure_input, build_result,
     except TVMError as exc:
         msg = str(exc)
         if "Stack trace returned" in msg:
-            msg = msg[:msg.index("Stack trace returned")]
+            msg = msg[: msg.index("Stack trace returned")]
         if "CUDA Source" in msg:
-            msg = msg[:msg.index("CUDA Source")]
+            msg = msg[: msg.index("CUDA Source")]
         costs = (RuntimeError(msg[:1024]),)
         errno = MeasureErrorNo.RUNTIME_DEVICE
     tstamp = time.time()
@@ -588,12 +661,11 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
     session: RPCSession
     """
     # connect to the tracker
-    host = host or os.environ['TVM_TRACKER_HOST']
-    port = port or int(os.environ['TVM_TRACKER_PORT'])
+    host = host or os.environ["TVM_TRACKER_HOST"]
+    port = port or int(os.environ["TVM_TRACKER_PORT"])
 
     tracker = _rpc.connect_tracker(host, port)
-    remote = tracker.request(device_key, priority=priority,
-                             session_timeout=timeout)
+    remote = tracker.request(device_key, priority=priority, session_timeout=timeout)
     return remote
 
 
@@ -623,12 +695,16 @@ def check_remote(target, device_key, host=None, port=None, priority=100, timeout
     available: bool
         True if can find available device
     """
+
     def _check():
         remote = request_remote(device_key, host, port, priority)
         ctx = remote.context(str(target))
         while not ctx.exist:  # wait until we get an available device
             pass
-    t = threading.Thread(target=_check,)
+
+    t = threading.Thread(
+        target=_check,
+    )
     t.start()
     t.join(timeout)
     return not t.is_alive()
@@ -643,8 +719,7 @@ def tvm_callback_cuda_compile(code):
     #   "-gencode", "arch=compute_70,code=sm_70"
     # ]
     target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx"
-    ptx = nvcc.compile_cuda(code, target=target,
-                            arch=AutotvmGlobalScope.current.cuda_target_arch)
+    ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch)
     return ptx
 
 
@@ -665,9 +740,11 @@ def gpu_verify_pass(**kwargs):
     """Verify the validity of a gpu kernel.
     This pass will check memory usage and number of threads per block.
     """
+
     def verify_pass(f, *_):
         valid = tvm.tir.analysis.verify_gpu_code(f, kwargs)
         if not valid:
             raise InstantiationError("Skipped because of invalid gpu kernel")
         return f
+
     return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
index b96e9bd..af3540e 100644 (file)
@@ -39,7 +39,7 @@ from .measure import MeasureInput, MeasureResult
 
 AUTOTVM_LOG_VERSION = 0.2
 _old_version_warning = True
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 try:  # convert unicode to str for python2
     _unicode = unicode
@@ -53,7 +53,7 @@ except NameError:
 
 
 def measure_str_key(inp, include_config=True):
-    """ get unique str key for MeasureInput
+    """get unique str key for MeasureInput
 
     Parameters
     ----------
@@ -68,11 +68,12 @@ def measure_str_key(inp, include_config=True):
         The str representation of key
     """
     config_str = str(inp.config) if include_config else ""
-    return "".join([str(inp.target), inp.task.name, str(inp.task.args),
-                    str(inp.task.kwargs), config_str])
+    return "".join(
+        [str(inp.target), inp.task.name, str(inp.task.args), str(inp.task.kwargs), config_str]
+    )
 
 
-def encode(inp, result, protocol='json'):
+def encode(inp, result, protocol="json"):
     """encode (MeasureInput, MeasureResult) pair to a string
 
     Parameters
@@ -89,38 +90,39 @@ def encode(inp, result, protocol='json'):
         a row in the logger file
     """
 
-    if protocol == 'json':
+    if protocol == "json":
         json_dict = {
-            "input": (str(inp.target),
-                      inp.task.name, inp.task.args, inp.task.kwargs),
-
+            "input": (str(inp.target), inp.task.name, inp.task.args, inp.task.kwargs),
             "config": inp.config.to_json_dict(),
-
-            "result": (result.costs if result.error_no == 0 else (1e9,),
-                       result.error_no,
-                       result.all_cost,
-                       result.timestamp),
-
+            "result": (
+                result.costs if result.error_no == 0 else (1e9,),
+                result.error_no,
+                result.all_cost,
+                result.timestamp,
+            ),
             "version": AUTOTVM_LOG_VERSION,
-
-            "tvm_version": __version__
+            "tvm_version": __version__,
         }
         return json.dumps(json_dict)
-    if protocol == 'pickle':
-        row = (str(inp.target),
-               str(base64.b64encode(pickle.dumps([inp.task.name,
-                                                  inp.task.args,
-                                                  inp.task.kwargs])).decode()),
-               str(base64.b64encode(pickle.dumps(inp.config)).decode()),
-               str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
-               str(AUTOTVM_LOG_VERSION),
-               str(__version__))
-        return '\t'.join(row)
+    if protocol == "pickle":
+        row = (
+            str(inp.target),
+            str(
+                base64.b64encode(
+                    pickle.dumps([inp.task.name, inp.task.args, inp.task.kwargs])
+                ).decode()
+            ),
+            str(base64.b64encode(pickle.dumps(inp.config)).decode()),
+            str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
+            str(AUTOTVM_LOG_VERSION),
+            str(__version__),
+        )
+        return "\t".join(row)
 
     raise RuntimeError("Invalid log protocol: " + protocol)
 
 
-def decode(row, protocol='json'):
+def decode(row, protocol="json"):
     """Decode encoded record string to python object
 
     Parameters
@@ -139,26 +141,24 @@ def decode(row, protocol='json'):
     # pylint: disable=unused-variable
     global _old_version_warning
 
-    if protocol == 'json':
+    if protocol == "json":
         row = json.loads(row)
-        if 'v' in row and row['v'] == 0.1:
+        if "v" in row and row["v"] == 0.1:
             if _old_version_warning:
-                logger.warning(
-                    "AutoTVM log version 0.1 is no longer supported.")
+                logger.warning("AutoTVM log version 0.1 is no longer supported.")
                 _old_version_warning = False
             return None
 
         tgt, task_name, task_args, task_kwargs = row["input"]
         tgt = str(tgt)
         if "-target" in tgt:
-            logger.warning(
-                "\"-target\" is deprecated, use \"-mtriple\" instead.")
+            logger.warning('"-target" is deprecated, use "-mtriple" instead.')
             tgt = tgt.replace("-target", "-mtriple")
         tgt = Target(str(tgt))
 
         def clean_json_to_python(x):
             """1. Convert all list in x to tuple (hashable)
-               2. Convert unicode to str for python2
+            2. Convert unicode to str for python2
             """
             if isinstance(x, list):
                 return tuple([clean_json_to_python(a) for a in x])
@@ -168,28 +168,24 @@ def decode(row, protocol='json'):
                 return int(x)
             return x
 
-        tsk = task.Task(clean_json_to_python(task_name),
-                        clean_json_to_python(task_args))
+        tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
         config = ConfigEntity.from_json_dict(row["config"])
         inp = MeasureInput(tgt, tsk, config)
-        result = MeasureResult(
-            *[tuple(x) if isinstance(x, list) else x for x in row["result"]])
+        result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["result"]])
         config.cost = np.mean(result.costs)
 
         return inp, result
-    if protocol == 'pickle':
+    if protocol == "pickle":
         items = row.split("\t")
         if len(items) == 4:
             if _old_version_warning:
-                logger.warning(
-                    "AutoTVM log version 0.1 is no longer supported.")
+                logger.warning("AutoTVM log version 0.1 is no longer supported.")
                 _old_version_warning = False
             return None
         tgt = Target(items[0])
         task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
         config = pickle.loads(base64.b64decode(items[2].encode()))
-        result = MeasureResult(
-            *pickle.loads(base64.b64decode(items[3].encode())))
+        result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
         config.cost = np.mean(result.costs)
 
         tsk = task.Task(task_tuple[0], task_tuple[1])
@@ -212,7 +208,7 @@ def load_from_file(filename):
     result: autotvm.tuner.MeasureResult
     """
     for row in open(filename):
-        if row and not row.startswith('#'):
+        if row and not row.startswith("#"):
             ret = decode(row)
             if ret is None:
                 continue
@@ -258,17 +254,16 @@ def split_workload(in_file, clean=True):
                 cleaned.append([inp, res])
 
             # write to file
-            logger.info("Key: %s\tValid: %d\tDup: %d\t", k,
-                        len(cleaned), len(v) - len(cleaned))
-            with open(args.i + ".%03d.wkl" % i, 'w') as fout:
+            logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
+            with open(args.i + ".%03d.wkl" % i, "w") as fout:
                 for inp, res in cleaned:
-                    fout.write(encode(inp, res) + '\n')
+                    fout.write(encode(inp, res) + "\n")
     else:
         for i, (k, v) in enumerate(wkl_dict.items()):
             logger.info("Key: %s\tNum: %d", k, len(v))
-            with open(args.i + ".%03d.wkl" % i, 'w') as fout:
+            with open(args.i + ".%03d.wkl" % i, "w") as fout:
                 for inp, res in v:
-                    fout.write(encode(inp, res) + '\n')
+                    fout.write(encode(inp, res) + "\n")
 
 
 def pick_best(in_file, out_file):
@@ -300,7 +295,7 @@ def pick_best(in_file, out_file):
         best_set.add(measure_str_key(v[0]))
 
     logger.info("Extract %d best records from the %s", len(best_set), in_file)
-    fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
+    fout = open(out_file, "w") if isinstance(out_file, str) else out_file
 
     for inp, res in context_clone:
         if measure_str_key(inp) in best_set:
@@ -321,24 +316,23 @@ e.g. python -m tvm.autotvm.record --mode pick --i collect.log
 * Split a log file into separate files, each of which contains only a single wkl
 e.g. python -m tvm.autotvm.record --mode split --i collect.log
 """
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--mode", choices=['read', 'pick', 'split'], default='read')
+    parser.add_argument("--mode", choices=["read", "pick", "split"], default="read")
     parser.add_argument("--i", type=str, help="input file")
-    parser.add_argument("--o", type=str, default=None, help='output file')
+    parser.add_argument("--o", type=str, default=None, help="output file")
     parser.add_argument("--begin", type=int, default=0)
     parser.add_argument("--end", type=int, default=5)
-    parser.add_argument("--ir", action='store_true')
-    parser.add_argument("--code", action='store_true')
+    parser.add_argument("--ir", action="store_true")
+    parser.add_argument("--code", action="store_true")
 
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
 
-    if args.mode == 'pick':
+    if args.mode == "pick":
         args.o = args.o or args.i + ".best.log"
         pick_best(args.i, args.o)
-    elif args.mode == 'read':
+    elif args.mode == "read":
         for i, (inp, result) in enumerate(load_from_file(args.i)):
             if args.begin <= i < args.end:
                 with inp.target:
@@ -356,5 +350,5 @@ if __name__ == '__main__':
                     with inp.target:
                         func = build(s, arg_bufs)
                         print(func.imported_modules[0].get_source())
-    elif args.mode == 'split':
+    elif args.mode == "split":
         split_workload(args.i)
index be50af7..6eea622 100644 (file)
@@ -22,13 +22,30 @@ This module defines the task data structure, as well as a collection(zoo)
 of typical tasks of interest.
 """
 
-from .task import Task, create, get_config, args_to_workload, template, \
-    serialize_args, deserialize_args
+from .task import (
+    Task,
+    create,
+    get_config,
+    args_to_workload,
+    template,
+    serialize_args,
+    deserialize_args,
+)
 from .space import ConfigSpace, ConfigEntity
 from .code_hash import attach_code_hash, attach_code_hash_to_arg
-from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
-    FallbackContext, clear_fallback_cache, ApplyGraphBest
+from .dispatcher import (
+    DispatchContext,
+    ApplyConfig,
+    ApplyHistoryBest,
+    FallbackContext,
+    clear_fallback_cache,
+    ApplyGraphBest,
+)
 
-from .topi_integration import register_topi_compute, register_topi_schedule, \
-    TaskExtractEnv, get_workload
+from .topi_integration import (
+    register_topi_compute,
+    register_topi_schedule,
+    TaskExtractEnv,
+    get_workload,
+)
 from .relay_integration import extract_from_program, extract_from_multiple_program
index 3076970..3331fc1 100644 (file)
@@ -24,6 +24,7 @@ import zlib
 
 from tvm.te import schedule
 
+
 def attach_code_hash(s):
     """Decorator for attaching a code hash to a schedule
 
@@ -32,14 +33,18 @@ def attach_code_hash(s):
     s: Schedule
         tvm.te.schedule.Schedule to attach the hash to
     """
+
     def decorator(func):
         def wrapper(*args, **kwargs):
             func(*args, **kwargs)
-            raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
+            raw_hash = zlib.crc32("".join(inspect.getsourcelines(func)[0]).encode())
             s.code_hash = hex(raw_hash)[2:]
+
         return wrapper
+
     return decorator
 
+
 def attach_code_hash_to_arg(arg_idx=1):
     """Decorator for attaching a code hash to a schedule
 
@@ -49,11 +54,14 @@ def attach_code_hash_to_arg(arg_idx=1):
         index of the argument (expected to be a Schedule) to attach the code
         hash to
     """
+
     def decorator(func):
         def wrapper(*args, **kwargs):
             func(*args, **kwargs)
             assert isinstance(args[arg_idx], schedule.Schedule)
-            raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
+            raw_hash = zlib.crc32("".join(inspect.getsourcelines(func)[0]).encode())
             args[arg_idx].code_hash = hex(raw_hash)[2:]
+
         return wrapper
+
     return decorator
index ebb19b0..bfc49d5 100644 (file)
@@ -37,7 +37,7 @@ import numpy as np
 from .space import FallbackConfigEntity
 from .. import env as _env
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 
 class DispatchContext(object):
@@ -47,6 +47,7 @@ class DispatchContext(object):
     DispatchContext enables the target and workload
     specific dispatch mechanism for templates.
     """
+
     current = None
     # a set to prevent print duplicated message
     warning_messages = set()
@@ -243,7 +244,7 @@ class ApplyHistoryBest(DispatchContext):
             # use model as key to build best map
             key = (inp.target.model, inp.task.workload)
             if key not in best_by_model:
-                if inp.target.model != 'unknown':
+                if inp.target.model != "unknown":
                     best_by_model[key] = (inp, res)
             else:
                 _, other_res = best_by_model[key]
@@ -254,9 +255,11 @@ class ApplyHistoryBest(DispatchContext):
 
     def _query_inside(self, target, workload):
         if target is None:
-            raise RuntimeError("Need a target context to find the history best. "
-                               "Hint: If your target is llvm, use `with tvm.target.Target('llvm'):`"
-                               " above the dispatcher call. So does other target. ")
+            raise RuntimeError(
+                "Need a target context to find the history best. "
+                "Hint: If your target is llvm, use `with tvm.target.Target('llvm'):`"
+                " above the dispatcher call. So does other target. "
+            )
 
         # first try matching by model
         key = (target.model, workload)
@@ -307,9 +310,10 @@ class FallbackContext(DispatchContext):
             return self.memory[key]
 
         if not _env.GLOBAL_SCOPE.silent:
-            msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
-                  "is used, which may bring great performance regression." % (
-                      target, workload)
+            msg = (
+                "Cannot find config for target=%s, workload=%s. A fallback configuration "
+                "is used, which may bring great performance regression." % (target, workload)
+            )
             if msg not in DispatchContext.warning_messages:
                 DispatchContext.warning_messages.add(msg)
                 logger.warning(msg)
@@ -421,9 +425,11 @@ class ApplyGraphBest(DispatchContext):
             return cfg
         key = (str(target), workload)
         if key not in self._global_cfg_dict:
-            msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \
-                  "A fallback configuration is used, which may bring great performance " \
-                  "regression." % (target, workload)
+            msg = (
+                "Config for target=%s, workload=%s is missing in ApplyGraphBest context. "
+                "A fallback configuration is used, which may bring great performance "
+                "regression." % (target, workload)
+            )
             logger.warning(msg)
             cfg = FallbackConfigEntity()
             self._global_cfg_dict[key] = cfg
index 15d4534..fe88d17 100644 (file)
@@ -28,21 +28,19 @@ from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
 from .task import create
 from .topi_integration import TaskExtractEnv
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 
 # TODO(moreau89) find a more elegant way to lower for VTAs
-def _lower(mod,
-           target,
-           params):
-    """ Helper to lower VTA properly.
-    """
+def _lower(mod, target, params):
+    """Helper to lower VTA properly."""
     # pylint: disable=import-outside-toplevel
     from tvm import relay
     from tvm.relay.backend import graph_runtime_codegen
 
-    if hasattr(target, 'device_name') and target.device_name == "vta":
+    if hasattr(target, "device_name") and target.device_name == "vta":
         import vta
+
         with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
             mod, _ = relay.optimize(mod, target, params)
             grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
@@ -58,8 +56,10 @@ def _lower(mod,
         grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
         grc.codegen(opt_mod["main"])
     except tvm.TVMError as e:
-        print("Get errors with GraphRuntimeCodegen for task extraction. "
-              "Fallback to VMCompiler. Error details:\n%s" % str(e))
+        print(
+            "Get errors with GraphRuntimeCodegen for task extraction. "
+            "Fallback to VMCompiler. Error details:\n%s" % str(e)
+        )
         compiler = relay.vm.VMCompiler()
         if params:
             compiler.set_params(params)
@@ -67,7 +67,7 @@ def _lower(mod,
 
 
 def extract_from_program(mod, params, target, target_host=None, ops=None):
-    """ Extract tuning tasks from a relay program.
+    """Extract tuning tasks from a relay program.
 
     This function is the single program version of extract_from_multiple_program.
 
@@ -93,7 +93,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
 
 
 def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
-    """ Extract tuning tasks from multiple relay programs.
+    """Extract tuning tasks from multiple relay programs.
 
     This function collects tuning tasks by building a list of programs
     with a "tracing" target and tracing all the calls to topi.
@@ -132,12 +132,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
         for mod, param in zip(mods, params):
             if isinstance(mod, relay.function.Function):
                 mod = tvm.IRModule.from_expr(mod)
-            assert isinstance(mod, tvm.IRModule), \
-                "only support relay Module or Function to be tuned"
+            assert isinstance(
+                mod, tvm.IRModule
+            ), "only support relay Module or Function to be tuned"
             relay.backend.compile_engine.get().clear()
             # wrap build call in thread to avoid multiprocessing problems
-            build_thread = threading.Thread(target=_lower,
-                                            args=(mod, target, param))
+            build_thread = threading.Thread(target=_lower, args=(mod, target, param))
             build_thread.start()
             build_thread.join()
             relay.backend.compile_engine.get().clear()
@@ -152,8 +152,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
     tasks = []
     for task_name, args in env.get_tasks():
         try:
-            tsk = create(task_name, args,
-                         target=target, target_host=target_host)
+            tsk = create(task_name, args, target=target, target_host=target_host)
             tasks.append(tsk)
         except topi.InvalidShapeError:
             logger.warning("Invalid shape during AutoTVM task creation")
index 4937661..d700b64 100644 (file)
@@ -36,7 +36,7 @@ from tvm.te import schedule, thread_axis
 from tvm.tir import expr
 from tvm.autotvm.util import get_const_int
 
-Axis = namedtuple('Axis', ['space', 'index'])
+Axis = namedtuple("Axis", ["space", "index"])
 
 try:
     _long = long
@@ -46,8 +46,8 @@ except NameError:
 
 class InstantiationError(ValueError):
     """Actively detected error in instantiating a template with a config,
-     raised by cfg.raise_error
-     e.g. too many unrolling, too many threads in a block
+    raised by cfg.raise_error
+    e.g. too many unrolling, too many threads in a block
     """
 
 
@@ -69,6 +69,7 @@ class TransformSpace(object):
     We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
     We call a specific entity in a space as XXXEntity.
     """
+
     def __init__(self):
         self.ins = []
         self.num_output = 0
@@ -114,6 +115,7 @@ class VirtualAxis(TransformSpace):
 
     name: str
     """
+
     name_ct = 0
 
     def __init__(self, var, name=None):
@@ -121,7 +123,7 @@ class VirtualAxis(TransformSpace):
         self.num_output = 1
 
         if name is None:
-            name = 'axis_%d' % VirtualAxis.name_ct
+            name = "axis_%d" % VirtualAxis.name_ct
             VirtualAxis.name_ct += 1
 
         self.name = name
@@ -160,13 +162,18 @@ def get_factors(n):
         List of all factors
     """
     step = 2 if n % 2 else 1
-    ret = list(set(
-        functools.reduce(
-            list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step)
-                           if n % i == 0))))
+    ret = list(
+        set(
+            functools.reduce(
+                list.__add__,
+                ([i, n // i] for i in range(1, int(math.sqrt(n)) + 1, step) if n % i == 0),
+            )
+        )
+    )
     ret.sort()
     return ret
 
+
 def get_pow2s(n):
     """return all power-of-two numbers that are less or equal than the integer
 
@@ -180,10 +187,12 @@ def get_pow2s(n):
     factors: list
         List of all power-of-two numbers
     """
-    return [2**x for x in range(math.floor(math.log2(n)) + 1)]
+    return [2 ** x for x in range(math.floor(math.log2(n)) + 1)]
+
 
 class SplitSpace(TransformSpace):
     """Split an axis for several times"""
+
     def __init__(self, axes, policy, **kwargs):
         super(SplitSpace, self).__init__()
         axis = axes[0]
@@ -197,27 +206,27 @@ class SplitSpace(TransformSpace):
         self.num_output = kwargs.get("num_outputs", 0)
         assert self.num_output > 0
 
-        if policy == 'candidate':
+        if policy == "candidate":
             for size in kwargs["candidate"]:
                 assert len(size) == self.num_output
                 self.entities.append(SplitEntity(size))
         else:
-            if policy == 'verbose':
+            if policy == "verbose":
                 # Include factors and power-of-twos. May generate tails.
                 divisibles = get_factors(self.product)
                 pow2s = get_pow2s(self.product)
                 factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor]
-            elif policy == 'factors':
+            elif policy == "factors":
                 # Include divisible factors. Guarantee no tails.
                 factors = [x for x in get_factors(self.product) if x <= max_factor]
-            elif policy == 'power2':
+            elif policy == "power2":
                 # Include less, equal, and round-up power-of-two numbers. May generate tails.
                 factors = [x for x in get_pow2s(self.product) if x <= max_factor]
             else:
                 raise RuntimeError("Invalid policy: %s" % policy)
 
             # Enforce the product of all split factors equals to the axis length
-            no_tail = kwargs.get("no_tail", policy == 'factors')
+            no_tail = kwargs.get("no_tail", policy == "factors")
 
             # Generate split entity by enumerating candidate factors.
             self.factors = factors
@@ -243,8 +252,12 @@ class SplitSpace(TransformSpace):
         return kwargs["num_outputs"]
 
     def __repr__(self):
-        return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
-                (self.policy, self.product, self.num_output, len(self)))
+        return "Split(policy=%s, product=%d, num_outputs=%d) len=%d" % (
+            self.policy,
+            self.product,
+            self.num_output,
+            len(self),
+        )
 
 
 class SplitEntity(object):
@@ -259,6 +272,7 @@ class SplitEntity(object):
         e.g. an axis of extent 128, we split it into 3 axes, a possible
         size is [4, 4, 8] (4x4x8 = 128).
     """
+
     def __init__(self, size):
         self.size = size
 
@@ -292,29 +306,29 @@ class SplitEntity(object):
 
 class ReorderSpace(TransformSpace):
     """The parameter space for ordering an array of axes"""
+
     def __init__(self, axes, policy, **kwargs):
         super(ReorderSpace, self).__init__()
         self.ins = axes
         self.policy = policy
         self.num_output = len(axes)
 
-        if policy == 'identity':
+        if policy == "identity":
             self.entities = [ReorderEntity(range(len(axes)))]
-        elif policy == 'all':
-            self.entities = [
-                ReorderEntity(x) for x in itertools.permutations(range(len(axes)))]
-        elif policy == 'interval_all':
-            begin, end = kwargs['interval']
+        elif policy == "all":
+            self.entities = [ReorderEntity(x) for x in itertools.permutations(range(len(axes)))]
+        elif policy == "interval_all":
+            begin, end = kwargs["interval"]
             sub_space = list(itertools.permutations(range(begin, end)))
             prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes)))
             self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space]
-        elif policy == 'candidate':
+        elif policy == "candidate":
             candidate = kwargs["candidate"]
             for can in candidate:
                 perm = [axes.index(x) for x in can]
                 self.entities.append(ReorderEntity(perm))
-        elif policy == 'interleave':
-            spatial, reduce = kwargs['spatial'], kwargs['reduce']
+        elif policy == "interleave":
+            spatial, reduce = kwargs["spatial"], kwargs["reduce"]
 
             spatial = [[axes.index(x) for x in ch] for ch in spatial]
             reduce = [[axes.index(x) for x in ch] for ch in reduce]
@@ -325,8 +339,8 @@ class ReorderSpace(TransformSpace):
             for o in outer_merged:
                 for i in inner_merged:
                     self.entities.append(ReorderEntity(o + i))
-        elif policy == 'interleave_cuda':
-            spatial, reduce = kwargs['spatial'], kwargs['reduce']
+        elif policy == "interleave_cuda":
+            spatial, reduce = kwargs["spatial"], kwargs["reduce"]
 
             spatial = [[axes.index(x) for x in ch] for ch in spatial]
             reduce = [[axes.index(x) for x in ch] for ch in reduce]
@@ -366,8 +380,9 @@ class ReorderSpace(TransformSpace):
         for i in range(len(chains)):
             # use i == np.argmax(....) here to take spatial order into consideration
             # if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
-            if (tmp_pt[i] < len(chains[i]) and
-                    (i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
+            if tmp_pt[i] < len(chains[i]) and (
+                i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))])
+            ):
                 tmp_stack.append(chains[i][tmp_pt[i]])
                 tmp_pt[i] += 1
                 self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
@@ -383,6 +398,7 @@ class ReorderEntity(object):
     perm: Array of int
         define the permutation
     """
+
     def __init__(self, perm):
         self.perm = perm
 
@@ -416,6 +432,7 @@ class ReorderEntity(object):
 
 class AnnotateSpace(TransformSpace):
     """The parameter space for annotating an array of axes"""
+
     def __init__(self, axes, policy, **kwargs):
         super(AnnotateSpace, self).__init__()
 
@@ -423,54 +440,86 @@ class AnnotateSpace(TransformSpace):
         self.policy = policy
         self.num_output = len(axes)
 
-        if policy == 'bind_gpu':
+        if policy == "bind_gpu":
             self.num_axis = len(axes)
             if self.num_axis >= 6:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 6) +
-                    ['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
-                     'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(
+                        ["fuse"] * (self.num_axis - 6)
+                        + [
+                            "blockIdx.z",
+                            "blockIdx.y",
+                            "blockIdx.x",
+                            "threadIdx.z",
+                            "threadIdx.y",
+                            "threadIdx.x",
+                        ]
+                    )
+                )
             elif self.num_axis >= 4:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 4) +
-                    ['blockIdx.y', 'blockIdx.x',
-                     'threadIdx.y', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(
+                        ["fuse"] * (self.num_axis - 4)
+                        + ["blockIdx.y", "blockIdx.x", "threadIdx.y", "threadIdx.x"]
+                    )
+                )
             elif self.num_axis >= 2:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 2) +
-                    ['blockIdx.x', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(["fuse"] * (self.num_axis - 2) + ["blockIdx.x", "threadIdx.x"])
+                )
             else:
                 raise RuntimeError("Unhandled case in bind_gpu")
-        elif policy == 'bind_gpu_virtual':
+        elif policy == "bind_gpu_virtual":
             self.num_axis = len(axes)
             if self.num_axis >= 9:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 9) +
-                    ['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
-                     'vthread', 'vthread', 'vthread',
-                     'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(
+                        ["fuse"] * (self.num_axis - 9)
+                        + [
+                            "blockIdx.z",
+                            "blockIdx.y",
+                            "blockIdx.x",
+                            "vthread",
+                            "vthread",
+                            "vthread",
+                            "threadIdx.z",
+                            "threadIdx.y",
+                            "threadIdx.x",
+                        ]
+                    )
+                )
             elif self.num_axis >= 6:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 6) +
-                    ['blockIdx.y', 'blockIdx.x',
-                     'vthread', 'vthread',
-                     'threadIdx.y', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(
+                        ["fuse"] * (self.num_axis - 6)
+                        + [
+                            "blockIdx.y",
+                            "blockIdx.x",
+                            "vthread",
+                            "vthread",
+                            "threadIdx.y",
+                            "threadIdx.x",
+                        ]
+                    )
+                )
             elif self.num_axis >= 3:
-                self.entities.append(AnnotateEntity(
-                    ['fuse'] * (self.num_axis - 3) +
-                    ['blockIdx.x', 'vthread', 'threadIdx.x']))
+                self.entities.append(
+                    AnnotateEntity(
+                        ["fuse"] * (self.num_axis - 3) + ["blockIdx.x", "vthread", "threadIdx.x"]
+                    )
+                )
             else:
                 raise RuntimeError("Unhandled case in bind_gpu")
-        elif policy == 'locate_cache':
+        elif policy == "locate_cache":
             self.num_axis = len(axes)
             num_anchor = kwargs["num_anchor"]
             self.anns = list(itertools.combinations(range(self.num_axis), num_anchor))
             self.entities = [AnnotateEntity(x) for x in self.anns]
         else:  # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ...
-            anns = policy.replace('try', 'none').split('_')
+            anns = policy.replace("try", "none").split("_")
 
             for ann in anns:
-                if ann not in ['none', 'unroll', 'vec']:
+                if ann not in ["none", "unroll", "vec"]:
                     raise RuntimeError("Invalid policy: " + policy)
 
             self.num_axis = len(axes)
@@ -481,7 +530,7 @@ class AnnotateSpace(TransformSpace):
         """Generate space by DFS"""
         if now == self.num_axis:
             # only vectorize inner most dimension
-            vec_ct = tmp_stack.count('vec')
+            vec_ct = tmp_stack.count("vec")
             if vec_ct in (0, 1):
                 self.entities.append(AnnotateEntity(list(tmp_stack)))
         else:
@@ -505,11 +554,13 @@ class AnnotateEntity(object):
     anns: Array of string
         The annotations of axes
     """
+
     def __init__(self, anns):
         self.anns = anns
 
-    def apply(self, sch, op, axes, axis_lens=None,
-              max_unroll=None, vec_size=None, cfg=None, source=None):
+    def apply(
+        self, sch, op, axes, axis_lens=None, max_unroll=None, vec_size=None, cfg=None, source=None
+    ):
         """Apply annotation to an array of axes
 
         Parameters
@@ -542,33 +593,33 @@ class AnnotateEntity(object):
                     sch[t].compute_at(sch[op], axes[to])
         else:  # other cases
             for i, ann in enumerate(self.anns):
-                if ann == 'none':
+                if ann == "none":
                     pass
-                elif ann == 'unroll':
+                elif ann == "unroll":
                     if max_unroll and axis_lens[i] > max_unroll:
                         cfg.raise_error("Too large factor for unrolling")
                     sch[op].unroll(axes[i])
-                elif ann == 'vec':
+                elif ann == "vec":
                     if vec_size and axis_lens[i] not in vec_size:
                         cfg.raise_error("Wrong size of lanes in vectorization")
                     sch[op].vectorize(axes[i])
-                elif ann == 'blockIdx.x':
-                    sch[op].bind(axes[i], thread_axis('blockIdx.x'))
-                elif ann == 'blockIdx.y':
-                    sch[op].bind(axes[i], thread_axis('blockIdx.y'))
-                elif ann == 'blockIdx.z':
-                    sch[op].bind(axes[i], thread_axis('blockIdx.z'))
-                elif ann == 'threadIdx.x':
-                    sch[op].bind(axes[i], thread_axis('threadIdx.x'))
-                elif ann == 'threadIdx.y':
-                    sch[op].bind(axes[i], thread_axis('threadIdx.y'))
-                elif ann == 'threadIdx.z':
-                    sch[op].bind(axes[i], thread_axis('threadIdx.z'))
-                elif ann == 'vthread':
+                elif ann == "blockIdx.x":
+                    sch[op].bind(axes[i], thread_axis("blockIdx.x"))
+                elif ann == "blockIdx.y":
+                    sch[op].bind(axes[i], thread_axis("blockIdx.y"))
+                elif ann == "blockIdx.z":
+                    sch[op].bind(axes[i], thread_axis("blockIdx.z"))
+                elif ann == "threadIdx.x":
+                    sch[op].bind(axes[i], thread_axis("threadIdx.x"))
+                elif ann == "threadIdx.y":
+                    sch[op].bind(axes[i], thread_axis("threadIdx.y"))
+                elif ann == "threadIdx.z":
+                    sch[op].bind(axes[i], thread_axis("threadIdx.z"))
+                elif ann == "vthread":
                     sch[op].bind(axes[i], thread_axis("vthread"))
-                elif ann == 'fuse':
+                elif ann == "fuse":
                     assert i < len(axes) - 1
-                    axes[i+1] = sch[op].fuse(axes[i], axes[i+1])
+                    axes[i + 1] = sch[op].fuse(axes[i], axes[i + 1])
                 else:
                     raise RuntimeError("Invalid annotation " + ann)
         return axes
@@ -579,6 +630,7 @@ class AnnotateEntity(object):
 
 class OtherOptionSpace(TransformSpace):
     """The parameter space for general option"""
+
     def __init__(self, axes, policy, **kwargs):
         super(OtherOptionSpace, self).__init__()
 
@@ -595,6 +647,7 @@ class OtherOptionSpace(TransformSpace):
 
 class OtherOptionEntity(object):
     """The parameter entity for general option, with a detailed value"""
+
     def __init__(self, val):
         self.val = val
 
@@ -604,11 +657,12 @@ class OtherOptionEntity(object):
 
 class ConfigSpace(object):
     """The configuration space of a schedule. Pass it as config in template to
-       collect transformation space and build transform graph of axes
+    collect transformation space and build transform graph of axes
     """
+
     def __init__(self):
         # private dict to provide sugar
-        self.space_map = OrderedDict()    # name -> space
+        self.space_map = OrderedDict()  # name -> space
         self._collect = True
         self._length = None
         self._entity_map = OrderedDict()  # name -> entity
@@ -634,7 +688,7 @@ class ConfigSpace(object):
 
     reduce_axis = axis
 
-    def define_split(self, name, axis, policy='factors', **kwargs):
+    def define_split(self, name, axis, policy="factors", **kwargs):
         """Define a new tunable knob which splits an axis into a list of axes
 
         Parameters
@@ -824,12 +878,20 @@ class ConfigSpace(object):
 
 
 _ann_to_number = {
-    'none': 0, 'vec': 1, 'unroll': 2,
-    'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5,
-    'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8,
-    'vthread': 9, 'fuse': 10
+    "none": 0,
+    "vec": 1,
+    "unroll": 2,
+    "blockIdx.x": 3,
+    "blockIdx.y": 4,
+    "blockIdx.z": 5,
+    "threadIdx.x": 6,
+    "threadIdx.y": 7,
+    "threadIdx.z": 8,
+    "vthread": 9,
+    "fuse": 10,
 }
 
+
 class ConfigEntity(ConfigSpace):
     """A configuration with detailed parameters
 
@@ -844,6 +906,7 @@ class ConfigEntity(ConfigSpace):
     constraints : list
         List of constraints
     """
+
     def __init__(self, index, code_hash, entity_map, constraints):
         super(ConfigEntity, self).__init__()
         self.index = index
@@ -854,7 +917,7 @@ class ConfigEntity(ConfigSpace):
         self.code_hash = code_hash
 
     def get_flatten_feature(self):
-        """ flatten entities to a numerical one-dimensional feature vector
+        """flatten entities to a numerical one-dimensional feature vector
 
         Returns
         -------
@@ -896,21 +959,21 @@ class ConfigEntity(ConfigSpace):
             a json serializable dictionary
         """
         ret = {}
-        ret['index'] = int(self.index)
-        ret['code_hash'] = self.code_hash
+        ret["index"] = int(self.index)
+        ret["code_hash"] = self.code_hash
         entity_map = []
         for k, v in self._entity_map.items():
             if isinstance(v, SplitEntity):
-                entity_map.append((k, 'sp', v.size))
+                entity_map.append((k, "sp", v.size))
             elif isinstance(v, ReorderEntity):
-                entity_map.append((k, 're', v.perm))
+                entity_map.append((k, "re", v.perm))
             elif isinstance(v, AnnotateEntity):
-                entity_map.append((k, 'an', v.anns))
+                entity_map.append((k, "an", v.anns))
             elif isinstance(v, OtherOptionEntity):
-                entity_map.append((k, 'ot', v.val))
+                entity_map.append((k, "ot", v.val))
             else:
                 raise RuntimeError("Invalid entity instance: " + v)
-        ret['entity'] = entity_map
+        ret["entity"] = entity_map
         return ret
 
     @staticmethod
@@ -936,13 +999,13 @@ class ConfigEntity(ConfigSpace):
 
         for item in json_dict["entity"]:
             key, knob_type, knob_args = item
-            if knob_type == 'sp':
+            if knob_type == "sp":
                 entity = SplitEntity(knob_args)
-            elif knob_type == 're':
+            elif knob_type == "re":
                 entity = ReorderEntity(knob_args)
-            elif knob_type == 'an':
+            elif knob_type == "an":
                 entity = AnnotateEntity(knob_args)
-            elif knob_type == 'ot':
+            elif knob_type == "ot":
                 entity = OtherOptionEntity(knob_args)
             else:
                 raise RuntimeError("Invalid config knob type: " + knob_type)
@@ -1018,8 +1081,7 @@ class FallbackConfigEntity(ConfigSpace):
         ref_log: List of (MeasureInput, MeasureResult)
             The reference log
         """
-        knob_names = [x for x in self.space_map.keys() if
-                      isinstance(self.space_map[x], SplitSpace)]
+        knob_names = [x for x in self.space_map.keys() if isinstance(self.space_map[x], SplitSpace)]
 
         # find best match config in reference data by matching tiling factors
         factor_list = []
@@ -1032,8 +1094,9 @@ class FallbackConfigEntity(ConfigSpace):
             match_score = 0
             for i, knob_name in enumerate(knob_names):
                 factors = get_factors(int(np.prod(inp.config[knob_name].size)))
-                match_score += (float(len(set(factor_list[i]).intersection(factors))) /
-                                len(factor_list[i]))
+                match_score += float(len(set(factor_list[i]).intersection(factors))) / len(
+                    factor_list[i]
+                )
 
                 if match_score > best_match_score:
                     best_match_score, best_match_cfg = match_score, inp.config
index c4b1d34..4231532 100644 (file)
@@ -36,9 +36,11 @@ from .space import ConfigSpace
 
 
 def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
-    raise RuntimeError("The function of this task is not found. Possibly the function "
-                       "of this task is registered in another python file "
-                       "which is not imported in this run")
+    raise RuntimeError(
+        "The function of this task is not found. Possibly the function "
+        "of this task is registered in another python file "
+        "which is not imported in this run"
+    )
 
 
 def serialize_args(args):
@@ -48,9 +50,10 @@ def serialize_args(args):
     ----------
     args: list of hashable or Tensor
     """
+
     def _encode(x):
         if isinstance(x, tensor.Tensor):
-            return ('TENSOR', get_const_tuple(x.shape), x.dtype)
+            return ("TENSOR", get_const_tuple(x.shape), x.dtype)
         if isinstance(x, (tuple, list, container.Array)):
             return tuple([_encode(a) for a in x])
         if isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
@@ -61,8 +64,11 @@ def serialize_args(args):
             return str(x)
         if x is None:
             return None
-        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
-                           'primitive types or tvm.tir.Var only' % type(x))
+        raise RuntimeError(
+            'Do not support type "%s" in argument. Consider to use'
+            "primitive types or tvm.tir.Var only" % type(x)
+        )
+
     ret = []
     for t in args:
         ret.append(_encode(t))
@@ -78,7 +84,7 @@ def deserialize_args(args):
     """
     ret = []
     for t in args:
-        if isinstance(t, tuple) and t[0] == 'TENSOR':
+        if isinstance(t, tuple) and t[0] == "TENSOR":
             ret.append(placeholder(shape=t[1], dtype=t[2]))
         else:
             ret.append(t)
@@ -171,7 +177,7 @@ class Task(object):
             "config_space": self.config_space,
             "flop": self.flop,
             "target": self.target,
-            "target_host": self.target_host
+            "target_host": self.target_host,
         }
 
     def __setstate__(self, state):
@@ -186,7 +192,10 @@ class Task(object):
 
     def __repr__(self):
         return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
-            self.name, self.args, self.kwargs, self.workload
+            self.name,
+            self.args,
+            self.kwargs,
+            self.workload,
         )
 
 
@@ -235,8 +244,7 @@ class TaskTemplate(object):
             if isinstance(t.op, tensor.PlaceholderOp):
                 inputs.append(t)
             else:
-                input_tensors = [
-                    t for t in t.op.input_tensors if t not in hash_set]
+                input_tensors = [t for t in t.op.input_tensors if t not in hash_set]
                 queue.extend(input_tensors)
                 hash_set.update(input_tensors)
         return inputs
@@ -259,15 +267,16 @@ def _register_task_compute(name, func=None):
     decorator: callable
         A decorator
     """
+
     def _do_reg(f):
         if name not in TASK_TABLE:
             TASK_TABLE[name] = TaskTemplate()
         tmpl = TASK_TABLE[name]
         if tmpl.fcompute is not None:
-            raise ValueError(
-                "Compute is already registered in autoTVM task %s" % name)
+            raise ValueError("Compute is already registered in autoTVM task %s" % name)
         tmpl.fcompute = f
         return f
+
     if func:
         return _do_reg(func)
     return _do_reg
@@ -290,15 +299,16 @@ def _register_task_schedule(name, func=None):
     decorator: callable
         A decorator
     """
+
     def _do_reg(f):
         if name not in TASK_TABLE:
             TASK_TABLE[name] = TaskTemplate()
         tmpl = TASK_TABLE[name]
         if tmpl.fschedule is not None:
-            raise ValueError(
-                "Schedule is already registered in autoTVM task %s" % name)
+            raise ValueError("Schedule is already registered in autoTVM task %s" % name)
         tmpl.fschedule = f
         return f
+
     if func:
         return _do_reg(func)
     return _do_reg
@@ -321,15 +331,16 @@ def _register_customized_task(name, func=None):
     decorator: callable
         A decorator
     """
+
     def _do_reg(f):
         if name not in TASK_TABLE:
             TASK_TABLE[name] = TaskTemplate()
         tmpl = TASK_TABLE[name]
         if tmpl.fcustomized is not None:
-            raise ValueError(
-                "Customized func is already registered in autoTVM task %s" % name)
+            raise ValueError("Customized func is already registered in autoTVM task %s" % name)
         tmpl.fcustomized = f
         return f
+
     if func:
         return _do_reg(func)
     return _do_reg
@@ -386,6 +397,7 @@ def template(task_name, func=None):
 
             return s, [A, B, C]
     """
+
     def _decorate(f):
         def wrapper(*args, **kwargs):
             assert not kwargs, "Do not support kwargs in template function call"
@@ -435,7 +447,7 @@ def create(task_name, args, target, target_host=None):
     with ctx:
         with target:
             sch, _ = ret.func(*args)
-            ret.config_space.code_hash = getattr(sch, 'code_hash', None)
+            ret.config_space.code_hash = getattr(sch, "code_hash", None)
 
     ret.flop = ret.config_space.flop or compute_flop(sch)
     ret.target = target
@@ -473,11 +485,11 @@ def compute_flop(sch):
     flop: int
         number of FLOP in this schedule
     """
+
     def _prod_length(axes):
         """compute product of the lengths of a list of axes"""
         try:
-            num_iter = int(
-                np.prod([get_const_int(axis.dom.extent) for axis in axes]))
+            num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
         except ValueError:
             raise FlopCalculationError("The length of axis is not constant. ")
         return num_iter
@@ -489,11 +501,9 @@ def compute_flop(sch):
             combiner = exp.combiner.result
             source = exp.source
             if len(combiner) != 1:
-                raise FlopCalculationError(
-                    "Found multiple output in the combiner of reduce op")
+                raise FlopCalculationError("Found multiple output in the combiner of reduce op")
             if len(source) != 1:
-                raise FlopCalculationError(
-                    "Found multiple output in the source of reduce op")
+                raise FlopCalculationError("Found multiple output in the source of reduce op")
             return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
         if isinstance(exp, (expr.FloatImm, expr.IntImm)):
             return 0
@@ -501,12 +511,29 @@ def compute_flop(sch):
             return _count_flop(exp.value)
         if isinstance(exp, expr.Var):
             return 0
-        if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
-                            expr.Div, expr.Mod,
-                            expr.FloorDiv, expr.FloorMod,
-                            expr.Max, expr.Min,
-                            expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
-                            expr.And, expr.Or, expr.Not)):
+        if isinstance(
+            exp,
+            (
+                expr.Add,
+                expr.Sub,
+                expr.Mul,
+                expr.Div,
+                expr.Mod,
+                expr.FloorDiv,
+                expr.FloorMod,
+                expr.Max,
+                expr.Min,
+                expr.EQ,
+                expr.NE,
+                expr.LT,
+                expr.LE,
+                expr.GT,
+                expr.GE,
+                expr.And,
+                expr.Or,
+                expr.Not,
+            ),
+        ):
             base = 1
 
             if isinstance(exp, expr.Not):  # unary
@@ -514,8 +541,9 @@ def compute_flop(sch):
 
             return base + _count_flop(exp.a) + _count_flop(exp.b)
         if isinstance(exp, expr.Select):
-            return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
-                                                    _count_flop(exp.false_value))
+            return _count_flop(exp.condition) + max(
+                _count_flop(exp.true_value), _count_flop(exp.false_value)
+            )
         if isinstance(exp, expr.ProducerLoad):
             # Ignore flops from indexing expressions.
             return 0
@@ -523,8 +551,7 @@ def compute_flop(sch):
         if isinstance(exp, expr.Call):
             return sum([_count_flop(x) for x in exp.args])
 
-        raise FlopCalculationError(
-            "Found unsupported operator in the compute expr")
+        raise FlopCalculationError("Found unsupported operator in the compute expr")
 
     def traverse(ops):
         """accumulate flops"""
@@ -535,8 +562,7 @@ def compute_flop(sch):
 
                 body = op.body
                 if len(body) != 1:
-                    raise FlopCalculationError(
-                        "Found multiple output in the compute")
+                    raise FlopCalculationError("Found multiple output in the compute")
                 exp = body[0]
 
                 ret += num_element * _count_flop(exp)
@@ -545,20 +571,26 @@ def compute_flop(sch):
             elif isinstance(op, tensor.PlaceholderOp):
                 pass
             else:
-                raise FlopCalculationError("Only support te.compute currently. "
-                                           "Other ops like tvm.te.scan/te.extern is not supported")
+                raise FlopCalculationError(
+                    "Only support te.compute currently. "
+                    "Other ops like tvm.te.scan/te.extern is not supported"
+                )
         return ret
 
     try:
         ret = traverse(sch.outputs)
     except FlopCalculationError as exc:
-        raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
-                           + str(exc) +
-                           ". Please use `cfg.add_flop` to manually set "
-                           "FLOP for this operator")
+        raise RuntimeError(
+            "FLOP estimator fails for this operator. Error msg: "
+            + str(exc)
+            + ". Please use `cfg.add_flop` to manually set "
+            "FLOP for this operator"
+        )
 
     if ret == 0:
-        raise RuntimeError("Cannot find float number operation in this operator. "
-                           "Please use `cfg.add_flop` to manually set "
-                           "FLOP for this operator")
+        raise RuntimeError(
+            "Cannot find float number operation in this operator. "
+            "Please use `cfg.add_flop` to manually set "
+            "FLOP for this operator"
+        )
     return ret
index cc0170f..f6ca3b1 100644 (file)
@@ -30,13 +30,19 @@ import tvm.te._ffi_api
 from tvm.target import Target
 from tvm.te import tensor
 
-from .task import args_to_workload, serialize_args, DispatchContext, \
-    _register_task_compute, _register_task_schedule
+from .task import (
+    args_to_workload,
+    serialize_args,
+    DispatchContext,
+    _register_task_compute,
+    _register_task_schedule,
+)
 
 
 # Task extractor for relay program
 class TaskExtractEnv:
     """Global environment for extracting tuning tasks from graph"""
+
     current = None
     registered = None
 
@@ -141,6 +147,7 @@ def register_topi_compute(task_name, func=None):
     --------
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
+
     def _decorate(topi_compute):
         @_register_task_compute(task_name)
         def wrapper(*args, **kwargs):
@@ -159,15 +166,19 @@ def register_topi_compute(task_name, func=None):
             attrs = {}
             for k, v in node.op.attrs.items():
                 attrs[k] = v
-            attrs['workload'] = workload
+            attrs["workload"] = workload
             if isinstance(op, tensor.ComputeOp):
-                op = tvm.te._ffi_api.ComputeOp(
-                    op.name, op.tag, attrs, op.axis, op.body)
+                op = tvm.te._ffi_api.ComputeOp(op.name, op.tag, attrs, op.axis, op.body)
             elif isinstance(op, tensor.ExternOp):
                 op = tvm.te._ffi_api.ExternOp(
-                    op.name, op.tag, attrs,
-                    op.inputs, op.input_placeholders,
-                    op.output_placeholders, op.body)
+                    op.name,
+                    op.tag,
+                    attrs,
+                    op.inputs,
+                    op.input_placeholders,
+                    op.output_placeholders,
+                    op.body,
+                )
             else:
                 raise RuntimeError("Unsupported op type: " + str(type(op)))
 
@@ -211,18 +222,20 @@ def register_topi_schedule(task_name, func=None):
     --------
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
+
     def _decorate(topi_schedule):
         @_register_task_schedule(task_name)
         def wrapper(outs, *args, **kwargs):
             """wrapper function for topi schedule"""
             workload = get_workload(outs)
             if workload is None:
-                raise RuntimeError(
-                    "Cannot find workload in attribute of this schedule")
+                raise RuntimeError("Cannot find workload in attribute of this schedule")
             tgt = Target.current()
             cfg = DispatchContext.current.query(tgt, workload)
             return topi_schedule(cfg, outs, *args, **kwargs)
+
         return wrapper
+
     if func:
         return _decorate(func)
     return _decorate
@@ -230,15 +243,17 @@ def register_topi_schedule(task_name, func=None):
 
 def get_workload(outs):
     """Retrieve the workload from outputs"""
+
     def traverse(tensors):
         """traverse all ops to find attached workload"""
         for t in tensors:
             op = t.op
-            if 'workload' in op.attrs:
-                return args_to_workload(op.attrs['workload'])
+            if "workload" in op.attrs:
+                return args_to_workload(op.attrs["workload"])
             wkl = traverse(op.input_tensors)
             if wkl:
                 return wkl
         return None
+
     outs = [outs] if isinstance(outs, tensor.Tensor) else outs
     return traverse(outs)
index e637e29..2076ee7 100644 (file)
@@ -42,36 +42,33 @@ AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-dist
 AUTOTVM_TOPHUB_NONE_LOC = "NONE"
 
 # root path to store TopHub files
-AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(
-    os.path.expanduser('~'), ".tvm", "tophub")
+AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".tvm", "tophub")
 
 # the version of each package
 PACKAGE_VERSION = {
-    'arm_cpu':          "v0.07",
-    'llvm':             "v0.04",
-
-    'cuda':             "v0.09",
-    'rocm':             "v0.05",
-    'opencl':           "v0.04",
-    'mali':             "v0.06",
-    'intel_graphics':   "v0.02",
-    'vta':              "v0.09",
-    'amd_apu':          "v0.01",
+    "arm_cpu": "v0.07",
+    "llvm": "v0.04",
+    "cuda": "v0.09",
+    "rocm": "v0.05",
+    "opencl": "v0.04",
+    "mali": "v0.06",
+    "intel_graphics": "v0.02",
+    "vta": "v0.09",
+    "amd_apu": "v0.01",
 }
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 
 def _alias(name):
     """convert alias for some packages"""
     table = {
-        'vtacpu': 'vta',
-
-        'metal': 'opencl',
-        'webgpu': 'opencl',
-        'vulkan': 'opencl',
-        'nvptx': 'cuda',
-        'amd_apu': 'amd_apu'
+        "vtacpu": "vta",
+        "metal": "opencl",
+        "webgpu": "opencl",
+        "vulkan": "opencl",
+        "nvptx": "cuda",
+        "amd_apu": "amd_apu",
     }
     return table.get(name, name)
 
@@ -120,9 +117,8 @@ def context(target, extra_files=None):
                     continue
 
                 filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
-                best_context.load(os.path.join(
-                    AUTOTVM_TOPHUB_ROOT_PATH, filename))
-                break   # only load one file to avoid some fallback template mismatch problem
+                best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename))
+                break  # only load one file to avoid some fallback template mismatch problem
 
     if extra_files:
         for filename in extra_files:
@@ -162,8 +158,7 @@ def check_backend(tophub_location, backend):
         download_package(tophub_location, package_name)
         return True
     except urllib2.URLError as e:
-        logging.warning(
-            "Failed to download tophub package for %s: %s", backend, e)
+        logging.warning("Failed to download tophub package for %s: %s", backend, e)
         return False
 
 
@@ -183,15 +178,14 @@ def download_package(tophub_location, package_name):
     if not os.path.isdir(rootpath):
         # make directory
         splits = os.path.split(rootpath)
-        for j in range(1, len(splits)+1):
+        for j in range(1, len(splits) + 1):
             path = os.path.join(*splits[:j])
             if not os.path.isdir(path):
                 os.mkdir(path)
 
     download_url = "{0}/{1}".format(tophub_location, package_name)
     logger.info("Download pre-tuned parameters package from %s", download_url)
-    download(download_url, os.path.join(
-        rootpath, package_name), True, verbose=0)
+    download(download_url, os.path.join(rootpath, package_name), True, verbose=0)
 
 
 # global cache for load_reference_log
@@ -199,7 +193,7 @@ REFERENCE_LOG_CACHE = {}
 
 
 def load_reference_log(backend, model, workload_name):
-    """ Load reference log from TopHub to support fallback in template.
+    """Load reference log from TopHub to support fallback in template.
     Template will use these reference logs to choose fallback config.
 
     Parameters
index cfc1b2c..bb9dafa 100644 (file)
@@ -25,10 +25,10 @@ import numpy as np
 from .. import record
 from ..util import format_si_prefix
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
 
 
-def log_to_file(file_out, protocol='json'):
+def log_to_file(file_out, protocol="json"):
     """Log the tuning records into file.
     The rows of the log are stored in the format of autotvm.record.encode.
 
@@ -44,6 +44,7 @@ def log_to_file(file_out, protocol='json'):
     callback : callable
         Callback function to do the logging.
     """
+
     def _callback(_, inputs, results):
         """Callback implementation"""
         if isinstance(file_out, str):
@@ -56,6 +57,7 @@ def log_to_file(file_out, protocol='json'):
 
     # pylint: disable=import-outside-toplevel
     from pathlib import Path
+
     if isinstance(file_out, Path):
         file_out = str(file_out)
 
@@ -70,15 +72,18 @@ def log_to_database(db):
     db: Database
         The database
     """
+
     def _callback(_, inputs, results):
         """Callback implementation"""
         for inp, result in zip(inputs, results):
             db.save(inp, result)
+
     return _callback
 
 
 class Monitor(object):
     """A monitor to collect statistic during tuning"""
+
     def __init__(self):
         self.scores = []
         self.timestamps = []
@@ -106,7 +111,7 @@ class Monitor(object):
         return np.array(self.timestamps)
 
 
-def progress_bar(total, prefix='', si_prefix='G'):
+def progress_bar(total, prefix="", si_prefix="G"):
     """Display progress bar for tuning
 
     Parameters
@@ -118,8 +123,10 @@ def progress_bar(total, prefix='', si_prefix='G'):
     si_prefix: str
         SI prefix for flops
     """
+
     class _Context(object):
         """Context to store local variables"""
+
         def __init__(self):
             self.best_flops = 0
             self.cur_flops = 0
@@ -128,7 +135,7 @@ def progress_bar(total, prefix='', si_prefix='G'):
 
         def __del__(self):
             if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
-                sys.stdout.write(' Done.\n')
+                sys.stdout.write(" Done.\n")
 
     ctx = _Context()
     tic = time.time()
@@ -137,8 +144,10 @@ def progress_bar(total, prefix='', si_prefix='G'):
     format_si_prefix(0, si_prefix)
 
     if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
-        sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
-                         '| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
+        sys.stdout.write(
+            "\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) "
+            "| %.2f s" % (prefix, 0, 0, 0, total, time.time() - tic)
+        )
         sys.stdout.flush()
 
     def _callback(tuner, inputs, results):
@@ -153,11 +162,19 @@ def progress_bar(total, prefix='', si_prefix='G'):
             ctx.cur_flops = flops
             ctx.best_flops = tuner.best_flops
 
-            sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) '
-                             '| %.2f s' %
-                             (prefix, format_si_prefix(ctx.cur_flops, si_prefix),
-                              format_si_prefix(ctx.best_flops, si_prefix), si_prefix,
-                              ctx.ct, ctx.total, time.time() - tic))
+            sys.stdout.write(
+                "\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
+                "| %.2f s"
+                % (
+                    prefix,
+                    format_si_prefix(ctx.cur_flops, si_prefix),
+                    format_si_prefix(ctx.best_flops, si_prefix),
+                    si_prefix,
+                    ctx.ct,
+                    ctx.total,
+                    time.time() - tic,
+                )
+            )
             sys.stdout.flush()
 
     return _callback
index da10f73..165d5d3 100644 (file)
@@ -38,6 +38,7 @@ class GATuner(Tuner):
     mutation_prob: float
         probability of mutation of a knob in a gene
     """
+
     def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
         super(GATuner, self).__init__(task)
 
@@ -95,11 +96,11 @@ class GATuner(Tuner):
 
         if len(self.scores) >= len(self.genes) and len(self.visited) < len(self.space):
             genes = self.genes + self.elites
-            scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
+            scores = np.array(self.scores[: len(self.genes)] + self.elite_scores)
 
             # reserve elite
             self.elites, self.elite_scores = [], []
-            elite_indexes = np.argpartition(scores, -self.elite_num)[-self.elite_num:]
+            elite_indexes = np.argpartition(scores, -self.elite_num)[-self.elite_num :]
             for ind in elite_indexes:
                 self.elites.append(genes[ind])
                 self.elite_scores.append(scores[ind])
@@ -127,7 +128,9 @@ class GATuner(Tuner):
                 if len(self.visited) < len(self.space):
                     while knob2point(tmp_gene, self.dims) in self.visited:
                         j = np.random.randint(len(self.dims))
-                        tmp_gene[j] = np.random.randint(self.dims[j])  # pylint: disable=invalid-sequence-index
+                        tmp_gene[j] = np.random.randint(
+                            self.dims[j]
+                        )  # pylint: disable=invalid-sequence-index
                     next_genes.append(tmp_gene)
                     self.visited.add(knob2point(tmp_gene, self.dims))
                 else:
index 99fc9f2..945bcfd 100644 (file)
@@ -21,6 +21,7 @@ import numpy as np
 
 from .tuner import Tuner
 
+
 class IndexBaseTuner(Tuner):
     """Base class for index based tuner
     This type of tuner determine the next batch of configs based on config indices.
@@ -33,10 +34,12 @@ class IndexBaseTuner(Tuner):
     range_idx: Optional[Tuple[int, int]]
         A tuple of index range that this tuner can select from
     """
+
     def __init__(self, task, range_idx=None):
         super(IndexBaseTuner, self).__init__(task)
-        assert range_idx is None or isinstance(range_idx, tuple), \
-            "range_idx must be None or (int, int)"
+        assert range_idx is None or isinstance(
+            range_idx, tuple
+        ), "range_idx must be None or (int, int)"
 
         self.range_length = len(self.task.config_space)
         self.index_offset = 0
@@ -79,6 +82,7 @@ class RandomTuner(IndexBaseTuner):
     range_idx: Optional[Tuple[int, int]]
         A tuple of index range to random
     """
+
     def __init__(self, task, range_idx=None):
         super(RandomTuner, self).__init__(task, range_idx)
 
index 5497ff4..1ed04ab 100644 (file)
@@ -21,8 +21,9 @@ import numpy as np
 
 from ..util import get_rank
 
+
 def max_curve(trial_scores):
-    """ f(n) = max([s[i] fo i < n])
+    """f(n) = max([s[i] fo i < n])
 
     Parameters
     ----------
@@ -41,8 +42,9 @@ def max_curve(trial_scores):
         ret[i] = keep
     return ret
 
+
 def mean_curve(trial_scores):
-    """ f(n) = mean([s[i] fo i < n])
+    """f(n) = mean([s[i] fo i < n])
 
     Parameters
     ----------
@@ -58,9 +60,10 @@ def mean_curve(trial_scores):
     keep = 0
     for i, score in enumerate(trial_scores):
         keep += score
-        ret[i] = keep / (i+1)
+        ret[i] = keep / (i + 1)
     return ret
 
+
 def recall_curve(trial_ranks, top=None):
     """
     if top is None, f(n) = sum([I(rank[i] < n) for i < n]) / n
@@ -84,12 +87,13 @@ def recall_curve(trial_ranks, top=None):
     ret = np.zeros(len(trial_ranks))
     if top is None:
         for i in range(len(trial_ranks)):
-            ret[i] = np.sum(trial_ranks[:i] <= i) / (i+1)
+            ret[i] = np.sum(trial_ranks[:i] <= i) / (i + 1)
     else:
         for i in range(len(trial_ranks)):
             ret[i] = 1.0 * np.sum(trial_ranks[:i] < top) / top
     return ret
 
+
 def cover_curve(trial_ranks):
     """
     f(n) = max k s.t. {1,2,...,k} is a subset of {ranks[i] for i < n}
@@ -109,11 +113,12 @@ def cover_curve(trial_ranks):
     cover = set()
     for i, rank in enumerate(trial_ranks):
         cover.add(rank)
-        while keep+1 in cover:
+        while keep + 1 in cover:
             keep += 1
         ret[i] = keep + 1
     return ret / len(trial_ranks)
 
+
 def average_recall(preds, labels, N):
     """evaluate average recall-n for predictions and labels"""
     trials = np.argsort(preds)[::-1]
index 432f707..4d16339 100644 (file)
@@ -26,13 +26,15 @@ import numpy as np
 from .tuner import Tuner
 from ..env import GLOBAL_SCOPE
 
+
 class FeatureCache(object):
     """Feature cache manager for cache sharing between different cost models"""
+
     def __init__(self):
         self.feature_cache = {}
 
     def get(self, key):
-        """ Get feature cache dictionary for a key
+        """Get feature cache dictionary for a key
 
         Parameters
         ----------
@@ -50,7 +52,7 @@ class FeatureCache(object):
         return self.feature_cache[key]
 
     def size(self, key):
-        """" Get the size of a feature cache dictionary
+        """ " Get the size of a feature cache dictionary
 
         Parameters
         ----------
@@ -78,6 +80,7 @@ class FeatureCache(object):
 
 class CostModel(object):
     """Cost model to predict the speed of a config"""
+
     def __init__(self):
         pass
 
@@ -149,6 +152,7 @@ class CostModel(object):
 
 class ModelOptimizer(object):
     """Optimizer used to find optimal points of cost model"""
+
     def __init__(self):
         pass
 
@@ -207,8 +211,9 @@ class ModelBasedTuner(Tuner):
         self.diversity_filter_ratio = diversity_filter_ratio
 
         if self.diversity_filter_ratio:
-            assert self.diversity_filter_ratio >= 1, "Diversity filter ratio " \
-                                                     "must be larger than one"
+            assert self.diversity_filter_ratio >= 1, (
+                "Diversity filter ratio " "must be larger than one"
+            )
 
         # trial plan
         self.trials = []
@@ -261,19 +266,20 @@ class ModelBasedTuner(Tuner):
                 self.ys.append(0.0)
 
         # if we have enough new training samples
-        if len(self.xs) >= self.plan_size * (self.train_ct + 1) \
-                and self.flops_max > 1e-6:
+        if len(self.xs) >= self.plan_size * (self.train_ct + 1) and self.flops_max > 1e-6:
             self.cost_model.fit(self.xs, self.ys, self.plan_size)
             if self.diversity_filter_ratio:
                 candidate = self.model_optimizer.find_maximums(
-                    self.cost_model, self.plan_size * self.diversity_filter_ratio, self.visited)
+                    self.cost_model, self.plan_size * self.diversity_filter_ratio, self.visited
+                )
                 scores = self.cost_model.predict(candidate)
                 knobs = [point2knob(x, self.dims) for x in candidate]
                 pick_index = submodular_pick(0 * scores, knobs, self.plan_size, knob_weight=1)
                 maximums = np.array(candidate)[pick_index]
             else:
                 maximums = self.model_optimizer.find_maximums(
-                    self.cost_model, self.plan_size, self.visited)
+                    self.cost_model, self.plan_size, self.visited
+                )
 
             self.trials = maximums
             self.trial_pt = 0
index 5812033..5535246 100644 (file)
@@ -28,7 +28,8 @@ import numpy as np
 from ..util import sample_ints
 from .model_based_tuner import ModelOptimizer, knob2point, point2knob
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
+
 
 class SimulatedAnnealingOptimizer(ModelOptimizer):
     """parallel simulated annealing optimization algorithm
@@ -47,8 +48,17 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
     log_interval: int, optional
         Print log every `log_interval` iterations
     """
-    def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
-                 early_stop=50, log_interval=50):
+
+    def __init__(
+        self,
+        task,
+        n_iter=500,
+        temp=(1, 0),
+        persistent=True,
+        parallel_size=128,
+        early_stop=50,
+        log_interval=50,
+    ):
         super(SimulatedAnnealingOptimizer, self).__init__()
 
         self.task = task
@@ -64,8 +74,12 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
 
     def find_maximums(self, model, num, exclusive):
         tic = time.time()
-        temp, n_iter, early_stop, log_interval = \
-                self.temp, self.n_iter, self.early_stop, self.log_interval
+        temp, n_iter, early_stop, log_interval = (
+            self.temp,
+            self.n_iter,
+            self.early_stop,
+            self.log_interval,
+        )
 
         if self.persistent and self.points is not None:
             points = self.points
@@ -75,7 +89,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
         scores = model.predict(points)
 
         # build heap and insert initial points
-        heap_items = [(float('-inf'), - 1 - i) for i in range(num)]
+        heap_items = [(float("-inf"), -1 - i) for i in range(num)]
         heapq.heapify(heap_items)
         in_heap = set(exclusive)
         in_heap.update([x[1] for x in heap_items])
@@ -121,16 +135,22 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
 
             if log_interval and k % log_interval == 0:
                 t_str = "%.2f" % t
-                logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
-                             "elapsed: %.2f",
-                             k, k_last_modify, heap_items[0][0],
-                             np.max([v for v, _ in heap_items]), t_str,
-                             time.time() - tic)
+                logger.debug(
+                    "SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
+                    "elapsed: %.2f",
+                    k,
+                    k_last_modify,
+                    heap_items[0][0],
+                    np.max([v for v, _ in heap_items]),
+                    t_str,
+                    time.time() - tic,
+                )
 
         heap_items.sort(key=lambda item: -item[0])
         heap_items = [x for x in heap_items if x[0] >= 0]
-        logger.debug("SA iter: %d\tlast_update: %d\telapsed: %.2f",
-                     k, k_last_modify, time.time() - tic)
+        logger.debug(
+            "SA iter: %d\tlast_update: %d\telapsed: %.2f", k, k_last_modify, time.time() - tic
+        )
         logger.debug("SA Maximums: %s", heap_items)
 
         if self.persistent:
@@ -138,6 +158,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
 
         return [x[1] for x in heap_items]
 
+
 def random_walk(p, dims):
     """random walk as local transition
 
index 2441a4a..cbfe973 100644 (file)
@@ -25,7 +25,8 @@ from ..util import format_si_prefix
 
 from ..env import GLOBAL_SCOPE
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
+
 
 class Tuner(object):
     """Base class for tuners
@@ -87,8 +88,7 @@ class Tuner(object):
             result for measurement
         """
 
-
-    def tune(self, n_trial, measure_option, early_stopping=None, callbacks=(), si_prefix='G'):
+    def tune(self, n_trial, measure_option, early_stopping=None, callbacks=(), si_prefix="G"):
         """Begin tuning
 
         Parameters
@@ -109,7 +109,7 @@ class Tuner(object):
             One of tvm.autotvm.util.SI_PREFIXES. The SI prefix to use when reporting FLOPS.
         """
         measure_batch = create_measure_batch(self.task, measure_option)
-        n_parallel = getattr(measure_batch, 'n_parallel', 1)
+        n_parallel = getattr(measure_batch, "n_parallel", 1)
         early_stopping = early_stopping or 1e9
         self.n_trial = n_trial
         self.early_stopping = early_stopping
@@ -146,9 +146,15 @@ class Tuner(object):
                     self.best_measure_pair = (inp, res)
                     self.best_iter = i + k
 
-                logger.debug("No: %d\t%sFLOPS: %.2f/%.2f\tresult: %s\t%s",
-                             i + k + 1, si_prefix, format_si_prefix(flops, si_prefix),
-                             format_si_prefix(self.best_flops, si_prefix), res, config)
+                logger.debug(
+                    "No: %d\t%sFLOPS: %.2f/%.2f\tresult: %s\t%s",
+                    i + k + 1,
+                    si_prefix,
+                    format_si_prefix(flops, si_prefix),
+                    format_si_prefix(self.best_flops, si_prefix),
+                    res,
+                    config,
+                )
 
             i += len(results)
             self.ttl = min(early_stopping + self.best_iter, n_trial) - i
index 15a3390..7b9df1c 100644 (file)
@@ -22,6 +22,7 @@ import logging
 import time
 
 import numpy as np
+
 try:
     import xgboost as xgb
 except ImportError:
@@ -32,7 +33,8 @@ from ..util import get_rank
 from .metric import max_curve, recall_curve, cover_curve
 from .model_based_tuner import CostModel, FeatureCache
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
+
 
 class XGBoostCostModel(CostModel):
     """XGBoost as cost model
@@ -69,14 +71,18 @@ class XGBoostCostModel(CostModel):
     upper_model: XGBoostCostModel, optional
         The upper model used in transfer learning
     """
-    def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25,
-                 upper_model=None):
+
+    def __init__(
+        self, task, feature_type, loss_type, num_threads=None, log_interval=25, upper_model=None
+    ):
         super(XGBoostCostModel, self).__init__()
 
         if xgb is None:
-            raise RuntimeError("XGBoost is required for XGBoostCostModel. "
-                               "Please install its python package first. "
-                               "Help: (https://xgboost.readthedocs.io/en/latest/) ")
+            raise RuntimeError(
+                "XGBoost is required for XGBoostCostModel. "
+                "Please install its python package first. "
+                "Help: (https://xgboost.readthedocs.io/en/latest/) "
+            )
 
         self.task = task
         self.target = task.target
@@ -87,47 +93,41 @@ class XGBoostCostModel(CostModel):
         self.num_threads = num_threads
         self.log_interval = log_interval
 
-        if loss_type == 'reg':
+        if loss_type == "reg":
             self.xgb_params = {
-                'max_depth': 3,
-                'gamma': 0.0001,
-                'min_child_weight': 1,
-
-                'subsample': 1.0,
-
-                'eta': 0.3,
-                'lambda': 1.00,
-                'alpha': 0,
-
-                'objective': 'reg:linear',
+                "max_depth": 3,
+                "gamma": 0.0001,
+                "min_child_weight": 1,
+                "subsample": 1.0,
+                "eta": 0.3,
+                "lambda": 1.00,
+                "alpha": 0,
+                "objective": "reg:linear",
             }
-        elif loss_type == 'rank':
+        elif loss_type == "rank":
             self.xgb_params = {
-                'max_depth': 3,
-                'gamma': 0.0001,
-                'min_child_weight': 1,
-
-                'subsample': 1.0,
-
-                'eta': 0.3,
-                'lambda': 1.00,
-                'alpha': 0,
-
-                'objective': 'rank:pairwise',
+                "max_depth": 3,
+                "gamma": 0.0001,
+                "min_child_weight": 1,
+                "subsample": 1.0,
+                "eta": 0.3,
+                "lambda": 1.00,
+                "alpha": 0,
+                "objective": "rank:pairwise",
             }
         else:
             raise RuntimeError("Invalid loss type: " + loss_type)
 
-        self.xgb_params['verbosity'] = 0
+        self.xgb_params["verbosity"] = 0
         if num_threads:
-            self.xgb_params['nthread'] = num_threads
+            self.xgb_params["nthread"] = num_threads
         self.bst = None
 
-        if feature_type == 'itervar':
+        if feature_type == "itervar":
             self.feature_extract_func = _extract_itervar_feature_index
-        elif feature_type == 'knob':
+        elif feature_type == "knob":
             self.feature_extract_func = _extract_knob_feature_index
-        elif feature_type == 'curve':
+        elif feature_type == "curve":
             self.feature_extract_func = _extract_curve_feature_index
         else:
             raise RuntimeError("Invalid feature type " + feature_type)
@@ -196,22 +196,31 @@ class XGBoostCostModel(CostModel):
             else:
                 dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
 
-        self.bst = xgb.train(self.xgb_params, dtrain,
-                             num_boost_round=8000,
-                             callbacks=[custom_callback(
-                                 stopping_rounds=20,
-                                 metric='tr-a-recall@%d' % plan_size,
-                                 evals=[(dtrain, 'tr')],
-                                 maximize=True,
-                                 fevals=[
-                                     xgb_average_recalln_curve_score(plan_size),
-                                 ],
-                                 verbose_eval=self.log_interval)])
-
-        logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
-                     time.time() - tic, len(xs),
-                     len(xs) - np.sum(valid_index),
-                     self.feature_cache.size(self.fea_type))
+        self.bst = xgb.train(
+            self.xgb_params,
+            dtrain,
+            num_boost_round=8000,
+            callbacks=[
+                custom_callback(
+                    stopping_rounds=20,
+                    metric="tr-a-recall@%d" % plan_size,
+                    evals=[(dtrain, "tr")],
+                    maximize=True,
+                    fevals=[
+                        xgb_average_recalln_curve_score(plan_size),
+                    ],
+                    verbose_eval=self.log_interval,
+                )
+            ],
+        )
+
+        logger.debug(
+            "XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
+            time.time() - tic,
+            len(xs),
+            len(xs) - np.sum(valid_index),
+            self.feature_cache.size(self.fea_type),
+        )
 
     def fit_log(self, records, plan_size):
         tic = time.time()
@@ -227,11 +236,11 @@ class XGBoostCostModel(CostModel):
         # extract feature
         self._reset_pool(self.space, self.target, self.task)
         pool = self._get_pool()
-        if self.fea_type == 'itervar':
+        if self.fea_type == "itervar":
             feature_extract_func = _extract_itervar_feature_log
-        elif self.fea_type == 'knob':
+        elif self.fea_type == "knob":
             feature_extract_func = _extract_knob_feature_log
-        elif self.fea_type == 'curve':
+        elif self.fea_type == "curve":
             feature_extract_func = _extract_curve_feature_log
         else:
             raise RuntimeError("Invalid feature type: " + self.fea_type)
@@ -259,17 +268,23 @@ class XGBoostCostModel(CostModel):
         dtrain = xgb.DMatrix(x_train[index], y_train[index])
 
         plan_size *= 2
-        self.bst = xgb.train(self.xgb_params, dtrain,
-                             num_boost_round=400,
-                             callbacks=[custom_callback(
-                                 stopping_rounds=100,
-                                 metric='tr-a-recall@%d' % plan_size,
-                                 evals=[(dtrain, 'tr')],
-                                 maximize=True,
-                                 fevals=[
-                                     xgb_average_recalln_curve_score(plan_size),
-                                 ],
-                                 verbose_eval=self.log_interval)])
+        self.bst = xgb.train(
+            self.xgb_params,
+            dtrain,
+            num_boost_round=400,
+            callbacks=[
+                custom_callback(
+                    stopping_rounds=100,
+                    metric="tr-a-recall@%d" % plan_size,
+                    evals=[(dtrain, "tr")],
+                    maximize=True,
+                    fevals=[
+                        xgb_average_recalln_curve_score(plan_size),
+                    ],
+                    verbose_eval=self.log_interval,
+                )
+            ],
+        )
 
         logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
 
@@ -280,8 +295,9 @@ class XGBoostCostModel(CostModel):
         dtest = xgb.DMatrix(feas)
 
         if self.base_model:
-            dtest.set_base_margin(self._base_model_discount() *
-                                  self.base_model.predict(xs, output_margin=True))
+            dtest.set_base_margin(
+                self._base_model_discount() * self.base_model.predict(xs, output_margin=True)
+            )
 
         return self.bst.predict(dtest, output_margin=output_margin)
 
@@ -291,8 +307,9 @@ class XGBoostCostModel(CostModel):
         self.base_model.upper_model = self
 
     def spawn_base_model(self):
-        return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
-                                self.num_threads, self.log_interval, self)
+        return XGBoostCostModel(
+            self.task, self.fea_type, self.loss_type, self.num_threads, self.log_interval, self
+        )
 
     def _get_feature(self, indexes):
         """get features for indexes, run extraction if we do not have cache for them"""
@@ -331,6 +348,7 @@ _extract_space = None
 _extract_target = None
 _extract_task = None
 
+
 def _extract_itervar_feature_index(index):
     """extract iteration var feature for an index in extract_space"""
     try:
@@ -343,6 +361,7 @@ def _extract_itervar_feature_index(index):
     except Exception:  # pylint: disable=broad-except
         return None
 
+
 def _extract_itervar_feature_log(arg):
     """extract iteration var feature for log items"""
     try:
@@ -361,6 +380,7 @@ def _extract_itervar_feature_log(arg):
     except Exception:  # pylint: disable=broad-except
         return None
 
+
 def _extract_knob_feature_index(index):
     """extract knob feature for an index in extract_space"""
     try:
@@ -369,6 +389,7 @@ def _extract_knob_feature_index(index):
     except Exception:  # pylint: disable=broad-except
         return None
 
+
 def _extract_knob_feature_log(arg):
     """extract knob feature for log items"""
     try:
@@ -386,6 +407,7 @@ def _extract_knob_feature_log(arg):
     except Exception:  # pylint: disable=broad-except
         return None
 
+
 def _extract_curve_feature_index(index):
     """extract sampled curve feature for an index in extract_space"""
     try:
@@ -398,6 +420,7 @@ def _extract_curve_feature_index(index):
     except Exception:  # pylint: disable=broad-except
         return None
 
+
 def _extract_curve_feature_log(arg):
     """extract sampled curve feature for log items"""
     try:
@@ -416,8 +439,10 @@ def _extract_curve_feature_log(arg):
     except Exception:  # pylint: disable=broad-except
         return None
 
-def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
-                    maximize=False, verbose_eval=True):
+
+def custom_callback(
+    stopping_rounds, metric, fevals, evals=(), log_file=None, maximize=False, verbose_eval=True
+):
     """callback function for xgboost to support multiple custom evaluation functions"""
     # pylint: disable=import-outside-toplevel
     from xgboost.core import EarlyStopException
@@ -431,21 +456,21 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         """internal function"""
         bst = env.model
 
-        state['maximize_score'] = maximize
-        state['best_iteration'] = 0
+        state["maximize_score"] = maximize
+        state["best_iteration"] = 0
         if maximize:
-            state['best_score'] = float('-inf')
+            state["best_score"] = float("-inf")
         else:
-            state['best_score'] = float('inf')
+            state["best_score"] = float("inf")
 
         if bst is not None:
-            if bst.attr('best_score') is not None:
-                state['best_score'] = float(bst.attr('best_score'))
-                state['best_iteration'] = int(bst.attr('best_iteration'))
-                state['best_msg'] = bst.attr('best_msg')
+            if bst.attr("best_score") is not None:
+                state["best_score"] = float(bst.attr("best_score"))
+                state["best_iteration"] = int(bst.attr("best_iteration"))
+                state["best_msg"] = bst.attr("best_msg")
             else:
-                bst.set_attr(best_iteration=str(state['best_iteration']))
-                bst.set_attr(best_score=str(state['best_score']))
+                bst.set_attr(best_iteration=str(state["best_iteration"]))
+                bst.set_attr(best_score=str(state["best_score"]))
         else:
             assert env.cvfolds is not None
 
@@ -469,7 +494,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         else:
             for feval in fevals:
                 bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(':') for x in bst_eval.split()]
+                res = [x.split(":") for x in bst_eval.split()]
                 for kv in res[1:]:
                     res_dict[kv[0]] = [float(kv[1])]
 
@@ -483,7 +508,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
         ##### print eval result #####
         infos = ["XGB iter: %3d" % i]
         for item in eval_res:
-            if 'null' in item[0]:
+            if "null" in item[0]:
                 continue
             infos.append("%s: %.6f" % (item[0], item[1]))
 
@@ -491,7 +516,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
             logger.debug("\t".join(infos))
         if log_file:
             with open(log_file, "a") as fout:
-                fout.write("\t".join(infos) + '\n')
+                fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
@@ -501,24 +526,23 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
                 break
         assert score is not None
 
-        best_score = state['best_score']
-        best_iteration = state['best_iteration']
-        maximize_score = state['maximize_score']
-        if (maximize_score and score > best_score) or \
-                (not maximize_score and score < best_score):
-            msg = '[%d] %s' % (
-                env.iteration,
-                '\t'.join([_fmt_metric(x) for x in eval_res]))
-            state['best_msg'] = msg
-            state['best_score'] = score
-            state['best_iteration'] = env.iteration
+        best_score = state["best_score"]
+        best_iteration = state["best_iteration"]
+        maximize_score = state["maximize_score"]
+        if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
+            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
+            state["best_msg"] = msg
+            state["best_score"] = score
+            state["best_iteration"] = env.iteration
             # save the property to attributes, so they will occur in checkpoint.
             if env.model is not None:
-                env.model.set_attr(best_score=str(state['best_score']),
-                                   best_iteration=str(state['best_iteration']),
-                                   best_msg=state['best_msg'])
+                env.model.set_attr(
+                    best_score=str(state["best_score"]),
+                    best_iteration=str(state["best_iteration"]),
+                    best_msg=state["best_msg"],
+                )
         elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state['best_msg']
+            best_msg = state["best_msg"]
             if verbose_eval and env.rank == 0:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
             raise EarlyStopException(best_iteration)
@@ -529,56 +553,73 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
 # feval wrapper for xgboost
 def xgb_max_curve_score(N):
     """evaluate max curve score for xgb"""
+
     def feval(preds, labels):
         labels = labels.get_label()
         trials = np.argsort(preds)[::-1]
         scores = labels[trials]
         curve = max_curve(scores)
         return "Smax@%d" % N, curve[N] / np.max(labels)
+
     return feval
 
+
 def xgb_recalln_curve_score(N):
     """evaluate recall-n curve score for xgb"""
+
     def feval(preds, labels):
         labels = labels.get_label()
         trials = np.argsort(preds)[::-1]
         ranks = get_rank(labels[trials])
         curve = recall_curve(ranks)
         return "recall@%d" % N, curve[N]
+
     return feval
 
+
 def xgb_average_recalln_curve_score(N):
     """evaluate average recall-n curve score for xgb"""
+
     def feval(preds, labels):
         labels = labels.get_label()
         trials = np.argsort(preds)[::-1]
         ranks = get_rank(labels[trials])
         curve = recall_curve(ranks)
         return "a-recall@%d" % N, np.sum(curve[:N]) / N
+
     return feval
 
+
 def xgb_recallk_curve_score(N, topk):
     """evaluate recall-k curve score for xgb"""
+
     def feval(preds, labels):
         labels = labels.get_label()
         trials = np.argsort(preds)[::-1]
         ranks = get_rank(labels[trials])
         curve = recall_curve(ranks, topk)
         return "recall@%d" % topk, curve[N]
+
     return feval
 
+
 def xgb_cover_curve_score(N):
     """evaluate cover curve score for xgb"""
+
     def feval(preds, labels):
         labels = labels.get_label()
         trials = np.argsort(preds)[::-1]
         ranks = get_rank(labels[trials])
         curve = cover_curve(ranks)
         return "cover@%d" % N, curve[N]
+
     return feval
 
+
 def xgb_null_score(_):
     """empty score function for xgb"""
+
     def feval(__, ___):
         return "null", 0
+
     return feval
index a7ddf85..8f8ddfe 100644 (file)
@@ -20,6 +20,7 @@ from .model_based_tuner import ModelBasedTuner, ModelOptimizer
 from .xgboost_cost_model import XGBoostCostModel
 from .sa_model_optimizer import SimulatedAnnealingOptimizer
 
+
 class XGBTuner(ModelBasedTuner):
     """Tuner that uses xgboost as cost model
 
@@ -68,23 +69,35 @@ class XGBTuner(ModelBasedTuner):
         If is 0, output nothing.
         Otherwise, output debug information every `verbose` iterations.
     """
-    def __init__(self, task, plan_size=64,
-                 feature_type='itervar', loss_type='rank', num_threads=None,
-                 optimizer='sa', diversity_filter_ratio=None, log_interval=50):
-        cost_model = XGBoostCostModel(task,
-                                      feature_type=feature_type,
-                                      loss_type=loss_type,
-                                      num_threads=num_threads,
-                                      log_interval=log_interval // 2)
-        if optimizer == 'sa':
+
+    def __init__(
+        self,
+        task,
+        plan_size=64,
+        feature_type="itervar",
+        loss_type="rank",
+        num_threads=None,
+        optimizer="sa",
+        diversity_filter_ratio=None,
+        log_interval=50,
+    ):
+        cost_model = XGBoostCostModel(
+            task,
+            feature_type=feature_type,
+            loss_type=loss_type,
+            num_threads=num_threads,
+            log_interval=log_interval // 2,
+        )
+        if optimizer == "sa":
             optimizer = SimulatedAnnealingOptimizer(task, log_interval=log_interval)
         else:
-            assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \
-                                                          "a supported name string" \
-                                                          "or a ModelOptimizer object."
+            assert isinstance(optimizer, ModelOptimizer), (
+                "Optimizer must be " "a supported name string" "or a ModelOptimizer object."
+            )
 
-        super(XGBTuner, self).__init__(task, cost_model, optimizer,
-                                       plan_size, diversity_filter_ratio)
+        super(XGBTuner, self).__init__(
+            task, cost_model, optimizer, plan_size, diversity_filter_ratio
+        )
 
     def tune(self, *args, **kwargs):  # pylint: disable=arguments-differ
         super(XGBTuner, self).tune(*args, **kwargs)
index 0d81c12..fa1dcfd 100644 (file)
@@ -26,10 +26,12 @@ import numpy as np
 import tvm.arith
 from tvm.tir import expr
 
-logger = logging.getLogger('autotvm')
+logger = logging.getLogger("autotvm")
+
 
 class EmptyContext(object):
     """An empty context"""
+
     def __enter__(self):
         pass
 
@@ -114,9 +116,8 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
         logger.info("mapping begin")
     for i in range(0, len(args), batch_size):
         if verbose:
-            logger.info("mapping %d/%d elapsed %.2f", i, len(args),
-                        time.time() - tic)
-        tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
+            logger.info("mapping %d/%d elapsed %.2f", i, len(args), time.time() - tic)
+        tmp = np.array(local_pool.map(func, args[i : i + batch_size]))
         ret = tmp if ret is None else np.concatenate((ret, tmp))
     if verbose:
         logger.info("mapping done")
@@ -124,6 +125,7 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
         local_pool.close()
     return ret
 
+
 def get_func_name(func):
     """Get name of a function
 
@@ -137,7 +139,7 @@ def get_func_name(func):
         The name
     """
 
-    return func.func_name if hasattr(func, 'func_name') else func.__name__
+    return func.func_name if hasattr(func, "func_name") else func.__name__
 
 
 def get_const_int(exp):
@@ -190,7 +192,7 @@ def get_const_tuple(in_tuple):
     return tuple(ret)
 
 
-SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY'
+SI_PREFIXES = "yzafpn\xb5m kMGTPEZY"
 YOCTO_EXP10 = -24
 
 
index d784b7b..b0f36c8 100644 (file)
@@ -72,6 +72,7 @@ SECTIONS
 }}
 """
 
+
 def run_cmd(cmd):
     """Runs `cmd` in a subprocess and awaits its completion.
 
@@ -85,15 +86,12 @@ def run_cmd(cmd):
     output : str
         resulting stdout capture from the subprocess
     """
-    proc = subprocess.Popen(
-        cmd,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (output, _) = proc.communicate()
     output = output.decode("utf-8")
     if proc.returncode != 0:
         cmd_str = " ".join(cmd)
-        msg = f"error while running command \"{cmd_str}\":\n{output}"
+        msg = f'error while running command "{cmd_str}":\n{output}'
         raise RuntimeError(msg)
     return output
 
@@ -161,14 +159,15 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix):
 
 @tvm._ffi.register_func("tvm_callback_relocate_binary")
 def tvm_callback_relocate_binary(
-        binary_path,
-        word_size,
-        text_start,
-        rodata_start,
-        data_start,
-        bss_start,
-        stack_end,
-        toolchain_prefix):
+    binary_path,
+    word_size,
+    text_start,
+    rodata_start,
+    data_start,
+    bss_start,
+    stack_end,
+    toolchain_prefix,
+):
     """Relocates sections in the binary to new addresses
 
     Parameters
@@ -215,18 +214,17 @@ def tvm_callback_relocate_binary(
         rodata_start=rodata_start,
         data_start=data_start,
         bss_start=bss_start,
-        stack_pointer_init=stack_pointer_init)
+        stack_pointer_init=stack_pointer_init,
+    )
 
     tmp_dir = util.tempdir()
     rel_obj_path = tmp_dir.relpath("relocated.obj")
     rel_ld_script_path = tmp_dir.relpath("relocate.lds")
     with open(rel_ld_script_path, "w") as f:
         f.write(ld_script_contents)
-    run_cmd([
-        "{}ld".format(toolchain_prefix),
-        binary_path,
-        "-T", rel_ld_script_path,
-        "-o", rel_obj_path])
+    run_cmd(
+        ["{}ld".format(toolchain_prefix), binary_path, "-T", rel_ld_script_path, "-o", rel_obj_path]
+    )
 
     with open(rel_obj_path, "rb") as f:
         rel_bin = bytearray(f.read())
@@ -272,11 +270,14 @@ def tvm_callback_read_binary_section(binary, section, toolchain_prefix):
     tmp_section = tmp_dir.relpath("tmp_section.bin")
     with open(tmp_bin, "wb") as out_file:
         out_file.write(bytes(binary))
-    run_cmd([
-        "{}objcopy".format(toolchain_prefix),
-        "--dump-section",
-        ".{}={}".format(section, tmp_section),
-        tmp_bin])
+    run_cmd(
+        [
+            "{}objcopy".format(toolchain_prefix),
+            "--dump-section",
+            ".{}={}".format(section, tmp_section),
+            tmp_bin,
+        ]
+    )
     if os.path.isfile(tmp_section):
         # Get section content if it exists.
         with open(tmp_section, "rb") as f:
@@ -309,11 +310,7 @@ def tvm_callback_get_symbol_map(binary, toolchain_prefix):
     tmp_obj = tmp_dir.relpath("tmp_obj.bin")
     with open(tmp_obj, "wb") as out_file:
         out_file.write(bytes(binary))
-    nm_output = run_cmd([
-        "{}nm".format(toolchain_prefix),
-        "-C",
-        "--defined-only",
-        tmp_obj])
+    nm_output = run_cmd(["{}nm".format(toolchain_prefix), "-C", "--defined-only", tmp_obj])
     nm_output = nm_output.splitlines()
     map_str = ""
     for line in nm_output:
index e1a4a8a..58bf933 100644 (file)
@@ -48,7 +48,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
             "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
 
 
@@ -89,5 +89,5 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
             transb,
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
index cdde0cb..7b1a8c9 100644 (file)
@@ -24,10 +24,8 @@ import os
 from .._ffi.base import py_str
 from .util import tempdir
 
-def create_shared(output,
-                  objects,
-                  options=None,
-                  cc="g++"):
+
+def create_shared(output, objects, options=None, cc="g++"):
     """Create shared library.
 
     Parameters
@@ -52,10 +50,7 @@ def create_shared(output,
         raise ValueError("Unsupported platform")
 
 
-def create_executable(output,
-                      objects,
-                      options=None,
-                      cc="g++"):
+def create_executable(output, objects, options=None, cc="g++"):
     """Create executable binary.
 
     Parameters
@@ -79,7 +74,7 @@ def create_executable(output,
 
 
 def get_target_by_dump_machine(compiler):
-    """ Functor of get_target_triple that can get the target triple using compiler.
+    """Functor of get_target_triple that can get the target triple using compiler.
 
     Parameters
     ----------
@@ -91,12 +86,12 @@ def get_target_by_dump_machine(compiler):
     out: Callable
         A function that can get target triple according to dumpmachine option of compiler.
     """
+
     def get_target_triple():
         """ Get target triple according to dumpmachine option of compiler."""
         if compiler:
             cmd = [compiler, "-dumpmachine"]
-            proc = subprocess.Popen(
-                cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
             (out, _) = proc.communicate()
             if proc.returncode != 0:
                 msg = "dumpmachine error:\n"
@@ -111,14 +106,13 @@ def get_target_by_dump_machine(compiler):
 # assign so as default output format
 create_shared.output_format = "so" if sys.platform != "win32" else "dll"
 create_shared.get_target_triple = get_target_by_dump_machine(
-    "g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)
+    "g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None
+)
 
 
-def cross_compiler(compile_func,
-                   options=None,
-                   output_format=None,
-                   get_target_triple=None,
-                   add_files=None):
+def cross_compiler(
+    compile_func, options=None, output_format=None, get_target_triple=None, add_files=None
+):
     """Create a cross compiler function by specializing compile_func with options.
 
     This function can be used to construct compile functions that
@@ -169,10 +163,9 @@ def cross_compiler(compile_func,
 
     # handle case where compile_func is the name of the cc
     if isinstance(compile_func, str):
-        kwargs = {"cc" : compile_func}
+        kwargs = {"cc": compile_func}
         compile_func = create_shared
 
-
     def _fcompile(outputs, objects, options=None):
         all_options = base_options
         if options is not None:
@@ -191,8 +184,7 @@ def cross_compiler(compile_func,
     return _fcompile
 
 
-def _linux_compile(output, objects, options,
-                   compile_cmd="g++", compile_shared=False):
+def _linux_compile(output, objects, options, compile_cmd="g++", compile_shared=False):
     cmd = [compile_cmd]
     if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
         cmd += ["-shared", "-fPIC"]
@@ -207,8 +199,7 @@ def _linux_compile(output, objects, options,
         cmd += objects
     if options:
         cmd += options
-    proc = subprocess.Popen(
-        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     if proc.returncode != 0:
         msg = "Compilation error:\n"
@@ -229,23 +220,25 @@ def _windows_shared(output, objects, options):
     temp = tempdir()
     dllmain_path = temp.relpath("dllmain.cc")
     with open(dllmain_path, "w") as dllmain_obj:
-        dllmain_obj.write('#include <windows.h>\
+        dllmain_obj.write(
+            "#include <windows.h>\
 BOOL APIENTRY DllMain( HMODULE hModule,\
                        DWORD  ul_reason_for_call,\
                        LPVOID lpReserved)\
-{return TRUE;}')
+{return TRUE;}"
+        )
 
     cl_cmd += [dllmain_path]
 
     temp_path = dllmain_path.replace("dllmain.cc", "")
     cl_cmd += ["-Fo:" + temp_path]
     try:
-        proc = subprocess.Popen(
-            cl_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        proc = subprocess.Popen(cl_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
         (out, _) = proc.communicate()
     except FileNotFoundError:
-        raise RuntimeError("Can not find cl.exe,"
-                           "please run this in Vistual Studio Command Prompt.")
+        raise RuntimeError(
+            "Can not find cl.exe," "please run this in Vistual Studio Command Prompt."
+        )
     if proc.returncode != 0:
         msg = "Compilation error:\n"
         msg += py_str(out)
@@ -266,15 +259,16 @@ BOOL APIENTRY DllMain( HMODULE hModule,\
     link_cmd += ["-out:" + output]
 
     try:
-        proc = subprocess.Popen(
-            link_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        proc = subprocess.Popen(link_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
         (out, _) = proc.communicate()
     except FileNotFoundError:
-        raise RuntimeError("Can not find the LLVM linker for Windows (lld-link.exe)."
-                           "Make sure it's installed"
-                           " and the installation directory is in the %PATH% environment "
-                           "variable. Prebuilt binaries can be found at: https://llvm.org/"
-                           "For building the linker on your own see: https://lld.llvm.org/#build")
+        raise RuntimeError(
+            "Can not find the LLVM linker for Windows (lld-link.exe)."
+            "Make sure it's installed"
+            " and the installation directory is in the %PATH% environment "
+            "variable. Prebuilt binaries can be found at: https://llvm.org/"
+            "For building the linker on your own see: https://lld.llvm.org/#build"
+        )
     if proc.returncode != 0:
         msg = "Compilation error:\n"
         msg += py_str(out)
index cb7bdcc..edc1200 100644 (file)
@@ -52,15 +52,11 @@ def find_clang(required=True):
     valid_list = [util.which(x) for x in cc_list]
     valid_list = [x for x in valid_list if x]
     if not valid_list and required:
-        raise RuntimeError(
-            "cannot find clang, candidates are: " + str(cc_list))
+        raise RuntimeError("cannot find clang, candidates are: " + str(cc_list))
     return valid_list
 
 
-def create_llvm(inputs,
-                output=None,
-                options=None,
-                cc=None):
+def create_llvm(inputs, output=None, options=None, cc=None):
     """Create llvm text ir.
 
     Parameters
@@ -103,8 +99,7 @@ def create_llvm(inputs,
         cmd += options
     cmd += ["-o", output]
     cmd += input_files
-    proc = subprocess.Popen(
-        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     if proc.returncode != 0:
         msg = "Compilation error:\n"
index 6a82681..4ef3593 100644 (file)
@@ -18,6 +18,7 @@
 import tvm._ffi
 from ..rpc import base as rpc_base
 
+
 def create(symbol, compiled_model_path, ctx):
     """Create a runtime executor module given a coreml model and context.
     Parameters
index 7b42bec..9a36fa5 100644 (file)
@@ -42,10 +42,15 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     m = rhs.shape[0] if transb else rhs.shape[1]
     dtype = dtype if dtype is not None else lhs.dtype
     return te.extern(
-        (n, m), [lhs, rhs],
+        (n, m),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.cublas.matmul",
-            ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
+            "tvm.contrib.cublas.matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        dtype=dtype,
+        name="C",
+    )
+
 
 def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
@@ -71,7 +76,11 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     m = rhs.shape[1] if transb else rhs.shape[2]
     dtype = dtype if dtype is not None else lhs.dtype
     return te.extern(
-        (b, n, m), [lhs, rhs],
+        (b, n, m),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.cublas.batch_matmul",
-            ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
+            "tvm.contrib.cublas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        dtype=dtype,
+        name="C",
+    )
index 3b36f47..1c9fe7c 100644 (file)
@@ -44,7 +44,11 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
         m = rhs.shape[0] if transb else rhs.shape[1]
     dtype = dtype if dtype is not None else lhs.dtype
     return te.extern(
-        (n, m), [lhs, rhs],
+        (n, m),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.cublaslt.matmul",
-            ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
+            "tvm.contrib.cublaslt.matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        dtype=dtype,
+        name="C",
+    )
index 0650b93..6dc04c9 100644 (file)
@@ -61,11 +61,7 @@ _BWD_DATA_ALGOS = [
     "CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT",
 ]
 
-_ALGO_TYPE = [
-    "fwd",
-    "bwd_filter",
-    "bwd_data"
-]
+_ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"]
 
 
 def algo_to_index(algo_type, algo_name):
@@ -148,12 +144,8 @@ def _get_np_int32_array_handle(arr):
     ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
     return ctypes.cast(ptr, ctypes.c_void_p)
 
-def _prepare_global_func_params(dims,
-                                pad,
-                                stride,
-                                dilation,
-                                x_shape=None,
-                                w_shape=None):
+
+def _prepare_global_func_params(dims, pad, stride, dilation, x_shape=None, w_shape=None):
     full_dims = dims + 2
     if x_shape:
         assert isinstance(x_shape, list)
@@ -162,12 +154,21 @@ def _prepare_global_func_params(dims,
         assert isinstance(w_shape, list)
         assert len(w_shape) == full_dims
 
-    pad = np.full(dims, pad, dtype=np.int32) if isinstance(pad, int) \
+    pad = (
+        np.full(dims, pad, dtype=np.int32)
+        if isinstance(pad, int)
         else np.array(pad, dtype=np.int32)
-    stride = np.full(dims, stride, dtype=np.int32) if isinstance(stride, int) \
+    )
+    stride = (
+        np.full(dims, stride, dtype=np.int32)
+        if isinstance(stride, int)
         else np.array(stride, dtype=np.int32)
-    dilation = np.full(dims, dilation, dtype=np.int32) if isinstance(dilation, int) \
+    )
+    dilation = (
+        np.full(dims, dilation, dtype=np.int32)
+        if isinstance(dilation, int)
         else np.array(dilation, dtype=np.int32)
+    )
 
     xshape = np.array(x_shape, dtype=np.int32) if x_shape else None
     wshape = np.array(w_shape, dtype=np.int32) if x_shape else None
@@ -175,15 +176,9 @@ def _prepare_global_func_params(dims,
     return pad, stride, dilation, xshape, wshape
 
 
-def conv_output_shape(tensor_format,
-                      pad,
-                      stride,
-                      dilation,
-                      x_shape,
-                      w_shape,
-                      data_dtype,
-                      conv_dtype,
-                      groups=1):
+def conv_output_shape(
+    tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1
+):
     """Get output shape of 2D or 3D convolution
 
     Paramters
@@ -217,35 +212,40 @@ def conv_output_shape(tensor_format,
     dims = len(x_shape)
     assert dims in (4, 5)
 
-    pad, stride, dilation, xshape, wshape = \
-        _prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape)
+    pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
+        dims - 2, pad, stride, dilation, x_shape, w_shape
+    )
     oshape = np.zeros((dims), dtype=np.int32)
 
     func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape")
-    func(tensor_format,
-         dims - 2,
-         _get_np_int32_array_handle(pad),
-         _get_np_int32_array_handle(stride),
-         _get_np_int32_array_handle(dilation),
-         _get_np_int32_array_handle(xshape),
-         _get_np_int32_array_handle(wshape),
-         _get_np_int32_array_handle(oshape),
-         data_dtype,
-         conv_dtype,
-         groups)
+    func(
+        tensor_format,
+        dims - 2,
+        _get_np_int32_array_handle(pad),
+        _get_np_int32_array_handle(stride),
+        _get_np_int32_array_handle(dilation),
+        _get_np_int32_array_handle(xshape),
+        _get_np_int32_array_handle(wshape),
+        _get_np_int32_array_handle(oshape),
+        data_dtype,
+        conv_dtype,
+        groups,
+    )
     return list(oshape)
 
 
-def conv_find_algo(tensor_format,
-                   pad,
-                   stride,
-                   dilation,
-                   x_shape,
-                   w_shape,
-                   y_shape,
-                   data_dtype,
-                   conv_dtype,
-                   groups=1):
+def conv_find_algo(
+    tensor_format,
+    pad,
+    stride,
+    dilation,
+    x_shape,
+    w_shape,
+    y_shape,
+    data_dtype,
+    conv_dtype,
+    groups=1,
+):
     """Choose the best algo for the given input.
 
     Paramters
@@ -281,33 +281,27 @@ def conv_find_algo(tensor_format,
     dims = len(x_shape)
     assert dims in (4, 5)
 
-    pad, stride, dilation, xshape, wshape = \
-        _prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape)
+    pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
+        dims - 2, pad, stride, dilation, x_shape, w_shape
+    )
     yshape = np.array(y_shape, dtype=np.int32)
     func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.find_algo")
-    return func(tensor_format,
-                dims - 2,
-                _get_np_int32_array_handle(pad),
-                _get_np_int32_array_handle(stride),
-                _get_np_int32_array_handle(dilation),
-                _get_np_int32_array_handle(xshape),
-                _get_np_int32_array_handle(wshape),
-                _get_np_int32_array_handle(yshape),
-                data_dtype,
-                conv_dtype,
-                groups)
-
-
-def conv_forward(x,
-                 w,
-                 pad,
-                 stride,
-                 dilation,
-                 conv_mode,
-                 tensor_format,
-                 algo,
-                 conv_dtype,
-                 groups=1):
+    return func(
+        tensor_format,
+        dims - 2,
+        _get_np_int32_array_handle(pad),
+        _get_np_int32_array_handle(stride),
+        _get_np_int32_array_handle(dilation),
+        _get_np_int32_array_handle(xshape),
+        _get_np_int32_array_handle(wshape),
+        _get_np_int32_array_handle(yshape),
+        data_dtype,
+        conv_dtype,
+        groups,
+    )
+
+
+def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, conv_dtype, groups=1):
     """Create an extern op that compute 2D or 3D convolution with CuDNN
 
     Parameters
@@ -348,15 +342,17 @@ def conv_forward(x,
     conv_dtype = x.dtype if conv_dtype is None else conv_dtype
     pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)
 
-    oshape = conv_output_shape(tensor_format,
-                               pad,
-                               stride,
-                               dilation,
-                               list(x.shape),
-                               list(w.shape),
-                               x.dtype,
-                               conv_dtype,
-                               groups)
+    oshape = conv_output_shape(
+        tensor_format,
+        pad,
+        stride,
+        dilation,
+        list(x.shape),
+        list(w.shape),
+        x.dtype,
+        conv_dtype,
+        groups,
+    )
     if algo == -1:
         # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
         # using INT8 data type, CuDNN will crash down.
@@ -364,20 +360,23 @@ def conv_forward(x,
         if tensor_format == 1 and conv_dtype == "int32":
             algo = 1
         else:
-            algo = conv_find_algo(tensor_format,
-                                  pad,
-                                  stride,
-                                  dilation,
-                                  list(x.shape),
-                                  list(w.shape),
-                                  oshape,
-                                  x.dtype,
-                                  conv_dtype,
-                                  groups)
+            algo = conv_find_algo(
+                tensor_format,
+                pad,
+                stride,
+                dilation,
+                list(x.shape),
+                list(w.shape),
+                oshape,
+                x.dtype,
+                conv_dtype,
+                groups,
+            )
 
     if dims == 4:
         return te.extern(
-            oshape, [x, w],
+            oshape,
+            [x, w],
             lambda ins, outs: tvm.tir.call_packed(
                 "tvm.contrib.cudnn.conv2d.forward",
                 conv_mode,
@@ -393,10 +392,14 @@ def conv_forward(x,
                 ins[1],
                 outs[0],
                 conv_dtype,
-                groups), name="y")
+                groups,
+            ),
+            name="y",
+        )
 
     return te.extern(
-        oshape, [x, w],
+        oshape,
+        [x, w],
         lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cudnn.conv3d.forward",
             conv_mode,
@@ -415,7 +418,11 @@ def conv_forward(x,
             ins[1],
             outs[0],
             conv_dtype,
-            groups), name="y")
+            groups,
+        ),
+        name="y",
+    )
+
 
 def softmax(x, axis=-1):
     """Compute softmax using CuDNN
@@ -434,9 +441,10 @@ def softmax(x, axis=-1):
         The result tensor
     """
     return te.extern(
-        x.shape, [x],
+        x.shape,
+        [x],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.cudnn.softmax.forward",
-            ins[0],
-            outs[0],
-            axis), name="y")
+            "tvm.contrib.cudnn.softmax.forward", ins[0], outs[0], axis
+        ),
+        name="y",
+    )
index b1fe1b6..0b9810e 100644 (file)
@@ -22,13 +22,10 @@ import numpy as np
 import tvm
 
 
-GRAPH_DUMP_FILE_NAME = '_tvmdbg_graph_dump.json'
+GRAPH_DUMP_FILE_NAME = "_tvmdbg_graph_dump.json"
 CHROME_TRACE_FILE_NAME = "_tvmdbg_execution_trace.json"
 
-ChromeTraceEvent = collections.namedtuple(
-    'ChromeTraceEvent',
-    ['ts', 'tid', 'pid', 'name', 'ph']
-)
+ChromeTraceEvent = collections.namedtuple("ChromeTraceEvent", ["ts", "tid", "pid", "name", "ph"])
 
 
 class DebugResult(object):
@@ -66,9 +63,9 @@ class DebugResult(object):
            The graph to be deployed in json format output by JSON graph.
         """
         json_obj = json.loads(graph_json)
-        self._nodes_list = json_obj['nodes']
-        self._shapes_list = json_obj['attrs']['shape']
-        self._dtype_list = json_obj['attrs']['dltype']
+        self._nodes_list = json_obj["nodes"]
+        self._shapes_list = json_obj["attrs"]["shape"]
+        self._dtype_list = json_obj["attrs"]["dltype"]
         self._update_graph_json()
         return json_obj
 
@@ -80,48 +77,42 @@ class DebugResult(object):
         for i in range(nodes_len):
             node = self._nodes_list[i]
             input_list = []
-            for input_node in node['inputs']:
-                input_list.append(self._nodes_list[input_node[0]]['name'])
-            node['inputs'] = input_list
+            for input_node in node["inputs"]:
+                input_list.append(self._nodes_list[input_node[0]]["name"])
+            node["inputs"] = input_list
             dtype = str("type: " + self._dtype_list[1][i])
-            if 'attrs' not in node:
-                node['attrs'] = {}
-                node['op'] = "param"
+            if "attrs" not in node:
+                node["attrs"] = {}
+                node["op"] = "param"
             else:
-                node['op'] = node['attrs']['func_name']
-            node['attrs'].update({"T": dtype})
-            node['shape'] = self._shapes_list[1][i]
+                node["op"] = node["attrs"]["func_name"]
+            node["attrs"].update({"T": dtype})
+            node["shape"] = self._shapes_list[1][i]
 
     def _cleanup_tensors(self):
-        """Remove the tensor dump file (graph wont be removed)
-        """
+        """Remove the tensor dump file (graph wont be removed)"""
         for filename in os.listdir(self._dump_path):
             if os.path.isfile(filename) and not filename.endswith(".json"):
                 os.remove(filename)
 
     def get_graph_nodes(self):
-        """Return the nodes list
-        """
+        """Return the nodes list"""
         return self._nodes_list
 
     def get_graph_node_shapes(self):
-        """Return the nodes shapes list
-        """
+        """Return the nodes shapes list"""
         return self._shapes_list
 
     def get_graph_node_output_num(self, node):
-        """Return the number of outputs of a node
-        """
-        return 1 if node['op'] == 'param' else int(node['attrs']['num_outputs'])
+        """Return the number of outputs of a node"""
+        return 1 if node["op"] == "param" else int(node["attrs"]["num_outputs"])
 
     def get_graph_node_dtypes(self):
-        """Return the nodes dtype list
-        """
+        """Return the nodes dtype list"""
         return self._dtype_list
 
     def get_output_tensors(self):
-        """Dump the outputs to a temporary folder, the tensors are in numpy format
-        """
+        """Dump the outputs to a temporary folder, the tensors are in numpy format"""
         eid = 0
         order = 0
         output_tensors = {}
@@ -129,15 +120,14 @@ class DebugResult(object):
             num_outputs = self.get_graph_node_output_num(node)
             for j in range(num_outputs):
                 order += time[0]
-                key = node['name'] + "_" + str(j)
+                key = node["name"] + "_" + str(j)
                 output_tensors[key] = self._output_tensor_list[eid]
                 eid += 1
         return output_tensors
 
     def dump_output_tensor(self):
-        """Dump the outputs to a temporary folder, the tensors are in numpy format
-        """
-        #cleanup existing tensors before dumping
+        """Dump the outputs to a temporary folder, the tensors are in numpy format"""
+        # cleanup existing tensors before dumping
         self._cleanup_tensors()
         eid = 0
         order = 0
@@ -146,7 +136,7 @@ class DebugResult(object):
             num_outputs = self.get_graph_node_output_num(node)
             for j in range(num_outputs):
                 order += time[0]
-                key = node['name'] + "_" + str(j) + "__" + str(order)
+                key = node["name"] + "_" + str(j) + "__" + str(order)
                 output_tensors[key] = self._output_tensor_list[eid]
                 eid += 1
 
@@ -154,8 +144,8 @@ class DebugResult(object):
             param_f.write(save_tensors(output_tensors))
 
     def dump_chrome_trace(self):
-        """Dump the trace to the Chrome trace.json format.
-        """
+        """Dump the trace to the Chrome trace.json format."""
+
         def s_to_us(t):
             return t * 10 ** 6
 
@@ -168,26 +158,27 @@ class DebugResult(object):
                     ts=s_to_us(starting_time),
                     tid=1,
                     pid=1,
-                    ph='B',
-                    name=node['name'],
+                    ph="B",
+                    name=node["name"],
                 ),
                 ChromeTraceEvent(
                     # Use start + duration instead of end to ensure precise timings.
                     ts=s_to_us(times[0] + starting_time),
                     tid=1,
                     pid=1,
-                    ph='E',
-                    name=node['name'],
+                    ph="E",
+                    name=node["name"],
                 ),
             ]
+
         events = [
-            e for (node, times, starting_time) in zip(
-                self._nodes_list, self._time_list, starting_times)
-            for e in node_to_events(node, times, starting_time)]
-        result = dict(
-            displayTimeUnit='ns',
-            traceEvents=[e._asdict() for e in events]
-        )
+            e
+            for (node, times, starting_time) in zip(
+                self._nodes_list, self._time_list, starting_times
+            )
+            for e in node_to_events(node, times, starting_time)
+        ]
+        result = dict(displayTimeUnit="ns", traceEvents=[e._asdict() for e in events])
 
         with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") as trace_f:
             json.dump(result, trace_f)
@@ -202,7 +193,7 @@ class DebugResult(object):
             name, shape and type.
         """
         graph_dump_file_name = GRAPH_DUMP_FILE_NAME
-        with open(os.path.join(self._dump_path, graph_dump_file_name), 'w') as outfile:
+        with open(os.path.join(self._dump_path, graph_dump_file_name), "w") as outfile:
             json.dump(graph, outfile, indent=4, sort_keys=False)
 
     def get_debug_result(self, sort_by_time=True):
@@ -215,16 +206,16 @@ class DebugResult(object):
         for node, time in zip(self._nodes_list, self._time_list):
             num_outputs = self.get_graph_node_output_num(node)
             for j in range(num_outputs):
-                op = node['op']
-                if node['op'] == 'param':
+                op = node["op"]
+                if node["op"] == "param":
                     eid += 1
                     continue
-                name = node['name']
+                name = node["name"]
                 shape = str(self._output_tensor_list[eid].shape)
                 time_us = round(time[0] * 1000000, 3)
                 time_percent = round(((time[0] / total_time) * 100), 3)
-                inputs = str(node['attrs']['num_inputs'])
-                outputs = str(node['attrs']['num_outputs'])
+                inputs = str(node["attrs"]["num_inputs"])
+                outputs = str(node["attrs"]["num_outputs"])
                 node_data = [name, op, time_us, time_percent, shape, inputs, outputs]
                 data.append(node_data)
                 eid += 1
@@ -248,12 +239,13 @@ class DebugResult(object):
         log.append(fmt.format(*lines))
         for row in data:
             log.append(fmt.format(*row))
-        return '\n'.join(log)
+        return "\n".join(log)
 
     def display_debug_result(self, sort_by_time=True):
         """Displays the debugger result"""
         print(self.get_debug_result(sort_by_time))
 
+
 def save_tensors(params):
     """Save parameter dictionary to binary bytes.
 
index 1f96a86..4d2fab4 100644 (file)
@@ -59,8 +59,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
     try:
         ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
         if num_rpc_ctx == len(ctx):
-            fcreate = ctx[0]._rpc_sess.get_function(
-                "tvm.graph_runtime_debug.create")
+            fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.create")
         else:
             fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create")
     except ValueError:
@@ -176,9 +175,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
         Time consumed for each execution will be set as debug output.
 
         """
-        self.debug_datum._time_list = [
-            [float(t) * 1e-6] for t in self.run_individual(10, 1, 1)
-        ]
+        self.debug_datum._time_list = [[float(t) * 1e-6] for t in self.run_individual(10, 1, 1)]
         for i, node in enumerate(self.debug_datum.get_graph_nodes()):
             num_outputs = self.debug_datum.get_graph_node_output_num(node)
             for j in range(num_outputs):
@@ -204,11 +201,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
             except KeyError:
                 node_list = output_tensors.keys()
                 raise RuntimeError(
-                    "Node "
-                    + node
-                    + " not found, available nodes are: "
-                    + str(node_list)
-                    + "."
+                    "Node " + node + " not found, available nodes are: " + str(node_list) + "."
                 )
         elif isinstance(node, int):
             output_tensors = self.debug_datum._output_tensor_list
index 7d006a1..75b37ce 100644 (file)
@@ -17,6 +17,7 @@
 """Wrapping functions to bridge frameworks with DLPack support to TVM"""
 from tvm.runtime import ndarray
 
+
 def convert_func(tvm_func, tensor_type, to_dlpack_func):
     """Convert a tvm function into one that accepts a tensor from another
        framework, provided the other framework supports DLPACK
@@ -35,12 +36,15 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
     assert callable(tvm_func)
 
     def _wrapper(*args):
-        args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))\
-            if isinstance(arg, tensor_type) else arg for arg in args)
+        args = tuple(
+            ndarray.from_dlpack(to_dlpack_func(arg)) if isinstance(arg, tensor_type) else arg
+            for arg in args
+        )
         return tvm_func(*args)
 
     return _wrapper
 
+
 def to_pytorch_func(tvm_func):
     """Convert a tvm function into one that accepts PyTorch tensors
 
@@ -57,4 +61,5 @@ def to_pytorch_func(tvm_func):
     # pylint: disable=import-outside-toplevel
     import torch
     import torch.utils.dlpack
+
     return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
index cdb8101..9603024 100644 (file)
@@ -21,6 +21,7 @@ import time
 import uuid
 import shutil
 
+
 def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
     """Downloads the file from the internet.
     Set the input options correctly to overwrite or do the size comparison
@@ -51,21 +52,22 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
     if os.path.isfile(path) and not overwrite:
         if size_compare:
             import requests
+
             file_size = os.path.getsize(path)
             res_head = requests.head(url)
             res_get = requests.get(url, stream=True)
-            if 'Content-Length' not in res_head.headers:
+            if "Content-Length" not in res_head.headers:
                 res_get = urllib2.urlopen(url)
-            url_file_size = int(res_get.headers['Content-Length'])
+            url_file_size = int(res_get.headers["Content-Length"])
             if url_file_size != file_size:
                 print("exist file got corrupted, downloading %s file freshly..." % path)
                 download(url, path, True, False)
                 return
-        print('File {} exists, skip.'.format(path))
+        print("File {} exists, skip.".format(path))
         return
 
     if verbose >= 1:
-        print('Downloading from url {} to {}'.format(url, path))
+        print("Downloading from url {} to {}".format(url, path))
 
     # Stateful start time
     start_time = time.time()
@@ -76,17 +78,18 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
     tempfile = os.path.join(dirpath, random_uuid)
 
     def _download_progress(count, block_size, total_size):
-        #pylint: disable=unused-argument
-        """Show the download progress.
-        """
+        # pylint: disable=unused-argument
+        """Show the download progress."""
         if count == 0:
             return
         duration = time.time() - start_time
         progress_size = int(count * block_size)
         speed = int(progress_size / (1024 * duration))
         percent = min(int(count * block_size * 100 / total_size), 100)
-        sys.stdout.write("\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" %
-                         (percent, progress_size / (1024.0 * 1024), speed, duration))
+        sys.stdout.write(
+            "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed"
+            % (percent, progress_size / (1024.0 * 1024), speed, duration)
+        )
         sys.stdout.flush()
 
     while retries >= 0:
@@ -109,14 +112,17 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
                 if os.path.exists(tempfile):
                     os.remove(tempfile)
                 raise err
-            print("download failed due to {}, retrying, {} attempt{} left"
-                  .format(repr(err), retries, 's' if retries > 1 else ''))
+            print(
+                "download failed due to {}, retrying, {} attempt{} left".format(
+                    repr(err), retries, "s" if retries > 1 else ""
+                )
+            )
 
 
 if "TEST_DATA_ROOT_PATH" in os.environ:
     TEST_DATA_ROOT_PATH = os.environ.get("TEST_DATA_ROOT_PATH")
 else:
-    TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
+    TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".tvm_test_data")
 os.makedirs(TEST_DATA_ROOT_PATH, exist_ok=True)
 
 
@@ -141,7 +147,7 @@ def download_testdata(url, relpath, module=None):
     """
     global TEST_DATA_ROOT_PATH
     if module is None:
-        module_path = ''
+        module_path = ""
     elif isinstance(module, str):
         module_path = module
     elif isinstance(module, (list, tuple)):
index 6e7e997..0cecc66 100644 (file)
@@ -21,10 +21,7 @@ from tvm._ffi.base import py_str
 from tvm._ffi.libinfo import find_lib_path
 
 
-def create_tvmjs_wasm(output,
-                      objects,
-                      options=None,
-                      cc="emcc"):
+def create_tvmjs_wasm(output, objects, options=None, cc="emcc"):
     """Create wasm that is supposed to run with the tvmjs.
 
     Parameters
@@ -49,7 +46,6 @@ def create_tvmjs_wasm(output,
     cmd += ["-s", "STANDALONE_WASM=1"]
     cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"]
 
-
     objects = [objects] if isinstance(objects, str) else objects
 
     with_runtime = False
@@ -69,10 +65,7 @@ def create_tvmjs_wasm(output,
     if options:
         cmd += options
 
-    proc = subprocess.Popen(
-        cmd,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
@@ -80,4 +73,5 @@ def create_tvmjs_wasm(output,
         msg += py_str(out)
         raise RuntimeError(msg)
 
+
 create_tvmjs_wasm.object_format = "bc"
index 17c331d..de941af 100644 (file)
@@ -80,12 +80,10 @@ def get_device_ctx(libmod, ctx):
     if isinstance(ctx, TVMContext):
         ctx = [ctx]
     elif not isinstance(ctx, (list, tuple)):
-        raise ValueError("ctx has to be the type of TVMContext or a list of "
-                         "TVMContext")
+        raise ValueError("ctx has to be the type of TVMContext or a list of " "TVMContext")
     for cur_ctx in ctx:
         if not isinstance(cur_ctx, TVMContext):
-            raise ValueError("ctx has to be the type of TVMContext or a list "
-                             "of TVMContext")
+            raise ValueError("ctx has to be the type of TVMContext or a list " "of TVMContext")
 
     # device_type_id[0], device_type_id[1] are used as the primary/fallback
     # context type and id. All other ones are used as device context for
@@ -96,8 +94,7 @@ def get_device_ctx(libmod, ctx):
         device_type = cur_ctx.device_type
         if device_type >= rpc_base.RPC_SESS_MASK:
             assert libmod.type_key == "rpc"
-            assert _rpc_ffi_api.SessTableIndex(
-                libmod) == cur_ctx._rpc_sess._tbl_index
+            assert _rpc_ffi_api.SessTableIndex(libmod) == cur_ctx._rpc_sess._tbl_index
             num_rpc_ctx += 1
             device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
         device_type_id.append(device_type)
@@ -246,8 +243,7 @@ class GraphModule(object):
         out : NDArray
             The output array container
         """
-        raise NotImplementedError(
-            "Please use debugger.debug_runtime as graph_runtime instead.")
+        raise NotImplementedError("Please use debugger.debug_runtime as graph_runtime instead.")
 
     def load_params(self, params_bytes):
         """Load parameters from serialized byte array of parameter dict.
index 6870e2a..ed938ba 100644 (file)
@@ -14,8 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-'''Utility for Hexagon backend'''
+# pylint: disable=invalid-name
+"""Utility for Hexagon backend"""
 
 import functools as ft
 import os
@@ -39,20 +39,24 @@ from .._ffi.registry import register_func
 #
 # Subsequent calls to 'link_shared' will use the newly registered linker.
 
-hexagon_toolchain_root = os.environ.get('HEXAGON_TOOLCHAIN') or ''  # pylint: disable=invalid-name
-hexagon_link_master = os.path.join(                                 # pylint: disable=invalid-name
-    hexagon_toolchain_root, 'bin', 'hexagon-link')
+hexagon_toolchain_root = os.environ.get("HEXAGON_TOOLCHAIN") or ""  # pylint: disable=invalid-name
+hexagon_link_master = os.path.join(  # pylint: disable=invalid-name
+    hexagon_toolchain_root, "bin", "hexagon-link"
+)
+
 
 def register_linker(f):
     """Register a function that will return the path to the Hexagon linker."""
-    return register_func('tvm.contrib.hexagon.hexagon_link', f, True)
+    return register_func("tvm.contrib.hexagon.hexagon_link", f, True)
+
 
-@register_func('tvm.contrib.hexagon.hexagon_link')
+@register_func("tvm.contrib.hexagon.hexagon_link")
 def hexagon_link():
     """Return path to the Hexagon linker."""
     return hexagon_link_master
 
-@register_func('tvm.contrib.hexagon.link_shared')
+
+@register_func("tvm.contrib.hexagon.link_shared")
 def link_shared(so_name, objs, **kwargs):
     """Link shared library on Hexagon using the registered Hexagon linker.
 
@@ -76,49 +80,67 @@ def link_shared(so_name, objs, **kwargs):
             return s.value
         assert isinstance(s, str), 'argument "' + str(s) + '" should be a string or StrImm'
         return s
+
     objs = [to_str(s) for s in objs]
 
-    linker = tvm.get_global_func('tvm.contrib.hexagon.hexagon_link')()
-    if kwargs.get('verbose'):
-        print('tvm.contrib.hexagon.link_shared:')
-        print('  Using linker:', linker)
-        print('  Library name:', so_name)
-        print('  Object files:', objs)
+    linker = tvm.get_global_func("tvm.contrib.hexagon.hexagon_link")()
+    if kwargs.get("verbose"):
+        print("tvm.contrib.hexagon.link_shared:")
+        print("  Using linker:", linker)
+        print("  Library name:", so_name)
+        print("  Object files:", objs)
     if not os.access(linker, os.X_OK):
         message = 'The linker "' + linker + '" does not exist or is not executable.'
-        if not os.environ.get('HEXAGON_TOOLCHAIN'):
-            message += ' The environment variable HEXAGON_TOOLCHAIN is unset. Please export ' + \
-                'HEXAGON_TOOLCHAIN in your environment, so that ${HEXAGON_TOOLCHAIN}/bin/' + \
-                'hexagon-link exists.'
+        if not os.environ.get("HEXAGON_TOOLCHAIN"):
+            message += (
+                " The environment variable HEXAGON_TOOLCHAIN is unset. Please export "
+                + "HEXAGON_TOOLCHAIN in your environment, so that ${HEXAGON_TOOLCHAIN}/bin/"
+                + "hexagon-link exists."
+            )
         else:
-            message += ' Please verify the value of the HEXAGON_LINKER environment variable ' + \
-                '(currently set to "' + hexagon_toolchain_root + '").'
+            message += (
+                " Please verify the value of the HEXAGON_LINKER environment variable "
+                + '(currently set to "'
+                + hexagon_toolchain_root
+                + '").'
+            )
         raise Exception(message)
 
-    libpath = os.path.join(
-        hexagon_toolchain_root, 'target', 'hexagon', 'lib', 'v66', 'G0')
+    libpath = os.path.join(hexagon_toolchain_root, "target", "hexagon", "lib", "v66", "G0")
     cc.create_shared(
-        so_name, objs,
+        so_name,
+        objs,
         # pylint: disable=bad-whitespace
-        options = ['-Bdynamic', '-shared', '-export-dynamic',
-                   os.path.join(libpath, 'pic', 'libgcc.so')],
-        cc = linker)
+        options=[
+            "-Bdynamic",
+            "-shared",
+            "-export-dynamic",
+            os.path.join(libpath, "pic", "libgcc.so"),
+        ],
+        cc=linker,
+    )
     return 0
 
 
 ### VTCM
 
-vtcm_size = 4*1024*1024  # pylint: disable=invalid-name
-@register_func('tvm.info.mem.local.vtcm')
+vtcm_size = 4 * 1024 * 1024  # pylint: disable=invalid-name
+
+
+@register_func("tvm.info.mem.local.vtcm")
 def mem_info_vtcm():
     # pylint: disable=bad-whitespace
-    return tvm.ir.make_node('MemoryInfo',
-                            unit_bits = 8,
-                            max_num_bits = vtcm_size*8,
-                            max_simd_bits = 128*8,
-                            head_address = tvm.runtime.const(100, 'uint32'))
+    return tvm.ir.make_node(
+        "MemoryInfo",
+        unit_bits=8,
+        max_num_bits=vtcm_size * 8,
+        max_simd_bits=128 * 8,
+        head_address=tvm.runtime.const(100, "uint32"),
+    )
+
 
 def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx):  # pylint: disable=unused-argument
+
     """Generic VTCM allocation
 
     Parameters
@@ -154,9 +176,9 @@ def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx):  # pylint: disa
     def visit(stmt):
         """Collect information about VTCM buffers and their alignments."""
         if isinstance(stmt, tvm.tir.AttrStmt):
-            if stmt.attr_key == 'storage_scope' and stmt.value == 'local.vtcm':
+            if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm":
                 vtcm_buffers.append(stmt.node)
-            elif stmt.attr_key == 'storage_alignment':
+            elif stmt.attr_key == "storage_alignment":
                 if not stmt.node in alignments:
                     alignments[stmt.node] = []
                 alignments[stmt.node].append(stmt.value)
@@ -164,27 +186,33 @@ def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx):  # pylint: disa
     def mutate(stmt):
         """Insert calls to VTCM allocation and deallocation routines."""
         if isinstance(stmt, tvm.tir.AttrStmt):
-            if stmt.attr_key == 'storage_scope' and stmt.value == 'local.vtcm':
+            if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm":
                 vtcm_buffers.pop()
-            elif stmt.attr_key == 'storage_alignment':
+            elif stmt.attr_key == "storage_alignment":
                 alignments[stmt.node].pop()
             return stmt
         if isinstance(stmt, tvm.tir.Allocate):
             var = stmt.buffer_var
             if var in vtcm_buffers:
-                is_null = tvm.tir.call_intrin('bool', tvm.ir.Op.get('tir.isnullptr'), var)
-                throw_error = \
-                    tvm.tir.call_intrin('int32', tvm.ir.Op.get('tir.tvm_throw_last_error'))
+                is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var)
+                throw_error = tvm.tir.call_intrin(
+                    "int32", tvm.ir.Op.get("tir.tvm_throw_last_error")
+                )
                 body_w_free = tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(get_free(var))])
-                body_w_check = \
-                    tvm.tir.IfThenElse(is_null, tvm.tir.Evaluate(throw_error), body_w_free)
-                return tvm.tir.LetStmt(stmt.buffer_var, get_alloc(stmt, buf_align(var)),
-                                       body_w_check)
+                body_w_check = tvm.tir.IfThenElse(
+                    is_null, tvm.tir.Evaluate(throw_error), body_w_free
+                )
+                return tvm.tir.LetStmt(
+                    stmt.buffer_var, get_alloc(stmt, buf_align(var)), body_w_check
+                )
             return stmt
         raise ValueError("Wrong argument type (" + type(stmt) + ") to 'mutate'")
 
-    f = func.with_body(tvm.tir.stmt_functor.ir_transform(func.body, visit, mutate,
-                                                         ['tir.Allocate', 'tir.AttrStmt']))
+    f = func.with_body(
+        tvm.tir.stmt_functor.ir_transform(
+            func.body, visit, mutate, ["tir.Allocate", "tir.AttrStmt"]
+        )
+    )
     return f
 
 
@@ -193,19 +221,26 @@ def ir_lower_vtcm():
 
     VTCM memory has to be allocated using special functions.
     """
+
     def get_alloc(stmt, align):
         assert isinstance(stmt, tvm.tir.Allocate)
-        return tvm.tir.call_extern('handle', 'HexagonBackendAllocateVTCM',
-                                   ft.reduce(lambda x, y: x*y, stmt.extents, 1), align)
+        return tvm.tir.call_extern(
+            "handle",
+            "HexagonBackendAllocateVTCM",
+            ft.reduce(lambda x, y: x * y, stmt.extents, 1),
+            align,
+        )
+
     def get_free(var):
-        return tvm.tir.call_extern('handle', 'HexagonBackendFreeVTCM', var)
+        return tvm.tir.call_extern("handle", "HexagonBackendFreeVTCM", var)
 
     # pylint: disable=bad-whitespace
-    @tvm.tir.transform.prim_func_pass(opt_level = 0, name = "Lower VTCM pass")
+    @tvm.tir.transform.prim_func_pass(opt_level=0, name="Lower VTCM pass")
     def transform(func, mod, ctx):
         return lower_vtcm_(get_alloc, get_free, 2048, func, mod, ctx)
 
     return transform
 
+
 def ir_lower_vtcm_pass():
     return [(3, ir_lower_vtcm())]
index 04e35de..112fc32 100644 (file)
@@ -42,17 +42,19 @@ def _get_np_int32_array_handle(arr):
     return ctypes.cast(ptr, ctypes.c_void_p)
 
 
-def conv2d_forward(x,
-                   w,
-                   stride_h=1,
-                   stride_w=1,
-                   pad_h=0,
-                   pad_w=0,
-                   dilation_h=1,
-                   dilation_w=1,
-                   conv_mode=0,
-                   data_type=1,
-                   group_count=1):
+def conv2d_forward(
+    x,
+    w,
+    stride_h=1,
+    stride_w=1,
+    pad_h=0,
+    pad_w=0,
+    dilation_h=1,
+    dilation_w=1,
+    conv_mode=0,
+    data_type=1,
+    group_count=1,
+):
     """Create an extern op that compute 2D convolution with MIOpen
 
     Parameters
@@ -86,34 +88,37 @@ def conv2d_forward(x,
     y: Tensor
         The result tensor
     """
-    assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
+    assert 0 <= conv_mode <= 2, "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
     if group_count > 1:
         conv_mode = 2
     oshape = np.zeros((len(x.shape)), dtype=np.int32)
     xshape = x.shape
     wshape = w.shape
     setup_func = tvm._ffi.get_global_func("tvm.contrib.miopen.conv2d.setup")
-    algo = setup_func(conv_mode,
-                      data_type,
-                      pad_h,
-                      pad_w,
-                      stride_h,
-                      stride_w,
-                      dilation_h,
-                      dilation_w,
-                      xshape[0].value,
-                      xshape[1].value,
-                      xshape[2].value,
-                      xshape[3].value,
-                      wshape[0].value,
-                      wshape[1].value,
-                      wshape[2].value,
-                      wshape[3].value,
-                      group_count,
-                      _get_np_int32_array_handle(oshape))
+    algo = setup_func(
+        conv_mode,
+        data_type,
+        pad_h,
+        pad_w,
+        stride_h,
+        stride_w,
+        dilation_h,
+        dilation_w,
+        xshape[0].value,
+        xshape[1].value,
+        xshape[2].value,
+        xshape[3].value,
+        wshape[0].value,
+        wshape[1].value,
+        wshape[2].value,
+        wshape[3].value,
+        group_count,
+        _get_np_int32_array_handle(oshape),
+    )
 
     return te.extern(
-        list(oshape), [x, w],
+        list(oshape),
+        [x, w],
         lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.miopen.conv2d.forward",
             conv_mode,
@@ -127,4 +132,7 @@ def conv2d_forward(x,
             algo,
             ins[0],
             ins[1],
-            outs[0]), name="y")
+            outs[0],
+        ),
+        name="y",
+    )
index 175db44..c6e3406 100644 (file)
@@ -48,7 +48,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
             "tvm.contrib.mkl.matmul", ins[0], ins[1], outs[0], transa, transb
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
 
 
@@ -81,7 +81,7 @@ def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs):
             "tvm.contrib.mkl.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
 
 
@@ -122,5 +122,5 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
             transb,
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
index 48ba14c..04af300 100644 (file)
@@ -48,5 +48,5 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
             "tvm.contrib.mkl.matmul", ins[0], ins[1], outs[0], transa, transb
         ),
         name="C",
-        **kwargs
+        **kwargs,
     )
index 8f310b0..eb8ad77 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 
 # pylint: disable=C0103,W0612
 
+
 def matmul(lhs, rhs, transa=False, transb=False):
     """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
 
@@ -49,12 +50,16 @@ def matmul(lhs, rhs, transa=False, transb=False):
     if transb:
         n = c
     return te.extern(
-        (m, n), [lhs, rhs],
+        (m, n),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
-        name="C")
+            "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        name="C",
+    )
+
 
-def conv2d(data, weight, pad='SAME', stride=1):
+def conv2d(data, weight, pad="SAME", stride=1):
     """
     Create an extern op that compute data * weight and return result in output
 
@@ -76,12 +81,15 @@ def conv2d(data, weight, pad='SAME', stride=1):
     """
     n, hi, wi, ci = data.shape
     co, kh, kw, ciw = weight.shape
-    padding = 0 if pad == 'SAME' else 1
+    padding = 0 if pad == "SAME" else 1
     ho = hi // stride
     wo = wi // stride
 
     return te.extern(
-        (n, ho, wo, co), [data, weight],
+        (n, ho, wo, co),
+        [data, weight],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
-        name="C")
+            "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride
+        ),
+        name="C",
+    )
index 3f05b70..6e551df 100644 (file)
@@ -51,6 +51,7 @@ def to_mxnet_func(func, const_loc=None):
     # only import mxnet when wrap get called.
     # pylint: disable=import-self, import-outside-toplevel
     import mxnet
+
     if isinstance(func, Module):
         func = func.entry_func
 
@@ -58,13 +59,14 @@ def to_mxnet_func(func, const_loc=None):
         """Get MXNet bridge function"""
         if not mxnet.base._LIB.MXTVMBridge:
             raise RuntimeError(
-                "MXTVMBridge not exist in mxnet package,"
-                " please update to latest version")
+                "MXTVMBridge not exist in mxnet package," " please update to latest version"
+            )
 
         fdict = tvm._ffi.registry.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
         ret = fdict["WrapAsyncCall"]
         ret.is_global = True
         return ret
+
     global _wrap_async
 
     if _wrap_async is None:
@@ -73,5 +75,4 @@ def to_mxnet_func(func, const_loc=None):
         tvm._ffi.registry.register_extension(mxnet.nd.NDArray)
 
     const_loc = const_loc if const_loc else []
-    return _wrap_async(func, tvm.runtime._ffi_api.TVMSetStream,
-                       len(const_loc), *const_loc)
+    return _wrap_async(func, tvm.runtime._ffi_api.TVMSetStream, len(const_loc), *const_loc)
index 66facae..275d40f 100644 (file)
@@ -23,9 +23,8 @@ import os
 from .._ffi.base import py_str
 from .cc import get_target_by_dump_machine
 
-def create_shared(output,
-                  objects,
-                  options=None):
+
+def create_shared(output, objects, options=None):
     """Create shared library.
 
     Parameters
@@ -40,8 +39,9 @@ def create_shared(output,
         The additional options.
     """
     if "TVM_NDK_CC" not in os.environ:
-        raise RuntimeError("Require environment variable TVM_NDK_CC"
-                           " to be the NDK standalone compiler")
+        raise RuntimeError(
+            "Require environment variable TVM_NDK_CC" " to be the NDK standalone compiler"
+        )
     compiler = os.environ["TVM_NDK_CC"]
     cmd = [compiler]
     cmd += ["-o", output]
@@ -54,10 +54,7 @@ def create_shared(output,
     options = options if options else ["-shared", "-fPIC", "-lm"]
     cmd += options
 
-    proc = subprocess.Popen(
-        cmd,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
@@ -68,5 +65,6 @@ def create_shared(output,
 
 # assign output format
 create_shared.output_format = "so"
-create_shared.get_target_triple = get_target_by_dump_machine(
-    os.environ["TVM_NDK_CC"]) if "TVM_NDK_CC" in os.environ else None
+create_shared.get_target_triple = (
+    get_target_by_dump_machine(os.environ["TVM_NDK_CC"]) if "TVM_NDK_CC" in os.environ else None
+)
index 1ce1dcc..010bef5 100644 (file)
@@ -26,6 +26,7 @@ def is_available():
     """
     return _initialize() == 0
 
+
 def fully_connected_inference(lhs, rhs, nthreads=1):
     """Create an extern op that compute fully connected of 1D tensor lhs and
     2D tensor rhs with nnpack.
@@ -44,10 +45,13 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
     """
     m = rhs.shape[0]
     return te.extern(
-        (m, ), [lhs, rhs],
+        (m,),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.nnpack.fully_connected_inference",
-            ins[0], ins[1], outs[0], nthreads), name="C")
+            "tvm.contrib.nnpack.fully_connected_inference", ins[0], ins[1], outs[0], nthreads
+        ),
+        name="C",
+    )
 
 
 class ConvolutionAlgorithm:
@@ -66,8 +70,8 @@ class ConvolutionTransformStrategy:
 
 
 def convolution_inference(
-        data, kernel, bias, padding, stride, nthreads=1,
-        algorithm=ConvolutionAlgorithm.AUTO):
+    data, kernel, bias, padding, stride, nthreads=1, algorithm=ConvolutionAlgorithm.AUTO
+):
     """Create an extern op to do inference convolution of 4D tensor data and
     4D tensor kernel and 1D tensor bias with nnpack.
 
@@ -101,10 +105,8 @@ def convolution_inference(
     batch, _, input_height, input_width = data.shape
     output_channels, _, kernel_height, kernel_width = kernel.shape
     idxdiv = te.indexdiv
-    output_height = idxdiv(
-        input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
-    output_width = idxdiv(
-        input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
+    output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
+    output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
 
     return te.extern(
         (batch, output_channels, output_height, output_width),
@@ -114,12 +116,23 @@ def convolution_inference(
             ins[0],
             ins[1],
             ins[2] if bias is not None else 0,
-            outs[0], padding[0], padding[1], padding[2], padding[3],
-            stride[0], stride[1], nthreads, algorithm), name="C")
+            outs[0],
+            padding[0],
+            padding[1],
+            padding[2],
+            padding[3],
+            stride[0],
+            stride[1],
+            nthreads,
+            algorithm,
+        ),
+        name="C",
+    )
+
 
 def convolution_inference_without_weight_transform(
-        data, transformed_kernel, bias, padding, stride, nthreads=1,
-        algorithm=ConvolutionAlgorithm.AUTO):
+    data, transformed_kernel, bias, padding, stride, nthreads=1, algorithm=ConvolutionAlgorithm.AUTO
+):
     """Create an extern op to do inference convolution of 4D tensor data and
     4D pre-transformed tensor kernel and 1D tensor bias with nnpack.
 
@@ -148,8 +161,7 @@ def convolution_inference_without_weight_transform(
         of FP32 elements.
     """
 
-    assert algorithm in (ConvolutionAlgorithm.WT_8x8,
-                         ConvolutionAlgorithm.WT_8x8_FP16)
+    assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
     assert isinstance(padding, list) and len(padding) == 4
     assert isinstance(stride, list) and len(stride) == 2
     batch, _, input_height, input_width = data.shape
@@ -167,13 +179,24 @@ def convolution_inference_without_weight_transform(
             ins[0],
             ins[1],
             ins[2] if bias is not None else 0,
-            outs[0], padding[0], padding[1], padding[2], padding[3],
-            stride[0], stride[1], nthreads, algorithm), name="C", dtype='float32')
+            outs[0],
+            padding[0],
+            padding[1],
+            padding[2],
+            padding[3],
+            stride[0],
+            stride[1],
+            nthreads,
+            algorithm,
+        ),
+        name="C",
+        dtype="float32",
+    )
+
 
 def convolution_inference_weight_transform(
-        kernel, nthreads=1,
-        algorithm=ConvolutionAlgorithm.AUTO,
-        dtype='float32'):
+    kernel, nthreads=1, algorithm=ConvolutionAlgorithm.AUTO, dtype="float32"
+):
     """Create an extern op to do inference convolution of 3D tensor data and
     4D tensor kernel and 1D tensor bias with nnpack.
 
@@ -199,6 +222,14 @@ def convolution_inference_weight_transform(
         [kernel],
         lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.nnpack.convolution_inference_weight_transform",
-            ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
+            ins[0],
+            outs[0],
+            nthreads,
+            algorithm,
+        ),
+        name="transform_kernel",
+        dtype=dtype,
+    )
+
 
 tvm._ffi._init_api("tvm.contrib.nnpack")
index 755493b..f958c1f 100644 (file)
@@ -28,11 +28,8 @@ from tvm.runtime import ndarray as nd
 from . import util
 from .._ffi.base import py_str
 
-def compile_cuda(code,
-                 target="ptx",
-                 arch=None,
-                 options=None,
-                 path_target=None):
+
+def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):
     """Compile cuda code with NVCC from env.
 
     Parameters
@@ -69,7 +66,7 @@ def compile_cuda(code,
     if arch is None:
         if nd.gpu(0).exist:
             # auto detect the compute arch argument
-            arch = "sm_" + "".join(nd.gpu(0).compute_version.split('.'))
+            arch = "sm_" + "".join(nd.gpu(0).compute_version.split("."))
         else:
             raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env")
 
@@ -92,8 +89,7 @@ def compile_cuda(code,
     cmd += ["-o", file_target]
     cmd += [temp_code]
 
-    proc = subprocess.Popen(
-        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
 
     (out, _) = proc.communicate()
 
@@ -105,10 +101,10 @@ def compile_cuda(code,
 
     data = bytearray(open(file_target, "rb").read())
     if not data:
-        raise RuntimeError(
-            "Compilation error: empty result is generated")
+        raise RuntimeError("Compilation error: empty result is generated")
     return data
 
+
 def find_cuda_path():
     """Utility function to find cuda path
 
@@ -120,8 +116,7 @@ def find_cuda_path():
     if "CUDA_PATH" in os.environ:
         return os.environ["CUDA_PATH"]
     cmd = ["which", "nvcc"]
-    proc = subprocess.Popen(
-        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     out = py_str(out)
     if proc.returncode == 0:
@@ -151,7 +146,7 @@ def get_cuda_version(cuda_path):
         version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt")
     try:
         with open(version_file_path) as f:
-            version_str = f.readline().replace('\n', '').replace('\r', '')
+            version_str = f.readline().replace("\n", "").replace("\r", "")
             return float(version_str.split(" ")[2][:2])
     except:
         raise RuntimeError("Cannot read cuda version file")
@@ -218,7 +213,7 @@ def parse_compute_version(compute_version):
     minor : int
         minor version number
     """
-    split_ver = compute_version.split('.')
+    split_ver = compute_version.split(".")
     try:
         major = int(split_ver[0])
         minor = int(split_ver[1])
@@ -245,6 +240,7 @@ def have_fp16(compute_version):
 
     return False
 
+
 def have_int8(compute_version):
     """Either int8 support is provided in the compute capability or not
 
@@ -259,6 +255,7 @@ def have_int8(compute_version):
 
     return False
 
+
 def have_tensorcore(compute_version):
     """Either TensorCore support is provided in the compute capability or not
 
index 98a51b8..d9dd536 100644 (file)
@@ -35,10 +35,20 @@ def _convert_to_remote(func, remote):
     return func
 
 
-def measure_bandwidth_sum(total_item, item_per_thread, stride,
-                          base_type, bits, lanes,
-                          target, target_host, remote, ctx, n_times):
-    """ measure memory bandwidth of gpu by product reduction for a given type
+def measure_bandwidth_sum(
+    total_item,
+    item_per_thread,
+    stride,
+    base_type,
+    bits,
+    lanes,
+    target,
+    target_host,
+    remote,
+    ctx,
+    n_times,
+):
+    """measure memory bandwidth of gpu by product reduction for a given type
 
     The IR for measurement is
 
@@ -85,10 +95,10 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
     k = te.reduce_axis((0, m), name="k")
 
     x = te.placeholder((n,), dtype=dtype, name="x")
-    op = te.comm_reducer(
-        lambda x, y: x*y, lambda t: tvm.tir.const(1, dtype=t), name="sum")
-    y = te.compute((n // m,),
-                   lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k))
+    op = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name="sum")
+    y = te.compute(
+        (n // m,), lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k)
+    )
     s = te.create_schedule(y.op)
 
     yo, yi = s[y].split(y.op.axis[0], target.max_num_threads)
@@ -112,9 +122,10 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
     return 1.0 * (total_item * bits / 8) / 1e9 / time
 
 
-def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
-                                target, target_host, remote, ctx, verbose=True):
-    """ measure memory bandwidth for all types
+def measure_bandwidth_all_types(
+    total_item, item_per_thread, n_times, target, target_host, remote, ctx, verbose=True
+):
+    """measure memory bandwidth for all types
 
     Parameters
     ----------
@@ -149,21 +160,31 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
                 max_speed = -1e9
                 # try different strides
                 for stride in [max_threads, total_item // (lanes * item_per_thread)]:
-                    speed = measure_bandwidth_sum(total_item, item_per_thread, stride,
-                                                  base_type, bits, lanes, target,
-                                                  target_host, remote, ctx, n_times)
+                    speed = measure_bandwidth_sum(
+                        total_item,
+                        item_per_thread,
+                        stride,
+                        base_type,
+                        bits,
+                        lanes,
+                        target,
+                        target_host,
+                        remote,
+                        ctx,
+                        n_times,
+                    )
                     max_speed = max(max_speed, speed)
                 type_name = base_type + str(bits)
                 result.append(["%sx%d" % (type_name, lanes), max_speed])
                 if verbose:
-                    logging.info("\t%-10s %.2f GBPS",
-                                 result[-1][0], result[-1][1])
+                    logging.info("\t%-10s %.2f GBPS", result[-1][0], result[-1][1])
     return result
 
 
-def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
-                        target, target_host, remote, ctx, n_times):
-    """ measure peak compute speed by computing mad for a type
+def measure_compute_mad(
+    total_item, item_per_thread, base_type, bits, lanes, target, target_host, remote, ctx, n_times
+):
+    """measure peak compute speed by computing mad for a type
 
     The IR for measurement is
 
@@ -224,16 +245,19 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
 
         idx = bx.var * max_threads + tx.var
 
-        a = ib.allocate(dtype, (1), name='a', scope='local')
-        b = ib.allocate(dtype, (1), name='b', scope='local')
+        a = ib.allocate(dtype, (1), name="a", scope="local")
+        b = ib.allocate(dtype, (1), name="b", scope="local")
 
         a[0] = outs[0].vload(idx, dtype)
         b[0] = outs[0].vload(idx, dtype)
 
-        if base_type.find('float') != -1:
+        if base_type.find("float") != -1:
+
             def mad_func(x, y):
                 return x * x + y
+
         else:
+
             def mad_func(x, y):
                 return y * y + x
 
@@ -260,9 +284,10 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
     return 1.0 * (n * item_per_thread) / 1e9 / time
 
 
-def measure_compute_all_types(total_item, item_per_thread, n_times,
-                              target, target_host, remote, ctx, verbose=True):
-    """ measure peak flops for all types
+def measure_compute_all_types(
+    total_item, item_per_thread, n_times, target, target_host, remote, ctx, verbose=True
+):
+    """measure peak flops for all types
 
     Parameters
     ----------
@@ -292,14 +317,23 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
     for base_type in ["float", "int"]:
         for bits in [16, 32, 64]:
             for lanes in [1, 2, 4, 8, 16]:
-                if base_type == 'int' and bits != 32:  # only measure int32
+                if base_type == "int" and bits != 32:  # only measure int32
                     continue
 
                 max_speed = -1e9
-                for per_thread in [item_per_thread//2, item_per_thread, item_per_thread*2]:
-                    speed = measure_compute_mad(total_item, per_thread,
-                                                base_type, bits, lanes, target,
-                                                target_host, remote, ctx, n_times)
+                for per_thread in [item_per_thread // 2, item_per_thread, item_per_thread * 2]:
+                    speed = measure_compute_mad(
+                        total_item,
+                        per_thread,
+                        base_type,
+                        bits,
+                        lanes,
+                        target,
+                        target_host,
+                        remote,
+                        ctx,
+                        n_times,
+                    )
                     max_speed = max(max_speed, speed)
                 type_name = base_type + str(bits)
                 result.append(["%sx%d" % (type_name, lanes), max_speed])
@@ -307,8 +341,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
                 unit = "GFLOPS" if base_type == "float" else "GIOPS"
 
                 if verbose:
-                    logging.info("\t%-10s %.2f %s",
-                                 result[-1][0], result[-1][1], unit)
+                    logging.info("\t%-10s %.2f %s", result[-1][0], result[-1][1], unit)
 
     return result
 
@@ -344,9 +377,11 @@ def measure_peak_all(target, target_host, host, port):
         raise RuntimeError("Unsupported target")
 
     logging.info("========== measure memory bandwidth ==========")
-    measure_bandwidth_all_types(bandwidth_total_item, bandwidth_item_per_thread,
-                                n_times, target, target_host, remote, ctx)
+    measure_bandwidth_all_types(
+        bandwidth_total_item, bandwidth_item_per_thread, n_times, target, target_host, remote, ctx
+    )
 
     logging.info("========== measure peak compute ==========")
-    measure_compute_all_types(compute_total_item, compute_item_per_thread,
-                              n_times, target, target_host, remote, ctx)
+    measure_compute_all_types(
+        compute_total_item, compute_item_per_thread, n_times, target, target_host, remote, ctx
+    )
index 6e72aba..d875046 100644 (file)
@@ -27,6 +27,7 @@ try:
 except ImportError:
     import pickle
 
+
 class Cache(object):
     """A cache object for result cache.
 
@@ -37,7 +38,9 @@ class Cache(object):
     save_at_exit: bool
         Whether save the cache to file when the program exits
     """
+
     cache_by_key = {}
+
     def __init__(self, key, save_at_exit):
         cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
         try:
@@ -63,6 +66,7 @@ class Cache(object):
             with open(self.path, "wb") as out_file:
                 pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL)
 
+
 @atexit.register
 def _atexit():
     """Save handler."""
@@ -86,6 +90,7 @@ def memoize(key, save_at_exit=False):
     fmemoize : function
         The decorator function to perform memoization.
     """
+
     def _register(f):
         """Registration function"""
         allow_types = (string_types, int, float, tuple)
index 727b68b..bbc74fc 100644 (file)
@@ -20,7 +20,7 @@ from tvm import te
 import tvm._ffi
 
 
-def randint(low, high, size, dtype='int32'):
+def randint(low, high, size, dtype="int32"):
     """Return random integers from low (inclusive) to high (exclusive).
     Return random integers from the "discrete uniform" distribution of the
     specified dtype in the "half-open" interval [low, high).
@@ -37,9 +37,15 @@ def randint(low, high, size, dtype='int32'):
     out : Tensor
         A tensor with specified size and dtype
     """
-    assert 'int' in dtype, "the type of randint output must be int or uint"
-    return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
-        "tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype)
+    assert "int" in dtype, "the type of randint output must be int or uint"
+    return te.extern(
+        size,
+        [],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.random.randint", int(low), int(high), outs[0]
+        ),
+        dtype=dtype,
+    )
 
 
 def uniform(low, high, size):
@@ -66,8 +72,14 @@ def uniform(low, high, size):
     out : Tensor
         A tensor with specified size and dtype.
     """
-    return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
-        "tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
+    return te.extern(
+        size,
+        [],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.random.uniform", float(low), float(high), outs[0]
+        ),
+        dtype="float32",
+    )
 
 
 def normal(loc, scale, size):
@@ -90,8 +102,14 @@ def normal(loc, scale, size):
     out : Tensor
         A tensor with specified size and dtype
     """
-    return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
-        "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
+    return te.extern(
+        size,
+        [],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.random.normal", float(loc), float(scale), outs[0]
+        ),
+        dtype="float32",
+    )
 
 
 tvm._ffi._init_api("tvm.contrib.random")
index 86ffaea..03ea2b5 100644 (file)
@@ -41,7 +41,10 @@ def matmul(lhs, rhs, transa=False, transb=False):
     n = lhs.shape[1] if transa else lhs.shape[0]
     m = rhs.shape[0] if transb else rhs.shape[1]
     return te.extern(
-        (n, m), [lhs, rhs],
+        (n, m),
+        [lhs, rhs],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.rocblas.matmul",
-            ins[0], ins[1], outs[0], transa, transb), name="C")
+            "tvm.contrib.rocblas.matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        name="C",
+    )
index 7d4b4a2..7b222f3 100644 (file)
@@ -54,8 +54,7 @@ def find_lld(required=True):
     valid_list = [util.which(x) for x in lld_list]
     valid_list = [x for x in valid_list if x]
     if not valid_list and required:
-        raise RuntimeError(
-            "cannot find ld.lld, candidates are: " + str(lld_list))
+        raise RuntimeError("cannot find ld.lld, candidates are: " + str(lld_list))
     return valid_list
 
 
@@ -75,10 +74,7 @@ def rocm_link(in_file, out_file, lld=None):
         we will try to guess the matched clang version.
     """
     args = [lld if lld is not None else find_lld()[0], "-shared", in_file, "-o", out_file]
-    proc = subprocess.Popen(
-        args,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
@@ -110,6 +106,7 @@ def callback_rocm_link(obj_bin):
     cobj_bin = bytearray(open(tmp_cobj, "rb").read())
     return cobj_bin
 
+
 @tvm._ffi.register_func("tvm_callback_rocm_bitcode_path")
 def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
     """Utility function to find ROCm device library bitcodes
@@ -137,7 +134,7 @@ def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
         "oclc_isa_version_906.amdgcn.bc",
         "oclc_unsafe_math_off.amdgcn.bc",
         "oclc_unsafe_math_on.amdgcn.bc",
-        "oclc_wavefrontsize64_on.amdgcn.bc"
+        "oclc_wavefrontsize64_on.amdgcn.bc",
     ]
     paths = [join(rocdl_dir, bitcode) for bitcode in bitcode_files]
     return tvm.runtime.convert([path for path in paths if exists(path)])
index c57efb6..acdd3df 100644 (file)
@@ -22,4 +22,5 @@ from ..rpc import Server, RPCSession, LocalSession, TrackerSession, connect, con
 
 warnings.warn(
     "Please use tvm.rpc instead of tvm.conrtib.rpc. tvm.contrib.rpc is going to be removed in 0.5",
-    DeprecationWarning)
+    DeprecationWarning,
+)
index 3f9bf43..b88fa4c 100644 (file)
@@ -44,10 +44,15 @@ def compile_vhls(kernel_info, device_name):
 
     sdk = os.environ.get("XILINX_SDX", None)
     xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc"
-    target = os.environ.get("XCL_TARGET",
-                            "sw_emu" if os.environ.get("XCL_EMULATION_MODE") else "hw")
-    advanced_params = ["--xp", "param:compiler.preserveHlsOutput=1",
-                       "--xp", "param:compiler.generateExtraRunData=true"]
+    target = os.environ.get(
+        "XCL_TARGET", "sw_emu" if os.environ.get("XCL_EMULATION_MODE") else "hw"
+    )
+    advanced_params = [
+        "--xp",
+        "param:compiler.preserveHlsOutput=1",
+        "--xp",
+        "param:compiler.generateExtraRunData=true",
+    ]
     platform = device_name
     if not platform:
         platform = os.environ.get("XCL_PLATFORM", os.environ.get("AWS_PLATFORM"))
@@ -56,7 +61,7 @@ def compile_vhls(kernel_info, device_name):
         raise RuntimeError("No Xilinx device specified.")
 
     tmp_xo_files = []
-    for funcname, code  in kernel_info:
+    for funcname, code in kernel_info:
         funcname = funcname.value
         code = code.value
 
@@ -67,8 +72,11 @@ def compile_vhls(kernel_info, device_name):
             out_file.write(bytes(code))
 
         # build xo
-        args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", funcname] + \
-               advanced_params + [tmp_cpp]
+        args = (
+            [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", funcname]
+            + advanced_params
+            + [tmp_cpp]
+        )
         returncode = subprocess.call(args)
         if returncode != 0:
             raise RuntimeError("Compile error")
@@ -77,8 +85,11 @@ def compile_vhls(kernel_info, device_name):
 
     # build xclbin
     tmp_xclbin = tmp_dir.relpath("output.xclbin")
-    args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin] + tmp_xo_files + \
-           advanced_params
+    args = (
+        [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin]
+        + tmp_xo_files
+        + advanced_params
+    )
     returncode = subprocess.call(args)
     if returncode != 0:
         raise RuntimeError("Link error")
index 77f84b1..c1263c4 100644 (file)
@@ -24,10 +24,12 @@ from tvm.te import tensor as _tensor
 
 
 float32 = "float32"
-itype = 'int32'
+itype = "int32"
+
 
 class CSRNDArray(object):
     """Sparse tensor object in CSR format."""
+
     def __init__(self, arg1, ctx=None, shape=None):
         """Construct a sparse matrix in CSR format.
 
@@ -54,43 +56,52 @@ class CSRNDArray(object):
             self.data = _nd.array(data, ctx)
             indices = _np.nonzero(source_array)[1].astype(itype)
             self.indices = _nd.array(indices, ctx)
-            indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
+            indptr = [0] + _np.apply_along_axis(
+                _np.count_nonzero, axis=1, arr=source_array
+            ).tolist()
             indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
             self.indptr = _nd.array(indptr, ctx)
             self.shape = source_array.shape
         else:
-            raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
-                               "or a numpy.array, can't handle type %s." % (type(arg1),))
-        self.stype = 'csr'
+            raise RuntimeError(
+                "Construct CSRNDArray with either a tuple (data, indices, indptr) "
+                "or a numpy.array, can't handle type %s." % (type(arg1),)
+            )
+        self.stype = "csr"
         self.dtype = self.data.dtype
         assert self.shape is not None
         assert isinstance(self.data, _nd.NDArray)
         assert isinstance(self.indices, _nd.NDArray)
-        assert str(self.indices.dtype) == 'int32' or \
-            str(self.indices.dtype) == 'int64', str(self.indices.dtype)
+        assert str(self.indices.dtype) == "int32" or str(self.indices.dtype) == "int64", str(
+            self.indices.dtype
+        )
         assert isinstance(self.indptr, _nd.NDArray)
-        assert str(self.indptr.dtype) == 'int32' or \
-            str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)
+        assert str(self.indptr.dtype) == "int32" or str(self.indptr.dtype) == "int64", str(
+            self.indptr.dtype
+        )
 
     def asnumpy(self):
         """Construct a full matrix and convert it to numpy array."""
         full = _np.zeros(self.shape, self.dtype)
         ridx = _np.diff(self.indptr.asnumpy())
-        ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
+        ridx = _np.hstack((_np.ones((v,), itype) * i for i, v in enumerate(ridx)))
         full[ridx, self.indices.asnumpy().astype(itype)] = self.data.asnumpy()
         return full
 
-def array(source_array, ctx=None, shape=None, stype='csr'):
+
+def array(source_array, ctx=None, shape=None, stype="csr"):
     """Construct a sparse NDArray from numpy.ndarray"""
     ret = None
-    if stype == 'csr':
+    if stype == "csr":
         ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
     else:
-        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
+        raise NotImplementedError("stype=%s is not supported yet." % (stype,))
     return ret
 
+
 class SparsePlaceholderOp(object):
     """Placeholder class for sparse tensor representations."""
+
     def __init__(self, shape, nonzeros, dtype, name):
         # pylint: disable=unused-argument
         """Contructing a bare bone structure for a sparse matrix
@@ -112,10 +123,12 @@ class SparsePlaceholderOp(object):
         self.shape = shape
         self.dtype = dtype
         self.name = name
-        self.stype = 'unknown'
+        self.stype = "unknown"
+
 
 class CSRPlaceholderOp(SparsePlaceholderOp):
     """Placeholder class for CSR based sparse tensor representation."""
+
     def __init__(self, shape, nonzeros, dtype, name):
         """Contructing a bare bone structure for a csr_matrix
 
@@ -134,14 +147,15 @@ class CSRPlaceholderOp(SparsePlaceholderOp):
             The name hint of the tensor
         """
         SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
-        self.stype = 'csr'
-        self.data = te.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
-        self.indices = te.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
-        self.indptr = te.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
+        self.stype = "csr"
+        self.data = te.placeholder((nonzeros,), dtype=dtype, name=self.name + "_data")
+        self.indices = te.placeholder((nonzeros,), dtype=itype, name=self.name + "_indices")
+        self.indptr = te.placeholder((self.shape[0] + 1,), dtype=itype, name=self.name + "_indptr")
         assert isinstance(self.data, _tensor.Tensor)
         assert isinstance(self.indices, _tensor.Tensor)
         assert isinstance(self.indptr, _tensor.Tensor)
 
+
 def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
     """Construct an empty sparse tensor object.
 
@@ -170,10 +184,10 @@ def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None
     shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
     nonzeros = 0 if nonzeros is None else nonzeros
     dtype = float32 if dtype is None else dtype
-    stype = 'csr' if stype is None else stype
+    stype = "csr" if stype is None else stype
     ret = None
-    if stype == 'csr':
+    if stype == "csr":
         ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
     else:
-        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
+        raise NotImplementedError("stype=%s is not supported yet." % (stype,))
     return ret
index 38228a9..a5d8471 100644 (file)
@@ -20,6 +20,7 @@ import os
 from . import util
 from .._ffi.base import py_str
 
+
 def optimize(spv_bin):
     """Optimize SPIRV using spirv-opt via CLI
 
@@ -45,10 +46,7 @@ def optimize(spv_bin):
     sdk = os.environ.get("VULKAN_SDK", None)
     cmd = os.path.join(sdk, "bin/spirv-opt") if sdk else "spirv-opt"
     args = [cmd, "-O", tmp_in, "-o", tmp_out]
-    proc = subprocess.Popen(
-        args,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
index f0de616..bcc34a1 100644 (file)
@@ -24,6 +24,7 @@ import subprocess
 from . import util
 from .._ffi.base import py_str
 
+
 def tar(output, files):
     """Create tarball containing all files in root.
 
@@ -48,9 +49,7 @@ def tar(output, files):
     cmd += [output]
     cmd += ["-C", temp.temp_dir]
     cmd += temp.listdir()
-    proc = subprocess.Popen(cmd,
-                            stdout=subprocess.PIPE,
-                            stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
@@ -58,6 +57,7 @@ def tar(output, files):
         msg += py_str(out)
         raise RuntimeError(msg)
 
+
 # assign output format
 tar.output_format = "tar"
 
@@ -77,9 +77,7 @@ def untar(tar_file, directory):
     cmd += ["-xf"]
     cmd += [tar_file]
     cmd += ["-C", directory]
-    proc = subprocess.Popen(cmd,
-                            stdout=subprocess.PIPE,
-                            stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
index 474c966..0f4bb66 100644 (file)
@@ -24,21 +24,14 @@ import tvm._ffi
 from ...relay.expr_functor import ExprVisitor
 from .. import xcode, coreml_runtime
 
+
 def _convert_add(builder, name, inputs, outputs, args, attrs):
-    builder.add_elementwise(
-        name=name,
-        input_names=inputs,
-        output_name=outputs[0],
-        mode='ADD'
-    )
+    builder.add_elementwise(name=name, input_names=inputs, output_name=outputs[0], mode="ADD")
+
 
 def _convert_multiply(builder, name, inputs, outputs, args, attrs):
-    builder.add_elementwise(
-        name=name,
-        input_names=inputs,
-        output_name=outputs[0],
-        mode='MULTIPLY'
-    )
+    builder.add_elementwise(name=name, input_names=inputs, output_name=outputs[0], mode="MULTIPLY")
+
 
 def _convert_clip(builder, name, inputs, outputs, args, attrs):
     builder.add_clip(
@@ -46,48 +39,38 @@ def _convert_clip(builder, name, inputs, outputs, args, attrs):
         input_name=inputs[0],
         output_name=outputs[0],
         min_value=attrs.a_min,
-        max_value=attrs.a_max
+        max_value=attrs.a_max,
     )
 
+
 def _convert_batch_flatten(builder, name, inputs, outputs, args, attrs):
-    builder.add_flatten_to_2d(
-        name=name,
-        input_name=inputs[0],
-        output_name=outputs[0]
-    )
+    builder.add_flatten_to_2d(name=name, input_name=inputs[0], output_name=outputs[0])
+
 
 def _convert_expand_dims(builder, name, inputs, outputs, args, attrs):
     if attrs.axis >= 0:
-        axes = list(range(attrs.axis, attrs.axis+attrs.num_newaxis))
+        axes = list(range(attrs.axis, attrs.axis + attrs.num_newaxis))
     else:
-        axes = list(range(attrs.axis-attrs.num_newaxis+1, attrs.axis+1))
+        axes = list(range(attrs.axis - attrs.num_newaxis + 1, attrs.axis + 1))
+
+    builder.add_expand_dims(name=name, input_name=inputs[0], output_name=outputs[0], axes=axes)
 
-    builder.add_expand_dims(
-        name=name,
-        input_name=inputs[0],
-        output_name=outputs[0],
-        axes=axes
-    )
 
 def _convert_relu(builder, name, inputs, outputs, args, attrs):
     builder.add_activation(
-        name=name,
-        non_linearity='RELU',
-        input_name=inputs[0],
-        output_name=outputs[0]
+        name=name, non_linearity="RELU", input_name=inputs[0], output_name=outputs[0]
     )
 
+
 def _convert_softmax(builder, name, inputs, outputs, args, attrs):
     builder.add_softmax_nd(
-        name=name,
-        input_name=inputs[0],
-        output_name=outputs[0],
-        axis=int(attrs['axis'])
+        name=name, input_name=inputs[0], output_name=outputs[0], axis=int(attrs["axis"])
     )
 
+
 def _convert_conv2d(builder, name, inputs, outputs, args, attrs):
     weight = args[1].data.asnumpy()
-    if attrs['kernel_layout'] == 'OIHW':
+    if attrs["kernel_layout"] == "OIHW":
         # convert to 'HWIO'
         weight = weight.transpose([2, 3, 1, 0])
     kh, kw, kc, oc = weight.shape
@@ -98,22 +81,23 @@ def _convert_conv2d(builder, name, inputs, outputs, args, attrs):
         output_channels=oc,
         height=kh,
         width=kw,
-        stride_height=int(attrs['strides'][0]),
-        stride_width=int(attrs['strides'][0]),
+        stride_height=int(attrs["strides"][0]),
+        stride_width=int(attrs["strides"][0]),
         border_mode="valid",
-        groups=int(attrs['groups']),
+        groups=int(attrs["groups"]),
         W=weight,
         b=None,
         has_bias=False,
         input_name=inputs[0],
         output_name=outputs[0],
-        dilation_factors=[int(v) for v in attrs['dilation']],
-        padding_top=int(attrs['padding'][0]),
-        padding_bottom=int(attrs['padding'][2]),
-        padding_left=int(attrs['padding'][1]),
-        padding_right=int(attrs['padding'][3])
+        dilation_factors=[int(v) for v in attrs["dilation"]],
+        padding_top=int(attrs["padding"][0]),
+        padding_bottom=int(attrs["padding"][2]),
+        padding_left=int(attrs["padding"][1]),
+        padding_right=int(attrs["padding"][3]),
     )
 
+
 def _convert_global_avg_pool2d(builder, name, inputs, outputs, args, attrs):
     builder.add_pooling(
         name=name,
@@ -121,29 +105,32 @@ def _convert_global_avg_pool2d(builder, name, inputs, outputs, args, attrs):
         width=1,
         stride_height=1,
         stride_width=1,
-        layer_type='AVERAGE',
-        padding_type='VALID',
+        layer_type="AVERAGE",
+        padding_type="VALID",
         input_name=inputs[0],
         output_name=outputs[0],
-        is_global=True
+        is_global=True,
     )
 
+
 _convert_map = {
-    'add'                       : _convert_add,
-    'multiply'                  : _convert_multiply,
-    'clip'                      : _convert_clip,
-    'expand_dims'               : _convert_expand_dims,
-    'nn.relu'                   : _convert_relu,
-    'nn.batch_flatten'          : _convert_batch_flatten,
-    'nn.softmax'                : _convert_softmax,
-    'nn.conv2d'                 : _convert_conv2d,
-    'nn.global_avg_pool2d'      : _convert_global_avg_pool2d,
+    "add": _convert_add,
+    "multiply": _convert_multiply,
+    "clip": _convert_clip,
+    "expand_dims": _convert_expand_dims,
+    "nn.relu": _convert_relu,
+    "nn.batch_flatten": _convert_batch_flatten,
+    "nn.softmax": _convert_softmax,
+    "nn.conv2d": _convert_conv2d,
+    "nn.global_avg_pool2d": _convert_global_avg_pool2d,
 }
 
+
 class CodegenCoreML(ExprVisitor):
     """
     A visitor to traverse subgraphs and build Core ML models.
     """
+
     def __init__(self, model_name, function):
         import coremltools
         from coremltools.models.neural_network import NeuralNetworkBuilder
@@ -158,10 +145,24 @@ class CodegenCoreML(ExprVisitor):
         # Update inputs and outputs after we visit all the nodes.
         # Set dummy values for now.
         # TODO: support multiple outputs
-        inputs = [('', coremltools.models.datatypes.Array(1,)) for _ in self.function.params]
-        outputs = [('', coremltools.models.datatypes.Array(1,))]
-        self.builder = NeuralNetworkBuilder(inputs, outputs,
-                                            disable_rank5_shape_mapping=True)
+        inputs = [
+            (
+                "",
+                coremltools.models.datatypes.Array(
+                    1,
+                ),
+            )
+            for _ in self.function.params
+        ]
+        outputs = [
+            (
+                "",
+                coremltools.models.datatypes.Array(
+                    1,
+                ),
+            )
+        ]
+        self.builder = NeuralNetworkBuilder(inputs, outputs, disable_rank5_shape_mapping=True)
 
     def visit_constant(self, const):
         output = "buf_" + str(self.buf_idx_)
@@ -169,7 +170,7 @@ class CodegenCoreML(ExprVisitor):
             name=output,
             output_name=output,
             constant_value=const.data.asnumpy(),
-            shape=const.data.shape
+            shape=const.data.shape,
         )
         self.buf_idx_ = self.buf_idx_ + 1
         self.out_map[const] = [output]
@@ -192,8 +193,7 @@ class CodegenCoreML(ExprVisitor):
         layer_name = op_name + "_" + str(self.buf_idx_)
 
         assert op_name in _convert_map, "{} is not supported".format(op_name)
-        _convert_map[op_name](self.builder, layer_name, inputs, outputs,
-                              call.args, call.attrs)
+        _convert_map[op_name](self.builder, layer_name, inputs, outputs, call.args, call.attrs)
 
         self.buf_idx_ = self.buf_idx_ + 1
         self.out_map[call] = outputs
index 25e9fd4..f9141a6 100644 (file)
@@ -58,14 +58,15 @@ def call_node_infer_type(node):
     elif isinstance(out_type, TupleType):
         types = list(out_type.fields)
     else:
-        raise RuntimeError("Unsupported output type %s in operator %s"
-                           % (type(out_type), node.op.nae))
+        raise RuntimeError(
+            "Unsupported output type %s in operator %s" % (type(out_type), node.op.nae)
+        )
 
     return types
 
 
 def add_input(data, name, prefix, model_container):
-    input_name = '{}_{}'.format(prefix, name)
+    input_name = "{}_{}".format(prefix, name)
     dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[data.dtype]
     tensor_value_info = onnx.helper.make_tensor_value_info(input_name, dtype, shape=data.shape)
     model_container.add_inputs([tensor_value_info])
@@ -75,258 +76,250 @@ def add_input(data, name, prefix, model_container):
 
 
 class OpConverter(object):
-    """ Operator converter Base Class.
-    """
+    """Operator converter Base Class."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         """convert Relay attributes to ONNX attributes.
-           The derived classes should implement this method
-           if attributes are required by the operator
-           otherwise by default no attributes are passed
+        The derived classes should implement this method
+        if attributes are required by the operator
+        otherwise by default no attributes are passed
         """
         return {}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
-        onnx_node = onnx.helper.make_node(cls.__name__,
-                                          node_entry['input_names'],
-                                          node_entry['output_names'],
-                                          **attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
+        onnx_node = onnx.helper.make_node(
+            cls.__name__, node_entry["input_names"], node_entry["output_names"], **attrs
+        )
         model_container.add_nodes([onnx_node])
 
 
 def rename(op_name):
-    """ This method creates dynamic operator of name op_name with empty attributes
-    """
+    """This method creates dynamic operator of name op_name with empty attributes"""
     return type(op_name, (OpConverter,), {})
 
 
 class Reshape(object):
-    """ Operator converter for Reshape.
-    """
+    """Operator converter for Reshape."""
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
         """Converts Relay operator Reshape to ONNX operator.
-           Relay operator accepts shape as attribute but ONNX operator
-           accepts it as a input.
+        Relay operator accepts shape as attribute but ONNX operator
+        accepts it as a input.
         """
-        name = node_entry['name']
-        shape = numpy.asarray([a.value for a in node_entry['relay_node'].attrs.newshape],
-                              dtype=numpy.int64)
+        name = node_entry["name"]
+        shape = numpy.asarray(
+            [a.value for a in node_entry["relay_node"].attrs.newshape], dtype=numpy.int64
+        )
 
-        input_names = [node_entry['input_names'][0],
-                       add_input(shape, name, 'shape', model_container)]
+        input_names = [
+            node_entry["input_names"][0],
+            add_input(shape, name, "shape", model_container),
+        ]
 
-        node = onnx.helper.make_node(cls.__name__, input_names,
-                                     node_entry['output_names'])
+        node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"])
         model_container.add_nodes([node])
 
 
 class Conv(OpConverter):
-    """ Operator converter for Conv.
-    """
+    """Operator converter for Conv."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'group': attrs.get_int("groups"),
-            'pads': attrs.get_int_tuple("padding"),
-            'strides': attrs.get_int_tuple("strides"),
-            'dilations': attrs.get_int_tuple("dilation"),
-            'kernel_shape': attrs.get_int_tuple("kernel_size"),
+            "group": attrs.get_int("groups"),
+            "pads": attrs.get_int_tuple("padding"),
+            "strides": attrs.get_int_tuple("strides"),
+            "dilations": attrs.get_int_tuple("dilation"),
+            "kernel_shape": attrs.get_int_tuple("kernel_size"),
         }
 
 
 class MaxPool(OpConverter):
-    """ Operator converter for MaxPool.
-    """
+    """Operator converter for MaxPool."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'pads': attrs.get_int_tuple("padding"),
-            'strides': attrs.get_int_tuple("strides"),
-            'kernel_shape': attrs.get_int_tuple("pool_size"),
+            "pads": attrs.get_int_tuple("padding"),
+            "strides": attrs.get_int_tuple("strides"),
+            "kernel_shape": attrs.get_int_tuple("pool_size"),
         }
 
 
 class Transpose(OpConverter):
-    """ Operator converter for Transpose.
-    """
+    """Operator converter for Transpose."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        return {'perm': attrs.get_int_tuple("axes")} if attrs["axes"] else {}
+        return {"perm": attrs.get_int_tuple("axes")} if attrs["axes"] else {}
 
 
 class MatMul(OpConverter):
-    """ Operator converter for MatMul.
-    """
+    """Operator converter for MatMul."""
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        inter_output_name = 'inter{}'.format(node_entry['name'])
-        transpose_node = onnx.helper.make_node(Transpose.__name__,
-                                               [node_entry['input_names'][1]],
-                                               [inter_output_name],
-                                               perm=(1, 0))
+        inter_output_name = "inter{}".format(node_entry["name"])
+        transpose_node = onnx.helper.make_node(
+            Transpose.__name__, [node_entry["input_names"][1]], [inter_output_name], perm=(1, 0)
+        )
         model_container.add_nodes([transpose_node])
 
-        inputs = [node_entry['input_names'][0], inter_output_name]
-        matmul_node = onnx.helper.make_node(cls.__name__, inputs, node_entry['output_names'])
+        inputs = [node_entry["input_names"][0], inter_output_name]
+        matmul_node = onnx.helper.make_node(cls.__name__, inputs, node_entry["output_names"])
         model_container.add_nodes([matmul_node])
 
 
 class Flatten(OpConverter):
-    """ Operator converter for Flatten.
-    """
+    """Operator converter for Flatten."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'axis': 1,
+            "axis": 1,
         }
 
 
 class BatchNormalization(OpConverter):
-    """ Operator converter for BatchNormalization.
-    """
+    """Operator converter for BatchNormalization."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'epsilon': float(attrs.get_str('epsilon')),
-            'axis': float(attrs.get_int('axis')),
+            "epsilon": float(attrs.get_str("epsilon")),
+            "axis": float(attrs.get_int("axis")),
         }
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
         """Converts Relay operator batch_norm to ONNX operator.
-           Relay operator has property axis to handle data in NHWC format.
+        Relay operator has property axis to handle data in NHWC format.
         """
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
-        transpose_out_name = node_entry['input_names'][0]
-        inter_output_names = [node_entry['output_names'][0]]
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
+        transpose_out_name = node_entry["input_names"][0]
+        inter_output_names = [node_entry["output_names"][0]]
         # axis==3 means channel is specified along the 3rd axis
-        if attrs['axis'] == 3:
-            transpose_out_name = 'transpose_{}'.format(node_entry['name'])
-            node_transposed = onnx.helper.make_node(Transpose.__name__,
-                                                    [node_entry['input_names'][0]],
-                                                    [transpose_out_name],
-                                                    perm=[0, 3, 1, 2])
+        if attrs["axis"] == 3:
+            transpose_out_name = "transpose_{}".format(node_entry["name"])
+            node_transposed = onnx.helper.make_node(
+                Transpose.__name__,
+                [node_entry["input_names"][0]],
+                [transpose_out_name],
+                perm=[0, 3, 1, 2],
+            )
             model_container.add_nodes([node_transposed])
-            inter_output_names = ['batch_norm_{}'.format(node_entry['name'])]
+            inter_output_names = ["batch_norm_{}".format(node_entry["name"])]
 
-        input_names = [transpose_out_name] + node_entry['input_names'][1:]
-        batch_norm_node = onnx.helper.make_node(cls.__name__,
-                                                input_names,
-                                                inter_output_names,
-                                                epsilon=attrs['epsilon'])
+        input_names = [transpose_out_name] + node_entry["input_names"][1:]
+        batch_norm_node = onnx.helper.make_node(
+            cls.__name__, input_names, inter_output_names, epsilon=attrs["epsilon"]
+        )
         model_container.add_nodes([batch_norm_node])
 
-        if attrs['axis'] == 3:
-            node_transposed = onnx.helper.make_node(Transpose.__name__,
-                                                    inter_output_names,
-                                                    [node_entry['output_names'][0]],
-                                                    perm=[0, 2, 3, 1])
+        if attrs["axis"] == 3:
+            node_transposed = onnx.helper.make_node(
+                Transpose.__name__,
+                inter_output_names,
+                [node_entry["output_names"][0]],
+                perm=[0, 2, 3, 1],
+            )
             model_container.add_nodes([node_transposed])
 
 
 class Dropout(OpConverter):
-    """ Operator converter for Dropout.
-    """
+    """Operator converter for Dropout."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'ratio': float(attrs.get_str('rate')),
+            "ratio": float(attrs.get_str("rate")),
         }
 
 
 class AveragePool(MaxPool):
-    """ Operator converter for AveragePool.
-    """
+    """Operator converter for AveragePool."""
 
 
 class Concat(OpConverter):
-    """ Operator converter for Concat.
-    """
+    """Operator converter for Concat."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'axis': attrs.get_int("axis"),
+            "axis": attrs.get_int("axis"),
         }
 
 
 class BiasAdd(OpConverter):
-    """ Operator converter for BiasAdd.
-    """
+    """Operator converter for BiasAdd."""
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        input_node = node_dict[node_entry['inputs'][0]]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node_entry can not be a Tuple"
         input_node = input_node[0]
-        data_ndim = len(input_node['types'][0].shape)
-        axis = node_entry['relay_node'].attrs.get_int("axis")
+        data_ndim = len(input_node["types"][0].shape)
+        axis = node_entry["relay_node"].attrs.get_int("axis")
         if axis < 0:
             axis = axis + data_ndim
         new_axes = data_ndim - axis - 1
         if new_axes:
-            inter_output_name = 'inter{}'.format(node_entry['name'])
-            unsqueeze_node = onnx.helper.make_node('Unsqueeze',
-                                                   [node_entry['input_names'][1]],
-                                                   [inter_output_name],
-                                                   axes=tuple(range(1, new_axes + 1)))
+            inter_output_name = "inter{}".format(node_entry["name"])
+            unsqueeze_node = onnx.helper.make_node(
+                "Unsqueeze",
+                [node_entry["input_names"][1]],
+                [inter_output_name],
+                axes=tuple(range(1, new_axes + 1)),
+            )
             model_container.add_nodes([unsqueeze_node])
         else:
-            inter_output_name = node_entry['input_names'][1]
+            inter_output_name = node_entry["input_names"][1]
 
-        inputs = [node_entry['input_names'][0], inter_output_name]
-        matmul_node = onnx.helper.make_node('Add', inputs, node_entry['output_names'])
+        inputs = [node_entry["input_names"][0], inter_output_name]
+        matmul_node = onnx.helper.make_node("Add", inputs, node_entry["output_names"])
         model_container.add_nodes([matmul_node])
 
 
 class ReduceMean(OpConverter):
-    """ Operator converter for ReduceMean.
-    """
+    """Operator converter for ReduceMean."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'axes': attrs.axis,
-            'keepdims': 0 if bool(attrs.get_int("keepdims", 0)) is False else 1
+            "axes": attrs.axis,
+            "keepdims": 0 if bool(attrs.get_int("keepdims", 0)) is False else 1,
         }
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        input_node = node_dict[node_entry['inputs'][0]]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        shape = input_node['types'][0].shape
-        axis = node_entry['relay_node'].attrs.axis
+        shape = input_node["types"][0].shape
+        axis = node_entry["relay_node"].attrs.axis
         axis = list(range(shape.size())) if not axis else tvm_array_to_list(axis)
-        exclude = 0 if not bool(node_entry['relay_node'].attrs.exclude) else 1
-        keepdims = 0 if not bool(node_entry['relay_node'].attrs.keepdims) else 1
+        exclude = 0 if not bool(node_entry["relay_node"].attrs.exclude) else 1
+        keepdims = 0 if not bool(node_entry["relay_node"].attrs.keepdims) else 1
         if exclude:
             all_axis = list(range(len(shape)))
             axis = set(all_axis) - set(axis)
 
-        node = onnx.helper.make_node(cls.__name__,
-                                     node_entry['input_names'],
-                                     node_entry['output_names'],
-                                     axes=axis,
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node(
+            cls.__name__,
+            node_entry["input_names"],
+            node_entry["output_names"],
+            axes=axis,
+            keepdims=keepdims,
+        )
         model_container.add_nodes([node])
 
 
 class Pad(OpConverter):
-    """ Operator converter for Pad.
-    """
+    """Operator converter for Pad."""
 
     @classmethod
     def convert_attributes(cls, attrs):
@@ -337,108 +330,103 @@ class Pad(OpConverter):
             after.append(axis_pads[1])
         pads = before + after
         pads = numpy.asarray(pads, dtype=pads[0].dtype)
-        return {
-            'pads': pads,
-            'mode': attrs.get_str('pad_mode'),
-            'constant_value': attrs.pad_value
-        }
+        return {"pads": pads, "mode": attrs.get_str("pad_mode"), "constant_value": attrs.pad_value}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
         """Converts Relay operator Pad to ONNX operator.
-           Relay operator accepts pads as attribute but ONNX operator
-           accepts it as a input.
+        Relay operator accepts pads as attribute but ONNX operator
+        accepts it as a input.
         """
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
 
-        name = node_entry['name']
-        data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64)
-        value = numpy.dtype(node_entry['types'][0].dtype).type(attrs['constant_value'])
+        name = node_entry["name"]
+        data = numpy.asarray(attrs["pads"], dtype=attrs["pads"][0].dtype).astype(numpy.int64)
+        value = numpy.dtype(node_entry["types"][0].dtype).type(attrs["constant_value"])
 
-        input_names = [node_entry['input_names'][0],
-                       add_input(data, name, 'pads', model_container),
-                       add_input(value, name, 'value', model_container)]
+        input_names = [
+            node_entry["input_names"][0],
+            add_input(data, name, "pads", model_container),
+            add_input(value, name, "value", model_container),
+        ]
 
-        node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
+        node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"])
         model_container.add_nodes([node])
 
 
 class Softmax(OpConverter):
-    """ Operator converter for SoftMax.
-    """
+    """Operator converter for SoftMax."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'axis': attrs.axis,
+            "axis": attrs.axis,
         }
 
 
 class Squeeze(OpConverter):
-    """ Operator converter for Squeeze.
-    """
+    """Operator converter for Squeeze."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'axes': attrs.axis,
+            "axes": attrs.axis,
         }
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        input_node = node_dict[node_entry['inputs'][0]]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        shape = input_node['types'][0].shape
-        axis = node_entry['relay_node'].attrs.get_int("axis")
+        shape = input_node["types"][0].shape
+        axis = node_entry["relay_node"].attrs.get_int("axis")
         if not axis:
             axis = []
             for axis_idx, val in enumerate(shape):
                 if val.value == 1:
                     axis.append(axis_idx)
         else:
-            axis = node_entry['relay_node'].attrs.get_int_tuple("axis")
+            axis = node_entry["relay_node"].attrs.get_int_tuple("axis")
 
-        node = onnx.helper.make_node(cls.__name__,
-                                     node_entry['input_names'],
-                                     node_entry['output_names'],
-                                     axes=axis)
+        node = onnx.helper.make_node(
+            cls.__name__, node_entry["input_names"], node_entry["output_names"], axes=axis
+        )
         model_container.add_nodes([node])
 
 
 class Slice(OpConverter):
-    """ Operator converter for Slice.
-    """
+    """Operator converter for Slice."""
 
     @classmethod
     def convert_attributes(cls, attrs):
         return {
-            'starts': attrs.get_int_tuple('begin'),
-            'ends': attrs.get_int_tuple('end'),
-            'steps': attrs.get_int_tuple('strides'),
-            'slice_mode': attrs.get_str('slice_mode')
+            "starts": attrs.get_int_tuple("begin"),
+            "ends": attrs.get_int_tuple("end"),
+            "steps": attrs.get_int_tuple("strides"),
+            "slice_mode": attrs.get_str("slice_mode"),
         }
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
 
-        name = node_entry['name']
-        input_node = node_dict[node_entry['inputs'][0]]
+        name = node_entry["name"]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        shape = input_node['types'][0].shape
+        shape = input_node["types"][0].shape
 
-        starts = list(attrs['starts'])
-        ends = list(attrs['ends'])
-        steps = list(attrs['steps'])
+        starts = list(attrs["starts"])
+        ends = list(attrs["ends"])
+        steps = list(attrs["steps"])
         starts += [0] * (len(shape) - len(starts))
         ends += [shape[i] + 1 for i in range(len(ends), len(shape))]
         axes = list(range(len(shape)))
 
-        if attrs['slice_mode'] == 'size':
-            ends = [starts[i] + (shape[i] + 1 if ends[i] < 0 else ends[i])
-                    for i in range(len(shape))]
+        if attrs["slice_mode"] == "size":
+            ends = [
+                starts[i] + (shape[i] + 1 if ends[i] < 0 else ends[i]) for i in range(len(shape))
+            ]
             steps = [1] * len(shape)
         else:
             steps += [1] * (len(shape) - len(steps))
@@ -449,45 +437,42 @@ class Slice(OpConverter):
         steps = numpy.asarray(steps).astype(numpy.int64)
 
         input_names = []
-        input_names.append(add_input(starts, name, 'starts', model_container))
-        input_names.append(add_input(ends, name, 'ends', model_container))
-        input_names.append(add_input(axes, name, 'axes', model_container))
-        input_names.append(add_input(steps, name, 'steps', model_container))
+        input_names.append(add_input(starts, name, "starts", model_container))
+        input_names.append(add_input(ends, name, "ends", model_container))
+        input_names.append(add_input(axes, name, "axes", model_container))
+        input_names.append(add_input(steps, name, "steps", model_container))
 
-        input_names = [node_entry['input_names'][0]] + input_names
+        input_names = [node_entry["input_names"][0]] + input_names
 
-        slice_node = onnx.helper.make_node(cls.__name__,
-                                           input_names,
-                                           node_entry['output_names'])
+        slice_node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"])
         model_container.add_nodes([slice_node])
 
 
 class Split(OpConverter):
-    """ Operator converter for Split.
-    """
+    """Operator converter for Split."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        indices_or_sections = attrs['indices_or_sections']
+        indices_or_sections = attrs["indices_or_sections"]
 
         if isinstance(indices_or_sections, (list, tvm.ir.container.Array)):
-            indices_or_sections = attrs.get_int_tuple('indices_or_sections')
+            indices_or_sections = attrs.get_int_tuple("indices_or_sections")
         if isinstance(indices_or_sections, tvm.ir.PrimExpr):
             indices_or_sections = indices_or_sections.value
 
         return {
-            'indices_or_section': indices_or_sections,
-            'axis': attrs.get_int('axis'),
+            "indices_or_section": indices_or_sections,
+            "axis": attrs.get_int("axis"),
         }
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
 
-        input_node = node_dict[node_entry['inputs'][0]]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        shape = input_node['types'][0].concrete_shape
+        shape = input_node["types"][0].concrete_shape
 
         indices_or_sect = attrs["indices_or_section"]
         axis = attrs["axis"]
@@ -505,17 +490,18 @@ class Split(OpConverter):
                 else:
                     split.append(indices_or_sect[i] - indices_or_sect[i - 1])
 
-        slice_node = onnx.helper.make_node(cls.__name__,
-                                           node_entry['input_names'],
-                                           node_entry['output_names'],
-                                           split=split,
-                                           axis=axis)
+        slice_node = onnx.helper.make_node(
+            cls.__name__,
+            node_entry["input_names"],
+            node_entry["output_names"],
+            split=split,
+            axis=axis,
+        )
         model_container.add_nodes([slice_node])
 
 
 class LayoutTransform(OpConverter):
-    """ Operator converter for Layouttransform
-    """
+    """Operator converter for Layouttransform"""
 
     @classmethod
     def convert_attributes(cls, attrs):
@@ -523,170 +509,152 @@ class LayoutTransform(OpConverter):
         dst_layout = attrs.get_str("dst_layout")
 
         perm = [src_layout.index(c) for c in dst_layout]
-        return {'perm': tuple(perm)}
+        return {"perm": tuple(perm)}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
-        onnx_node = onnx.helper.make_node("Transpose",
-                                          node_entry['input_names'],
-                                          node_entry['output_names'],
-                                          **attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
+        onnx_node = onnx.helper.make_node(
+            "Transpose", node_entry["input_names"], node_entry["output_names"], **attrs
+        )
         model_container.add_nodes([onnx_node])
 
 
 class Clip(OpConverter):
-    """ Operator converter for Clip.
-    """
+    """Operator converter for Clip."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        return {
-            'min': attrs.a_min,
-            'max': attrs.a_max
-        }
+        return {"min": attrs.a_min, "max": attrs.a_max}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
 
-        name = node_entry['name']
+        name = node_entry["name"]
 
-        min_val = numpy.asarray(attrs['min']).astype(numpy.float32)
-        max_val = numpy.asarray(attrs['max']).astype(numpy.float32)
+        min_val = numpy.asarray(attrs["min"]).astype(numpy.float32)
+        max_val = numpy.asarray(attrs["max"]).astype(numpy.float32)
 
         input_names = []
-        input_names.append(add_input(min_val, name, 'min', model_container))
-        input_names.append(add_input(max_val, name, 'max', model_container))
+        input_names.append(add_input(min_val, name, "min", model_container))
+        input_names.append(add_input(max_val, name, "max", model_container))
 
-        input_names = [node_entry['input_names'][0]] + input_names
+        input_names = [node_entry["input_names"][0]] + input_names
 
-        node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
+        node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"])
         model_container.add_nodes([node])
 
 
 class Expand(OpConverter):
-    """ Operator converter for Expand_dims.
-    """
+    """Operator converter for Expand_dims."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        return {
-            'axis': attrs.axis,
-            'num_newaxis': attrs.num_newaxis
-        }
+        return {"axis": attrs.axis, "num_newaxis": attrs.num_newaxis}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
 
-        name = node_entry['name']
+        name = node_entry["name"]
 
-        input_node = node_dict[node_entry['inputs'][0]]
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node_entry can not be a Tuple"
         input_node = input_node[0]
-        data_shape = input_node['types'][0].shape
+        data_shape = input_node["types"][0].shape
         new_shape = list(data_shape)
 
-        for _ in range(attrs['num_newaxis']):
-            new_shape.insert(attrs['axis'], 1)
+        for _ in range(attrs["num_newaxis"]):
+            new_shape.insert(attrs["axis"], 1)
 
         new_shape = numpy.asarray(new_shape).astype(numpy.int64)
         input_names = []
-        input_names.append(add_input(new_shape, name, 'shape', model_container))
+        input_names.append(add_input(new_shape, name, "shape", model_container))
 
-        input_names = [node_entry['input_names'][0]] + input_names
+        input_names = [node_entry["input_names"][0]] + input_names
 
-        node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
+        node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"])
         model_container.add_nodes([node])
 
 
 class ConstantOfShapeZeros(OpConverter):
-    """ Operator converter for ConstantOfShape.
-    """
+    """Operator converter for ConstantOfShape."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        return {
-            'value': 0
-        }
+        return {"value": 0}
 
     @classmethod
     def convert(cls, node_entry, model_container, node_dict):
-        attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
-        input_node = node_dict[node_entry['inputs'][0]]
+        attrs = cls.convert_attributes(node_entry["relay_node"].attrs)
+        input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        dtype = input_node['types'][0].dtype
+        dtype = input_node["types"][0].dtype
 
-        name = node_entry['name']
-        shape = [val.value for val in input_node['types'][0].shape]
+        name = node_entry["name"]
+        shape = [val.value for val in input_node["types"][0].shape]
         shape = numpy.asarray(shape).astype(numpy.int64)
 
         input_names = []
-        input_names.append(add_input(shape, name, 'shape', model_container))
+        input_names.append(add_input(shape, name, "shape", model_container))
 
         dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)]
-        tensor_value = onnx.helper.make_tensor("value", dtype,
-                                               [1], [attrs['value']])
+        tensor_value = onnx.helper.make_tensor("value", dtype, [1], [attrs["value"]])
 
-        node = onnx.helper.make_node('ConstantOfShape',
-                                     input_names,
-                                     node_entry['output_names'],
-                                     value=tensor_value)
+        node = onnx.helper.make_node(
+            "ConstantOfShape", input_names, node_entry["output_names"], value=tensor_value
+        )
         model_container.add_nodes([node])
 
 
 class ConstantOfShapeOnes(ConstantOfShapeZeros):
-    """ Operator converter for ConstantOfShape.
-    """
+    """Operator converter for ConstantOfShape."""
 
     @classmethod
     def convert_attributes(cls, attrs):
-        return {
-            'value': 1
-        }
+        return {"value": 1}
 
 
 relay_to_onnx_op_mapping = {
-    'reshape': Reshape,
-    'nn.conv2d': Conv,
-    'add': rename('Add'),
-    'nn.relu': rename('Relu'),
-    'transpose': Transpose,
-    'nn.dense': MatMul,
-    'nn.max_pool2d': MaxPool,
-    'nn.batch_flatten': Flatten,
-    'multiply': rename('Mul'),
-    'nn.bias_add': BiasAdd,
-    'nn.batch_norm': BatchNormalization,
-    'nn.global_avg_pool2d': rename('GlobalAveragePool'),
-    'concatenate': Concat,
-    'nn.dropout': Dropout,
-    'nn.avg_pool2d': AveragePool,
-    'divide': rename('Div'),
-    'mean': ReduceMean,
-    'nn.pad': Pad,
-    'nn.softmax': Softmax,
-    'squeeze': Squeeze,
-    'strided_slice': Slice,
-    'greater': rename('Greater'),
-    'less': rename('Less'),
-    'equal': rename('Equal'),
-    'zeros_like': ConstantOfShapeZeros,
-    'ones_like': ConstantOfShapeOnes,
-    'subtract': rename('Sub'),
-    'split': Split,
-    'exp': rename('Exp'),
-    'layout_transform': LayoutTransform,
-    'clip': Clip,
-    'expand_dims': Expand
+    "reshape": Reshape,
+    "nn.conv2d": Conv,
+    "add": rename("Add"),
+    "nn.relu": rename("Relu"),
+    "transpose": Transpose,
+    "nn.dense": MatMul,
+    "nn.max_pool2d": MaxPool,
+    "nn.batch_flatten": Flatten,
+    "multiply": rename("Mul"),
+    "nn.bias_add": BiasAdd,
+    "nn.batch_norm": BatchNormalization,
+    "nn.global_avg_pool2d": rename("GlobalAveragePool"),
+    "concatenate": Concat,
+    "nn.dropout": Dropout,
+    "nn.avg_pool2d": AveragePool,
+    "divide": rename("Div"),
+    "mean": ReduceMean,
+    "nn.pad": Pad,
+    "nn.softmax": Softmax,
+    "squeeze": Squeeze,
+    "strided_slice": Slice,
+    "greater": rename("Greater"),
+    "less": rename("Less"),
+    "equal": rename("Equal"),
+    "zeros_like": ConstantOfShapeZeros,
+    "ones_like": ConstantOfShapeOnes,
+    "subtract": rename("Sub"),
+    "split": Split,
+    "exp": rename("Exp"),
+    "layout_transform": LayoutTransform,
+    "clip": Clip,
+    "expand_dims": Expand,
 }
 
 
 class ModelContainer(object):
-    """ A container class to hold  different attributes of ONNX model graph
-    """
+    """A container class to hold  different attributes of ONNX model graph"""
 
     def __init__(self, name, opset_version):
         self._name = name
@@ -718,15 +686,11 @@ class ModelContainer(object):
     def make_model(self):
         """ Creates the onnx model from the graph """
         onnx_graph = onnx.helper.make_graph(
-            self._nodes,
-            self._name,
-            self._inputs,
-            self._outputs,
-            self._initializers
+            self._nodes, self._name, self._inputs, self._outputs, self._initializers
         )
         kwargs = {}
         kwargs["opset_imports"] = self._get_opsets()
-        kwargs["producer_name"] = 'TVM Relay'
+        kwargs["producer_name"] = "TVM Relay"
         kwargs["producer_version"] = tvm.__version__
 
         return onnx.helper.make_model(onnx_graph, **kwargs)
@@ -759,14 +723,15 @@ class RelayToONNXConverter(ExprVisitor):
 
     @classmethod
     def _get_node_entry(cls, relay_node, name):
-        return {"relay_node": relay_node,
-                "inputs": [relay_node],  # inputs in the form of relay nodes
-                "types": [],  # output types in case of call nodes else self type
-                "name": name,  # name of the node
-                "input_names": [name],  # input names in case of call nodes else self name
-                "output_names": [name],  # output names in case of call nodes else self name
-                "op": None,  # op name in case of call node else None
-               }
+        return {
+            "relay_node": relay_node,
+            "inputs": [relay_node],  # inputs in the form of relay nodes
+            "types": [],  # output types in case of call nodes else self type
+            "name": name,  # name of the node
+            "input_names": [name],  # input names in case of call nodes else self name
+            "output_names": [name],  # output names in case of call nodes else self name
+            "op": None,  # op name in case of call node else None
+        }
 
     def convert_to_onnx(self, func):
         """ Traverse Relay graph and generate a ONNX model"""
@@ -836,9 +801,9 @@ class RelayToONNXConverter(ExprVisitor):
             node_entry["input_names"].extend(input_names)
             node_entry["inputs"].extend([input_arg])
 
-        node_entry['types'] = call_node_infer_type(call)
+        node_entry["types"] = call_node_infer_type(call)
         node_entry["output_names"] = []
-        for i in range(len(node_entry['types'])):
+        for i in range(len(node_entry["types"])):
             node_entry["output_names"].append(name + str(i))
         self.last_node = call
         self._add_node(node_entry, node_index)
@@ -846,58 +811,58 @@ class RelayToONNXConverter(ExprVisitor):
 
     def _add_node(self, node_entry, idx):
         """Convert Relay operator node to ONNX operator and add it to container nodes list"""
-        if node_entry['op'].name not in relay_to_onnx_op_mapping:
-            raise NotImplementedError("Currently the operator '{0}' is "
-                                      "not supported.".format(node_entry['op'].name))
-        converter = relay_to_onnx_op_mapping[node_entry['op'].name]()
+        if node_entry["op"].name not in relay_to_onnx_op_mapping:
+            raise NotImplementedError(
+                "Currently the operator '{0}' is " "not supported.".format(node_entry["op"].name)
+            )
+        converter = relay_to_onnx_op_mapping[node_entry["op"].name]()
 
         return converter.convert(node_entry, self._mc, self._node_dict)
 
     def _add_params(self, node_entry, idx):
         """Add param value to initializer and name to inputs"""
-        param_name = node_entry['name']
-        assert param_name in self._params, "The parameter {0} is not present" \
-                                           "in params dict provided.".format(param_name)
+        param_name = node_entry["name"]
+        assert (
+            param_name in self._params
+        ), "The parameter {0} is not present" "in params dict provided.".format(param_name)
         value = self._params[param_name]
         numpy_array = value.asnumpy()
         tensor = numpy_helper.from_array(numpy_array, param_name)
         self._mc.add_initializers([tensor])
         dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy_array.dtype]
-        input = onnx.helper.make_tensor_value_info(param_name,
-                                                   dtype,
-                                                   shape=numpy_array.shape)
+        input = onnx.helper.make_tensor_value_info(param_name, dtype, shape=numpy_array.shape)
         self._mc.add_inputs([input])
 
     def _add_constant_input(self, node_entry, idx):
         """Create named input for constant and add it to container inputs.
         If input is a parameter then add to param
         """
-        node = node_entry['relay_node']
-        param_name = node_entry['name']
+        node = node_entry["relay_node"]
+        param_name = node_entry["name"]
         self._params[param_name] = node.data
         self._add_params(node_entry, idx)
 
     def _add_input(self, node_entry, idx):
         """Add input node to container inputs. If input is a parameter then add to param"""
-        if node_entry['name'] in self._params:
+        if node_entry["name"] in self._params:
             self._add_params(node_entry, idx)
         else:
-            node_type = node_entry['types'][0]
+            node_type = node_entry["types"][0]
             dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)]
-            input = onnx.helper.make_tensor_value_info(node_entry['name'],
-                                                       dtype,
-                                                       shape=node_type.concrete_shape)
+            input = onnx.helper.make_tensor_value_info(
+                node_entry["name"], dtype, shape=node_type.concrete_shape
+            )
             self._mc.add_inputs([input])
 
     def _add_output(self, node_entries):
         """Add output node to container outputs."""
 
         for node_entry in node_entries:
-            for node_type, output_name in zip(node_entry['types'], node_entry['output_names']):
+            for node_type, output_name in zip(node_entry["types"], node_entry["output_names"]):
                 dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)]
-                output = onnx.helper.make_tensor_value_info(output_name,
-                                                            dtype,
-                                                            shape=node_type.concrete_shape)
+                output = onnx.helper.make_tensor_value_info(
+                    output_name, dtype, shape=node_type.concrete_shape
+                )
                 self._mc.add_outputs([output])
 
 
@@ -932,9 +897,12 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None):
         raise NotImplementedError("Currently only opset version 11 is supported.")
 
     if opset_version > defs.onnx_opset_version():
-        raise Exception("The ONNX package installed of version {} does not support the opset "
-                        "version {}. Upgrade the ONNX package to latest version.".format(
-                            get_onnx_version(), opset_version))
+        raise Exception(
+            "The ONNX package installed of version {} does not support the opset "
+            "version {}. Upgrade the ONNX package to latest version.".format(
+                get_onnx_version(), opset_version
+            )
+        )
 
     func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir
     converter = RelayToONNXConverter(name, params, opset_version)
@@ -957,11 +925,11 @@ def onnx_compiler(func):
     name = str(func.attrs.global_symbol)
     model = to_onnx(func, {}, name)
     const_vars = [const.name for const in model.graph.initializer]
-    name_bytes = bytes(name, 'utf-8')
-    name_size = struct.pack('I', len(name_bytes))
+    name_bytes = bytes(name, "utf-8")
+    name_size = struct.pack("I", len(name_bytes))
     model_serialized = model.SerializeToString()
-    model_size = struct.pack('I', model.ByteSize())
-    data = b'' + name_size + name_bytes + model_size + model_serialized
+    model_size = struct.pack("I", model.ByteSize())
+    data = b"" + name_size + name_bytes + model_size + model_serialized
 
     runtime_func = "runtime.ONNXModuleCreate"
     fcreate = tvm._ffi.get_global_func(runtime_func)
@@ -970,7 +938,7 @@ def onnx_compiler(func):
 
 @tvm._ffi.register_func("relay.ext.onnx.save_to_file")
 def save_to_file(hex_str, path=None, fmt="onnx"):
-    """ Store the ONNX subgraphs in the path folder
+    """Store the ONNX subgraphs in the path folder
 
     :param hex_str: Subgrah names and corresponding serialized onnx hex string
     :param path: path to which ONNX files to be stored
@@ -982,12 +950,12 @@ def save_to_file(hex_str, path=None, fmt="onnx"):
     offset = 0
     while offset < len(onnx_ir):
         stop = offset + 4
-        (name_size,) = struct.unpack('I', onnx_ir[offset:stop])
-        name = onnx_ir[stop:stop + name_size].decode("utf-8")
+        (name_size,) = struct.unpack("I", onnx_ir[offset:stop])
+        name = onnx_ir[stop : stop + name_size].decode("utf-8")
         stop = stop + name_size
-        (model_size,) = struct.unpack('I', onnx_ir[stop:stop + 4])
+        (model_size,) = struct.unpack("I", onnx_ir[stop : stop + 4])
         stop = stop + 4
-        model_serialized = onnx_ir[stop:stop + model_size]
+        model_serialized = onnx_ir[stop : stop + model_size]
         offset = stop + model_size
 
         model_onnx = onnx.load_model_from_string(model_serialized)
index ae57211..10598e2 100644 (file)
@@ -26,42 +26,43 @@ import tvm
 TVMDD_TABLE_BODY_WIDTH = 30
 # Must match enum IterVarType defined in include/tvm/expr.h
 ITERVAR_TYPE_STRING_MAP = {
-    0: ('kDataPar', '#FFFFFF'),
-    1: ('kThreadIndex', '#2980B9'),
-    2: ('kCommReduce', '#FAD7A0'),
-    3: ('kOrdered', '#D35400'),
-    4: ('kOpaque', '#ABB2B9'),
-    5: ('kUnrolled', '#D2B4DE'),
-    6: ('kVectorized', '#AED6F1'),
-    7: ('kParallelized', '#F5B7B1'),
-    8: ('kTensorized', '#A9DFBF'),
+    0: ("kDataPar", "#FFFFFF"),
+    1: ("kThreadIndex", "#2980B9"),
+    2: ("kCommReduce", "#FAD7A0"),
+    3: ("kOrdered", "#D35400"),
+    4: ("kOpaque", "#ABB2B9"),
+    5: ("kUnrolled", "#D2B4DE"),
+    6: ("kVectorized", "#AED6F1"),
+    7: ("kParallelized", "#F5B7B1"),
+    8: ("kTensorized", "#A9DFBF"),
 }
 
 PALETTE = {
-    0: '#000000',
-    1: '#922B21',
-    2: '#76448A',
-    3: '#1F618D',
-    4: '#148F77',
-    5: '#B7950B',
-    6: '#AF601A',
-    7: '#F5B7B1',
-    8: '#A9DFBF',
+    0: "#000000",
+    1: "#922B21",
+    2: "#76448A",
+    3: "#1F618D",
+    4: "#148F77",
+    5: "#B7950B",
+    6: "#AF601A",
+    7: "#F5B7B1",
+    8: "#A9DFBF",
 }
 
 PALETTE_SIZE = 9
 
+
 def dom_path_to_string(dom_path, prefix=""):
     path_string = prefix
     for index in dom_path:
-        path_string = path_string + '_' + str(index)
+        path_string = path_string + "_" + str(index)
     return path_string
 
 
 def insert_dot_id(sch):
     """Insert unique ID for each node in the DOM tree.
-       They are used as Dot node ID.
-       """
+    They are used as Dot node ID.
+    """
     for stage_idx, stage in enumerate(sch["stages"]):
         dom_path = [stage_idx]
         stage["id"] = dom_path_to_string(dom_path, stage["type"])
@@ -79,7 +80,8 @@ def insert_dot_id(sch):
 
 class ObjectManager:
     """A helper class tracking schedule objects, e.g. stage, IterVar,
-       relationship, and tensor, to their DOM path."""
+    relationship, and tensor, to their DOM path."""
+
     def __init__(self, sch):
         self.dict = {}
         for stage_idx, stage in enumerate(sch.stages):
@@ -89,102 +91,111 @@ class ObjectManager:
             for rel_idx, rel in enumerate(stage.relations):
                 self.dict[rel] = [stage_idx, rel_idx]
             for tensor_idx in range(stage.op.num_outputs):
-                self.dict[frozenset({stage.op.name,
-                                     tensor_idx})] = [stage_idx, tensor_idx]
+                self.dict[frozenset({stage.op.name, tensor_idx})] = [stage_idx, tensor_idx]
 
     def get_dom_path(self, obj):
         if obj is None:
             return None
-        assert obj in self.dict, 'Node is no found.'
+        assert obj in self.dict, "Node is no found."
         return self.dict[obj]
 
 
 def get_or_create_dot_id(obj, prefix="", assert_on_missing=False):
     """If obj's ID has been registered, return it.
-       If not, either assert or create a unique and legal ID, register and
-       return it, according to assert_on_missing.
-       ID must be a unique and legal Dotty ID.
+    If not, either assert or create a unique and legal ID, register and
+    return it, according to assert_on_missing.
+    ID must be a unique and legal Dotty ID.
 
-        Parameters
-        ----------
-        obj : objet
-                    Serve as the key to the ID.
+     Parameters
+     ----------
+     obj : objet
+                 Serve as the key to the ID.
 
-        prefix : string
-                    Prefix to attach to the ID.  Usually use obj's non-unique
-                    name as prefix.
+     prefix : string
+                 Prefix to attach to the ID.  Usually use obj's non-unique
+                 name as prefix.
 
-        assert_on_missing : bool
-                    Assert or not if object doesn't have a registered ID.
+     assert_on_missing : bool
+                 Assert or not if object doesn't have a registered ID.
     """
-    prefix = prefix.replace('.', '_')
+    prefix = prefix.replace(".", "_")
     if not hasattr(get_or_create_dot_id, "obj_id_dict"):
         get_or_create_dot_id.obj_id_dict = {}
     if obj not in get_or_create_dot_id.obj_id_dict:
         if assert_on_missing:
-            assert False, 'dot_id ' + str(obj) + ' has not been registered.'
+            assert False, "dot_id " + str(obj) + " has not been registered."
         else:
             get_or_create_dot_id.obj_id_dict[obj] = prefix + hex(id(obj))
     return get_or_create_dot_id.obj_id_dict[obj]
 
 
 def get_port_id(is_input, index):
-    return 'I_' + str(index) if is_input else 'O_' + str(index)
+    return "I_" + str(index) if is_input else "O_" + str(index)
 
 
 def get_itervar_type_info(iter_type):
-    assert iter_type < len(
-        ITERVAR_TYPE_STRING_MAP), 'Unknown IterVar type: ' + str(iter_type)
+    assert iter_type < len(ITERVAR_TYPE_STRING_MAP), "Unknown IterVar type: " + str(iter_type)
     return ITERVAR_TYPE_STRING_MAP[iter_type]
 
 
 def get_itervar_label_color(itervar, iv_type):
     type_info = get_itervar_type_info(iv_type)
-    return linebrk(
-        str(itervar["name"]) + '(' + type_info[0] + ')',
-        TVMDD_TABLE_BODY_WIDTH), type_info[1]
+    return (
+        linebrk(str(itervar["name"]) + "(" + type_info[0] + ")", TVMDD_TABLE_BODY_WIDTH),
+        type_info[1],
+    )
 
 
 def linebrk(s, n):
     """ Break input string s with <br/> for every n charactors."""
-    result = ''
+    result = ""
     j = 0
     for i, c in enumerate(s):
         if j == n and i != len(s) - 1:
-            result = result + '\n'
+            result = result + "\n"
             j = 0
         j = j + 1
         result = result + c
     result = html.escape(str(result), quote=True)
-    result = result.replace('\n', '<br/>')
+    result = result.replace("\n", "<br/>")
     return result
 
 
-def create_graph(name="", rankdir='BT'):
+def create_graph(name="", rankdir="BT"):
     graph = Digraph(name=name)
-    graph.graph_attr['rankdir'] = rankdir
+    graph.graph_attr["rankdir"] = rankdir
     return graph
 
 
 def itervar_label(itervar, index, index_color, label):
-    return '<TR><TD PORT="' + itervar[
-        "id"] + '" BGCOLOR="' + index_color + '">' + str(
-            index
-        ) + '</TD><TD BGCOLOR="white" PORT="itervar">' + label + '<br/>' + str(
-            itervar["properties"]["range"]) + '</TD></TR>'
+    return (
+        '<TR><TD PORT="'
+        + itervar["id"]
+        + '" BGCOLOR="'
+        + index_color
+        + '">'
+        + str(index)
+        + '</TD><TD BGCOLOR="white" PORT="itervar">'
+        + label
+        + "<br/>"
+        + str(itervar["properties"]["range"])
+        + "</TD></TR>"
+    )
 
 
 def stage_label(stage):
-    return stage['name'] + '<br/>Scope: ' + stage['properties']['scope']
+    return stage["name"] + "<br/>Scope: " + stage["properties"]["scope"]
 
 
 def legend_label():
+    """Generate legend labels."""
     label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">'
     for iter_type in ITERVAR_TYPE_STRING_MAP:
         name, color = ITERVAR_TYPE_STRING_MAP[iter_type]
-        label += '<TR><TD BGCOLOR="' + color + '"></TD>' \
-            + '<TD BGCOLOR="white">' + name + '</TD></TR>'
-    label += '</TABLE>>'
+        label += (
+            '<TR><TD BGCOLOR="' + color + '"></TD>' + '<TD BGCOLOR="white">' + name + "</TD></TR>"
+        )
+    label += "</TABLE>>"
     return label
 
 
@@ -194,10 +205,10 @@ def leaf_itervars(stage):
 
 
 def legend_dot(g):
-    with g.subgraph(name='cluster_legend') as subgraph:
-        subgraph.attr(label='Legend')
+    with g.subgraph(name="cluster_legend") as subgraph:
+        subgraph.attr(label="Legend")
         label = legend_label()
-        subgraph.node('legend', label, shape='none', margin='0')
+        subgraph.node("legend", label, shape="none", margin="0")
 
 
 def extract_dom_for_viz(sch, need_range=True):
@@ -207,10 +218,7 @@ def extract_dom_for_viz(sch, need_range=True):
     return s
 
 
-def dump_graph(dot_string,
-               show_svg=True,
-               dot_file_path='',
-               output_dot_string=False):
+def dump_graph(dot_string, show_svg=True, dot_file_path="", output_dot_string=False):
     """Output dot_string in various formats."""
     if dot_file_path:
         try:
@@ -218,12 +226,13 @@ def dump_graph(dot_string,
             dot_file.write(dot_string)
             dot_file.close()
         except IOError:
-            print('Cannot open file: ' + dot_file_path)
+            print("Cannot open file: " + dot_file_path)
     if show_svg:
         from IPython.display import display
         from IPython.display import SVG
+
         src = Source(dot_string)
-        display(SVG(src.pipe(format='svg')))
+        display(SVG(src.pipe(format="svg")))
     if output_dot_string:
         return dot_string
     return None
@@ -232,34 +241,32 @@ def dump_graph(dot_string,
 def dump_json(sch, need_range):
     """Serialize data for visualization from a schedule in JSON format.
 
-        Parameters
-        ----------
-        sch : schedule
-                    The schedule object to serialize
+    Parameters
+    ----------
+    sch : schedule
+                The schedule object to serialize
 
-        Returns
-        -------
-        json : string
-            Serialized JSON string
+    Returns
+    -------
+    json : string
+        Serialized JSON string
     """
+
     def encode_itervar(itervar, stage, index, range_map):
         """Extract and encode IterVar visualization data to a dictionary"""
-        ivrange = range_map[
-            itervar] if range_map is not None and itervar in range_map else None
+        ivrange = range_map[itervar] if range_map is not None and itervar in range_map else None
         bind_thread = None
         tensor_intrin = None
         if itervar in stage.iter_var_attrs:
             attr = stage.iter_var_attrs[itervar]
             iv_type = attr.iter_type
             # binding
-            bind_thread = str(
-                attr.bind_thread.var) if attr.bind_thread is not None else None
+            bind_thread = str(attr.bind_thread.var) if attr.bind_thread is not None else None
             # tensorization
             if attr.tensor_intrin is not None:
                 tensor_intrin = str(attr.tensor_intrin.body)
                 # remove the final \n
-                tensor_intrin = tensor_intrin[0:-1] if tensor_intrin[
-                    -1] == "\n" else tensor_intrin
+                tensor_intrin = tensor_intrin[0:-1] if tensor_intrin[-1] == "\n" else tensor_intrin
             else:
                 tensor_intrin = None
         else:
@@ -272,13 +279,14 @@ def dump_json(sch, need_range):
             "properties": {
                 "thread": bind_thread,
                 "intrin": tensor_intrin,
-                "range": str(ivrange) if ivrange is not None else 'range(N/A)',
-            }
+                "range": str(ivrange) if ivrange is not None else "range(N/A)",
+            },
         }
         return itervar_dict
 
     def encode_itervars(stage, range_map):
         """Extract and encode IterVars visualization data from a stage to a dictionary"""
+
         def get_leaf_itervar_index(itervar, leaf_iv):
             for leaf_index, ivar in enumerate(leaf_iv):
                 if ivar == itervar:
@@ -288,15 +296,14 @@ def dump_json(sch, need_range):
         itervars = []
         for itervar in stage.all_iter_vars:
             leaf_index = get_leaf_itervar_index(itervar, stage.leaf_iter_vars)
-            itervars.append(
-                encode_itervar(itervar, stage, leaf_index, range_map))
+            itervars.append(encode_itervar(itervar, stage, leaf_index, range_map))
         return itervars
 
     def encode_itervar_relation(obj_manager, rel):
         """Extract and encode IterVar Relationship visualization data to a dictionary"""
         rel_type = type(rel)
         if rel_type is tvm.te.schedule.Split:
-            node_type = 'Split_Relation'
+            node_type = "Split_Relation"
             rel_dict = {
                 "type": node_type,
                 "parent": obj_manager.get_dom_path(rel.parent),
@@ -304,7 +311,7 @@ def dump_json(sch, need_range):
                 "inner": obj_manager.get_dom_path(rel.inner),
             }
         elif rel_type is tvm.te.schedule.Fuse:
-            node_type = 'Fuse_Relation'
+            node_type = "Fuse_Relation"
             rel_dict = {
                 "type": node_type,
                 "fused": obj_manager.get_dom_path(rel.fused),
@@ -312,7 +319,7 @@ def dump_json(sch, need_range):
                 "inner": obj_manager.get_dom_path(rel.inner),
             }
         elif rel_type is tvm.te.schedule.Singleton:
-            node_type = 'Singleton_Relation'
+            node_type = "Singleton_Relation"
             rel_dict = {
                 "type": node_type,
                 "iter": obj_manager.get_dom_path(rel.iter),
@@ -351,28 +358,20 @@ def dump_json(sch, need_range):
     def encode_stage(obj_manager, stage, range_map):
         """Extract and encode stage visualization data to a dictionary"""
         stage_dict = {
-            "type":
-            "Stage",
-            "name":
-            stage.op.name,
-            "attaching_to":
-            obj_manager.get_dom_path(stage.attach_ivar),
-            "compute":
-            str(stage.op.body) if hasattr(stage.op, 'body') else None,
+            "type": "Stage",
+            "name": stage.op.name,
+            "attaching_to": obj_manager.get_dom_path(stage.attach_ivar),
+            "compute": str(stage.op.body) if hasattr(stage.op, "body") else None,
             "properties": {
                 "scope": stage.scope,
             },
-            "all_itervars":
-            encode_itervars(stage, range_map),
-            "relations":
-            encode_itervar_relations(obj_manager, stage),
+            "all_itervars": encode_itervars(stage, range_map),
+            "relations": encode_itervar_relations(obj_manager, stage),
             "input_tensors": [
-                obj_manager.get_dom_path(
-                    frozenset({tensor.op.name, tensor.value_index}))
+                obj_manager.get_dom_path(frozenset({tensor.op.name, tensor.value_index}))
                 for tensor in stage.op.input_tensors
             ],
-            "output_tensors":
-            encode_tensors(obj_manager, stage),
+            "output_tensors": encode_tensors(obj_manager, stage),
         }
         return stage_dict
 
@@ -390,16 +389,18 @@ def dump_json(sch, need_range):
             dict : dictionary
                 A nested dictionary
         """
-        assert isinstance(sch, tvm.te.schedule.Schedule
-                          ), 'Input is not a tvm.te.schedule.Schedule object.'
+        assert isinstance(
+            sch, tvm.te.schedule.Schedule
+        ), "Input is not a tvm.te.schedule.Schedule object."
         range_map = None
         if need_range:
             try:
                 range_map = tvm.te.schedule.InferBound(sch)
             except tvm._ffi.base.TVMError as expt:
                 warnings.warn(
-                    'Ranges are not available, because InferBound fails with the following error:\n'
-                    + str(expt))
+                    "Ranges are not available, because InferBound fails with the following error:\n"
+                    + str(expt)
+                )
 
         obj_manager = ObjectManager(sch)
         stages = []
@@ -413,78 +414,87 @@ def dump_json(sch, need_range):
     return json.dumps(sch, default=lambda s: encode_schedule(s, need_range))
 
 
-def viz_schedule_tree(sch,
-                      show_svg=False,
-                      dot_file_path='',
-                      output_dot_string=False):
+def viz_schedule_tree(sch, show_svg=False, dot_file_path="", output_dot_string=False):
     """Top level API to render schedule tree
 
-        Parameters
-        ----------
-        sch : schedule
-                    The schedule object to visualize
+    Parameters
+    ----------
+    sch : schedule
+                The schedule object to visualize
 
-        show_svg : bool
-                    Display graph as SVG, useful for Jupyter notebooks.
+    show_svg : bool
+                Display graph as SVG, useful for Jupyter notebooks.
 
-        dot_file_path : string
-                    Dot file to save the graph.
+    dot_file_path : string
+                Dot file to save the graph.
 
-        output_dot_string : bool
-                    Return dot file content or an empty string.
+    output_dot_string : bool
+                Return dot file content or an empty string.
 
-        Returns
-        -------
-        dot_string : string
-            Dot file content or an empty string according to output_dot_string
+    Returns
+    -------
+    dot_string : string
+        Dot file content or an empty string according to output_dot_string
 
-        Examples
-        --------
-        The following code writes a schedule tree to a dot file.
+    Examples
+    --------
+    The following code writes a schedule tree to a dot file.
 
-        .. code-block:: python
-            tedd.viz_schedule_tree(s, dot_file_path = '/tmp/example.dot')
+    .. code-block:: python
+        tedd.viz_schedule_tree(s, dot_file_path = '/tmp/example.dot')
 
-        Use the following code to render a SVG graph in a Jupyter notebook.
+    Use the following code to render a SVG graph in a Jupyter notebook.
 
-        .. code-block:: python
-            tedd.viz_schedule_tree(s, show_svg = True)
+    .. code-block:: python
+        tedd.viz_schedule_tree(s, show_svg = True)
     """
+
     def create_schedule_tree_graph(name=""):
-        return create_graph(name=name, rankdir='BT')
+        return create_graph(name=name, rankdir="BT")
 
     def root_dot(g):
-        g.node('ROOT', 'ROOT', shape='oval', margin='0')
+        g.node("ROOT", "ROOT", shape="oval", margin="0")
 
     def stage_node_dot(g, stage):
         node_label = stage_node_label(stage)
-        g.node(stage['id'], node_label, shape='none', margin='0')
+        g.node(stage["id"], node_label, shape="none", margin="0")
 
     def stage_node_label(stage):
         """Return a html format label for the given stage."""
-        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
-            'CELLPADDING="4"> <TR><TD BGCOLOR="lightgrey" ' \
-            'COLSPAN="2" PORT="stage">' + stage_label(stage) + '</TD></TR>'
+        label = (
+            '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" '
+            'CELLPADDING="4"> <TR><TD BGCOLOR="lightgrey" '
+            'COLSPAN="2" PORT="stage">' + stage_label(stage) + "</TD></TR>"
+        )
 
         for leafiv in leaf_itervars(stage):
             iv_type = leafiv["itervar_type"]
-            var_attr_label = ''
-            if "thread" in leafiv["properties"] and \
-                    leafiv["properties"]["thread"] is not None:
-                var_attr_label = var_attr_label + "<br/><font color=\"#2980B9\">(" + str(
-                    leafiv["properties"]["thread"]) + ")</font>"
-            if "intrin" in leafiv["properties"] and \
-                    leafiv["properties"]["intrin"] is not None:
-                var_attr_label = var_attr_label + "<br/>" + \
-                    linebrk("(tensor_intrin:" + str(
-                        leafiv["properties"]["intrin"]) + ")", TVMDD_TABLE_BODY_WIDTH)
+            var_attr_label = ""
+            if "thread" in leafiv["properties"] and leafiv["properties"]["thread"] is not None:
+                var_attr_label = (
+                    var_attr_label
+                    + '<br/><font color="#2980B9">('
+                    + str(leafiv["properties"]["thread"])
+                    + ")</font>"
+                )
+            if "intrin" in leafiv["properties"] and leafiv["properties"]["intrin"] is not None:
+                var_attr_label = (
+                    var_attr_label
+                    + "<br/>"
+                    + linebrk(
+                        "(tensor_intrin:" + str(leafiv["properties"]["intrin"]) + ")",
+                        TVMDD_TABLE_BODY_WIDTH,
+                    )
+                )
             var_label, color = get_itervar_label_color(leafiv, iv_type)
-            label += itervar_label(leafiv, leafiv["index"], color,
-                                   var_label + var_attr_label)
+            label += itervar_label(leafiv, leafiv["index"], color, var_label + var_attr_label)
         if stage["compute"] is not None:
-            label += '<TR><TD COLSPAN="2">' + linebrk(str(
-                stage["compute"]), TVMDD_TABLE_BODY_WIDTH) + '</TD></TR>'
-        label += '</TABLE>>'
+            label += (
+                '<TR><TD COLSPAN="2">'
+                + linebrk(str(stage["compute"]), TVMDD_TABLE_BODY_WIDTH)
+                + "</TD></TR>"
+            )
+        label += "</TABLE>>"
         return label
 
     def compute_at_dot(g, stage):
@@ -492,260 +502,268 @@ def viz_schedule_tree(sch,
         stage to its attach point; otherwise, create an edge to the ROOT.
         """
         src = stage["id"]
-        dst = dom_path_to_string(
-            [stage["attaching_to"][0]], "Stage") + ":" + dom_path_to_string(
-                stage["attaching_to"],
-                "IterVar") if stage["attaching_to"] is not None else "ROOT"
-        color = PALETTE[
-            stage["attaching_to"][1] +
-            1] if stage["attaching_to"] is not None and stage["attaching_to"][
-                1] < PALETTE_SIZE - 1 else PALETTE[0]
+        dst = (
+            dom_path_to_string([stage["attaching_to"][0]], "Stage")
+            + ":"
+            + dom_path_to_string(stage["attaching_to"], "IterVar")
+            if stage["attaching_to"] is not None
+            else "ROOT"
+        )
+        color = (
+            PALETTE[stage["attaching_to"][1] + 1]
+            if stage["attaching_to"] is not None and stage["attaching_to"][1] < PALETTE_SIZE - 1
+            else PALETTE[0]
+        )
         g.edge(src, dst, color=color)
 
     graph = create_schedule_tree_graph("Schedule Tree")
     s = extract_dom_for_viz(sch)
     legend_dot(graph)
-    for stage in s['stages']:
+    for stage in s["stages"]:
         stage_node_dot(graph, stage)
-    for stage in s['stages']:
+    for stage in s["stages"]:
         compute_at_dot(graph, stage)
     root_dot(graph)
     return dump_graph(graph.source, show_svg, dot_file_path, output_dot_string)
 
 
-def viz_itervar_relationship_graph(sch,
-                                   show_svg=False,
-                                   dot_file_path='',
-                                   output_dot_string=False):
+def viz_itervar_relationship_graph(sch, show_svg=False, dot_file_path="", output_dot_string=False):
     """Top level API to render IterVar relationship graph
 
-        Parameters
-        ----------
-        sch : schedule
-                    The schedule object to visualize
+    Parameters
+    ----------
+    sch : schedule
+                The schedule object to visualize
 
-        show_svg : bool
-                    Display graph as SVG, useful for Jupyter notebooks.
+    show_svg : bool
+                Display graph as SVG, useful for Jupyter notebooks.
 
-        dot_file_path : string
-                    Dot file to save the graph.
+    dot_file_path : string
+                Dot file to save the graph.
 
-        output_dot_string : bool
-                    Return dot file content or an empty string.
+    output_dot_string : bool
+                Return dot file content or an empty string.
 
-        Examples
-        --------
-        The following code writes Ian tervar relationship graph to a dot file.
+    Examples
+    --------
+    The following code writes Ian tervar relationship graph to a dot file.
 
-        .. code-block:: python
-            tedd.viz_def viz_itervar_relationship_graph(sch,
-                (s, dot_file_path = '/tmp/example.dot')
+    .. code-block:: python
+        tedd.viz_def viz_itervar_relationship_graph(sch,
+            (s, dot_file_path = '/tmp/example.dot')
 
-        Use the following code to render a SVG graph in a Jupyter notebook.
+    Use the following code to render a SVG graph in a Jupyter notebook.
 
-        .. code-block:: python
-            tedd.viz_def viz_itervar_relationship_graph(sch,
-                (s, show_svg = True)
+    .. code-block:: python
+        tedd.viz_def viz_itervar_relationship_graph(sch,
+            (s, show_svg = True)
     """
+
     def create_itervar_relation_graph(name=""):
-        return create_graph(name=name, rankdir='TB')
+        return create_graph(name=name, rankdir="TB")
 
     def itervar_node_dot(g, itervar, iv_type, index):
         label = itervar_node_label(itervar, iv_type, index)
-        g.node(itervar["id"], label, shape='none', margin='0')
+        g.node(itervar["id"], label, shape="none", margin="0")
 
     def itervar_node_label(itervar, iv_type, index):
-        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
-            'CELLPADDING="4">' + itervar_label(
-                itervar, index,
+        label = (
+            '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" '
+            'CELLPADDING="4">'
+            + itervar_label(
+                itervar,
+                index,
                 get_itervar_label_color(itervar, iv_type)[1],
-                get_itervar_label_color(itervar, iv_type)[0]) + '</TABLE>>'
+                get_itervar_label_color(itervar, iv_type)[0],
+            )
+            + "</TABLE>>"
+        )
         return label
 
-    def itervar_relation_node_dot(g, node_id, node_label, input_ports,
-                                  output_ports):
-        label = itervar_relation_node_label(node_label, input_ports,
-                                            output_ports)
-        g.node(node_id, label, shape='none', margin='0')
+    def itervar_relation_node_dot(g, node_id, node_label, input_ports, output_ports):
+        label = itervar_relation_node_label(node_label, input_ports, output_ports)
+        g.node(node_id, label, shape="none", margin="0")
 
     def itervar_relation_node_label(node_label, input_ports, output_ports):
         """Return a html format label for an itervar relationship node
         including node_label and input/output ports.
         """
-        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
-            'CELLPADDING="4">' + '<TR>'
+        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' 'CELLPADDING="4">' + "<TR>"
         max_port_num = max(len(input_ports), len(output_ports))
         for i in range(max_port_num):
             if i < len(input_ports):
                 input_port = input_ports[i]
-                label += '<TD BGCOLOR="lightgrey" PORT="' + input_port + '">' \
-                    + input_port + '</TD>'
+                label += '<TD BGCOLOR="lightgrey" PORT="' + input_port + '">' + input_port + "</TD>"
             else:
                 label += '<TD BGCOLOR="white"></TD>'
-        label += '</TR>'
-        label += '<TR><TD BGCOLOR="white" COLSPAN="' + str(
-            max_port_num) + '" PORT="relation">' + node_label + '</TD></TR>'
-        label += '<TR>'
+        label += "</TR>"
+        label += (
+            '<TR><TD BGCOLOR="white" COLSPAN="'
+            + str(max_port_num)
+            + '" PORT="relation">'
+            + node_label
+            + "</TD></TR>"
+        )
+        label += "<TR>"
         for i in range(max_port_num):
             if i < len(output_ports):
                 output_port = output_ports[i]
-                label += '<TD BGCOLOR="lightgrey" PORT="' + output_port + '">' \
-                    + output_port + '</TD>'
+                label += (
+                    '<TD BGCOLOR="lightgrey" PORT="' + output_port + '">' + output_port + "</TD>"
+                )
             else:
                 label += '<TD BGCOLOR="white"></TD>'
-        label += '</TR>'
-        label += '</TABLE>>'
+        label += "</TR>"
+        label += "</TABLE>>"
         return label
 
     def itervar_relation_dot(g, node, node_id):
         """Create an itervar relationship node."""
         node_type = node["type"]
         if node_type == "Split_Relation":
-            node_type = 'Split'
-            itervar_relation_node_dot(g, node_id, node_type, ['Input'],
-                                      ['Outer', 'Inner'])
+            node_type = "Split"
+            itervar_relation_node_dot(g, node_id, node_type, ["Input"], ["Outer", "Inner"])
             parent = dom_path_to_string(node["parent"], "IterVar")
             outer = dom_path_to_string(node["outer"], "IterVar")
             inner = dom_path_to_string(node["inner"], "IterVar")
-            g.edge(parent + ':itervar', node_id + ':Input')
-            g.edge(node_id + ':Outer', outer + ':itervar')
-            g.edge(node_id + ':Inner', inner + ':itervar')
+            g.edge(parent + ":itervar", node_id + ":Input")
+            g.edge(node_id + ":Outer", outer + ":itervar")
+            g.edge(node_id + ":Inner", inner + ":itervar")
         elif node_type == "Fuse_Relation":
-            node_type = 'Fuse'
-            itervar_relation_node_dot(g, node_id, node_type,
-                                      ['Outer', 'Inner'], ['Fused'])
+            node_type = "Fuse"
+            itervar_relation_node_dot(g, node_id, node_type, ["Outer", "Inner"], ["Fused"])
             fused = dom_path_to_string(node["fused"], "IterVar")
             outer = dom_path_to_string(node["outer"], "IterVar")
             inner = dom_path_to_string(node["inner"], "IterVar")
-            g.edge(outer + ':itervar', node_id + ':Outer')
-            g.edge(inner + ':itervar', node_id + ':Inner')
-            g.edge(node_id + ':Fused', fused + ':itervar')
+            g.edge(outer + ":itervar", node_id + ":Outer")
+            g.edge(inner + ":itervar", node_id + ":Inner")
+            g.edge(node_id + ":Fused", fused + ":itervar")
         elif node_type == "Singleton_Relation":
-            node_type = 'Singleton'
-            itervar_relation_node_dot(g, node_id, node_type, [], ['Iter'])
+            node_type = "Singleton"
+            itervar_relation_node_dot(g, node_id, node_type, [], ["Iter"])
             itervar = dom_path_to_string(node["inner"], "IterVar")
-            g.edge(node_id + ':Iter', itervar + ':itervar')
+            g.edge(node_id + ":Iter", itervar + ":itervar")
         else:
-            assert False, 'Unknown IterVarRelationNode: ' + node_type
+            assert False, "Unknown IterVarRelationNode: " + node_type
 
     def stage_node_dot(g, stage):
         """Create a stage node."""
-        with g.subgraph(name='cluster_' + stage["id"]) as subgraph:
+        with g.subgraph(name="cluster_" + stage["id"]) as subgraph:
             subgraph.attr(label=stage["name"])
             if stage["all_itervars"]:
                 for itervar in stage["all_itervars"]:
                     iv_type = itervar["itervar_type"]
-                    itervar_node_dot(subgraph, itervar, iv_type,
-                                     itervar["index"])
+                    itervar_node_dot(subgraph, itervar, iv_type, itervar["index"])
                 for rel in stage["relations"]:
                     node_id = rel["id"]
                     itervar_relation_dot(subgraph, rel, node_id)
             else:
-                subgraph.node(stage["name"] + '_placeholder', style='invis')
+                subgraph.node(stage["name"] + "_placeholder", style="invis")
 
     graph = create_itervar_relation_graph("IterVar Relationship Graph")
     s = extract_dom_for_viz(sch)
     legend_dot(graph)
-    for stage in s['stages']:
+    for stage in s["stages"]:
         stage_node_dot(graph, stage)
 
     return dump_graph(graph.source, show_svg, dot_file_path, output_dot_string)
 
 
-def viz_dataflow_graph(sch,
-                       show_svg=False,
-                       dot_file_path='',
-                       output_dot_string=False):
+def viz_dataflow_graph(sch, show_svg=False, dot_file_path="", output_dot_string=False):
     """Top level API to render dataflow graph
 
-        Parameters
-        ----------
-        sch : schedule
-                    The schedule object to visualize
+    Parameters
+    ----------
+    sch : schedule
+                The schedule object to visualize
+
+    show_svg : bool
+                Display graph as SVG, useful for Jupyter notebooks.
 
-        show_svg : bool
-                    Display graph as SVG, useful for Jupyter notebooks.
+    dot_file_path : string
+                Dot file to save the graph.
 
-        dot_file_path : string
-                    Dot file to save the graph.
+    output_dot_string : bool
+                Return dot file content or an empty string.
 
-        output_dot_string : bool
-                    Return dot file content or an empty string.
+    Examples
+    --------
+    The following code writes a dataflow graph to a dot file.
 
-        Examples
-        --------
-        The following code writes a dataflow graph to a dot file.
+    .. code-block:: python
+        tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/example.dot')
 
-        .. code-block:: python
-            tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/example.dot')
+    Use the following code to render a SVG graph in a Jupyter notebook.
 
-        Use the following code to render a SVG graph in a Jupyter notebook.
+    .. code-block:: python
+        tedd.viz_dataflow_graph(s, show_svg = True)"""
 
-        .. code-block:: python
-            tedd.viz_dataflow_graph(s, show_svg = True)    """
     def create_dataflow_graph(name=""):
-        return create_graph(name=name, rankdir='LR')
+        return create_graph(name=name, rankdir="LR")
 
     def tensor_node_dot(g, tensor):
         """Create a tensor node."""
         label = tensor_node_label(tensor)
-        g.node(tensor["id"], label, shape='oval', margin='0')
+        g.node(tensor["id"], label, shape="oval", margin="0")
 
     def tensor_node_label(tensor):
         """Return a html format label for the given tensor."""
-        label = str(tensor["shape"]) + '\n' + str(tensor["data_type"])
+        label = str(tensor["shape"]) + "\n" + str(tensor["data_type"])
         return label
 
     def stage_node_dot(g, stage):
         """Create a stage node."""
         label = stage_node_label(stage)
-        g.node(stage["id"], label, shape='none', margin='0')
+        g.node(stage["id"], label, shape="none", margin="0")
 
     def stage_node_label(stage):
         """Return a html format label for the given stage."""
-        rows = max(
-            1, max(len(stage["output_tensors"]), len(stage["input_tensors"])))
-        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
-            'CELLPADDING="4">'
+        rows = max(1, max(len(stage["output_tensors"]), len(stage["input_tensors"])))
+        label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' 'CELLPADDING="4">'
         for i in range(rows):
-            label += '<TR>'
+            label += "<TR>"
             if i < len(stage["input_tensors"]):
                 port_id = get_port_id(True, i)
-                label += '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' \
-                    + port_id + '">' + str(
-                        i) + '</TD>'
+                label += (
+                    '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' + port_id + '">' + str(i) + "</TD>"
+                )
             else:
                 label += '<TD BGCOLOR="white" COLSPAN="2"></TD>'
             if i == 0:
-                label += '<TD BGCOLOR="white" COLSPAN="2" ROWSPAN="' + str(
-                    rows) + '">' + stage_label(stage) + '</TD>'
+                label += (
+                    '<TD BGCOLOR="white" COLSPAN="2" ROWSPAN="'
+                    + str(rows)
+                    + '">'
+                    + stage_label(stage)
+                    + "</TD>"
+                )
             if i < len(stage["output_tensors"]):
                 port_id = get_port_id(False, i)
-                label += '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' \
-                    + port_id + '">' + str(
-                        i) + '</TD>'
+                label += (
+                    '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' + port_id + '">' + str(i) + "</TD>"
+                )
             else:
                 label += '<TD BGCOLOR="white" COLSPAN="2"></TD>'
-            label += '</TR>'
-        label += '</TABLE>>'
+            label += "</TR>"
+        label += "</TABLE>>"
         return label
 
     def dfg_dot(g, sch):
         """Create edges among stages."""
-        stages = sch['stages']
+        stages = sch["stages"]
         for stage in stages:
             for i in range(len(stage["input_tensors"])):
                 src = dom_path_to_string(stage["input_tensors"][i], "Tensor")
-                dst = stage["id"] + ':' + get_port_id(True, i)
+                dst = stage["id"] + ":" + get_port_id(True, i)
                 g.edge(src, dst)
             for i in range(len(stage["output_tensors"])):
-                src = stage["id"] + ':' + get_port_id(False, i)
+                src = stage["id"] + ":" + get_port_id(False, i)
                 dst = stage["output_tensors"][i]["id"]
                 g.edge(src, dst)
 
     graph = create_dataflow_graph("Dataflow Graph")
     s = extract_dom_for_viz(sch, need_range=False)
-    for stage in s['stages']:
+    for stage in s["stages"]:
         stage_node_dot(graph, stage)
         for tensor in stage["output_tensors"]:
             tensor_node_dot(graph, tensor)
index 7daf45f..2572d5b 100644 (file)
@@ -72,13 +72,15 @@ class TensorFunc:
         self.tvm_dso_op = self.module.tvm_dso_op
 
     def apply(self, *params):
-        return self.tvm_dso_op(params,
-                               dynamic_output_shape=self.dynamic_output_shape,
-                               static_output_shape=self.static_output_shape,
-                               has_static_output_shape=self.has_static_output_shape,
-                               lib_path=self.lib_path,
-                               func_name=self.func_name,
-                               output_dtype=self.output_dtype)
+        return self.tvm_dso_op(
+            params,
+            dynamic_output_shape=self.dynamic_output_shape,
+            static_output_shape=self.static_output_shape,
+            has_static_output_shape=self.has_static_output_shape,
+            lib_path=self.lib_path,
+            func_name=self.func_name,
+            output_dtype=self.output_dtype,
+        )
 
     def __call__(self, *params):
         return self.apply(*params)
index 985c747..92501f9 100644 (file)
@@ -18,7 +18,8 @@
 import tvm._ffi
 from ..rpc import base as rpc_base
 
-def create(tflite_model_bytes, ctx, runtime_target='cpu'):
+
+def create(tflite_model_bytes, ctx, runtime_target="cpu"):
     """Create a runtime executor module given a tflite model and context.
     Parameters
     ----------
@@ -36,7 +37,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'):
     """
     device_type = ctx.device_type
 
-    if runtime_target == 'edge_tpu':
+    if runtime_target == "edge_tpu":
         runtime_func = "tvm.edgetpu_runtime.create"
     else:
         runtime_func = "tvm.tflite_runtime.create"
index 8f6dfc7..f3397ce 100644 (file)
@@ -22,6 +22,7 @@ import os
 import tempfile
 import threading
 import shutil
+
 try:
     import fcntl
 except ImportError:
@@ -31,6 +32,7 @@ except ImportError:
 class DirectoryCreatedPastAtExit(Exception):
     """Raised when a TempDirectory is created after the atexit hook runs."""
 
+
 class TempDirectory(object):
     """Helper object to manage temp directory during testing.
 
@@ -44,6 +46,7 @@ class TempDirectory(object):
     # In debug mode, each tempdir is named after the sequence
     _NUM_TEMPDIR_CREATED = 0
     _NUM_TEMPDIR_CREATED_LOCK = threading.Lock()
+
     @classmethod
     def _increment_num_tempdir_created(cls):
         with cls._NUM_TEMPDIR_CREATED_LOCK:
@@ -53,20 +56,23 @@ class TempDirectory(object):
         return to_return
 
     _DEBUG_PARENT_DIR = None
+
     @classmethod
     def _get_debug_parent_dir(cls):
         if cls._DEBUG_PARENT_DIR is None:
-            all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs'
+            all_parents = f"{tempfile.gettempdir()}/tvm-debug-mode-tempdirs"
             if not os.path.isdir(all_parents):
                 os.makedirs(all_parents)
             cls._DEBUG_PARENT_DIR = tempfile.mkdtemp(
-                prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents)
+                prefix=datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S___"), dir=all_parents
+            )
         return cls._DEBUG_PARENT_DIR
 
     TEMPDIRS = set()
+
     @classmethod
     def remove_tempdirs(cls):
-        temp_dirs = getattr(cls, 'TEMPDIRS', None)
+        temp_dirs = getattr(cls, "TEMPDIRS", None)
         if temp_dirs is None:
             return
 
@@ -97,7 +103,7 @@ class TempDirectory(object):
         else:
             if self._created_with_keep_for_debug:
                 parent_dir = self._get_debug_parent_dir()
-                self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}'
+                self.temp_dir = f"{parent_dir}/{self._increment_num_tempdir_created():05d}"
                 os.mkdir(self.temp_dir)
             else:
                 self.temp_dir = tempfile.mkdtemp()
@@ -114,7 +120,7 @@ class TempDirectory(object):
             self.temp_dir = None
 
     def __del__(self):
-        temp_dirs = getattr(self, 'TEMPDIRS', None)
+        temp_dirs = getattr(self, "TEMPDIRS", None)
         if temp_dirs is None:
             # Do nothing if the atexit hook has already run.
             return
@@ -174,12 +180,12 @@ class FileLock(object):
     path : str
         The path to the lock
     """
+
     def __init__(self, path):
         self.lock_file = open(path, "w")
         if fcntl:
             fcntl.lockf(self.lock_file, fcntl.LOCK_EX)
 
-
     def release(self):
         """Release the lock"""
         if self.lock_file:
index dd067c3..13bd747 100644 (file)
@@ -25,6 +25,7 @@ import json
 from .._ffi.base import py_str
 from . import util
 
+
 def xcrun(cmd):
     """Run xcrun and return the output.
 
@@ -39,9 +40,7 @@ def xcrun(cmd):
         The output string.
     """
     cmd = ["xcrun"] + cmd
-    proc = subprocess.Popen(cmd,
-                            stdout=subprocess.PIPE,
-                            stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     return out.strip()
 
@@ -57,14 +56,11 @@ def codesign(lib):
     lib : The path to the library.
     """
     if "TVM_IOS_CODESIGN" not in os.environ:
-        raise RuntimeError("Require environment variable TVM_IOS_CODESIGN "
-                           " to be the signature")
+        raise RuntimeError("Require environment variable TVM_IOS_CODESIGN " " to be the signature")
     signature = os.environ["TVM_IOS_CODESIGN"]
     cmd = ["codesign", "--force", "--sign", signature]
     cmd += [lib]
-    proc = subprocess.Popen(cmd,
-                            stdout=subprocess.PIPE,
-                            stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
     if proc.returncode != 0:
         msg = "Codesign error:\n"
@@ -104,9 +100,7 @@ def create_dylib(output, objects, arch, sdk="macosx"):
     else:
         cmd += objects
 
-    proc = subprocess.Popen(
-        cmd, stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
     (out, _) = proc.communicate()
 
     if proc.returncode != 0:
@@ -118,6 +112,7 @@ def create_dylib(output, objects, arch, sdk="macosx"):
 # assign so as default output format
 create_dylib.output_format = "dylib"
 
+
 def compile_metal(code, path_target=None, sdk="macosx"):
     """Compile metal with CLI tool from env.
 
@@ -156,10 +151,11 @@ def compile_metal(code, path_target=None, sdk="macosx"):
     cmd2 = ["xcrun", "-sdk", sdk, "metallib"]
     cmd2 += [temp_ir, "-o", file_target]
     proc = subprocess.Popen(
-        ' '.join(cmd1) + ";" + ' '.join(cmd2),
+        " ".join(cmd1) + ";" + " ".join(cmd2),
         shell=True,
         stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT)
+        stderr=subprocess.STDOUT,
+    )
     (out, _) = proc.communicate()
     if proc.returncode != 0:
         sys.stderr.write("Compilation error:\n")
@@ -172,14 +168,10 @@ def compile_metal(code, path_target=None, sdk="macosx"):
 
 
 def compile_coreml(model, model_name="main", out_dir="."):
-    """Compile coreml model and return the compiled model path.
-    """
+    """Compile coreml model and return the compiled model path."""
     mlmodel_path = os.path.join(out_dir, model_name + ".mlmodel")
     mlmodelc_path = os.path.join(out_dir, model_name + ".mlmodelc")
-    metadata = {
-        "inputs": list(model.input_description),
-        "outputs": list(model.output_description)
-    }
+    metadata = {"inputs": list(model.input_description), "outputs": list(model.output_description)}
     # Use the description field to send info to CoreML runtime
     model.short_description = json.dumps(metadata)
     model.save(mlmodel_path)
@@ -202,23 +194,18 @@ class XCodeRPCServer(object):
     lock: FileLock
        Lock on the path
     """
+
     def __init__(self, cmd, lock):
         self.proc = subprocess.Popen(cmd)
         self.lock = lock
 
     def join(self):
-        """Wait server to finish and release its resource
-        """
+        """Wait server to finish and release its resource"""
         self.proc.wait()
         self.lock.release()
 
 
-def popen_test_rpc(host,
-                   port,
-                   key,
-                   destination,
-                   libs=None,
-                   options=None):
+def popen_test_rpc(host, port, key, destination, libs=None, options=None):
     """Launch rpc server via xcodebuild test through another process.
 
     Parameters
@@ -255,8 +242,10 @@ def popen_test_rpc(host,
         rpc_root = os.path.join(curr_path, "../../../apps/ios_rpc")
     proj_path = os.path.realpath(os.path.join(rpc_root, "tvmrpc.xcodeproj"))
     if not os.path.exists(proj_path):
-        raise RuntimeError("Cannot find tvmrpc.xcodeproj in %s," +
-                           (" please set env TVM_IOS_RPC_ROOT correctly" % rpc_root))
+        raise RuntimeError(
+            "Cannot find tvmrpc.xcodeproj in %s,"
+            + (" please set env TVM_IOS_RPC_ROOT correctly" % rpc_root)
+        )
 
     # Lock the path so only one file can run
     lock = util.filelock(os.path.join(rpc_root, "ios_rpc.lock"))
@@ -267,10 +256,16 @@ def popen_test_rpc(host,
         for file_name in libs:
             fo.write("%s\n" % file_name)
 
-    cmd = ["xcrun", "xcodebuild",
-           "-scheme", "tvmrpc",
-           "-project", proj_path,
-           "-destination", destination]
+    cmd = [
+        "xcrun",
+        "xcodebuild",
+        "-scheme",
+        "tvmrpc",
+        "-project",
+        proj_path,
+        "-destination",
+        destination,
+    ]
     if options:
         cmd += options
     cmd += ["test"]
index e24e799..1c11a6c 100644 (file)
@@ -64,10 +64,8 @@ def get_binds(args, compact=False, binds=None):
             buffer_type = "auto_broadcast" if any_dim and not compact else ""
             if x not in binds:
                 buf = tvm.tir.decl_buffer(
-                    x.shape,
-                    dtype=x.dtype,
-                    name=x.name,
-                    buffer_type=buffer_type)
+                    x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
+                )
                 binds[x] = buf
                 arg_list.append(buf)
             else:
@@ -121,11 +119,7 @@ def form_irmodule(sch, args, name, binds):
     return tvm.IRModule({name: func})
 
 
-def lower(sch,
-          args,
-          name="main",
-          binds=None,
-          simple_mode=False):
+def lower(sch, args, name="main", binds=None, simple_mode=False):
     """Lowering step before build into target.
 
     Parameters
@@ -156,10 +150,8 @@ def lower(sch,
     """
     # config setup
     pass_ctx = PassContext.current()
-    instrument_bound_checkers = bool(pass_ctx.config.get(
-        "tir.instrument_bound_checkers", False))
-    disable_vectorize = bool(pass_ctx.config.get(
-        "tir.disable_vectorize", False))
+    instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
+    disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
     add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
 
     lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
@@ -193,7 +185,7 @@ def lower(sch,
         tvm.tir.transform.InjectVirtualThread(),
         tvm.tir.transform.InjectDoubleBuffer(),
         tvm.tir.transform.StorageRewrite(),
-        tvm.tir.transform.UnrollLoop()
+        tvm.tir.transform.UnrollLoop(),
     ]
     pass_list += lower_phase2
 
@@ -244,65 +236,67 @@ def _build_for_device(input_mod, target, target_host):
     device_type = ndarray.context(target.kind.name, 0).device_type
 
     mod_mixed = input_mod
-    mod_mixed = tvm.tir.transform.Apply(
-        lambda f: f.with_attr("target", target))(mod_mixed)
+    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
 
     opt_mixed = [tvm.tir.transform.VerifyMemory()]
     if len(mod_mixed.functions) == 1:
-        opt_mixed += [tvm.tir.transform.Apply(
-            lambda f: f.with_attr("tir.is_entry_func", True))]
+        opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
 
     if PassContext.current().config.get("tir.detect_global_barrier", False):
         opt_mixed += [tvm.tir.transform.ThreadSync("global")]
-    opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
-                  tvm.tir.transform.ThreadSync("warp"),
-                  tvm.tir.transform.InferFragment(),
-                  tvm.tir.transform.LowerThreadAllreduce(),
-                  tvm.tir.transform.MakePackedAPI(),
-                  tvm.tir.transform.SplitHostDevice()]
+    opt_mixed += [
+        tvm.tir.transform.ThreadSync("shared"),
+        tvm.tir.transform.ThreadSync("warp"),
+        tvm.tir.transform.InferFragment(),
+        tvm.tir.transform.LowerThreadAllreduce(),
+        tvm.tir.transform.MakePackedAPI(),
+        tvm.tir.transform.SplitHostDevice(),
+    ]
     mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
 
     # device optimizations
     opt_device = tvm.transform.Sequential(
-        [tvm.tir.transform.Filter(
-            lambda f: "calling_conv" in f.attrs and
-            f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
-         tvm.tir.transform.LowerWarpMemory(),
-         tvm.tir.transform.Simplify(),
-         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
-         tvm.tir.transform.LowerIntrin()])
+        [
+            tvm.tir.transform.Filter(
+                lambda f: "calling_conv" in f.attrs
+                and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
+            ),
+            tvm.tir.transform.LowerWarpMemory(),
+            tvm.tir.transform.Simplify(),
+            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+            tvm.tir.transform.LowerIntrin(),
+        ]
+    )
     mod_dev = opt_device(mod_mixed)
 
     # host optimizations
     opt_host = tvm.transform.Sequential(
-        [tvm.tir.transform.Filter(
-            lambda f: "calling_conv" not in f.attrs or
-            f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
-         tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
-         tvm.tir.transform.LowerTVMBuiltin(),
-         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
-         tvm.tir.transform.LowerIntrin(),
-         tvm.tir.transform.CombineContextCall()])
+        [
+            tvm.tir.transform.Filter(
+                lambda f: "calling_conv" not in f.attrs
+                or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
+            ),
+            tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
+            tvm.tir.transform.LowerTVMBuiltin(),
+            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+            tvm.tir.transform.LowerIntrin(),
+            tvm.tir.transform.CombineContextCall(),
+        ]
+    )
     mod_host = opt_host(mod_mixed)
 
     if device_type == ndarray.cpu(0).device_type and target_host == target:
         assert len(mod_dev.functions) == 0
     if "gpu" in target.keys and len(mod_dev.functions) == 0:
         warnings.warn(
-            "Specified target %s, but cannot find device code, did you do "
-            "bind?" % target)
+            "Specified target %s, but cannot find device code, did you do " "bind?" % target
+        )
 
-    rt_mod_dev = codegen.build_module(mod_dev, target) if len(
-        mod_dev.functions) != 0 else None
+    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
     return mod_host, rt_mod_dev
 
 
-def build(inputs,
-          args=None,
-          target=None,
-          target_host=None,
-          name="default_function",
-          binds=None):
+def build(inputs, args=None, target=None, target_host=None, name="default_function", binds=None):
     """Build a function with arguments as signature. Code will be generated
     for devices coupled with target information.
 
@@ -376,9 +370,7 @@ def build(inputs,
     if isinstance(inputs, schedule.Schedule):
         if args is None:
             raise ValueError("args must be given for build from schedule")
-        input_mod = lower(inputs, args,
-                          name=name,
-                          binds=binds)
+        input_mod = lower(inputs, args, name=name, binds=binds)
     elif isinstance(inputs, (list, tuple, container.Array)):
         merged_mod = tvm.IRModule({})
         for x in inputs:
@@ -387,8 +379,7 @@ def build(inputs,
     elif isinstance(inputs, tvm.IRModule):
         input_mod = inputs
     elif not isinstance(inputs, (dict, container.Map)):
-        raise ValueError(
-            "inputs must be Schedule, IRModule or dict of target to IRModule")
+        raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
 
     if not isinstance(inputs, (dict, container.Map)):
         target = Target.current() if target is None else target
@@ -399,11 +390,9 @@ def build(inputs,
 
     for tar, mod in target_input_mod.items():
         if not isinstance(tar, (str, Target)):
-            raise ValueError("The key of inputs must be str or "
-                             "Target when inputs is dict.")
+            raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
         if not isinstance(mod, tvm.IRModule):
-            raise ValueError("inputs must be Schedule, IRModule,"
-                             "or dict of str to IRModule.")
+            raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
 
     if not target_host:
         for tar, _ in target_input_mod.items():
index c9353d4..aa53ce7 100644 (file)
@@ -18,5 +18,6 @@
 Common utility functions shared by TVMC modules.
 """
 
+
 class TVMCException(Exception):
     """TVMC Exception"""
index d8083e3..531dd4f 100644 (file)
@@ -56,17 +56,13 @@ def _main(argv):
     """ TVM command line interface. """
 
     parser = argparse.ArgumentParser(
-        prog='tvmc',
+        prog="tvmc",
         formatter_class=argparse.RawDescriptionHelpFormatter,
         description="TVM compiler driver",
         epilog=__doc__,
     )
-    parser.add_argument(
-        "-v", "--verbose", action="count", default=0, help="increase verbosity"
-    )
-    parser.add_argument(
-        "--version", action="store_true", help="print the version and exit"
-    )
+    parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity")
+    parser.add_argument("--version", action="store_true", help="print the version and exit")
 
     subparser = parser.add_subparsers(title="commands")
     for make_subparser in REGISTERED_PARSER:
@@ -91,8 +87,10 @@ def _main(argv):
         sys.stderr.write("Error: %s\n" % err)
         return 4
 
+
 def main():
     sys.exit(_main(sys.argv[1:]))
 
+
 if __name__ == "__main__":
     main()
index 9125448..d7628a7 100644 (file)
@@ -27,6 +27,7 @@ copy the examples and raise errors with the same message convention.
 """
 from tvm._ffi.base import register_error, TVMError
 
+
 @register_error
 class InternalError(TVMError):
     """Internal error in the system.
@@ -43,11 +44,14 @@ class InternalError(TVMError):
         # Example code in python
         raise InternalError("internal error detail")
     """
+
     def __init__(self, msg):
         # Patch up additional hint message.
         if "TVM hint:" not in msg:
-            msg += ("\nTVM hint: You hit an internal error. " +
-                    "Please open a thread on https://discuss.tvm.ai/ to report it.")
+            msg += (
+                "\nTVM hint: You hit an internal error. "
+                + "Please open a thread on https://discuss.tvm.ai/ to report it."
+            )
         super(InternalError, self).__init__(msg)
 
 
@@ -122,6 +126,7 @@ class OpAttributeUnImplemented(OpError, NotImplementedError):
                 attr_name, op_name))
     """
 
+
 @register_error
 class DiagnosticError(TVMError):
     """Error diagnostics were reported during the execution of a pass.
index 5d53054..04d6aa6 100644 (file)
@@ -24,17 +24,16 @@ import warnings
 
 from .. import autotvm
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--act", type=str, choices=['pick-best'], required=True,
-                        help="The action")
+    parser.add_argument("--act", type=str, choices=["pick-best"], required=True, help="The action")
     parser.add_argument("--i", type=str, help="The input file or directory", required=True)
     parser.add_argument("--o", type=str, help="The output file")
 
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
 
-    if args.act == 'pick-best':
+    if args.act == "pick-best":
         if os.path.isfile(args.i):
             args.o = args.o or args.i + ".best.log"
             autotvm.record.pick_best(args.i, args.o)
@@ -42,7 +41,7 @@ if __name__ == '__main__':
             args.o = args.o or "best.log"
             tmp_filename = args.o + ".tmp"
 
-            with open(tmp_filename, 'w') as tmp_fout:
+            with open(tmp_filename, "w") as tmp_fout:
                 for filename in os.listdir(args.i):
                     if filename.endswith(".log"):
                         try:
index 2d42ed9..3b502a9 100644 (file)
@@ -27,22 +27,24 @@ import logging
 
 from ..contrib.peak import measure_peak_all
 
+
 def main():
     """Main funciton"""
     parser = argparse.ArgumentParser()
-    parser.add_argument('--target', type=str, default="llvm",
-                        help='The build target')
-    parser.add_argument('--target-host', type=str, default=None,
-                        help='The host code compilation target')
-    parser.add_argument('--rpc-host', type=str, default="0.0.0.0",
-                        help='the hostname of the server')
-    parser.add_argument('--rpc-port', type=int, default=9090,
-                        help='The port of the RPC')
+    parser.add_argument("--target", type=str, default="llvm", help="The build target")
+    parser.add_argument(
+        "--target-host", type=str, default=None, help="The host code compilation target"
+    )
+    parser.add_argument(
+        "--rpc-host", type=str, default="0.0.0.0", help="the hostname of the server"
+    )
+    parser.add_argument("--rpc-port", type=int, default=9090, help="The port of the RPC")
 
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
 
     measure_peak_all(args.target, args.target_host, args.rpc_host, args.rpc_port)
 
+
 if __name__ == "__main__":
     main()
index e610923..3603251 100644 (file)
@@ -22,13 +22,12 @@ import argparse
 import os
 from .. import rpc
 
+
 def main():
     """Main funciton"""
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default="",
-                        help='the hostname of the tracker')
-    parser.add_argument('--port', type=int, default=None,
-                        help='The port of the RPC')
+    parser.add_argument("--host", type=str, default="", help="the hostname of the tracker")
+    parser.add_argument("--port", type=int, default=None, help="The port of the RPC")
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
 
@@ -44,5 +43,6 @@ def main():
     print("Tracker address %s:%d\n" % (args.host, args.port))
     print("%s" % conn.text_summary())
 
+
 if __name__ == "__main__":
     main()
index eb80286..26625fb 100644 (file)
@@ -33,7 +33,7 @@ def find_example_resource():
     index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html")
     resource_files = [
         os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"),
-        os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js")
+        os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js"),
     ]
     resource_base = os.path.join(base_path, "web", "dist", "www")
     if os.path.isdir(resource_base):
@@ -58,46 +58,48 @@ def main(args):
 
     if args.example_rpc:
         index, js_files = find_example_resource()
-        prox = Proxy(args.host,
-                     port=args.port,
-                     web_port=args.web_port,
-                     index_page=index,
-                     resource_files=js_files,
-                     tracker_addr=tracker_addr)
+        prox = Proxy(
+            args.host,
+            port=args.port,
+            web_port=args.web_port,
+            index_page=index,
+            resource_files=js_files,
+            tracker_addr=tracker_addr,
+        )
     else:
-        prox = Proxy(args.host,
-                     port=args.port,
-                     web_port=args.web_port,
-                     tracker_addr=tracker_addr)
+        prox = Proxy(args.host, port=args.port, web_port=args.web_port, tracker_addr=tracker_addr)
     prox.proc.join()
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default="localhost",
-                        help='the hostname of the server')
-    parser.add_argument('--port', type=int, default=9090,
-                        help='The port of the RPC')
-    parser.add_argument('--web-port', type=int, default=8888,
-                        help='The port of the http/websocket server')
-    parser.add_argument('--example-rpc', type=bool, default=False,
-                        help='Whether to switch on example rpc mode')
-    parser.add_argument('--tracker', type=str, default="",
-                        help="Report to RPC tracker")
-    parser.add_argument('--no-fork', dest='fork', action='store_false',
-                        help="Use spawn mode to avoid fork. This option \
+    parser.add_argument("--host", type=str, default="localhost", help="the hostname of the server")
+    parser.add_argument("--port", type=int, default=9090, help="The port of the RPC")
+    parser.add_argument(
+        "--web-port", type=int, default=8888, help="The port of the http/websocket server"
+    )
+    parser.add_argument(
+        "--example-rpc", type=bool, default=False, help="Whether to switch on example rpc mode"
+    )
+    parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker")
+    parser.add_argument(
+        "--no-fork",
+        dest="fork",
+        action="store_false",
+        help="Use spawn mode to avoid fork. This option \
                          is able to avoid potential fork problems with Metal, OpenCL \
-                         and ROCM compilers.")
+                         and ROCM compilers.",
+    )
     parser.set_defaults(fork=True)
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
     if args.fork is False:
         if sys.version_info[0] < 3:
-            raise RuntimeError(
-                "Python3 is required for spawn mode."
-            )
-        multiprocessing.set_start_method('spawn')
+            raise RuntimeError("Python3 is required for spawn mode.")
+        multiprocessing.set_start_method("spawn")
     else:
-        logging.info("If you are running ROCM/Metal, \
-        fork with cause compiler internal error. Try to launch with arg ```--no-fork```")
+        logging.info(
+            "If you are running ROCM/Metal, \
+        fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
+        )
     main(args)
index e281e58..7233d71 100644 (file)
@@ -28,6 +28,7 @@ import tvm
 from tvm import micro
 from .. import rpc
 
+
 def main(args):
     """Main function
 
@@ -41,22 +42,23 @@ def main(args):
         port = int(port)
         tracker_addr = (url, port)
         if not args.key:
-            raise RuntimeError(
-                'Need key to present type of resource when tracker is available')
+            raise RuntimeError("Need key to present type of resource when tracker is available")
     else:
         tracker_addr = None
 
     if args.utvm_dev_config or args.utvm_dev_id:
         init_utvm(args)
 
-    server = rpc.Server(args.host,
-                        args.port,
-                        args.port_end,
-                        key=args.key,
-                        tracker_addr=tracker_addr,
-                        load_library=args.load_library,
-                        custom_addr=args.custom_addr,
-                        silent=args.silent)
+    server = rpc.Server(
+        args.host,
+        args.port,
+        args.port_end,
+        key=args.key,
+        tracker_addr=tracker_addr,
+        load_library=args.load_library,
+        custom_addr=args.custom_addr,
+        silent=args.silent,
+    )
     server.proc.join()
 
 
@@ -69,78 +71,96 @@ def init_utvm(args):
         parsed args from command-line invocation
     """
     if args.utvm_dev_config and args.utvm_dev_id:
-        raise RuntimeError('only one of --utvm-dev-config and --utvm-dev-id allowed')
+        raise RuntimeError("only one of --utvm-dev-config and --utvm-dev-id allowed")
 
     if args.utvm_dev_config:
-        with open(args.utvm_dev_config, 'r') as dev_conf_file:
+        with open(args.utvm_dev_config, "r") as dev_conf_file:
             dev_config = json.load(dev_conf_file)
     else:
         dev_config_args = ast.literal_eval(args.utvm_dev_config_args)
-        generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['generate_config']
+        generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)["generate_config"]
         dev_config = generate_config_func(*dev_config_args)
 
     if args.utvm_dev_config or args.utvm_dev_id:
         # add MicroTVM overrides
-        @tvm.register_func('tvm.rpc.server.start', override=True)
+        @tvm.register_func("tvm.rpc.server.start", override=True)
         def server_start():
             # pylint: disable=unused-variable
             session = micro.Session(dev_config)
             session._enter()
 
-            @tvm.register_func('tvm.rpc.server.shutdown', override=True)
+            @tvm.register_func("tvm.rpc.server.shutdown", override=True)
             def server_shutdown():
                 session._exit()
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default="0.0.0.0",
-                        help='the hostname of the server')
-    parser.add_argument('--port', type=int, default=9090,
-                        help='The port of the RPC')
-    parser.add_argument('--port-end', type=int, default=9199,
-                        help='The end search port of the RPC')
-    parser.add_argument('--tracker', type=str,
-                        help=("The address of RPC tracker in host:port format. "
-                              "e.g. (10.77.1.234:9190)"))
-    parser.add_argument('--key', type=str, default="",
-                        help="The key used to identify the device type in tracker.")
-    parser.add_argument('--silent', action='store_true',
-                        help="Whether run in silent mode.")
-    parser.add_argument('--load-library', type=str,
-                        help="Additional library to load")
-    parser.add_argument('--no-fork', dest='fork', action='store_false',
-                        help="Use spawn mode to avoid fork. This option \
+    parser.add_argument("--host", type=str, default="0.0.0.0", help="the hostname of the server")
+    parser.add_argument("--port", type=int, default=9090, help="The port of the RPC")
+    parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
+    parser.add_argument(
+        "--tracker",
+        type=str,
+        help=("The address of RPC tracker in host:port format. " "e.g. (10.77.1.234:9190)"),
+    )
+    parser.add_argument(
+        "--key", type=str, default="", help="The key used to identify the device type in tracker."
+    )
+    parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.")
+    parser.add_argument("--load-library", type=str, help="Additional library to load")
+    parser.add_argument(
+        "--no-fork",
+        dest="fork",
+        action="store_false",
+        help="Use spawn mode to avoid fork. This option \
                         is able to avoid potential fork problems with Metal, OpenCL \
-                        and ROCM compilers.")
-    parser.add_argument('--custom-addr', type=str,
-                        help="Custom IP Address to Report to RPC Tracker")
-    parser.add_argument('--utvm-dev-config', type=str,
-                        help=('JSON config file for the target device (if using MicroTVM). '
-                              'This file should contain serialized output similar to that returned '
-                              "from the device module's generate_config. Can't be specified when "
-                              '--utvm-dev-config-args is specified.'))
-    parser.add_argument('--utvm-dev-config-args', type=str,
-                        help=("Arguments to the device module's generate_config function. "
-                              'Must be a python literal parseable by literal_eval. If specified, '
-                              "the device configuration is generated using the device module's "
-                              "generate_config. Can't be specified when --utvm-dev-config is "
-                              "specified."))
-    parser.add_argument('--utvm-dev-id', type=str,
-                        help=('Unique ID for the target device (if using MicroTVM). Should '
-                              'match the name of a module underneath tvm.micro.device).'))
+                        and ROCM compilers.",
+    )
+    parser.add_argument(
+        "--custom-addr", type=str, help="Custom IP Address to Report to RPC Tracker"
+    )
+    parser.add_argument(
+        "--utvm-dev-config",
+        type=str,
+        help=(
+            "JSON config file for the target device (if using MicroTVM). "
+            "This file should contain serialized output similar to that returned "
+            "from the device module's generate_config. Can't be specified when "
+            "--utvm-dev-config-args is specified."
+        ),
+    )
+    parser.add_argument(
+        "--utvm-dev-config-args",
+        type=str,
+        help=(
+            "Arguments to the device module's generate_config function. "
+            "Must be a python literal parseable by literal_eval. If specified, "
+            "the device configuration is generated using the device module's "
+            "generate_config. Can't be specified when --utvm-dev-config is "
+            "specified."
+        ),
+    )
+    parser.add_argument(
+        "--utvm-dev-id",
+        type=str,
+        help=(
+            "Unique ID for the target device (if using MicroTVM). Should "
+            "match the name of a module underneath tvm.micro.device)."
+        ),
+    )
 
     parser.set_defaults(fork=True)
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
     if args.fork is False:
         if sys.version_info[0] < 3:
-            raise RuntimeError(
-                "Python3 is required for spawn mode."
-            )
-        multiprocessing.set_start_method('spawn')
+            raise RuntimeError("Python3 is required for spawn mode.")
+        multiprocessing.set_start_method("spawn")
     else:
         if not args.silent:
-            logging.info("If you are running ROCM/Metal, fork will cause "
-                         "compiler internal error. Try to launch with arg ```--no-fork```")
+            logging.info(
+                "If you are running ROCM/Metal, fork will cause "
+                "compiler internal error. Try to launch with arg ```--no-fork```"
+            )
     main(args)
index 8f5bd1d..4a1a964 100644 (file)
@@ -24,39 +24,39 @@ import multiprocessing
 import sys
 from ..rpc.tracker import Tracker
 
+
 def main(args):
     """Main funciton"""
-    tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
-                      silent=args.silent)
+    tracker = Tracker(args.host, port=args.port, port_end=args.port_end, silent=args.silent)
     tracker.proc.join()
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default="0.0.0.0",
-                        help='the hostname of the tracker')
-    parser.add_argument('--port', type=int, default=9190,
-                        help='The port of the RPC')
-    parser.add_argument('--port-end', type=int, default=9199,
-                        help='The end search port of the RPC')
-    parser.add_argument('--no-fork', dest='fork', action='store_false',
-                        help="Use spawn mode to avoid fork. This option \
+    parser.add_argument("--host", type=str, default="0.0.0.0", help="the hostname of the tracker")
+    parser.add_argument("--port", type=int, default=9190, help="The port of the RPC")
+    parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
+    parser.add_argument(
+        "--no-fork",
+        dest="fork",
+        action="store_false",
+        help="Use spawn mode to avoid fork. This option \
                          is able to avoid potential fork problems with Metal, OpenCL \
-                         and ROCM compilers.")
-    parser.add_argument('--silent', action='store_true',
-                        help="Whether run in silent mode.")
+                         and ROCM compilers.",
+    )
+    parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.")
 
     parser.set_defaults(fork=True)
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
     if args.fork is False:
         if sys.version_info[0] < 3:
-            raise RuntimeError(
-                "Python3 is required for spawn mode."
-            )
-        multiprocessing.set_start_method('spawn')
+            raise RuntimeError("Python3 is required for spawn mode.")
+        multiprocessing.set_start_method("spawn")
     else:
         if not args.silent:
-            logging.info("If you are running ROCM/Metal, fork will cause "
-                         "compiler internal error. Try to launch with arg ```--no-fork```")
+            logging.info(
+                "If you are running ROCM/Metal, fork will cause "
+                "compiler internal error. Try to launch with arg ```--no-fork```"
+            )
     main(args)
index bf8466f..b3b042a 100644 (file)
@@ -76,11 +76,7 @@ class HybridParser(ast.NodeVisitor):
         ast.Or: tir.Or,
     }
 
-    _unaryop_maker = {
-        ast.USub: operator.neg,
-        ast.Invert: operator.invert,
-        ast.Not: tir.Not
-    }
+    _unaryop_maker = {ast.USub: operator.neg, ast.Invert: operator.invert, ast.Not: tir.Not}
 
     def __init__(self, src, base_lienno):
         self.params = None
@@ -88,7 +84,7 @@ class HybridParser(ast.NodeVisitor):
         self.dict_attr = None
         self.scope_emitter = None
 
-        self.src = src.split('\n')
+        self.src = src.split("\n")
         self.base_lineno = base_lienno
         self.current_lineno = 0
         self.current_col_offset = 0
@@ -109,8 +105,12 @@ class HybridParser(ast.NodeVisitor):
     @staticmethod
     def is_meta(node):
         """Judge whether an AST node is META"""
-        return isinstance(node, ast.Assign) and len(node.targets) == 1 \
-               and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__"
+        return (
+            isinstance(node, ast.Assign)
+            and len(node.targets) == 1
+            and isinstance(node.targets[0], ast.Name)
+            and node.targets[0].id == "__tvm_meta__"
+        )
 
     def init_meta(self, meta_dict):
         if meta_dict is not None:
@@ -125,7 +125,7 @@ class HybridParser(ast.NodeVisitor):
         if hasattr(node, "col_offset"):
             self.current_col_offset = node.col_offset
 
-        method = 'visit_' + node.__class__.__name__
+        method = "visit_" + node.__class__.__name__
         visitor = getattr(self, method, self.generic_visit)
         visit_res = visitor(node)
 
@@ -136,14 +136,23 @@ class HybridParser(ast.NodeVisitor):
     def wrap_line_col(self, message, lineno, col_offset):
         """Wrap the message with line number and column offset"""
         src_line = self.src[lineno - self.base_lineno]
-        leading_space = len(src_line) - len(src_line.lstrip(' '))
+        leading_space = len(src_line) - len(src_line.lstrip(" "))
         col_offset = col_offset - leading_space
         src_line = src_line[leading_space:]
-        return "\n  " + src_line + "\n  " + " " * col_offset + "^\n" + "ParserError in line " \
-               + str(lineno) + " : " + message
+        return (
+            "\n  "
+            + src_line
+            + "\n  "
+            + " " * col_offset
+            + "^\n"
+            + "ParserError in line "
+            + str(lineno)
+            + " : "
+            + message
+        )
 
     def report_error(self, message, lineno=None, col_offset=None):
-        """ Report an error occur in line lineno and column col_offset
+        """Report an error occur in line lineno and column col_offset
         Parameters
         ----------
         message : str
@@ -161,8 +170,11 @@ class HybridParser(ast.NodeVisitor):
         raise HybridParserError(self.wrap_line_col(message, lineno, col_offset))
 
     def get_type_name(self, vtype):
-        if isinstance(vtype, ast.Attribute) \
-                and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty':
+        if (
+            isinstance(vtype, ast.Attribute)
+            and isinstance(vtype.value, ast.Name)
+            and vtype.value.id == "ty"
+        ):
             return vtype.attr
         self.report_error("invalid type annotation")
 
@@ -194,14 +206,14 @@ class HybridParser(ast.NodeVisitor):
         self.report_error("invalid type annotation")
 
     def generic_visit(self, node):
-        """ Override method in ast.NodeVisitor.
+        """Override method in ast.NodeVisitor.
         To directly filter out invalidate type of stmt.
         """
 
         self.report_error(type(node).__name__ + " stmt is not supported now")
 
     def visit_Module(self, node):
-        """ Module visitor
+        """Module visitor
         AST abstract grammar:
             Module(stmt* body, type_ignore* type_ignore)
         By now we support two format of hybrid script shown below.
@@ -250,10 +262,11 @@ class HybridParser(ast.NodeVisitor):
                 self.init_meta(MetaUnparser().visit(node.body[1].value))
                 return self.visit(node.body[0])
         self.report_error(
-            "Only one-function, one-class or function-with-meta source code is allowed")
+            "Only one-function, one-class or function-with-meta source code is allowed"
+        )
 
     def visit_ClassDef(self, node):
-        """ ClassDef visitor
+        """ClassDef visitor
         AST abstract grammar:
             ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body,
                      expr* decorator_list)
@@ -275,10 +288,11 @@ class HybridParser(ast.NodeVisitor):
             if isinstance(body_element, ast.FunctionDef):
                 self.visit(body_element)
         from .utils import create_module
+
         return create_module(self.functions)
 
     def visit_FunctionDef(self, node):
-        """ FunctionDef visitor
+        """FunctionDef visitor
         AST abstract grammar:
             FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list,
                         expr? returns, string? type_comment)
@@ -298,15 +312,18 @@ class HybridParser(ast.NodeVisitor):
         self.scope_emitter.node_stack[-1].extend(reversed(node.body))
 
         # fetch the body and return a tir.PrimFunc
-        func = tvm.tir.PrimFunc(self.params, self.get_body(),
-                                ret_type=self.parse_type(node.returns),
-                                buffer_map=self.buffer_map,
-                                attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr))
+        func = tvm.tir.PrimFunc(
+            self.params,
+            self.get_body(),
+            ret_type=self.parse_type(node.returns),
+            buffer_map=self.buffer_map,
+            attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr),
+        )
         self.functions[GlobalVar(node.name)] = func
         return func
 
     def visit_Assign(self, node):
-        """ Assign visitor
+        """Assign visitor
         AST abstract grammar:
             Assign(expr* targets, expr value, string? type_comment)
         By now only 2 types of Assign is supported:
@@ -338,13 +355,14 @@ class HybridParser(ast.NodeVisitor):
             else:
                 if len(indexes) != 1:
                     self.report_error("Invalid Store stmt")
-                return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0],
-                                     tvm.runtime.convert(True))
+                return tvm.tir.Store(
+                    symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True)
+                )
         else:
             self.report_error("Unsupported Assign stmt")
 
     def visit_AnnAssign(self, node):
-        """ AnnAssign visitor
+        """AnnAssign visitor
         AST abstract grammar:
             AnnAssign(expr target, expr annotation, expr? value, int simple)
         Corresponds to concise mode of with tir.let()
@@ -359,7 +377,7 @@ class HybridParser(ast.NodeVisitor):
             self.report_error("Unsupported AnnAssign stmt")
 
     def visit_Assert(self, node):
-        """ Assert visitor
+        """Assert visitor
         AST abstract grammar:
             Assert(expr test, expr? msg)
         Corresponds to concise mode of with tir.assert()
@@ -372,7 +390,7 @@ class HybridParser(ast.NodeVisitor):
         return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body())
 
     def visit_For(self, node):
-        """ For visitor
+        """For visitor
         AST abstract grammar:
             For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
         By now only 1 type of For is supported:
@@ -384,9 +402,11 @@ class HybridParser(ast.NodeVisitor):
         # check node.iter, which is a tir Call
         if not isinstance(node.iter, ast.Call):
             self.report_error("The loop iter should be a Call")
-        if not isinstance(node.iter.func, ast.Attribute) \
-                or not isinstance(node.iter.func.value, ast.Name) \
-                or node.iter.func.value.id != "tir":
+        if (
+            not isinstance(node.iter.func, ast.Attribute)
+            or not isinstance(node.iter.func.value, ast.Name)
+            or node.iter.func.value.id != "tir"
+        ):
             self.report_error("The loop iter Call should be tir.name()")
 
         func_name = node.iter.func.attr
@@ -396,19 +416,23 @@ class HybridParser(ast.NodeVisitor):
         kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args}
         # All the functions supported in For stmt are registered in scope_handler.ForScope
         if func_name not in Registry.for_scope:
-            self.report_error("Function " + func_name + " used in For stmt is not supported now",
-                              self.current_lineno,
-                              node.iter.col_offset)
+            self.report_error(
+                "Function " + func_name + " used in For stmt is not supported now",
+                self.current_lineno,
+                node.iter.col_offset,
+            )
 
         old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno, self.current_col_offset = \
-            self.base_lineno + node.iter.lineno - 1, node.iter.col_offset
+        self.current_lineno, self.current_col_offset = (
+            self.base_lineno + node.iter.lineno - 1,
+            node.iter.col_offset,
+        )
         res = Registry.for_scope.get(func_name)(self, node, args, kw_args)
         self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
         return res
 
     def visit_With(self, node):
-        """ With visitor
+        """With visitor
         AST abstract grammar:
             With(withitem* items, stmt* body, string? type_comment)
             withitem = (expr context_expr, expr? optional_vars)
@@ -421,9 +445,11 @@ class HybridParser(ast.NodeVisitor):
         if not isinstance(node.items[0].context_expr, ast.Call):
             self.report_error("The context expression of with should be a Call")
         func_call = node.items[0].context_expr
-        if not isinstance(func_call.func, ast.Attribute) \
-                or not isinstance(func_call.func.value, ast.Name) \
-                or func_call.func.value.id != "tir":
+        if (
+            not isinstance(func_call.func, ast.Attribute)
+            or not isinstance(func_call.func.value, ast.Name)
+            or func_call.func.value.id != "tir"
+        ):
             self.report_error("The context expression of with should be tir.name()")
 
         func_name = func_call.func.attr
@@ -436,14 +462,16 @@ class HybridParser(ast.NodeVisitor):
 
         # All the functions supported in With stmt are registered in scope_handler.WithScope
         old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno, self.current_col_offset = \
-            self.base_lineno + func_call.lineno - 1, func_call.col_offset
+        self.current_lineno, self.current_col_offset = (
+            self.base_lineno + func_call.lineno - 1,
+            func_call.col_offset,
+        )
         res = Registry.with_scope.get(func_name)(self, node, args, kw_args)
         self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
         return res
 
     def visit_If(self, node):
-        """ If visitor
+        """If visitor
         AST abstract grammar:
             If(expr test, stmt* body, stmt* orelse)
         """
@@ -466,7 +494,7 @@ class HybridParser(ast.NodeVisitor):
         return tvm.tir.IfThenElse(condition, then_body, else_body)
 
     def visit_Call(self, node):
-        """ Call visitor
+        """Call visitor
         AST abstract grammar:
             Call(expr func, expr* args, keyword* keywords)
             keyword = (identifier? arg, expr value)
@@ -499,7 +527,7 @@ class HybridParser(ast.NodeVisitor):
         self.report_error("Function " + func_name + " is not supported now")
 
     def visit_Expr(self, node):
-        """ Expr visitor
+        """Expr visitor
         AST abstract grammar:
             Expr(expr value)
 
@@ -515,7 +543,7 @@ class HybridParser(ast.NodeVisitor):
         return self.visit(node.value)
 
     def visit_BinOp(self, node):
-        """ BinOp visitor
+        """BinOp visitor
         AST abstract grammar:
             BinOp(expr left, operator op, expr right)
         """
@@ -527,7 +555,7 @@ class HybridParser(ast.NodeVisitor):
         return HybridParser._binop_maker[type(node.op)](lhs, rhs)
 
     def visit_Compare(self, node):
-        """ Compare visitor
+        """Compare visitor
         AST abstract grammar:
             Compare(expr left, expr right, ops=)
         """
@@ -542,7 +570,7 @@ class HybridParser(ast.NodeVisitor):
         return _all(*res)
 
     def visit_BoolOp(self, node):
-        """ BoolOp visitor
+        """BoolOp visitor
         AST abstract grammar:
             BoolOp(boolop op, expr* values)
         """
@@ -551,7 +579,7 @@ class HybridParser(ast.NodeVisitor):
         return HybridParser._binop_maker[type(node.op)](*values)
 
     def visit_UnaryOp(self, node):
-        """ UnaryOp visitor
+        """UnaryOp visitor
         AST abstract grammar:
             UnaryOp(unaryop op, expr operand)
         """
@@ -562,7 +590,7 @@ class HybridParser(ast.NodeVisitor):
         return HybridParser._unaryop_maker[type(node.op)](operand)
 
     def visit_Subscript(self, node):
-        """ Subscript visitor
+        """Subscript visitor
         AST abstract grammar:
             Subscript(expr value, slice slice, expr_context ctx)
             slice = Slice(expr? lower, expr? upper, expr? step)
@@ -616,12 +644,18 @@ class HybridParser(ast.NodeVisitor):
                     doms.append(tvm.ir.Range.from_min_extent(lower, extent))
                 return symbol, doms
 
-        elif isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) \
-                and node.value.value.id == 'meta':
+        elif (
+            isinstance(node.value, ast.Subscript)
+            and isinstance(node.value.value, ast.Name)
+            and node.value.value.id == "meta"
+        ):
             # meta[type_key][index]
-            if not (isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num)) \
-                    or not (isinstance(node.value.slice, ast.Index) \
-                            and isinstance(node.value.slice.value, ast.Name)):
+            if not (
+                isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num)
+            ) or not (
+                isinstance(node.value.slice, ast.Index)
+                and isinstance(node.value.slice.value, ast.Name)
+            ):
                 self.report_error("The meta access format ought to be meta[type_key][index]")
             type_key = node.value.slice.value.id
             index = node.slice.value.n
@@ -635,7 +669,7 @@ class HybridParser(ast.NodeVisitor):
             self.report_error("Only buffer variable and meta can be subscriptable")
 
     def visit_Name(self, node):
-        """ Name visitor
+        """Name visitor
         AST abstract grammar:
             Name(identifier id, expr_context ctx)
         """
@@ -647,7 +681,7 @@ class HybridParser(ast.NodeVisitor):
         return symbol
 
     def visit_Attribute(self, node):
-        """ Attribute visitor
+        """Attribute visitor
         AST abstract grammar:
             Attribute(expr value, identifier attr, expr_context ctx)
         """
@@ -663,7 +697,7 @@ class HybridParser(ast.NodeVisitor):
         return getattr(symbol, node.attr)
 
     def visit_Dict(self, node):
-        """ Dict visitor
+        """Dict visitor
         AST abstract grammar:
             Dict(expr* keys, expr* values)
         """
@@ -674,7 +708,7 @@ class HybridParser(ast.NodeVisitor):
         return {key: value for key, value in zip(keys, values)}
 
     def visit_Tuple(self, node):
-        """ Tuple visitor
+        """Tuple visitor
         AST abstract grammar:
             Tuple(expr* elts, expr_context ctx)
         """
@@ -682,7 +716,7 @@ class HybridParser(ast.NodeVisitor):
         return tuple(self.visit(element) for element in node.elts)
 
     def visit_List(self, node):
-        """ List visitor
+        """List visitor
         AST abstract grammar:
             List(expr* elts, expr_context ctx)
         """
@@ -690,7 +724,7 @@ class HybridParser(ast.NodeVisitor):
         return [self.visit(element) for element in node.elts]
 
     def visit_keyword(self, node):
-        """ Keyword visitor
+        """Keyword visitor
         AST abstract grammar:
             keyword = (identifier? arg, expr value)
         """
@@ -717,7 +751,7 @@ class HybridParser(ast.NodeVisitor):
 
 
 def from_source(src, func_lineno=0):
-    """ Parse the src into TIR
+    """Parse the src into TIR
 
     Parameters
     ----------
@@ -740,13 +774,14 @@ def from_source(src, func_lineno=0):
         raise e
     except TVMError as e:
         # TVM internal c++ error, we have to process the error message and inject line info
-        inject_e = str(e).split('\n')
-        msg = inject_e[-1].split(':', maxsplit=1)[1].strip()
+        inject_e = str(e).split("\n")
+        msg = inject_e[-1].split(":", maxsplit=1)[1].strip()
         inject_e = inject_e[:-1]
         inject_e.extend(
-            parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split('\n'))
+            parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split("\n")
+        )
         inject_e[-1] = "TVM" + inject_e[-1][6:]
-        raise TVMError('\n'.join(inject_e))
+        raise TVMError("\n".join(inject_e))
     except Exception as e:
         inject_e = parser.wrap_line_col(str(e), parser.current_lineno, parser.current_col_offset)
         raise HybridParserError(inject_e)
index f33e03d..9f5c391 100644 (file)
@@ -24,6 +24,7 @@ class Registry(object):
     """Registration map
     All these maps are static
     """
+
     intrin = dict()
     with_scope = dict()
     for_scope = dict()
@@ -93,7 +94,7 @@ def func_wrapper(func_name, func_to_register, arg_list, need_parser_and_node, ne
 
         for i, arg_info in enumerate(arg_list):
             if len(arg_info) == 1:
-                arg_name, = arg_info
+                (arg_name,) = arg_info
                 if need_body and arg_name == "body":
                     internal_args.append(body)
                 else:
@@ -130,25 +131,28 @@ def get_arg_list(origin_func, need_parser_and_node):
 
     if full_arg_spec.varargs is not None:
         raise RuntimeError(
-            "TVM Hybrid Script register error : variable argument is not supported now")
+            "TVM Hybrid Script register error : variable argument is not supported now"
+        )
     if full_arg_spec.varkw is not None:
         raise RuntimeError(
-            "TVM Hybrid Script register error : variable keyword argument is not supported now")
+            "TVM Hybrid Script register error : variable keyword argument is not supported now"
+        )
     if not len(full_arg_spec.kwonlyargs) == 0:
         raise RuntimeError(
-            "TVM Hybrid Script register error : keyword only argument is not supported now")
+            "TVM Hybrid Script register error : keyword only argument is not supported now"
+        )
 
     arg_list = list()
     for arg in args[: len(args) - len(defaults)]:
         arg_list.append((arg,))
-    for default, arg in zip(defaults, args[len(args) - len(defaults):]):
+    for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
         arg_list.append((arg, default))
 
     return arg_list
 
 
 def register_intrin(origin_func):
-    """ Decorator to register function under category intrin
+    """Decorator to register function under category intrin
 
     Example
     ------
@@ -160,9 +164,14 @@ def register_intrin(origin_func):
         return tvm.tir.Broadcast(value, lanes)
     """
     func_name = origin_func.__qualname__
-    Registry.intrin[func_name] = \
-        func_wrapper(func_name, origin_func, get_arg_list(origin_func, False),
-                     need_parser_and_node=False, need_body=False, concise=False)
+    Registry.intrin[func_name] = func_wrapper(
+        func_name,
+        origin_func,
+        get_arg_list(origin_func, False),
+        need_parser_and_node=False,
+        need_body=False,
+        concise=False,
+    )
     return origin_func
 
 
@@ -187,9 +196,14 @@ def register_with_scope(concise=False):
     def decorate(origin_func):
         """Register function under category with_scope"""
         func_name = origin_func.__qualname__
-        Registry.with_scope[func_name] = \
-            func_wrapper(func_name, origin_func, get_arg_list(origin_func, True),
-                         need_parser_and_node=True, need_body=True, concise=concise)
+        Registry.with_scope[func_name] = func_wrapper(
+            func_name,
+            origin_func,
+            get_arg_list(origin_func, True),
+            need_parser_and_node=True,
+            need_body=True,
+            concise=concise,
+        )
         return origin_func
 
     return decorate
@@ -197,19 +211,25 @@ def register_with_scope(concise=False):
 
 def register_for_scope():
     """Decorator to register function under for scope handler"""
+
     def decorate(origin_func):
         """Register function under category for_scope"""
         func_name = origin_func.__qualname__
-        Registry.for_scope[func_name] = \
-            func_wrapper(func_name, origin_func, get_arg_list(origin_func, True),
-                         need_parser_and_node=True, need_body=True, concise=False)
+        Registry.for_scope[func_name] = func_wrapper(
+            func_name,
+            origin_func,
+            get_arg_list(origin_func, True),
+            need_parser_and_node=True,
+            need_body=True,
+            concise=False,
+        )
         return origin_func
 
     return decorate
 
 
 def register_special_stmt(origin_func):
-    """ Decorator to register function under category special_stmt
+    """Decorator to register function under category special_stmt
 
     Example
     -------
@@ -225,7 +245,12 @@ def register_special_stmt(origin_func):
     """
 
     func_name = origin_func.__qualname__
-    Registry.special_stmt[func_name] = \
-        func_wrapper(func_name, origin_func, get_arg_list(origin_func, True),
-                     need_parser_and_node=True, need_body=False, concise=False)
+    Registry.special_stmt[func_name] = func_wrapper(
+        func_name,
+        origin_func,
+        get_arg_list(origin_func, True),
+        need_parser_and_node=True,
+        need_body=False,
+        concise=False,
+    )
     return origin_func
index 03b3cca..129354d 100644 (file)
@@ -35,9 +35,21 @@ from .registry import register_special_stmt
 
 
 @register_special_stmt
-def buffer_bind(parser, node, param, shape, dtype="float32", data=None, strides=None,
-                elem_offset=None, scope="global", align=-1, offset_factor=0, buffer_type="default"):
-    """ Special function buffer_bind(var, shape, dtype, data, strides, elem_offset, scope, align,
+def buffer_bind(
+    parser,
+    node,
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+):
+    """Special function buffer_bind(var, shape, dtype, data, strides, elem_offset, scope, align,
                                      offset_factor, buffer_type)
 
     Example
@@ -54,16 +66,37 @@ def buffer_bind(parser, node, param, shape, dtype="float32", data=None, strides=
         strides = []
     align = align.value if not isinstance(align, int) else align
     offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor
-    buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, elem_offset,
-                                 scope, align, offset_factor, buffer_type)
+    buffer = tvm.tir.decl_buffer(
+        shape,
+        dtype,
+        parser._assign_target,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+    )
     parser.buffer_map[param] = buffer
     return buffer
 
 
 @register_special_stmt
-def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=None, elem_offset=None,
-                scope="global", align=-1, offset_factor=0, buffer_type="default"):
-    """ Special function buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
+def buffer_decl(
+    parser,
+    node,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+):
+    """Special function buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
                                          offset_factor, buffer_type)
 
     Example
@@ -77,8 +110,18 @@ def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=None, e
         strides = []
     align = align.value if not isinstance(align, int) else align
     offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor
-    buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, elem_offset,
-                                 scope, align, offset_factor, buffer_type)
+    buffer = tvm.tir.decl_buffer(
+        shape,
+        dtype,
+        parser._assign_target,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+    )
     return buffer
 
 
@@ -90,7 +133,7 @@ def var(parser, node, dtype):
 
 @register_special_stmt
 def func_attr(parser, node, dict_attr):
-    """ Special function for declaring the DictAttr of PrimFunc
+    """Special function for declaring the DictAttr of PrimFunc
 
     Example
     -------
index ee33805..a331947 100644 (file)
@@ -25,12 +25,14 @@ import tvm
 
 class TypeGeneric:
     """Base class for all the hybrid script typing class"""
+
     def evaluate(self):
         raise TypeError("Cannot get tvm.Type from a generic type")
 
 
 class ConcreteType(TypeGeneric):
     """Hybrid script typing class for uniform Type objects"""
+
     def __init__(self, vtype):
         self.type = vtype
 
@@ -43,6 +45,7 @@ class GenericPtrType(TypeGeneric):
 
     [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType
     """
+
     def __getitem__(self, vtype):
         return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))
 
@@ -53,6 +56,7 @@ class GenericTupleType(TypeGeneric):
     [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType
     wrapping TupleType
     """
+
     def __getitem__(self, vtypes):
         return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
 
index 76a4c33..bb01b55 100644 (file)
@@ -1,4 +1,3 @@
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
index d126f28..5ca026f 100644 (file)
@@ -38,9 +38,9 @@ class Constructor(RelayExpr):
     belong_to : GlobalTypeVar
         Denotes which ADT the constructor belongs to.
     """
+
     def __init__(self, name_hint, inputs, belong_to):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Constructor, name_hint, inputs, belong_to)
+        self.__init_handle_by_constructor__(_ffi_api.Constructor, name_hint, inputs, belong_to)
 
     def __call__(self, *args):
         """Call the constructor.
@@ -57,6 +57,7 @@ class Constructor(RelayExpr):
         """
         # pylint: disable=import-outside-toplevel
         from tvm import relay
+
         return relay.Call(self, args)
 
 
@@ -82,6 +83,6 @@ class TypeData(Type):
     constructors: List[Constructor]
         The constructors for the ADT.
     """
+
     def __init__(self, header, type_vars, constructors):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TypeData, header, type_vars, constructors)
+        self.__init_handle_by_constructor__(_ffi_api.TypeData, header, type_vars, constructors)
index 3c656fc..170f64e 100644 (file)
@@ -29,8 +29,9 @@ class Attrs(Object):
     Used by function registered in python side, such as compute, schedule and alter_layout.
     Attrs is passed as the first argument to these functions.
     """
+
     def list_field_info(self):
-        """ Get fields information
+        """Get fields information
 
         Returns
         -------
@@ -94,8 +95,8 @@ class Attrs(Object):
 
 @tvm._ffi.register_object
 class DictAttrs(Attrs):
-    """Dictionary attributes.
-    """
+    """Dictionary attributes."""
+
     def _dict(self):
         """Get internal dict"""
         return _ffi_api.DictAttrsGetDict(self)
index b505a2e..00514b4 100644 (file)
@@ -24,8 +24,10 @@ from tvm.runtime import Object
 from . import _ffi_api
 from . import json_compact
 
+
 class Node(Object):
     """Base class of all IR Nodes, implements astext function."""
+
     def astext(self, show_meta_data=True, annotate=None):
         """Get the text format of the expression.
 
@@ -65,6 +67,7 @@ class SourceName(Object):
     name : str
         The name of the source.
     """
+
     def __init__(self, name):
         self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
 
@@ -84,9 +87,11 @@ class Span(Object):
     col_offset : int
         The column offset of the location.
     """
+
     def __init__(self, source_name, line, end_line, column, end_column):
         self.__init_handle_by_constructor__(
-            _ffi_api.Span, source_name, line, end_line, column, end_column)
+            _ffi_api.Span, source_name, line, end_line, column, end_column
+        )
 
 
 @tvm._ffi.register_object
@@ -95,6 +100,7 @@ class EnvFunc(Object):
 
     This is a global function object that can be serialized by its name.
     """
+
     def __call__(self, *args):
         return _ffi_api.EnvFuncCall(self, *args)
 
@@ -200,8 +206,7 @@ def structural_equal(lhs, rhs, map_free_vars=False):
     """
     lhs = tvm.runtime.convert(lhs)
     rhs = tvm.runtime.convert(rhs)
-    return bool(tvm.runtime._ffi_node_api.StructuralEqual(
-        lhs, rhs, False, map_free_vars))
+    return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
 
 
 def assert_structural_equal(lhs, rhs, map_free_vars=False):
@@ -229,8 +234,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
     """
     lhs = tvm.runtime.convert(lhs)
     rhs = tvm.runtime.convert(rhs)
-    tvm.runtime._ffi_node_api.StructuralEqual(
-        lhs, rhs, True, map_free_vars)
+    tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)
 
 
 def structural_hash(node, map_free_vars=False):
index e7374de..a87d679 100644 (file)
@@ -31,9 +31,9 @@ class Array(Object):
     to Array during tvm function call.
     You may get Array in return values of TVM function call.
     """
+
     def __getitem__(self, idx):
-        return getitem_helper(
-            self, _ffi_node_api.ArrayGetItem, len(self), idx)
+        return getitem_helper(self, _ffi_node_api.ArrayGetItem, len(self), idx)
 
     def __len__(self):
         return _ffi_node_api.ArraySize(self)
@@ -47,6 +47,7 @@ class Map(Object):
     Normally python dict will be converted automaticall to Map during tvm function call.
     You can use convert to create a dict[Object-> Object] into a Map
     """
+
     def __getitem__(self, k):
         return _ffi_node_api.MapGetItem(self, k)
 
@@ -56,7 +57,7 @@ class Map(Object):
     def items(self):
         """Get the items from the map"""
         akvs = _ffi_node_api.MapItems(self)
-        return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
+        return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
 
     def __len__(self):
         return _ffi_node_api.MapSize(self)
index 0a3f205..f6bf975 100644 (file)
@@ -35,6 +35,7 @@ class PrimExpr(BaseExpr):
 
 class RelayExpr(BaseExpr):
     """Base class of all non-primitive expressions."""
+
     @property
     def checked_type(self):
         """Get the checked type of tvm.relay.Expr.
@@ -46,8 +47,7 @@ class RelayExpr(BaseExpr):
         """
         ret = self._checked_type_
         if ret is None:
-            raise ValueError("The type checker has not populated"
-                             " the checked_type for this node")
+            raise ValueError("The type checker has not populated" " the checked_type for this node")
         return ret
 
 
@@ -63,6 +63,7 @@ class GlobalVar(RelayExpr):
     name_hint: str
         The name of the variable.
     """
+
     def __init__(self, name_hint):
         self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
 
@@ -82,10 +83,12 @@ class GlobalVar(RelayExpr):
         # pylint: disable=import-outside-toplevel
         if all(isinstance(x, RelayExpr) for x in args):
             from tvm import relay
+
             return relay.Call(self, args)
         arg_types = [type(x) for x in args]
         raise RuntimeError(
-            "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types))
+            "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
+        )
 
 
 @tvm._ffi.register_object
@@ -109,13 +112,12 @@ class Range(Node):
     The constructor creates the range `[begin, end)`
     if the end argument is not None. Otherwise, it creates `[0, begin)`.
     """
+
     def __init__(self, begin, end=None):
         if end is None:
-            self.__init_handle_by_constructor__(
-                _ffi_api.Range, 0, begin)
+            self.__init_handle_by_constructor__(_ffi_api.Range, 0, begin)
         else:
-            self.__init_handle_by_constructor__(
-                _ffi_api.Range, begin, end)
+            self.__init_handle_by_constructor__(_ffi_api.Range, begin, end)
 
     @staticmethod
     def from_min_extent(min_value, extent):
index d28ffa6..c3f1bf5 100644 (file)
@@ -24,6 +24,7 @@ from . import _ffi_api
 
 class CallingConv(IntEnum):
     """Possible kinds of calling conventions."""
+
     DEFAULT = 0
     C_PACKED_FUNC = 1
     DEVICE_KERNEL_LAUNCH = 2
@@ -31,10 +32,10 @@ class CallingConv(IntEnum):
 
 class BaseFunc(RelayExpr):
     """Base class of all functions."""
+
     @property
     def attrs(self):
-        """Return the attrs member of the function.
-        """
+        """Return the attrs member of the function."""
         return _ffi_api.BaseFunc_Attrs(self)
 
     def with_attr(self, attr_key_or_dict, attr_value=None):
@@ -59,9 +60,9 @@ class BaseFunc(RelayExpr):
 
         if isinstance(attr_key_or_dict, dict):
             for key, val in attr_key_or_dict.items():
-                res = _ffi_api.BaseFuncWithAttr(
-                    res._move(), key, tvm.runtime.convert(val))
+                res = _ffi_api.BaseFuncWithAttr(res._move(), key, tvm.runtime.convert(val))
             return res
 
         return _ffi_api.BaseFuncWithAttr(
-            res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value))
+            res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
+        )
index 8b75685..a22d7d3 100644 (file)
@@ -39,6 +39,7 @@ def create_updater(node_map, from_ver, to_ver):
     fupdater : function
         The updater function
     """
+
     def _updater(data):
         assert data["attrs"]["tvm_version"].startswith(from_ver)
         nodes = data["nodes"]
@@ -52,6 +53,7 @@ def create_updater(node_map, from_ver, to_ver):
             nodes[idx] = item
         data["attrs"]["tvm_version"] = to_ver
         return data
+
     return _updater
 
 
@@ -63,6 +65,7 @@ def create_updater_06_to_07():
     fupdater : function
         The updater function
     """
+
     def _ftype_var(item, nodes):
         vindex = int(item["attrs"]["var"])
         item["attrs"]["name_hint"] = nodes[vindex]["attrs"]["name"]
@@ -70,13 +73,14 @@ def create_updater_06_to_07():
         nodes[vindex]["type_key"] = ""
         del item["attrs"]["var"]
         assert item["type_key"].startswith("relay.")
-        item["type_key"] = item["type_key"][len("relay."):]
+        item["type_key"] = item["type_key"][len("relay.") :]
         return item
 
     def _rename(new_name):
         def _convert(item, _):
             item["type_key"] = new_name
             return item
+
         return _convert
 
     def _update_tir_var(new_name):
@@ -84,6 +88,7 @@ def create_updater_06_to_07():
             item["type_key"] = new_name
             item["attrs"]["type_annotation"] = "0"
             return item
+
         return _convert
 
     def _update_global_key(item, _):
@@ -100,12 +105,11 @@ def create_updater_06_to_07():
             val = jdata["nodes"][root_idx]
             sidx = len(nodes)
             nodes.append(val)
-            item["attrs"][key] = '%d' % sidx
+            item["attrs"][key] = "%d" % sidx
             return item
 
         return _convert
 
-
     node_map = {
         # Base IR
         "SourceName": _update_global_key,
@@ -179,7 +183,10 @@ def create_updater_06_to_07():
         "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
         "Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
         "Buffer": [
-            _rename("tir.Buffer"), _update_from_std_str("name"), _update_from_std_str("scope")],
+            _rename("tir.Buffer"),
+            _update_from_std_str("name"),
+            _update_from_std_str("scope"),
+        ],
     }
     return create_updater(node_map, "0.6", "0.7")
 
index 2f6fd20..378991a 100644 (file)
@@ -35,6 +35,7 @@ class IRModule(Node):
     functions: Optional[dict].
         Map of global var to BaseFunc
     """
+
     def __init__(self, functions=None, type_definitions=None):
         if functions is None:
             functions = {}
index da546ce..7b06c3d 100644 (file)
 # pylint: disable=invalid-name
 """Primitive operators in the TVM IR."""
 import tvm._ffi
-from . expr import RelayExpr
+from .expr import RelayExpr
 from . import _ffi_api
 
 
 @tvm._ffi.register_object("Op")
 class Op(RelayExpr):
     """Primitive operator in the IR."""
+
     def __init__(self):
         raise RuntimeError("Cannot create op, use get instead")
 
@@ -107,8 +108,10 @@ def register_op_attr(op_name, attr_key, value=None, level=10):
     fregister : function
         Register function if value is not specified.
     """
+
     def _register(v):
         """internal register function"""
         _ffi_api.RegisterOpAttr(op_name, attr_key, v, level)
         return v
+
     return _register(value) if value is not None else _register
index 99286ed..22b15a3 100644 (file)
@@ -36,9 +36,9 @@ class TensorType(Type):
     dtype : Optional[str]
         The content data type.
     """
+
     def __init__(self, shape, dtype="float32"):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TensorType, shape, dtype)
+        self.__init_handle_by_constructor__(_ffi_api.TensorType, shape, dtype)
 
     @property
     def concrete_shape(self):
index 358ad19..bb230ca 100644 (file)
@@ -25,6 +25,7 @@ import tvm.runtime
 
 from . import _ffi_transform_api
 
+
 @tvm._ffi.register_object("transform.PassInfo")
 class PassInfo(tvm.runtime.Object):
     """The class contains the meta data required by a pass. It is the
@@ -45,8 +46,7 @@ class PassInfo(tvm.runtime.Object):
     """
 
     def __init__(self, opt_level, name, required=None):
-        self.__init_handle_by_constructor__(
-            _ffi_transform_api.PassInfo, opt_level, name, required)
+        self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required)
 
 
 @tvm._ffi.register_object("transform.PassContext")
@@ -68,25 +68,22 @@ class PassContext(tvm.runtime.Object):
     config : Optional[Dict[str, Object]]
         Additional configurations for specific passes.
     """
-    def __init__(self,
-                 opt_level=2,
-                 required_pass=None,
-                 disabled_pass=None,
-                 trace=None,
-                 config=None):
+
+    def __init__(
+        self, opt_level=2, required_pass=None, disabled_pass=None, trace=None, config=None
+    ):
         required = list(required_pass) if required_pass else []
         if not isinstance(required, (list, tuple)):
-            raise TypeError("required_pass is expected to be the type of " +
-                            "list/tuple/set.")
+            raise TypeError("required_pass is expected to be the type of " + "list/tuple/set.")
 
         disabled = list(disabled_pass) if disabled_pass else []
         if not isinstance(disabled, (list, tuple)):
-            raise TypeError("disabled_pass is expected to be the type of " +
-                            "list/tuple/set.")
+            raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.")
 
         config = config if config else None
-        self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
-                                            required, disabled, trace, config)
+        self.__init_handle_by_constructor__(
+            _ffi_transform_api.PassContext, opt_level, required, disabled, trace, config
+        )
 
     def __enter__(self):
         _ffi_transform_api.EnterPassContext(self)
@@ -167,11 +164,8 @@ class Sequential(Pass):
     required : Optional[List[str]]
         The list of passes that the sequential pass is dependent on.
     """
-    def __init__(self,
-                 passes=None,
-                 opt_level=2,
-                 name="sequential",
-                 required=None):
+
+    def __init__(self, passes=None, opt_level=2, name="sequential", required=None):
         passes = passes if passes else []
         if not isinstance(passes, (list, tuple)):
             raise TypeError("passes must be a list of Pass objects.")
@@ -180,14 +174,17 @@ class Sequential(Pass):
         if not isinstance(required, (list, tuple)):
             raise TypeError("Required is expected to be the type of list/tuple.")
 
-        self.__init_handle_by_constructor__(_ffi_transform_api.Sequential,
-                                            passes, opt_level, name, required)
+        self.__init_handle_by_constructor__(
+            _ffi_transform_api.Sequential, passes, opt_level, name, required
+        )
 
 
 def _wrap_class_module_pass(pass_cls, pass_info):
     """Wrap a python class as function pass"""
+
     class PyModulePass(ModulePass):
         """Internal wrapper class to create a class instance."""
+
         def __init__(self, *args, **kwargs):
             # initialize handle in cass pass_cls creation failed.fg
             self.handle = None
@@ -196,8 +193,10 @@ def _wrap_class_module_pass(pass_cls, pass_info):
             # avoid a cyclic dependency
             def _pass_func(mod, ctx):
                 return inst.transform_module(mod, ctx)
+
             self.__init_handle_by_constructor__(
-                _ffi_transform_api.MakeModulePass, _pass_func, pass_info)
+                _ffi_transform_api.MakeModulePass, _pass_func, pass_info
+            )
             self._inst = inst
 
         def __getattr__(self, name):
@@ -298,8 +297,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
 
     required = required if required else []
     if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of " +
-                        "list/tuple.")
+        raise TypeError("Required is expected to be the type of " + "list/tuple.")
 
     def create_module_pass(pass_arg):
         """Internal function that creates a module pass"""
index e980011..06f6009 100644 (file)
@@ -25,6 +25,7 @@ from . import _ffi_api
 
 class Type(Node):
     """The base class of all types."""
+
     def __eq__(self, other):
         """Compare two types for structural equivalence."""
         return bool(tvm.ir.structural_equal(self, other))
@@ -39,6 +40,7 @@ class Type(Node):
 
 class TypeKind(IntEnum):
     """Possible kinds of TypeVars."""
+
     Type = 0
     ShapeVar = 1
     BaseType = 2
@@ -56,9 +58,9 @@ class PrimType(Type):
     dtype : str
         The runtime data type relates to the primtype.
     """
+
     def __init__(self, dtype):
-        self.__init_handle_by_constructor__(
-            _ffi_api.PrimType, dtype)
+        self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype)
 
 
 @tvm._ffi.register_object("PointerType")
@@ -70,9 +72,9 @@ class PointerType(Type):
     element_type : tvm.ir.Type
         The type of pointer's element.
     """
+
     def __init__(self, element_type):
-        self.__init_handle_by_constructor__(
-            _ffi_api.PointerType, element_type)
+        self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type)
 
 
 @tvm._ffi.register_object("TypeVar")
@@ -92,9 +94,9 @@ class TypeVar(Type):
     kind : Optional[TypeKind]
         The kind of the type parameter.
     """
+
     def __init__(self, name_hint, kind=TypeKind.Type):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TypeVar, name_hint, kind)
+        self.__init_handle_by_constructor__(_ffi_api.TypeVar, name_hint, kind)
 
     def __call__(self, *args):
         """Create a type call from this type.
@@ -111,6 +113,7 @@ class TypeVar(Type):
         """
         # pylint: disable=import-outside-toplevel
         from .type_relation import TypeCall
+
         return TypeCall(self, args)
 
 
@@ -127,9 +130,9 @@ class GlobalTypeVar(Type):
     kind : Optional[TypeKind]
         The kind of the type parameter.
     """
+
     def __init__(self, name_hint, kind=TypeKind.AdtHandle):
-        self.__init_handle_by_constructor__(
-            _ffi_api.GlobalTypeVar, name_hint, kind)
+        self.__init_handle_by_constructor__(_ffi_api.GlobalTypeVar, name_hint, kind)
 
     def __call__(self, *args):
         """Create a type call from this type.
@@ -146,6 +149,7 @@ class GlobalTypeVar(Type):
         """
         # pylint: disable=import-outside-toplevel
         from .type_relation import TypeCall
+
         return TypeCall(self, args)
 
 
@@ -160,8 +164,7 @@ class TupleType(Type):
     """
 
     def __init__(self, fields):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TupleType, fields)
+        self.__init_handle_by_constructor__(_ffi_api.TupleType, fields)
 
 
 @tvm._ffi.register_object("TypeConstraint")
@@ -195,17 +198,15 @@ class FuncType(Type):
     type_constraints : Optional[List[tvm.relay.TypeConstraint]]
         The type constraints.
     """
-    def __init__(self,
-                 arg_types,
-                 ret_type,
-                 type_params=None,
-                 type_constraints=None):
+
+    def __init__(self, arg_types, ret_type, type_params=None, type_constraints=None):
         if type_params is None:
             type_params = []
         if type_constraints is None:
             type_constraints = []
         self.__init_handle_by_constructor__(
-            _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)
+            _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints
+        )
 
 
 @tvm._ffi.register_object("IncompleteType")
@@ -215,9 +216,9 @@ class IncompleteType(Type):
     kind : Optional[TypeKind]
         The kind of the incomplete type.
     """
+
     def __init__(self, kind=TypeKind.Type):
-        self.__init_handle_by_constructor__(
-            _ffi_api.IncompleteType, kind)
+        self.__init_handle_by_constructor__(_ffi_api.IncompleteType, kind)
 
 
 @tvm._ffi.register_object("relay.RefType")
@@ -229,5 +230,6 @@ class RelayRefType(Type):
     value: Type
         The value type.
     """
+
     def __init__(self, value):
         self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value)
index bacb2c2..dba42db 100644 (file)
@@ -38,6 +38,7 @@ class TypeCall(Type):
     type_call: TypeCall
         The type function application.
     """
+
     def __init__(self, func, args):
         self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
 
@@ -69,6 +70,6 @@ class TypeRelation(TypeConstraint):
     type_relation : tvm.ir.TypeRelation
         The type relation.
     """
+
     def __init__(self, func, args, num_inputs, attrs):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TypeRelation, func, args, num_inputs, attrs)
+        self.__init_handle_by_constructor__(_ffi_api.TypeRelation, func, args, num_inputs, attrs)
index cb3c843..57e1756 100644 (file)
@@ -42,8 +42,10 @@ DEVICE_SECTIONS = [
     "stack",
 ]
 
+
 class LibType(Enum):
     """Enumeration of library types that can be compiled and loaded onto a device"""
+
     # library to be used as a MicroTVM runtime
     RUNTIME = 0
     # library to be used as an operator
@@ -88,7 +90,8 @@ class Session:
         runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj")
         options = ["-I{}".format(get_micro_host_driven_dir())]
         dev_funcs["create_micro_lib"](
-            runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options)
+            runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options
+        )
 
         comms_method = config["comms_method"]
         if comms_method == "openocd":
@@ -100,8 +103,9 @@ class Session:
         else:
             raise RuntimeError(f"unknown communication method: f{self.comms_method}")
 
-        assert all(map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS)), \
-            "not all sections have an assigned memory layout"
+        assert all(
+            map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS)
+        ), "not all sections have an assigned memory layout"
         self.module = _CreateSession(
             comms_method,
             runtime_obj_path,
@@ -127,7 +131,8 @@ class Session:
             self.use_device_timer,
             server_addr,
             server_port,
-            config.get("debug_func"))
+            config.get("debug_func"),
+        )
         self._enter = self.module["enter"]
         self._exit = self.module["exit"]
         self.get_last_batch_time = self.module["get_last_batch_time"]
@@ -142,7 +147,7 @@ class Session:
             raise RuntimeError("MicroTVM is currently only supported on Linux")
         # TODO(weberlo): Add 32-bit support.
         # It's primarily the compilation pipeline that isn't compatible.
-        if sys.maxsize <= 2**32:
+        if sys.maxsize <= 2 ** 32:
             raise RuntimeError("MicroTVM is currently only supported on 64-bit host platforms")
 
     def __enter__(self):
@@ -156,8 +161,9 @@ class Session:
 def _calc_max_workspace_usage(src):
     # TODO factor in alignment to the calculation (alloc sizes will be aligned up to the word size)
     alloc_re = re.compile(
-        r'.*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*')
-    free_re = re.compile(r'.*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*')
+        r".*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*"
+    )
+    free_re = re.compile(r".*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*")
     max_usage = 0
     alloc_map = {}
     for line in src.split("\n"):
@@ -175,8 +181,9 @@ def _calc_max_workspace_usage(src):
     return max_usage
 
 
-def create_micro_mod(c_mod, dev_config, lib_src_paths=None, lib_headers=None,
-                     lib_include_paths=None):
+def create_micro_mod(
+    c_mod, dev_config, lib_src_paths=None, lib_headers=None, lib_include_paths=None
+):
     """Produces a micro module from a given module.
 
     Parameters
@@ -209,13 +216,16 @@ def create_micro_mod(c_mod, dev_config, lib_src_paths=None, lib_headers=None,
             LibType.OPERATOR,
             lib_src_paths=lib_src_paths,
             lib_headers=lib_headers,
-            lib_include_paths=lib_include_paths))
+            lib_include_paths=lib_include_paths,
+        ),
+    )
     micro_mod = tvm.runtime.load_module(lib_obj_path)
     return micro_mod
 
 
-def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None,
-                   lib_include_paths=None):
+def cross_compiler(
+    dev_config, lib_type, lib_src_paths=None, lib_headers=None, lib_include_paths=None
+):
     """Create a cross compile function that wraps `create_lib` for a `Binutil` instance.
 
     For use in `tvm.runtime.Module.export_library`.
@@ -252,8 +262,9 @@ def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None,
       fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR)
       c_mod.export_library('dev_lib.obj', fcompile=fcompile)
     """
-    assert (lib_headers is None) == (lib_include_paths is None), \
-        "must specify both `lib_headers` and `lib_include_paths` or neither"
+    assert (lib_headers is None) == (
+        lib_include_paths is None
+    ), "must specify both `lib_headers` and `lib_include_paths` or neither"
 
     if lib_src_paths is None:
         lib_src_paths = []
@@ -263,8 +274,9 @@ def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None,
     for include_path in lib_include_paths:
         include_options.append("-I")
         include_options.append(include_path)
-    create_micro_lib = tvm.micro.device.get_device_funcs(
-        dev_config["device_id"])["create_micro_lib"]
+    create_micro_lib = tvm.micro.device.get_device_funcs(dev_config["device_id"])[
+        "create_micro_lib"
+    ]
     mem_layout = dev_config["mem_layout"]
 
     def compile_func(obj_path, src_path, **kwargs):
@@ -281,8 +293,10 @@ def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None,
             max_ws_usage = _calc_max_workspace_usage(src_contents)
             available_mem = mem_layout["workspace"]["size"]
             if max_ws_usage > available_mem:
-                raise RuntimeError(f"workspace allocations in library ({max_ws_usage}) "
-                                   f"exceed available memory ({available_mem})")
+                raise RuntimeError(
+                    f"workspace allocations in library ({max_ws_usage}) "
+                    f"exceed available memory ({available_mem})"
+                )
         # inject headers into new source path, if requested
         if lib_headers:
             headers_to_inject = "\n".join(map(lambda s: f"#include <{s}>", lib_headers)) + "\n"
@@ -293,6 +307,7 @@ def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None,
                 f.write(new_src_contents)
 
         create_micro_lib(obj_path, src_path, lib_type, options, lib_src_paths=lib_src_paths)
+
     return _cc.cross_compiler(compile_func, output_format="obj")
 
 
@@ -305,8 +320,9 @@ def get_micro_host_driven_dir():
         directory path
     """
     micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
-    micro_host_driven_dir = os.path.join(micro_dir, "..", "..", "..",
-                                         "src", "runtime", "micro", "host_driven")
+    micro_host_driven_dir = os.path.join(
+        micro_dir, "..", "..", "..", "src", "runtime", "micro", "host_driven"
+    )
     return micro_host_driven_dir
 
 
@@ -319,8 +335,9 @@ def get_micro_device_dir():
         directory path
     """
     micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
-    micro_device_dir = os.path.join(micro_dir, "..", "..", "..",
-                                    "src", "runtime", "micro", "device")
+    micro_device_dir = os.path.join(
+        micro_dir, "..", "..", "..", "src", "runtime", "micro", "device"
+    )
     return micro_device_dir
 
 
index 3f4efff..bd66601 100644 (file)
@@ -39,6 +39,7 @@ DEFAULT_SECTION_CONSTRAINTS = {
     "stack": (32, MemConstraint.ABSOLUTE_BYTES),
 }
 
+
 def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
     """Wrapper over `create_micro_lib_base` to add device-specific options
 
@@ -80,11 +81,17 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=N
         "-Wno-unused-variable",
         "-Wno-unused-parameter",
         "-I{}".format(os.environ["CMSIS_ST_PATH"]),
-        "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"])
-        ]
+        "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"]),
+    ]
     create_micro_lib_base(
-        obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options,
-        lib_src_paths=lib_src_paths)
+        obj_path,
+        src_path,
+        TOOLCHAIN_PREFIX,
+        DEVICE_ID,
+        lib_type,
+        options=options,
+        lib_src_paths=lib_src_paths,
+    )
 
 
 def generate_config(server_addr, server_port, section_constraints=None):
@@ -121,7 +128,10 @@ def generate_config(server_addr, server_port, section_constraints=None):
     }
 
 
-register_device(DEVICE_ID, {
-    "create_micro_lib": create_micro_lib,
-    "generate_config": generate_config,
-})
+register_device(
+    DEVICE_ID,
+    {
+        "create_micro_lib": create_micro_lib,
+        "generate_config": generate_config,
+    },
+)
index 767284c..fef0d11 100644 (file)
@@ -27,6 +27,7 @@ from tvm.micro import DEVICE_SECTIONS, LibType, get_micro_host_driven_dir, get_m
 
 _DEVICE_REGISTRY = {}
 
+
 def register_device(device_id, device_funcs):
     """Register a device and associated compilation/config functions
 
@@ -63,14 +64,14 @@ def get_device_funcs(device_id):
 
 
 def create_micro_lib_base(
-        out_obj_path,
-        in_src_path,
-        toolchain_prefix,
-        device_id,
-        lib_type,
-        options=None,
-        lib_src_paths=None,
-        ):
+    out_obj_path,
+    in_src_path,
+    toolchain_prefix,
+    device_id,
+    lib_type,
+    options=None,
+    lib_src_paths=None,
+):
     """Compiles code into a binary for the target micro device.
 
     Parameters
@@ -114,7 +115,7 @@ def create_micro_lib_base(
         "-nostdlib",
         "-fdata-sections",
         "-ffunction-sections",
-        ]
+    ]
     if options is not None:
         base_compile_cmd += options
 
@@ -174,6 +175,7 @@ def create_micro_lib_base(
 # TODO we shouldn't need an enum for this. too much bureaucracy.
 class MemConstraint(enum.Enum):
     """Represents a constraint on the device's memory layout"""
+
     ABSOLUTE_BYTES = 0
     WEIGHT = 1
 
@@ -197,12 +199,10 @@ def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints
     """
     assert word_size_bits in (32, 64), "only 32- or 64-bit devices are supported now"
     word_size_bytes = word_size_bits // 8
-    byte_sum = sum(x[0]
-                   for x in section_constraints.values()
-                   if x[1] == MemConstraint.ABSOLUTE_BYTES)
-    weight_sum = sum(x[0]
-                     for x in section_constraints.values()
-                     if x[1] == MemConstraint.WEIGHT)
+    byte_sum = sum(
+        x[0] for x in section_constraints.values() if x[1] == MemConstraint.ABSOLUTE_BYTES
+    )
+    weight_sum = sum(x[0] for x in section_constraints.values() if x[1] == MemConstraint.WEIGHT)
     assert byte_sum <= available_mem
     available_weight_mem = available_mem - byte_sum
 
@@ -211,8 +211,9 @@ def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints
     for section in DEVICE_SECTIONS:
         (val, cons_type) = section_constraints[section]
         if cons_type == MemConstraint.ABSOLUTE_BYTES:
-            assert val % word_size_bytes == 0, \
-                f"constraint {val} for {section} section is not word-aligned"
+            assert (
+                val % word_size_bytes == 0
+            ), f"constraint {val} for {section} section is not word-aligned"
             size = val
             res[section] = {
                 "start": curr_addr,
index cad65b9..c5f0e15 100644 (file)
@@ -21,7 +21,7 @@ from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstra
 
 DEVICE_ID = "host"
 TOOLCHAIN_PREFIX = ""
-WORD_SIZE_BITS = 64 if sys.maxsize > 2**32 else 32
+WORD_SIZE_BITS = 64 if sys.maxsize > 2 ** 32 else 32
 
 # we pretend we only have 320kb in the default case, so we can use `gen_mem_layout`
 DEFAULT_AVAILABLE_MEM = 3200000
@@ -36,6 +36,7 @@ DEFAULT_SECTION_CONSTRAINTS = {
     "stack": (80, MemConstraint.ABSOLUTE_BYTES),
 }
 
+
 def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
     """Wrapper over `create_micro_lib_base` to add device-specific options
 
@@ -62,12 +63,18 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=N
         options = list(options)
     # Cannot increase optimization level on host due to code loading method.
     options.append("-O0")
-    if sys.maxsize > 2**32 and sys.platform.startswith("linux"):
+    if sys.maxsize > 2 ** 32 and sys.platform.startswith("linux"):
         options += ["-mcmodel=large"]
-    options.append('-DUTVM_TARGET_HOST')
+    options.append("-DUTVM_TARGET_HOST")
     create_micro_lib_base(
-        obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options,
-        lib_src_paths=lib_src_paths)
+        obj_path,
+        src_path,
+        TOOLCHAIN_PREFIX,
+        DEVICE_ID,
+        lib_type,
+        options=options,
+        lib_src_paths=lib_src_paths,
+    )
 
 
 def generate_config(available_mem=None, section_constraints=None):
@@ -111,7 +118,10 @@ def generate_config(available_mem=None, section_constraints=None):
     }
 
 
-register_device(DEVICE_ID, {
-    "create_micro_lib": create_micro_lib,
-    "generate_config": generate_config,
-})
+register_device(
+    DEVICE_ID,
+    {
+        "create_micro_lib": create_micro_lib,
+        "generate_config": generate_config,
+    },
+)
index 32881ca..2781566 100644 (file)
@@ -33,6 +33,7 @@ DEFAULT_SECTION_CONSTRAINTS = {
     "stack": (32, MemConstraint.ABSOLUTE_BYTES),
 }
 
+
 def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
     """Wrapper over `create_micro_lib_base` to add device-specific options
 
@@ -60,8 +61,8 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=N
         DEVICE_ID,
         lib_type,
         options=options,
-        lib_src_paths=lib_src_paths
-        )
+        lib_src_paths=lib_src_paths,
+    )
 
 
 def generate_config(base_addr, available_mem, server_addr, server_port, section_constraints=None):
@@ -102,7 +103,10 @@ def generate_config(base_addr, available_mem, server_addr, server_port, section_
     }
 
 
-register_device(DEVICE_ID, {
-    "create_micro_lib": create_micro_lib,
-    "generate_config": generate_config,
-})
+register_device(
+    DEVICE_ID,
+    {
+        "create_micro_lib": create_micro_lib,
+        "generate_config": generate_config,
+    },
+)
index c13a28e..e19f4af 100644 (file)
@@ -19,6 +19,7 @@
 
 import json
 
+
 def graph_json_to_c_func_registry(graph_path, func_registry_path):
     """Convert a graph json file to a CRT-compatible FuncRegistry.
 
@@ -34,43 +35,45 @@ def graph_json_to_c_func_registry(graph_path, func_registry_path):
         graph = json.load(json_f)
 
     funcs = []
-    for n in graph['nodes']:
-        if n['op'] != 'tvm_op':
+    for n in graph["nodes"]:
+        if n["op"] != "tvm_op":
             continue
 
-        funcs.append(n['attrs']['func_name'])
+        funcs.append(n["attrs"]["func_name"])
 
-    encoded_funcs = f'\\{len(funcs):03o}' + '\\0'.join(funcs)
+    encoded_funcs = f"\\{len(funcs):03o}" + "\\0".join(funcs)
     lines = [
-        '#include <tvm/runtime/c_runtime_api.h>',
-        '#include <tvm/runtime/crt/module.h>',
-        '#include <stdio.h>',
-        '',
+        "#include <tvm/runtime/c_runtime_api.h>",
+        "#include <tvm/runtime/crt/module.h>",
+        "#include <stdio.h>",
+        "",
     ]
 
     for f in funcs:
-        lines.append(f'extern int {f}(TVMValue* args, int* type_codes, int num_args, '
-                     'TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle);')
+        lines.append(
+            f"extern int {f}(TVMValue* args, int* type_codes, int num_args, "
+            "TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle);"
+        )
 
-    lines.append('static TVMBackendPackedCFunc funcs[] = {')
+    lines.append("static TVMBackendPackedCFunc funcs[] = {")
 
     for f in funcs:
-        lines.append(f'    &{f},')
+        lines.append(f"    &{f},")
 
     lines += [
-        '};',
-        'static const TVMFuncRegistry system_lib_registry = {',
+        "};",
+        "static const TVMFuncRegistry system_lib_registry = {",
         f'       "{encoded_funcs}\\0",',
-        '        funcs,',
-        '};',
-        'static const TVMModule system_lib = {',
-        '    &system_lib_registry,',
-        '};',
-        '',
-        'const TVMModule* TVMSystemLibEntryPoint(void) {',
-        '    return &system_lib;',
-        '}',
-        '',   # blank line to end the file
+        "        funcs,",
+        "};",
+        "static const TVMModule system_lib = {",
+        "    &system_lib_registry,",
+        "};",
+        "",
+        "const TVMModule* TVMSystemLibEntryPoint(void) {",
+        "    return &system_lib;",
+        "}",
+        "",  # blank line to end the file
     ]
-    with open(func_registry_path, 'w') as wrapper_f:
-        wrapper_f.write('\n'.join(lines))
+    with open(func_registry_path, "w") as wrapper_f:
+        wrapper_f.write("\n".join(lines))
index 8001cd4..1189233 100644 (file)
 """The under development unified IR parsing infrastructure."""
 from . import _ffi_api
 
+
 def parse(source, source_name="from_string"):
     return _ffi_api.ParseModule(source_name, source)
 
+
 def parse_expr(source):
     return _ffi_api.ParseExpr("string", source)
 
+
 def fromtext(source, source_name="from_string"):
     return parse(source, source_name)
index d417c2b..7e49461 100644 (file)
@@ -121,6 +121,7 @@ def check_constant(expr):
     """
     return _ffi_api.check_constant(expr)
 
+
 def check_basic_block_normal_form(expr):
     """Check whether an expression is in the basic block form
 
@@ -440,8 +441,10 @@ def get_calibration_data(mod, data):
         offset = int(indices[0])
         in_len = int(indices[1])
         out_len = int(indices[2])
-        value = {"inputs": ref_res[offset:offset + in_len],
-                 "outputs": ref_res[offset + in_len:offset + in_len + out_len]}
+        value = {
+            "inputs": ref_res[offset : offset + in_len],
+            "outputs": ref_res[offset + in_len : offset + in_len + out_len],
+        }
         calib_data[gvar] = value
 
     return calib_data
index f29b726..437b97b 100644 (file)
@@ -37,10 +37,9 @@ class AnnotatedRegionSet(Object):
             The region end annotation.
 
         """
-        self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet,
-                                            expr,
-                                            region_begin_op,
-                                            region_end_op)
+        self.__init_handle_by_constructor__(
+            _ffi_api.AnnotatedRegionSet, expr, region_begin_op, region_end_op
+        )
 
     def __len__(self):
         return len(self.regions)
index 6850267..99e2cdc 100644 (file)
 """The type nodes of the Relay language."""
 from enum import IntEnum
 
+
 class Feature(IntEnum):
     """ The features a program might contain. """
+
     fVar = 0
     fGlobalVar = 1
     fConstant = 2
index 7e8f434..d521748 100644 (file)
@@ -27,10 +27,14 @@ import tvm
 from . import _ffi_api
 
 
-SparseAnalysisResult = namedtuple("SparseAnalysisResult", [
-    "weight_name",
-    "weight_shape",
-])
+SparseAnalysisResult = namedtuple(
+    "SparseAnalysisResult",
+    [
+        "weight_name",
+        "weight_shape",
+    ],
+)
+
 
 def _search_dense_op_weight(expr):
     """Search name of weight in all ```nn.dense``` operator
@@ -80,14 +84,16 @@ def process_params(expr, params, block_size, sparsity_threshold):
             # remove dense weight
             del params[name]
             memo.weight_name.append(name)
-            memo.weight_shape.append(list(sparse_weight.data.shape) +
-                                     list(sparse_weight.indices.shape) +
-                                     list(sparse_weight.indptr.shape))
+            memo.weight_shape.append(
+                list(sparse_weight.data.shape)
+                + list(sparse_weight.indices.shape)
+                + list(sparse_weight.indptr.shape)
+            )
             params[name + ".data"] = tvm.nd.array(sparse_weight.data)
             params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
             params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
     ret = SparseAnalysisResult(
         weight_name=tvm.runtime.convert(memo.weight_name),
-        weight_shape=tvm.runtime.convert(memo.weight_shape)
+        weight_shape=tvm.runtime.convert(memo.weight_shape),
     )
     return ret
index 41be47b..8d4a331 100644 (file)
@@ -30,8 +30,8 @@ from .. import function as _function
 from .. import ty as _ty
 from . import _backend
 
-logger = logging.getLogger('compile_engine')
-autotvm_logger = logging.getLogger('autotvm')
+logger = logging.getLogger("compile_engine")
+autotvm_logger = logging.getLogger("autotvm")
 
 
 @tvm._ffi.register_object("relay.LoweredOutput")
@@ -39,8 +39,7 @@ class LoweredOutput(Object):
     """Lowered output"""
 
     def __init__(self, outputs, implement):
-        self.__init_handle_by_constructor__(
-            _backend._make_LoweredOutput, outputs, implement)
+        self.__init_handle_by_constructor__(_backend._make_LoweredOutput, outputs, implement)
 
 
 @tvm._ffi.register_object("relay.CCacheKey")
@@ -57,14 +56,12 @@ class CCacheKey(Object):
     """
 
     def __init__(self, source_func, target):
-        self.__init_handle_by_constructor__(
-            _backend._make_CCacheKey, source_func, target)
+        self.__init_handle_by_constructor__(_backend._make_CCacheKey, source_func, target)
 
 
 @tvm._ffi.register_object("relay.CCacheValue")
 class CCacheValue(Object):
-    """Value in the CompileEngine, including usage statistics.
-    """
+    """Value in the CompileEngine, including usage statistics."""
 
 
 def _get_cache_key(source_func, target):
@@ -217,9 +214,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
         if cfg.is_fallback:
             # Skip fallback config
             continue
-        logger.info(
-            "Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost
-        )
+        logger.info("Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost)
         if best_cfg is None or best_cfg.cost > cfg.cost:
             best_autotvm_impl = impl
             best_cfg = cfg
@@ -235,9 +230,11 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
         return best_autotvm_impl, outputs[best_autotvm_impl]
     # Use the implementation with highest plevel
     if workloads[best_plevel_impl] is not None:
-        msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
-              "is used, which may bring great performance regression." \
-              % (target, workloads[best_plevel_impl])
+        msg = (
+            "Cannot find config for target=%s, workload=%s. A fallback configuration "
+            "is used, which may bring great performance regression."
+            % (target, workloads[best_plevel_impl])
+        )
         if msg not in autotvm.task.DispatchContext.warning_messages:
             autotvm.task.DispatchContext.warning_messages.add(msg)
             autotvm_logger.warning(msg)
@@ -266,8 +263,7 @@ def lower_call(call, inputs, target):
         new_fields = []
         for field in ret_type.fields:
             if isinstance(field, _ty.TensorType):
-                new_fields.append(_ty.TensorType(
-                    get_shape(field.shape), field.dtype))
+                new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype))
             else:
                 new_fields.append(field)
         ret_type = _ty.TupleType(new_fields)
@@ -285,13 +281,13 @@ def lower_call(call, inputs, target):
             reenable_tracing = True
 
     if not is_dyn:
-        best_impl, outputs = select_implementation(
-            op, call.attrs, inputs, ret_type, target)
+        best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
     else:
         # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes.
         #   Currently, we just use the implementation with highest plevel
         best_impl, outputs = select_implementation(
-            op, call.attrs, inputs, ret_type, target, use_autotvm=False)
+            op, call.attrs, inputs, ret_type, target, use_autotvm=False
+        )
 
     # re-enable AutoTVM tracing
     if reenable_tracing:
@@ -301,8 +297,7 @@ def lower_call(call, inputs, target):
 
 @tvm._ffi.register_object("relay.CompileEngine")
 class CompileEngine(Object):
-    """CompileEngine to get lowered code.
-    """
+    """CompileEngine to get lowered code."""
 
     def __init__(self):
         raise RuntimeError("Cannot construct a CompileEngine")
@@ -329,6 +324,7 @@ class CompileEngine(Object):
             return _backend._CompileEngineLower(self, key)
         except Exception:
             import traceback
+
             msg = traceback.format_exc()
             msg += "Error during compile func\n"
             msg += "--------------------------\n"
@@ -373,7 +369,7 @@ class CompileEngine(Object):
         """
         res = _backend._CompileEngineListItems(self)
         assert len(res) % 2 == 0
-        return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
+        return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]
 
     def dump(self):
         """Return a string representation of engine dump.
index 03170ea..a21a4a8 100644 (file)
@@ -20,6 +20,7 @@ from tvm._ffi.base import string_types
 from tvm._ffi.registry import get_global_func
 from tvm.runtime import ndarray
 
+
 class GraphRuntimeFactoryModule(object):
     """Graph runtime factory module.
     This is a module of graph runtime factory
@@ -74,7 +75,9 @@ class GraphRuntimeFactoryModule(object):
         warnings.warn(
             "legacy graph runtime behaviour of producing json / lib / params will be "
             "removed in the next release ",
-            DeprecationWarning, 2)
+            DeprecationWarning,
+            2,
+        )
         return self
 
     def __next__(self):
index 213a6c6..218bc9f 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=no-else-return
+# pylint: disable=no-else-return
 """The Python interface to the Relay reference interpreter."""
 from __future__ import absolute_import
 
@@ -35,15 +35,13 @@ from ..scope_builder import ScopeBuilder
 @tvm._ffi.register_object("relay.ConstructorValue")
 class ConstructorValue(Object):
     def __init__(self, tag, fields, constructor):
-        self.__init_handle_by_constructor__(
-            _make.ConstructorValue, tag, fields, constructor)
+        self.__init_handle_by_constructor__(_make.ConstructorValue, tag, fields, constructor)
 
 
 @tvm._ffi.register_object("relay.RefValue")
 class RefValue(Object):
     def __init__(self, value):
-        self.__init_handle_by_constructor__(
-            _make.RefValue, value)
+        self.__init_handle_by_constructor__(_make.RefValue, value)
 
 
 def _arg_to_ast(mod, arg):
@@ -56,8 +54,7 @@ def _arg_to_ast(mod, arg):
     elif isinstance(arg, RefValue):
         return RefCreate(_arg_to_ast(mod, arg.value))
     elif isinstance(arg, ConstructorValue):
-        return Call(mod.get_constructor(arg.tag),
-                    [_arg_to_ast(mod, field) for field in arg.fields])
+        return Call(mod.get_constructor(arg.tag), [_arg_to_ast(mod, field) for field in arg.fields])
     elif isinstance(arg, np.ndarray):
         return Constant(nd.array(arg))
     elif isinstance(arg, Constant):
@@ -102,8 +99,9 @@ class Executor(object):
             return args
 
         if kwargs and not isinstance(expr, Function):
-            raise Exception("can only supply keyword parameters for a "
-                            "relay.Function, found {0}".format(expr))
+            raise Exception(
+                "can only supply keyword parameters for a " "relay.Function, found {0}".format(expr)
+            )
 
         params = expr.params
         param_names = [p.name_hint for p in params]
@@ -116,14 +114,16 @@ class Executor(object):
                     raise Exception(
                         "duplicate argument supplied in "
                         "both positional args (at position: {0}), "
-                        "and keyword argument (with name: {1})".format(i, name))
+                        "and keyword argument (with name: {1})".format(i, name)
+                    )
             else:
                 cargs.append(kwargs[name])
 
         if len(cargs) != len(params):
             raise Exception(
                 "insufficient arguments, expected "
-                "{0}, provided {1}".format(len(cargs), len(params)))
+                "{0}, provided {1}".format(len(cargs), len(params))
+            )
 
         return tuple(cargs)
 
@@ -197,6 +197,7 @@ class Interpreter(Executor):
     target : tvm.Target
         The target option to build the function.
     """
+
     def __init__(self, mod, ctx, target):
         self.mod = mod
         self.ctx = ctx
@@ -210,15 +211,20 @@ class Interpreter(Executor):
         opt_mod : tvm.IRModule
             The optimized module.
         """
-        seq = tvm.transform.Sequential([transform.SimplifyInference(),
-                                        transform.FuseOps(0),
-                                        transform.ToANormalForm(),
-                                        transform.InferType()])
+        seq = tvm.transform.Sequential(
+            [
+                transform.SimplifyInference(),
+                transform.FuseOps(0),
+                transform.ToANormalForm(),
+                transform.InferType(),
+            ]
+        )
         return seq(self.mod)
 
     def _make_executor(self, expr=None):
         if expr is None or isinstance(expr, GlobalVar):
             assert self.mod is not None
+
         def _interp_wrapper(*args, **kwargs):
             if expr is None:
                 args = self._convert_args(self.mod["main"], args, kwargs)
@@ -247,4 +253,5 @@ class Interpreter(Executor):
             opt_expr = Call(mod["main"], relay_args)
             _intrp = _backend.CreateInterpreter(mod, self.ctx, self.target)
             return _intrp(opt_expr)
+
         return _interp_wrapper
index e1de326..b0a5e98 100644 (file)
@@ -188,18 +188,18 @@ class VMCompiler(object):
             raise ValueError("Target is not set in env or passed as argument.")
         tgts = {}
         if isinstance(target, (str, tvm.target.Target)):
-            dev_type = tvm.tir.IntImm(
-                "int32", tvm.nd.context(str(target)).device_type)
+            dev_type = tvm.tir.IntImm("int32", tvm.nd.context(str(target)).device_type)
             tgts[dev_type] = tvm.target.Target(target)
         elif isinstance(target, dict):
             for dev, tgt in target.items():
-                dev_type = tvm.tir.IntImm(
-                    "int32", tvm.nd.context(dev).device_type)
+                dev_type = tvm.tir.IntImm("int32", tvm.nd.context(dev).device_type)
                 tgts[dev_type] = tvm.target.Target(tgt)
         else:
-            raise TypeError("target is expected to be str, tvm.target.Target, " +
-                            "or dict of str to str/tvm.target.Target, but received " +
-                            "{}".format(type(target)))
+            raise TypeError(
+                "target is expected to be str, tvm.target.Target, "
+                + "or dict of str to str/tvm.target.Target, but received "
+                + "{}".format(type(target))
+            )
         return tgts
 
     def _update_target_host(self, target, target_host):
index 2c35681..323a8f6 100644 (file)
@@ -34,7 +34,8 @@ def _std_path():
 @tvm._ffi.register_object("relay.Id")
 class Id(Object):
     """Unique identifier(name) used in Var.
-       Guaranteed to be stable across all passes.
+    Guaranteed to be stable across all passes.
     """
+
     def __init__(self):
         raise RuntimeError("Cannot directly construct Id")
index f77988e..0b68c8e 100644 (file)
@@ -43,17 +43,18 @@ def _update_target(target):
 
     tgts = {}
     if isinstance(target, (str, Target)):
-        dev_type = tvm_expr.IntImm(
-            "int32", _nd.context(str(target)).device_type)
+        dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
         tgts[dev_type] = Target(target)
     elif isinstance(target, dict):
         for dev, tgt in target.items():
             dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
             tgts[dev_type] = Target(tgt)
     else:
-        raise TypeError("target is expected to be str or " +
-                        "tvm.target.Target, but received " +
-                        "{}".format(type(target)))
+        raise TypeError(
+            "target is expected to be str or "
+            + "tvm.target.Target, but received "
+            + "{}".format(type(target))
+        )
     return tgts
 
 
@@ -185,7 +186,7 @@ class BuildModule(object):
         return ret
 
 
-def build(mod, target=None, target_host=None, params=None, mod_name='default'):
+def build(mod, target=None, target_host=None, params=None, mod_name="default"):
     """Helper function that builds a Relay function to run on TVM graph
     runtime.
 
@@ -236,15 +237,15 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default'):
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
             "instead of deprecated parameter mod (tvm.relay.function.Function)",
-            DeprecationWarning)
+            DeprecationWarning,
+        )
 
     target = _update_target(target)
 
     if isinstance(target_host, (str, Target)):
         target_host = Target(target_host)
     elif target_host:
-        raise ValueError("target host must be the type of str, " +
-                         "tvm.target.Target, or None")
+        raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None")
 
     # If current dispatch context is fallback context (the default root context),
     # then load pre-tuned parameters from TopHub
@@ -255,10 +256,8 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default'):
 
     with tophub_context:
         bld_mod = BuildModule()
-        graph_json, mod, params = bld_mod.build(
-            mod, target, target_host, params)
-        mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
-            graph_json, mod, mod_name, params)
+        graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
+        mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
         return mod
 
 
@@ -297,7 +296,8 @@ def optimize(mod, target=None, params=None):
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
             "instead of deprecated parameter func (tvm.relay.function.Function)",
-            DeprecationWarning)
+            DeprecationWarning,
+        )
 
     target = _update_target(target)
 
@@ -365,12 +365,10 @@ class GraphExecutor(_interpreter.Executor):
             self.mod["main"] = expr
         ret_type = self.mod["main"].checked_type.ret_type
         if _ty.is_dynamic(ret_type):
-            raise ValueError("Graph Runtime only supports static graphs, got output type",
-                             ret_type)
-        num_outputs = len(ret_type.fields) if isinstance(
-            ret_type, _ty.TupleType) else 1
+            raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type)
+        num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
         mod = build(self.mod, target=self.target)
-        gmodule = _graph_rt.GraphModule(mod['default'](self.ctx))
+        gmodule = _graph_rt.GraphModule(mod["default"](self.ctx))
 
         def _graph_wrapper(*args, **kwargs):
             args = self._convert_args(self.mod["main"], args, kwargs)
@@ -390,10 +388,7 @@ class GraphExecutor(_interpreter.Executor):
         return _graph_wrapper
 
 
-def create_executor(kind="debug",
-                    mod=None,
-                    ctx=None,
-                    target="llvm"):
+def create_executor(kind="debug", mod=None, ctx=None, target="llvm"):
     """Factory function to create an executor.
 
     Parameters
index ab0caa2..8feb452 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Optimizations involves changing of paramters"""
 
 from . import bsr_dense
index cc3e5de..5f5875e 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Automatic convert model from dense to block sparse"""
 
 from tvm import relay
@@ -22,6 +22,7 @@ from tvm.relay.analysis.sparse_dense import process_params
 
 from .utils import _run_opt_pass
 
+
 def convert(func, params, blocksize, sparsity_threshold):
     """Convert a dense func and according parameters to block sparse
 
@@ -48,10 +49,6 @@ def convert(func, params, blocksize, sparsity_threshold):
     """
     weight_info = process_params(func, params, blocksize, sparsity_threshold)
     new_func = _run_opt_pass(
-        func,
-        relay.transform.DenseToSparse(
-            weight_info.weight_name,
-            weight_info.weight_shape
-        )
+        func, relay.transform.DenseToSparse(weight_info.weight_name, weight_info.weight_shape)
     )
     return new_func, params
index 345c579..2892c6c 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Automatic optimize fc tranpose"""
 import numpy as np
 
@@ -55,6 +55,6 @@ def convert(func, params):
         func,
         relay.transform.SimplifyFCTranspose(
             weight_info,
-        )
+        ),
     )
     return new_func, params
index 6b46f81..2b58fdc 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Utils functions for optimizations"""
 
 import tvm
 
+
 def _run_opt_pass(expr, opt_pass):
     """Helper function to run pass
 
index 03bdd19..19ad595 100644 (file)
@@ -39,14 +39,12 @@ def register_df_node(type_key=None):
         The type key of the node.
     """
     if not isinstance(type_key, str):
-        return tvm._ffi.register_object(
-            "relay.dataflow_pattern." + type_key.__name__)(type_key)
+        return tvm._ffi.register_object("relay.dataflow_pattern." + type_key.__name__)(type_key)
     return tvm._ffi.register_object(type_key)
 
 
 class DFPattern(Node):
-    """Base class of all Patterns.
-    """
+    """Base class of all Patterns."""
 
     def __call__(self, *args):
         return CallPattern(self, list(args))
@@ -146,10 +144,12 @@ class DFPattern(Node):
         """
         return match(self, expr)
 
-    def partition(self,
-                  expr: Expr,
-                  attrs: Optional[Dict[str, Object]] = None,
-                  check: Callable[[Expr], bool] = lambda x: True) -> Expr:
+    def partition(
+        self,
+        expr: Expr,
+        attrs: Optional[Dict[str, Object]] = None,
+        check: Callable[[Expr], bool] = lambda x: True,
+    ) -> Expr:
         """
         Parition the expression into functions defined by this pattern
 
@@ -486,8 +486,8 @@ class VarPattern(DFPattern):
 
 @register_df_node
 class ConstantPattern(DFPattern):
-    """A pattern matching a Relay Constant.
-    """
+    """A pattern matching a Relay Constant."""
+
     def __init__(self):
         self.__init_handle_by_constructor__(ffi.ConstantPattern)
 
@@ -512,11 +512,13 @@ class CallPattern(DFPattern):
         used in advanced usecase of template functions.
     """
 
-    def __init__(self,
-                 op: "DFPattern",
-                 args: List["DFPattern"],
-                 attrs: Optional[tvm.ir.attrs.Attrs] = None,
-                 type_args: Optional[List[tvm.ir.type.Type]] = None):
+    def __init__(
+        self,
+        op: "DFPattern",
+        args: List["DFPattern"],
+        attrs: Optional[tvm.ir.attrs.Attrs] = None,
+        type_args: Optional[List[tvm.ir.type.Type]] = None,
+    ):
         if not type_args:
             type_args = []
         self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args)
@@ -583,8 +585,7 @@ class AltPattern(DFPattern):
 
 @register_df_node
 class WildcardPattern(DFPattern):
-    """A pattern which matches anything.
-    """
+    """A pattern which matches anything."""
 
     def __init__(self):
         self.__init_handle_by_constructor__(ffi.WildcardPattern)
@@ -694,6 +695,7 @@ class DFPatternCallback:
     require_type: bool
         Whether InferType is required to be run before the callback.
     """
+
     def __init__(self, require_type=False):
         self.pattern = None
         self.require_type = require_type
@@ -734,8 +736,10 @@ class DFPatternCallback:
         """
         raise "Unimplemented"
 
+
 class _DFPatternCallback(Object):
     """C++ implemenation"""
+
     def __init__(self, pattern, callback, require_type):
         self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)
 
@@ -769,10 +773,12 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
     return ffi.rewrite(tmp, expr, mod)
 
 
-def partition(pattern: "DFPattern",
-              expr: Expr,
-              attrs: Optional[Dict[str, Object]] = None,
-              check: Callable[[Expr], bool] = lambda x: True) -> Expr:
+def partition(
+    pattern: "DFPattern",
+    expr: Expr,
+    attrs: Optional[Dict[str, Object]] = None,
+    check: Callable[[Expr], bool] = lambda x: True,
+) -> Expr:
     """
     Parition the expression into a series of functions that match the pattern
 
index 838eab5..87c8755 100644 (file)
@@ -21,13 +21,17 @@ import tvm._ffi
 # pylint: disable=unused-argument, import-outside-toplevel
 def _debugger_init(expr, stack):
     import pdb
+
     pdb.set_trace()
 
+
 @tvm._ffi.register_func("relay.debug")
 def _debug(*args):
     import pdb
+
     pdb.set_trace()
 
+
 # pylint: disable=unused-argument
 @tvm._ffi.register_func("relay.debug_interp")
 def _debug_interp(*args):
index 106edc2..6d30464 100644 (file)
@@ -35,8 +35,10 @@ Expr = RelayExpr
 # will be registered afterwards
 _op_make = None
 
+
 class ExprWithOp(RelayExpr):
     """Basetype of all relay expressions that defines op overloading."""
+
     def astype(self, dtype):
         """Cast the content type of the current data to dtype.
 
@@ -160,6 +162,7 @@ class ExprWithOp(RelayExpr):
         """
         return Call(self, args)
 
+
 @tvm._ffi.register_object("relay.Constant")
 class Constant(ExprWithOp):
     """A constant expression in Relay.
@@ -169,6 +172,7 @@ class Constant(ExprWithOp):
     data : tvm.nd.NDArray
         The data content of the constant expression.
     """
+
     def __init__(self, data):
         self.__init_handle_by_constructor__(_ffi_api.Constant, data)
 
@@ -182,6 +186,7 @@ class Tuple(ExprWithOp):
     fields : List[tvm.relay.Expr]
         The fields in the tuple.
     """
+
     def __init__(self, fields):
         self.__init_handle_by_constructor__(_ffi_api.Tuple, fields)
 
@@ -214,9 +219,9 @@ class Var(ExprWithOp):
     type_annotation: tvm.relay.Type, optional
         The type annotation on the variable.
     """
+
     def __init__(self, name_hint, type_annotation=None):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Var, name_hint, type_annotation)
+        self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation)
 
     @property
     def name_hint(self):
@@ -247,11 +252,11 @@ class Call(ExprWithOp):
         The additional type arguments, this is only
         used in advanced usecase of template functions.
     """
+
     def __init__(self, op, args, attrs=None, type_args=None):
         if not type_args:
             type_args = []
-        self.__init_handle_by_constructor__(
-            _ffi_api.Call, op, args, attrs, type_args)
+        self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args)
 
 
 @tvm._ffi.register_object("relay.Let")
@@ -269,9 +274,9 @@ class Let(ExprWithOp):
     body: tvm.relay.Expr
         The body of the let binding.
     """
+
     def __init__(self, variable, value, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Let, variable, value, body)
+        self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body)
 
 
 @tvm._ffi.register_object("relay.If")
@@ -289,9 +294,9 @@ class If(ExprWithOp):
     false_branch: tvm.relay.Expr
         The expression evaluated when condition is false.
     """
+
     def __init__(self, cond, true_branch, false_branch):
-        self.__init_handle_by_constructor__(
-            _ffi_api.If, cond, true_branch, false_branch)
+        self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch)
 
 
 @tvm._ffi.register_object("relay.TupleGetItem")
@@ -306,9 +311,9 @@ class TupleGetItem(ExprWithOp):
     index: int
         The index.
     """
+
     def __init__(self, tuple_value, index):
-        self.__init_handle_by_constructor__(
-            _ffi_api.TupleGetItem, tuple_value, index)
+        self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)
 
 
 @tvm._ffi.register_object("relay.RefCreate")
@@ -319,6 +324,7 @@ class RefCreate(ExprWithOp):
     value: tvm.relay.Expr
        The initial value.
     """
+
     def __init__(self, value):
         self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
 
@@ -331,6 +337,7 @@ class RefRead(ExprWithOp):
     ref: tvm.relay.Expr
          The reference.
     """
+
     def __init__(self, ref):
         self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
 
@@ -347,6 +354,7 @@ class RefWrite(ExprWithOp):
     value: tvm.relay.Expr
         The new value.
     """
+
     def __init__(self, ref, value):
         self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
 
@@ -358,6 +366,7 @@ class TempExpr(ExprWithOp):
     useful to define intermediate result in the
     rewriting pass such as layout or type transformation.
     """
+
     def realize(self):
         """Convert the expression to a normal(non-temp) Expr.
 
@@ -383,6 +392,7 @@ class TupleWrapper(object):
     size: int
         The size of the tuple.
     """
+
     def __init__(self, tuple_value, size):
         self.tuple_value = tuple_value
         self.size = size
@@ -411,17 +421,13 @@ class TupleWrapper(object):
         return self.size
 
     def __repr__(self):
-        return ("TupleWrapper(" + self.tuple_value.__repr__() +
-                ", " + str(self.size) + ")")
+        return "TupleWrapper(" + self.tuple_value.__repr__() + ", " + str(self.size) + ")"
 
     def astype(self, _):
         raise TypeError("astype cannot be used on tuple")
 
 
-def var(name_hint,
-        type_annotation=None,
-        shape=None,
-        dtype="float32"):
+def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
     """Create a new tvm.relay.Var.
 
     This is a simple wrapper function that allows specify
@@ -492,10 +498,9 @@ def const(value, dtype=None):
 
     if not dtype:
         # when dtype is None: int maps to "int32", float maps to "float32"
-        map_dtype = {
-            _np.dtype('int64'): _np.int32,
-            _np.dtype('float64'): _np.float32
-            }.get(value.dtype, None)
+        map_dtype = {_np.dtype("int64"): _np.int32, _np.dtype("float64"): _np.float32}.get(
+            value.dtype, None
+        )
         if map_dtype:
             value = value.astype(map_dtype)
 
index fd9b253..0a37e4d 100644 (file)
@@ -24,6 +24,7 @@ from .expr import If, Tuple, TupleGetItem, Constant
 from .expr import RefCreate, RefRead, RefWrite
 from .adt import Constructor, Match, Clause
 
+
 class ExprFunctor:
     """
     An abstract visitor defined over Expr.
@@ -31,6 +32,7 @@ class ExprFunctor:
     Defines the default dispatch over expressions, and
     implements memoization.
     """
+
     def __init__(self):
         self.memo_map = {}
 
@@ -132,6 +134,7 @@ class ExprVisitor(ExprFunctor):
 
     The default behavior recursively traverses the AST.
     """
+
     def visit_tuple(self, tup):
         for x in tup.fields:
             self.visit(x)
@@ -195,15 +198,11 @@ class ExprMutator(ExprFunctor):
     The default behavior recursively traverses the AST
     and reconstructs the AST.
     """
+
     def visit_function(self, fn):
         new_params = [self.visit(x) for x in fn.params]
         new_body = self.visit(fn.body)
-        return Function(
-            list(new_params),
-            new_body,
-            fn.ret_type,
-            fn.type_params,
-            fn.attrs)
+        return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
 
     def visit_let(self, let):
         new_var = self.visit(let.var)
@@ -223,10 +222,7 @@ class ExprMutator(ExprFunctor):
         return global_var
 
     def visit_if(self, ite):
-        return If(
-            self.visit(ite.cond),
-            self.visit(ite.true_branch),
-            self.visit(ite.false_branch))
+        return If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch))
 
     def visit_tuple(self, tup):
         return Tuple([self.visit(field) for field in tup.fields])
@@ -253,7 +249,8 @@ class ExprMutator(ExprFunctor):
         return Match(
             self.visit(m.data),
             [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
-            complete=m.complete)
+            complete=m.complete,
+        )
 
     def visit_ref_create(self, r):
         return RefCreate(self.visit(r.value))
index b7bcbde..caf4f1a 100644 (file)
@@ -29,11 +29,12 @@ from ... import nd as _nd
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
 
-__all__ = ['from_caffe']
+__all__ = ["from_caffe"]
 
 
 class OperatorConverter(object):
     """ Operator Converted for converting Caffe ops to Relay ops """
+
     def __init__(self, init_layer_dict, predict_layer, exp_tab):
         self.init_layer_dict = init_layer_dict
         self.predict_layer = predict_layer
@@ -42,26 +43,26 @@ class OperatorConverter(object):
         self.changed_layers = None
 
         self.convert_map = {
-            'BatchNorm': self.convert_batch_norm,
-            'Concat': self.convert_concat,
-            'Convolution': self.convert_conv,
-            'Crop': self.convert_crop,
-            'Deconvolution': self.convert_deconv,
-            'Dropout': self.convert_dropout,
-            'Eltwise': self.convert_eltwise,
-            'Flatten': self.convert_flatten,
-            'InnerProduct': self.convert_innerproduct,
-            'Input': None,
-            'LRN': self.convert_lrn,
-            'Pooling': self.convert_pooling,
-            'PReLU': self.convert_prelu,
-            'ReLU': self.convert_relu,
-            'Reshape': self.convert_reshape,
-            'Scale': self.convert_scale,
-            'Sigmoid': self.convert_sigmoid,
-            'Slice': self.convert_slice,
-            'Softmax': self.convert_softmax,
-            'TanH': self.convert_tanh,
+            "BatchNorm": self.convert_batch_norm,
+            "Concat": self.convert_concat,
+            "Convolution": self.convert_conv,
+            "Crop": self.convert_crop,
+            "Deconvolution": self.convert_deconv,
+            "Dropout": self.convert_dropout,
+            "Eltwise": self.convert_eltwise,
+            "Flatten": self.convert_flatten,
+            "InnerProduct": self.convert_innerproduct,
+            "Input": None,
+            "LRN": self.convert_lrn,
+            "Pooling": self.convert_pooling,
+            "PReLU": self.convert_prelu,
+            "ReLU": self.convert_relu,
+            "Reshape": self.convert_reshape,
+            "Scale": self.convert_scale,
+            "Sigmoid": self.convert_sigmoid,
+            "Slice": self.convert_slice,
+            "Softmax": self.convert_softmax,
+            "TanH": self.convert_tanh,
         }
 
     def convert_flatten(self, op):
@@ -89,29 +90,27 @@ class OperatorConverter(object):
         assert lhs_shape == rhs_shape, "input tensors shape should be equal"
 
         eltwise_params = op.eltwise_param
-        eltwise_type_dict = ['PROD', 'SUM', 'MAX']
+        eltwise_type_dict = ["PROD", "SUM", "MAX"]
         eltwise_type = eltwise_params.operation
         coeff = list(eltwise_params.coeff)
 
-        if eltwise_type_dict[eltwise_type] == 'PROD':
+        if eltwise_type_dict[eltwise_type] == "PROD":
             out = _op.multiply(lhs_expr, rhs_expr)
-        elif eltwise_type_dict[eltwise_type] == 'SUM':
+        elif eltwise_type_dict[eltwise_type] == "SUM":
             if coeff:
-                left_coeff_expr = self.exp_tab.new_const(
-                    np.asarray(coeff[0], np.float32))
-                right_coeff_expr = self.exp_tab.new_const(
-                    np.asarray(coeff[1], np.float32))
+                left_coeff_expr = self.exp_tab.new_const(np.asarray(coeff[0], np.float32))
+                right_coeff_expr = self.exp_tab.new_const(np.asarray(coeff[1], np.float32))
                 lhs_expr_scale = _op.multiply(lhs_expr, left_coeff_expr)
                 rhs_expr_scale = _op.multiply(rhs_expr, right_coeff_expr)
                 out = _op.add(lhs_expr_scale, rhs_expr_scale)
             else:
                 out = _op.add(lhs_expr, rhs_expr)
-        elif eltwise_type_dict[eltwise_type] == 'MAX':
+        elif eltwise_type_dict[eltwise_type] == "MAX":
             out = _op.maximum(lhs_expr, rhs_expr)
         else:
             raise tvm.error.OpNotImplemented(
-                "eltwise_type {} is not supported for frontend Caffe.".format(
-                    eltwise_type))
+                "eltwise_type {} is not supported for frontend Caffe.".format(eltwise_type)
+            )
 
         return out
 
@@ -124,41 +123,40 @@ class OperatorConverter(object):
         params = dict()
         # parse kernel size
         if conv_params.kernel_h > 0 or conv_params.kernel_w > 0:
-            params['kernel_size'] = (conv_params.kernel_h,
-                                     conv_params.kernel_w)
+            params["kernel_size"] = (conv_params.kernel_h, conv_params.kernel_w)
         else:
             ksize_h = nonzone(conv_params.kernel_size, 0, 1)
             ksize_w = nonzone(conv_params.kernel_size, 1, ksize_h)
-            params['kernel_size'] = (ksize_h, ksize_w)
+            params["kernel_size"] = (ksize_h, ksize_w)
 
         # parse padding size
         if conv_params.pad_h > 0 or conv_params.pad_w > 0:
-            params['padding'] = (conv_params.pad_h, conv_params.pad_w)
+            params["padding"] = (conv_params.pad_h, conv_params.pad_w)
         else:
             pad_h = nonzone(conv_params.pad, 0, 0)
             pad_w = nonzone(conv_params.pad, 1, pad_h)
-            params['padding'] = (pad_h, pad_w)
+            params["padding"] = (pad_h, pad_w)
 
         # parse stride size
         if conv_params.stride_h > 0 or conv_params.stride_w > 0:
-            params['strides'] = (conv_params.stride_h, conv_params.stride_w)
+            params["strides"] = (conv_params.stride_h, conv_params.stride_w)
         else:
             stride_h = nonzone(conv_params.stride, 0, 1)
             stride_w = nonzone(conv_params.stride, 1, stride_h)
-            params['strides'] = (stride_h, stride_w)
+            params["strides"] = (stride_h, stride_w)
 
         # parse dilation size
-        if hasattr(conv_params, 'dilation') and len(conv_params.dilation) > 0:
-            dilation = ' '.join(str(d) for d in conv_params.dilation)
-            dilation = tuple(map(int, dilation.split(' ')))
-            params['dilation'] = dilation
+        if hasattr(conv_params, "dilation") and len(conv_params.dilation) > 0:
+            dilation = " ".join(str(d) for d in conv_params.dilation)
+            dilation = tuple(map(int, dilation.split(" ")))
+            params["dilation"] = dilation
             if len(dilation) == 1:
-                params['dilation'] = (dilation[0], dilation[0])
+                params["dilation"] = (dilation[0], dilation[0])
 
-        params['kernel_layout'] = 'OIHW'
-        params['data_layout'] = 'NCHW'
-        params['groups'] = conv_params.group
-        params['channels'] = conv_params.num_output
+        params["kernel_layout"] = "OIHW"
+        params["data_layout"] = "NCHW"
+        params["groups"] = conv_params.group
+        params["channels"] = conv_params.num_output
         return params
 
     def convert_batch_norm(self, op):
@@ -169,17 +167,13 @@ class OperatorConverter(object):
 
         if op.name in self.new_bn:
             mean, var, eps, gamma, beta = self.new_bn[op.name]
-            mean_expr = self.exp_tab.new_const(mean, dtype='float32')
-            var_expr = self.exp_tab.new_const(var, dtype='float32')
-            gamma_expr = self.exp_tab.new_const(gamma, dtype='float32')
-            beta_expr = self.exp_tab.new_const(beta, dtype='float32')
-            out = _op.nn.batch_norm(in_expr,
-                                    gamma_expr,
-                                    beta_expr,
-                                    mean_expr,
-                                    var_expr,
-                                    epsilon=eps,
-                                    scale=True)
+            mean_expr = self.exp_tab.new_const(mean, dtype="float32")
+            var_expr = self.exp_tab.new_const(var, dtype="float32")
+            gamma_expr = self.exp_tab.new_const(gamma, dtype="float32")
+            beta_expr = self.exp_tab.new_const(beta, dtype="float32")
+            out = _op.nn.batch_norm(
+                in_expr, gamma_expr, beta_expr, mean_expr, var_expr, epsilon=eps, scale=True
+            )
 
         else:
             weight_bias_blobs = self.init_layer_dict[op.name].blobs
@@ -188,11 +182,11 @@ class OperatorConverter(object):
             if len(weight_bias_blobs) == 2:
                 mean = np.repeat(mean, h * w).reshape((c, h, w))
                 mean = np.expand_dims(mean, 0).repeat(n, axis=0)
-                mean_expr = self.exp_tab.new_const(mean, dtype='float32')
+                mean_expr = self.exp_tab.new_const(mean, dtype="float32")
 
                 var = np.repeat(var, h * w).reshape((c, h, w))
                 var = np.expand_dims(var, 0).repeat(n, axis=0)
-                var_expr = self.exp_tab.new_const(var, dtype='float32')
+                var_expr = self.exp_tab.new_const(var, dtype="float32")
 
                 tmp_out = _op.multiply(in_expr, mean_expr)
                 out = _op.add(tmp_out, var_expr)
@@ -202,25 +196,21 @@ class OperatorConverter(object):
                 scale = np.asarray(weight_bias_blobs[2].data, np.float32)
                 if scale:
                     scale = 1 / scale
-            mean_expr = self.exp_tab.new_const(mean * scale, dtype='float32')
-            var_expr = self.exp_tab.new_const(var * scale, dtype='float32')
+            mean_expr = self.exp_tab.new_const(mean * scale, dtype="float32")
+            var_expr = self.exp_tab.new_const(var * scale, dtype="float32")
 
-            #caffe bn layer not support scale
-            gamma_expr = self.exp_tab.new_const(np.ones(mean.shape,
-                                                        dtype=np.float32),
-                                                dtype='float32')
-            beta_expr = self.exp_tab.new_const(np.zeros(mean.shape,
-                                                        dtype=np.float32),
-                                               dtype='float32')
+            # caffe bn layer not support scale
+            gamma_expr = self.exp_tab.new_const(
+                np.ones(mean.shape, dtype=np.float32), dtype="float32"
+            )
+            beta_expr = self.exp_tab.new_const(
+                np.zeros(mean.shape, dtype=np.float32), dtype="float32"
+            )
 
             bn_params = op.batch_norm_param.eps
-            out = _op.nn.batch_norm(in_expr,
-                                    gamma_expr,
-                                    beta_expr,
-                                    mean_expr,
-                                    var_expr,
-                                    epsilon=bn_params,
-                                    scale=False)
+            out = _op.nn.batch_norm(
+                in_expr, gamma_expr, beta_expr, mean_expr, var_expr, epsilon=bn_params, scale=False
+            )
 
         return out[0]
 
@@ -231,18 +221,18 @@ class OperatorConverter(object):
         weight_bias_blobs = self.init_layer_dict[op.name].blobs
 
         params = dict()
-        params['bias'] = op.scale_param.bias_term
-        params['axis'] = op.scale_param.axis
+        params["bias"] = op.scale_param.bias_term
+        params["axis"] = op.scale_param.axis
 
         gamma = np.asarray(weight_bias_blobs[0].data, np.float32)
-        gamma_expr = self.exp_tab.new_const(gamma, dtype='float32')
-        if params['bias']:
+        gamma_expr = self.exp_tab.new_const(gamma, dtype="float32")
+        if params["bias"]:
             beta = np.asarray(weight_bias_blobs[1].data, np.float32)
-            beta_expr = self.exp_tab.new_const(beta, dtype='float32')
+            beta_expr = self.exp_tab.new_const(beta, dtype="float32")
         else:
-            beta_expr = self.exp_tab.new_const(np.zeros(gamma.shape,
-                                                        dtype=np.float32),
-                                               dtype='float32')
+            beta_expr = self.exp_tab.new_const(
+                np.zeros(gamma.shape, dtype=np.float32), dtype="float32"
+            )
 
         _, c, _, _ = _infer_shape(in_expr)
         gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1))
@@ -255,12 +245,11 @@ class OperatorConverter(object):
     def convert_concat(self, op):
         """ Convert Concat layer """
         inputs = op.bottom
-        in_expr = (self.exp_tab.get_expr(inputs[i])
-                   for i in range(len(inputs)))
+        in_expr = (self.exp_tab.get_expr(inputs[i]) for i in range(len(inputs)))
 
         c_params = dict()
-        c_params['axis'] = op.concat_param.axis
-        out = _op.concatenate(in_expr, axis=c_params['axis'])
+        c_params["axis"] = op.concat_param.axis
+        out = _op.concatenate(in_expr, axis=c_params["axis"])
 
         return out
 
@@ -313,7 +302,7 @@ class OperatorConverter(object):
         in_expr = self.exp_tab.get_expr(input_name)
 
         softmax_param = op.softmax_param
-        parmas = {'axis': softmax_param.axis}
+        parmas = {"axis": softmax_param.axis}
 
         out = _op.nn.softmax(in_expr, **parmas)
 
@@ -333,20 +322,19 @@ class OperatorConverter(object):
         else:
             weight = weight_bias_blobs[0]
         if weight:
-            kh, kw = params['kernel_size']
+            kh, kw = params["kernel_size"]
             weight_shape = [conv_params.num_output, -1, kh, kw]
             weight_value = np.asarray(weight.data, np.float32)
             weight_value = np.reshape(weight_value, weight_shape)
         else:
-            raise Exception('No weight value of layer {} in caffemodel'.format(
-                op.name))
+            raise Exception("No weight value of layer {} in caffemodel".format(op.name))
 
-        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+        weight_expr = self.exp_tab.new_const(weight_value, dtype="float32")
         in_expr = self.exp_tab.get_expr(inputs[0])
         out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params)
         if bias:
             bias_value = np.asarray(bias.data, np.float32)
-            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
             out = _op.nn.bias_add(out, bias_expr)
         return out
 
@@ -356,37 +344,36 @@ class OperatorConverter(object):
         input_name = inputs[0]
 
         pool_params = op.pooling_param
-        pool_type_dict = ['MAX', 'AVE', 'STOCHASTIC']
+        pool_type_dict = ["MAX", "AVE", "STOCHASTIC"]
 
         params = dict()
         # parse pool type: 0: MAX, 1: AVE, 2: STOCHASTIC
         pool_type = pool_params.pool
         # parse kernel size
         if pool_params.kernel_h > 0 or pool_params.kernel_w > 0:
-            params['pool_size'] = (pool_params.kernel_h, pool_params.kernel_w)
+            params["pool_size"] = (pool_params.kernel_h, pool_params.kernel_w)
         else:
-            params['pool_size'] = (pool_params.kernel_size,
-                                   pool_params.kernel_size)
+            params["pool_size"] = (pool_params.kernel_size, pool_params.kernel_size)
 
         # parse padding size
         if pool_params.pad_h > 0 or pool_params.pad_w > 0:
-            params['padding'] = (pool_params.pad_h, pool_params.pad_w)
+            params["padding"] = (pool_params.pad_h, pool_params.pad_w)
         else:
-            params['padding'] = (pool_params.pad, pool_params.pad)
+            params["padding"] = (pool_params.pad, pool_params.pad)
 
         # parse stride size
         if pool_params.stride_h > 0 or pool_params.stride_w > 0:
-            params['strides'] = (pool_params.stride_h, pool_params.stride_w)
+            params["strides"] = (pool_params.stride_h, pool_params.stride_w)
         else:
-            params['strides'] = (pool_params.stride, pool_params.stride)
+            params["strides"] = (pool_params.stride, pool_params.stride)
 
-        params['ceil_mode'] = True
-        if hasattr(pool_params, 'ceil_mode'):
-            params['ceil_mode'] = pool_params.ceil_mode
+        params["ceil_mode"] = True
+        if hasattr(pool_params, "ceil_mode"):
+            params["ceil_mode"] = pool_params.ceil_mode
 
         in_expr = self.exp_tab.get_expr(input_name)
 
-        if pool_type_dict[pool_type] == 'MAX':
+        if pool_type_dict[pool_type] == "MAX":
             if pool_params.global_pooling:
                 out = _op.nn.global_max_pool2d(in_expr)
             else:
@@ -397,16 +384,18 @@ class OperatorConverter(object):
                     out2 = _op.vision.max_pool2d_location(in_expr, **params)
                     return _expr.Tuple((out1, out2))
 
-        elif pool_type_dict[pool_type] == 'AVE':  # AVE
+        elif pool_type_dict[pool_type] == "AVE":  # AVE
             if pool_params.global_pooling:
                 out = _op.nn.global_avg_pool2d(in_expr)
             else:
-                params['count_include_pad'] = True
+                params["count_include_pad"] = True
                 out = _op.nn.avg_pool2d(in_expr, **params)
         else:
             raise tvm.error.OpNotImplemented(
                 "Operator {} is not supported for frontend Caffe.".format(
-                    pool_type_dict[pool_type] + ' pool'))
+                    pool_type_dict[pool_type] + " pool"
+                )
+            )
 
         return out
 
@@ -417,10 +406,10 @@ class OperatorConverter(object):
 
         params = dict()
         lrn_params = op.lrn_param
-        params['size'] = lrn_params.local_size
-        params['bias'] = lrn_params.k
-        params['alpha'] = lrn_params.alpha
-        params['beta'] = lrn_params.beta
+        params["size"] = lrn_params.local_size
+        params["bias"] = lrn_params.k
+        params["alpha"] = lrn_params.alpha
+        params["beta"] = lrn_params.beta
 
         in_expr = self.exp_tab.get_expr(input_name)
         out = _op.nn.lrn(in_expr, **params)
@@ -452,10 +441,9 @@ class OperatorConverter(object):
             weight_value = np.reshape(weight_value, (params["num_output"], -1))
             weight_shape = weight_value.shape
         else:
-            raise Exception('No weight value of layer {} in caffemodel'.format(
-                op.name))
+            raise Exception("No weight value of layer {} in caffemodel".format(op.name))
 
-        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+        weight_expr = self.exp_tab.new_const(weight_value, dtype="float32")
 
         in_expr = self.exp_tab.get_expr(inputs[0])
         in_reshape = _op.reshape(data=in_expr, newshape=(-1, weight_shape[-1]))
@@ -464,7 +452,7 @@ class OperatorConverter(object):
 
         if bias:
             bias_value = np.asarray(bias.data, np.float32)
-            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
             out = _op.nn.bias_add(out, bias_expr, axis=params["axis"])
         return out
 
@@ -476,7 +464,7 @@ class OperatorConverter(object):
         params = dict()
         dropout_params = op.dropout_param
 
-        params['rate'] = dropout_params.dropout_ratio
+        params["rate"] = dropout_params.dropout_ratio
 
         in_expr = self.exp_tab.get_expr(input_name)
         out = _op.nn.dropout(in_expr, **params)
@@ -501,7 +489,7 @@ class OperatorConverter(object):
 
         alpha = self.init_layer_dict[op.name].blobs[0].data
         alpha = np.asarray(alpha, np.float32)
-        alpha = self.exp_tab.new_const(alpha, dtype='float32')
+        alpha = self.exp_tab.new_const(alpha, dtype="float32")
         axis = 1
         out = _op.nn.prelu(in_expr, alpha, axis=axis)
         return out
@@ -521,23 +509,20 @@ class OperatorConverter(object):
         else:
             weight = weight_bias_blobs[0]
         if weight:
-            kh, kw = params['kernel_size']
+            kh, kw = params["kernel_size"]
             weight_shape = [-1, conv_params.num_output, kh, kw]
             weight_value = np.asarray(weight.data, np.float32)
             weight_value = np.reshape(weight_value, weight_shape)
         else:
-            raise Exception('No weight value of layer {} in caffemodel'.format(
-                op.name))
+            raise Exception("No weight value of layer {} in caffemodel".format(op.name))
 
-        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+        weight_expr = self.exp_tab.new_const(weight_value, dtype="float32")
         in_expr = self.exp_tab.get_expr(inputs[0])
-        out = _op.nn.conv2d_transpose(data=in_expr,
-                                      weight=weight_expr,
-                                      **params)
+        out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params)
         if bias:
 
             bias_value = np.asarray(bias.data, np.float32)
-            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
             out = _op.nn.bias_add(out, bias_expr)
         return out
 
@@ -556,9 +541,7 @@ class OperatorConverter(object):
         else:
             indices_or_sections = sorted(indices_or_sections)
 
-        out = _op.split(in_expr,
-                        indices_or_sections=indices_or_sections,
-                        axis=axis)
+        out = _op.split(in_expr, indices_or_sections=indices_or_sections, axis=axis)
         return out
 
     def convert_sigmoid(self, op):
@@ -584,8 +567,8 @@ class OperatorConverter(object):
 
         # parse crop params
         crop_params = op.crop_param
-        axis = int(getattr(crop_params, 'axis', 2))
-        offset = list(getattr(crop_params, 'offset', 0))
+        axis = int(getattr(crop_params, "axis", 2))
+        offset = list(getattr(crop_params, "offset", 0))
 
         # expand offset to (offset1, offset2, ...)
         in_a_shape = _infer_shape(in_expr_a)
@@ -610,7 +593,6 @@ class OperatorConverter(object):
         out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis)
         return out
 
-
     def check_unsupported_ops(self):
         """Check unsupported Caffe ops in our converter."""
         unsupported_ops_set = set()
@@ -628,9 +610,8 @@ class OperatorConverter(object):
                 unsupported_ops_set.add(op_name)
 
         if unsupported_ops_set:
-            msg = 'The following operators are not supported in frontend ' \
-                'Caffe: {}'
-            ops = str(list(unsupported_ops_set)).strip('[,]')
+            msg = "The following operators are not supported in frontend " "Caffe: {}"
+            ops = str(list(unsupported_ops_set)).strip("[,]")
             raise tvm.error.OpNotImplemented(msg.format(ops))
 
     def fuse_op(self, layers):
@@ -642,10 +623,8 @@ class OperatorConverter(object):
         bn_scale = np.asarray(bn_weight_bias_blobs[2].data, np.float32)
         if bn_scale:
             bn_scale = 1 / bn_scale
-        bn_mean = np.asarray(bn_weight_bias_blobs[0].data,
-                             np.float32) * bn_scale
-        bn_var = np.asarray(bn_weight_bias_blobs[1].data,
-                            np.float32) * bn_scale
+        bn_mean = np.asarray(bn_weight_bias_blobs[0].data, np.float32) * bn_scale
+        bn_var = np.asarray(bn_weight_bias_blobs[1].data, np.float32) * bn_scale
         bn_eps = bn.batch_norm_param.eps
 
         # scale params
@@ -653,15 +632,12 @@ class OperatorConverter(object):
         scale_gamma = np.asarray(scale_weight_bias_blobs[0].data, np.float32)
         scale_bias = scale.scale_param.bias_term
         if scale_bias:
-            scale_beta = np.asarray(scale_weight_bias_blobs[1].data,
-                                    np.float32)
+            scale_beta = np.asarray(scale_weight_bias_blobs[1].data, np.float32)
         else:
             scale_beta = np.zeros(scale_gamma.shape, dtype=np.float32)
 
         # new params
-        self.new_bn[bn.name] = [
-            bn_mean, bn_var, bn_eps, scale_gamma, scale_beta
-        ]
+        self.new_bn[bn.name] = [bn_mean, bn_var, bn_eps, scale_gamma, scale_beta]
         return bn
 
     def op_fuse(self):
@@ -677,7 +653,8 @@ class OperatorConverter(object):
                 continue
             elif op_type == "BatchNorm":
                 if (index != len(self.predict_layer) - 1) and (
-                        self.predict_layer[index + 1].type == "Scale"):
+                    self.predict_layer[index + 1].type == "Scale"
+                ):
                     temp_layers["bn"] = pl
                     continue
                 else:
@@ -695,14 +672,13 @@ class OperatorConverter(object):
             if len(temp_layers) == 2:
                 layer = self.fuse_op(temp_layers)
                 new_layers.append(layer)
-                changed_layers[
-                    temp_layers["scale"].name] = temp_layers['bn'].name
+                changed_layers[temp_layers["scale"].name] = temp_layers["bn"].name
 
             for idx, plt in enumerate(pl.bottom):
                 if plt in changed_layers:
                     pl.bottom[idx] = changed_layers[plt]
 
-            if op_type not in ['BatchNorm', 'Scale']:
+            if op_type not in ["BatchNorm", "Scale"]:
                 new_layers.append(pl)
 
         self.predict_layer = new_layers
@@ -737,7 +713,7 @@ def _rebuild_layers(predict_layer):
             continue
         # if current layer has single input and output and input equals to output
         # it means that the layer does "in-place"
-        if (len(pl.top) == 1 and len(pl.bottom) == 1):
+        if len(pl.top) == 1 and len(pl.bottom) == 1:
             if pl.top[0] == pl.bottom[0]:
                 # change current layer's input firstly
                 if pl.bottom[0] in changed_top_dict:
@@ -766,9 +742,7 @@ def _get_inputs_outputs(predict_layer):
     not_outputs = set()
     for pl in predict_layer:
         if pl.type == "Input":
-            assert len(
-                pl.top
-            ) == 1, "The number of Input layer's output is more than 1."
+            assert len(pl.top) == 1, "The number of Input layer's output is more than 1."
             model_inputs.append(pl.top[0])
         for i in pl.bottom:
             not_outputs.add(i)
index 8a5803f..c85c4a6 100644 (file)
@@ -27,15 +27,17 @@ from ... import nd as _nd
 from .common import AttrCvt, Renamer
 from .common import get_relay_op, new_var, infer_channels
 
-__all__ = ['from_caffe2']
+__all__ = ["from_caffe2"]
 
-def dimension_picker(prefix, surfix=''):
+
+def dimension_picker(prefix, surfix=""):
     def _impl(attr):
-        kernel = attr['kernel_shape']
+        kernel = attr["kernel_shape"]
         if len(kernel) == 2:
-            return prefix + '2d' + surfix
+            return prefix + "2d" + surfix
         raise tvm.error.OpAttributeUnImplemented(
-            'Non-2D kernels are not supported for operator {}2d'.format(prefix))
+            "Non-2D kernels are not supported for operator {}2d".format(prefix)
+        )
 
     return _impl
 
@@ -47,14 +49,13 @@ def revert_caffe2_pad(pads):
     elif len(pads) == 2:
         pass
     else:
-        raise tvm.error.OpAttributeInvalid(
-            'Number of pads must equal 2 or 4.')
+        raise tvm.error.OpAttributeInvalid("Number of pads must equal 2 or 4.")
     return pads
 
 
 def dimension_constraint():
     def _dim_check(args):
-        if len(args['kernel_shape']) == 2:
+        if len(args["kernel_shape"]) == 2:
             return True
         return False
 
@@ -62,177 +63,176 @@ def dimension_constraint():
 
 
 def _clean_up_pool_args(args):
-    """ A helper function to clean up common arguments in conv and pooling ops.
-    """
+    """A helper function to clean up common arguments in conv and pooling ops."""
     assert isinstance(args, dict)
 
-    if 'stride_h' in args and 'stride_w' in args:
-        assert 'stride' not in args and 'strides' not in args
-        args['strides'] = [args['stride_h'], args['stride_w']]
-        args.pop('stride_h')
-        args.pop('stride_w')
-    elif 'stride' in args:
-        args['strides'] = [args['stride'], args['stride']]
-        args.pop('stride')
+    if "stride_h" in args and "stride_w" in args:
+        assert "stride" not in args and "strides" not in args
+        args["strides"] = [args["stride_h"], args["stride_w"]]
+        args.pop("stride_h")
+        args.pop("stride_w")
+    elif "stride" in args:
+        args["strides"] = [args["stride"], args["stride"]]
+        args.pop("stride")
 
     # rename 'kernel', 'kernels', to 'kernel_shape'
-    if 'kernel_h' in args and 'kernel_w' in args:
-        assert 'kernel' not in args and 'kernels' not in args
-        args['kernel_shape'] = [args['kernel_h'], args['kernel_w']]
-        args.pop('kernel_h')
-        args.pop('kernel_w')
-    elif 'kernel' in args:
-        args['kernel_shape'] = [args['kernel'], args['kernel']]
-        args.pop('kernel')
-    elif 'kernels' in args:
-        args['kernel_shape'] = args['kernels']
-        args.pop('kernels')
-
-    if 'pad_t' in args and 'pad_l' in args and 'pad_b' in args and 'pad_r' in args:
-        assert 'pad' not in args and 'pads' not in args
-        args['pads'] = [
-            args['pad_t'], args['pad_l'], args['pad_b'], args['pad_r']
-        ]
-        for pad in ['pad_t', 'pad_l', 'pad_b', 'pad_r']:
+    if "kernel_h" in args and "kernel_w" in args:
+        assert "kernel" not in args and "kernels" not in args
+        args["kernel_shape"] = [args["kernel_h"], args["kernel_w"]]
+        args.pop("kernel_h")
+        args.pop("kernel_w")
+    elif "kernel" in args:
+        args["kernel_shape"] = [args["kernel"], args["kernel"]]
+        args.pop("kernel")
+    elif "kernels" in args:
+        args["kernel_shape"] = args["kernels"]
+        args.pop("kernels")
+
+    if "pad_t" in args and "pad_l" in args and "pad_b" in args and "pad_r" in args:
+        assert "pad" not in args and "pads" not in args
+        args["pads"] = [args["pad_t"], args["pad_l"], args["pad_b"], args["pad_r"]]
+        for pad in ["pad_t", "pad_l", "pad_b", "pad_r"]:
             args.pop(pad)
-    elif 'pad' in args:
-        args['pads'] = [args['pad'], args['pad']]
-        args.pop('pad')
-
-    if 'dilation_h' in args and 'dilation_w' in args:
-        assert 'dilation' not in args and 'dilations' not in args
-        args['dilations'] = [args['dilation_h'], args['dilation_w']]
-        args.pop('dilation_h')
-        args.pop('dilation_w')
-    elif 'dilation' in args:
-        args['dilations'] = [args['dilation'], args['dilation']]
-        args.pop('dilation')
+    elif "pad" in args:
+        args["pads"] = [args["pad"], args["pad"]]
+        args.pop("pad")
+
+    if "dilation_h" in args and "dilation_w" in args:
+        assert "dilation" not in args and "dilations" not in args
+        args["dilations"] = [args["dilation_h"], args["dilation_w"]]
+        args.pop("dilation_h")
+        args.pop("dilation_w")
+    elif "dilation" in args:
+        args["dilations"] = [args["dilation"], args["dilation"]]
+        args.pop("dilation")
 
     return args
 
 
 class Caffe2OpConverter(object):
-    """ A helper class for holding Caffe2 op converters.
-    """
+    """A helper class for holding Caffe2 op converters."""
 
     @classmethod
     def get_converter(cls):
-        """ Get converter.
+        """Get converter.
 
         :return: converter, which should be `_impl`.
         """
 
-        if hasattr(cls, '_impl'):
-            return getattr(cls, '_impl')
+        if hasattr(cls, "_impl"):
+            return getattr(cls, "_impl")
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported in frontend Caffe2.'.format(cls.__name__))
+            "Operator {} is not supported in frontend Caffe2.".format(cls.__name__)
+        )
 
 
 _caffe2_internal_args = [
     # nnpack args
-    'algo',
-    'convolution_transform_strategy',
-    'float16_compute',
-    'shared_buffer',
-
+    "algo",
+    "convolution_transform_strategy",
+    "float16_compute",
+    "shared_buffer",
     # training args
-    'init_params',
-    'cudnn_exhaustive_search',
-    'exhaustive_search',
-
+    "init_params",
+    "cudnn_exhaustive_search",
+    "exhaustive_search",
     # training args
-    'adj',
-    'hwgq',
-
+    "adj",
+    "hwgq",
     # args that we don't care
-    'legacy_pad',
+    "legacy_pad",
 ]
 
 
 class Elemwise(Caffe2OpConverter):
-    """ A helper class for elemwise op converters.
-    """
-    name = ''
+    """A helper class for elemwise op converters."""
+
+    name = ""
+
     @classmethod
     def _impl(cls, inputs, args, params):
-        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
-            len(inputs))
+        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
         op_name = cls.name
         conv_ops = ["conv2d", "conv2d_transpose"]
-        if args.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
+        if args.get("broadcast", 0) and any(x in str(inputs[0]) for x in conv_ops):
             # TODO(zhreshold): remove hard coded infershape
-            axis = int(args.get('axis', 0))
+            axis = int(args.get("axis", 0))
             inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
         return get_relay_op(op_name)(*inputs)
 
 
 class Add(Elemwise):
-    """ Operator converter for Add.
-    """
-    name = 'add'
+    """Operator converter for Add."""
+
+    name = "add"
 
 
 class Mul(Elemwise):
-    """ Operator converter for Mul.
-    """
-    name = 'multiply'
+    """Operator converter for Mul."""
+
+    name = "multiply"
 
 
 class Pool(Caffe2OpConverter):
-    """ A helper class for pool op converters.
-    """
+    """A helper class for pool op converters."""
+
+    name = ""
 
-    name = ''
     @classmethod
     def _impl(cls, inputs, args, params):
         _clean_up_pool_args(args)
-        if 'global_pooling' in args and args['global_pooling'] == 1:
-            op_name = dimension_picker('global_' + cls.name)
+        if "global_pooling" in args and args["global_pooling"] == 1:
+            op_name = dimension_picker("global_" + cls.name)
             return get_relay_op(op_name(args))(*inputs)
 
         return AttrCvt(
             op_name=dimension_picker(cls.name),
             transforms={
-                'kernel_shape': 'pool_size',
-                'pads': ('padding', (0, 0), revert_caffe2_pad),
-                'strides': 'strides',
+                "kernel_shape": "pool_size",
+                "pads": ("padding", (0, 0), revert_caffe2_pad),
+                "strides": "strides",
             },
-            ignores=['dilations', 'order', 'legacy_pad', 'global_pooling'],
-            extras={'ceil_mode': False},
-            custom_check=dimension_constraint())(inputs, args, params)
+            ignores=["dilations", "order", "legacy_pad", "global_pooling"],
+            extras={"ceil_mode": False},
+            custom_check=dimension_constraint(),
+        )(inputs, args, params)
 
 
 class AveragePool(Pool):
-    name = 'avg_pool'
+    name = "avg_pool"
 
 
 class MaxPool(Pool):
-    name = 'max_pool'
+    name = "max_pool"
 
 
 class Conv(Caffe2OpConverter):
-    """ Operator converter for Conv.
-    """
+    """Operator converter for Conv."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
         # get number of channels
         channels = infer_channels(inputs[1])
-        args['channels'] = channels
+        args["channels"] = channels
         _clean_up_pool_args(args)
         out = AttrCvt(
-            op_name=dimension_picker('conv'),
+            op_name=dimension_picker("conv"),
             transforms={
-                'group': ('groups', 1),
-                'kernel_shape': 'kernel_size',
-                'pads': ('padding', (0, 0), revert_caffe2_pad),
-                'strides': 'strides',
-                'dilations': ('dilation', (1, 1)),
-                'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
+                "group": ("groups", 1),
+                "kernel_shape": "kernel_size",
+                "pads": ("padding", (0, 0), revert_caffe2_pad),
+                "strides": "strides",
+                "dilations": ("dilation", (1, 1)),
+                "order": (
+                    "data_layout",
+                    ("NCHW"),
+                    lambda x: x if isinstance(x, str) else x.decode("UTF-8"),
+                ),
             },
             excludes=[],
             ignores=_caffe2_internal_args,
-            custom_check=dimension_constraint())(inputs[:2], args, params)
+            custom_check=dimension_constraint(),
+        )(inputs[:2], args, params)
         use_bias = len(inputs) == 3
         if use_bias:
             out = _op.nn.bias_add(out, inputs[2])
@@ -240,26 +240,30 @@ class Conv(Caffe2OpConverter):
 
 
 class ConvTranspose(Caffe2OpConverter):
-    """ Operator converter for ConvTranspose.
-    """
+    """Operator converter for ConvTranspose."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
         # get number of channels
         channels = infer_channels(inputs[1], True)
-        args['channels'] = channels
+        args["channels"] = channels
         _clean_up_pool_args(args)
         out = AttrCvt(
-            op_name=dimension_picker('conv', '_transpose'),
+            op_name=dimension_picker("conv", "_transpose"),
             transforms={
-                'kernel_shape': 'kernel_size',
-                'pads': ('padding', (0, 0), revert_caffe2_pad),
-                'dilations': ('dilation', (1, 1)),
-                'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
+                "kernel_shape": "kernel_size",
+                "pads": ("padding", (0, 0), revert_caffe2_pad),
+                "dilations": ("dilation", (1, 1)),
+                "order": (
+                    "data_layout",
+                    ("NCHW"),
+                    lambda x: x if isinstance(x, str) else x.decode("UTF-8"),
+                ),
             },
             excludes=[],
             ignores=_caffe2_internal_args,
-            custom_check=dimension_constraint())(inputs[:2], args, params)
+            custom_check=dimension_constraint(),
+        )(inputs[:2], args, params)
         use_bias = len(inputs) == 3
         if use_bias:
             out = _op.nn.bias_add(out, inputs[2])
@@ -267,30 +271,31 @@ class ConvTranspose(Caffe2OpConverter):
 
 
 class Concat(Caffe2OpConverter):
-    """ Operator converter for Concat.
-    """
+    """Operator converter for Concat."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
         def _get_axis_from_order_str(order):
-            order = order if isinstance(order, str) else order.decode('UTF-8')
-            if order == 'NCHW':
+            order = order if isinstance(order, str) else order.decode("UTF-8")
+            if order == "NCHW":
                 return 1
-            if order == 'NHWC':
+            if order == "NHWC":
                 return 3
             raise tvm.error.OpAttributeUnImplemented(
-                'Order {} is not supported in operator Concat.'.format(order))
+                "Order {} is not supported in operator Concat.".format(order)
+            )
 
         return AttrCvt(
-            op_name='concatenate',
+            op_name="concatenate",
             transforms={
-                'order': ('axis', (1), _get_axis_from_order_str),
+                "order": ("axis", (1), _get_axis_from_order_str),
             },
-            excludes=['add_axis'])((inputs,), args, params)
+            excludes=["add_axis"],
+        )((inputs,), args, params)
 
 
 class NormalizePlanarYUV(Caffe2OpConverter):
-    """ Operator converter for NormalizePlanarYUV.
+    """Operator converter for NormalizePlanarYUV.
     caffe2 definition: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/norm_planar_yuv_op.cc
     """
 
@@ -304,22 +309,21 @@ class NormalizePlanarYUV(Caffe2OpConverter):
 
 
 class ResizeNearest(Caffe2OpConverter):
-    """ Operator converter for Upsample (nearest mode).
-    """
+    """Operator converter for Upsample (nearest mode)."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
-        width_scale = args['width_scale'] if 'width_scale' in args else 1
-        height_scale = args['height_scale'] if 'height_scale' in args else 1
+        width_scale = args["width_scale"] if "width_scale" in args else 1
+        height_scale = args["height_scale"] if "height_scale" in args else 1
         assert width_scale == height_scale
 
         return _op.nn.upsampling(
-            inputs[0], scale_h=int(width_scale), scale_w=int(width_scale), method="NEAREST_NEIGHBOR")
+            inputs[0], scale_h=int(width_scale), scale_w=int(width_scale), method="NEAREST_NEIGHBOR"
+        )
 
 
 class Sum(Caffe2OpConverter):
-    """ Operator converter for Sum.
-    """
+    """Operator converter for Sum."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
@@ -331,20 +335,18 @@ class Sum(Caffe2OpConverter):
 
 
 class Softmax(Caffe2OpConverter):
-    """ Operator converter for Softmax.
-    """
+    """Operator converter for Softmax."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
         # set default value when axis is not set in the model
-        if 'axis' not in args:
-            args['axis'] = 1
-        return AttrCvt('softmax', transforms={'axis': ('axis', args['axis'])})(inputs, args, params)
+        if "axis" not in args:
+            args["axis"] = 1
+        return AttrCvt("softmax", transforms={"axis": ("axis", args["axis"])})(inputs, args, params)
 
 
 class FC(Caffe2OpConverter):
-    """ Operator converter for FC.
-    """
+    """Operator converter for FC."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
@@ -358,17 +360,15 @@ class FC(Caffe2OpConverter):
 
 
 class SpatialBN(Caffe2OpConverter):
-    """ Operator converter for SpatialBN.
-    """
+    """Operator converter for SpatialBN."""
 
     @classmethod
     def _impl(cls, inputs, args, params):
         return AttrCvt(
-            op_name='batch_norm',
-            disables=['momentum'],
-            ignores=[
-                'order', 'spatial', 'is_test', 'consumed_inputs', 'num_batches'
-            ])(inputs, args, params)
+            op_name="batch_norm",
+            disables=["momentum"],
+            ignores=["order", "spatial", "is_test", "consumed_inputs", "num_batches"],
+        )(inputs, args, params)
 
 
 # compatible operators that do NOT require any conversion.
@@ -384,26 +384,24 @@ _identity_list = []
 def _get_convert_map():
     return {
         # caffe2 common operators
-        'Add': Add.get_converter(),
-        'Sum': Sum.get_converter(),
-        'Mul': Mul.get_converter(),
-        'Softmax': Softmax.get_converter(),
-
+        "Add": Add.get_converter(),
+        "Sum": Sum.get_converter(),
+        "Mul": Mul.get_converter(),
+        "Softmax": Softmax.get_converter(),
         # nn
-        'AveragePool': AveragePool.get_converter(),
-        'MaxPool': MaxPool.get_converter(),
-        'Conv': Conv.get_converter(),
-        'ConvTranspose': ConvTranspose.get_converter(),
-        'Concat': Concat.get_converter(),
-        'FC': FC.get_converter(),
-        'SpatialBN': SpatialBN.get_converter(),
-        'ResizeNearest': ResizeNearest.get_converter(),
-        'Relu': AttrCvt('relu', {}, ignores=['order']),
-        'Sigmoid': Renamer('sigmoid'),
-        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
-
+        "AveragePool": AveragePool.get_converter(),
+        "MaxPool": MaxPool.get_converter(),
+        "Conv": Conv.get_converter(),
+        "ConvTranspose": ConvTranspose.get_converter(),
+        "Concat": Concat.get_converter(),
+        "FC": FC.get_converter(),
+        "SpatialBN": SpatialBN.get_converter(),
+        "ResizeNearest": ResizeNearest.get_converter(),
+        "Relu": AttrCvt("relu", {}, ignores=["order"]),
+        "Sigmoid": Renamer("sigmoid"),
+        "Dropout": AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]),
         # c2 image preprocessing ops
-        'NormalizePlanarYUV': NormalizePlanarYUV.get_converter(),
+        "NormalizePlanarYUV": NormalizePlanarYUV.get_converter(),
     }
 
 
@@ -439,6 +437,7 @@ class Caffe2NetDef(object):
         """
         # pylint: disable=import-outside-toplevel
         from caffe2.python import workspace
+
         workspace.RunNetOnce(init_net)
 
         # Input
@@ -458,7 +457,9 @@ class Caffe2NetDef(object):
         self._nodes = {}
         for blob in predict_net.external_input:
             if blob in self._params:
-                self._nodes[blob] = new_var(blob, shape=self._params[blob].shape, dtype=self._params[blob].dtype)
+                self._nodes[blob] = new_var(
+                    blob, shape=self._params[blob].shape, dtype=self._params[blob].dtype
+                )
             else:
                 shape = self._shape[blob] if blob in self._shape else ()
                 if isinstance(self._dtype, dict) and blob in self._dtype:
@@ -497,8 +498,9 @@ class Caffe2NetDef(object):
         if blob in self._nodes:
             return self._nodes[blob]
 
-        assert blob not in self._visited_nodes, 'Cyclic dependency in the graph (in {})'.format(
-            blob)
+        assert blob not in self._visited_nodes, "Cyclic dependency in the graph (in {})".format(
+            blob
+        )
         self._visited_nodes.add(blob)
 
         self._process_op(self._ops[blob])
@@ -520,31 +522,24 @@ class Caffe2NetDef(object):
         """Convert a list of Argument to a dict, with names as keys."""
         args = {}
         for a in arg:
-            for f in ['f', 'i', 's']:
+            for f in ["f", "i", "s"]:
                 if a.HasField(f):
                     args[a.name] = getattr(a, f)
-            for f in ['floats', 'ints', 'strings']:
+            for f in ["floats", "ints", "strings"]:
                 if list(getattr(a, f)):
                     assert a.name not in args, "Only one type of attr is allowed"
                     args[a.name] = tuple(getattr(a, f))
-            for f in ['n']:
+            for f in ["n"]:
                 if a.HasField(f):
-                    raise NotImplementedError(
-                        "Field {} is not supported in relay.".format(f))
-            for f in ['nets']:
+                    raise NotImplementedError("Field {} is not supported in relay.".format(f))
+            for f in ["nets"]:
                 if list(getattr(a, f)):
-                    raise NotImplementedError(
-                        "Field {} is not supported in relay.".format(f))
+                    raise NotImplementedError("Field {} is not supported in relay.".format(f))
             if a.name not in args:
                 raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
         return args
 
-    def _convert_operator(self,
-                          op_type,
-                          inputs,
-                          args,
-                          identity_list=None,
-                          convert_map=None):
+    def _convert_operator(self, op_type, inputs, args, identity_list=None, convert_map=None):
         """Convert from Caffe2 operator to Relay operator.
         The converter must specify conversions explicitly for incompatible name, and
         apply handlers to operator attributes.
@@ -578,7 +573,8 @@ class Caffe2NetDef(object):
             func = convert_map[op_type](inputs, args, self._params)
         else:
             raise tvm.error.OpNotImplemented(
-                'Operator {} is not supported in frontend Caffe2.'.format(op_type))
+                "Operator {} is not supported in frontend Caffe2.".format(op_type)
+            )
         return func
 
 
index c86d176..b46cd27 100644 (file)
@@ -43,6 +43,7 @@ class StrAttrsDict(object):
     attrs : Dict[str, str]
         The attributes to be used.
     """
+
     def __init__(self, attrs):
         self.attrs = attrs
 
@@ -143,8 +144,11 @@ class StrAttrsDict(object):
         """
         if key in self.attrs:
             tshape = self.attrs[key]
-            return tuple(int(x) if x.strip("- ").isdigit() else None
-                         for x in tshape.strip('()[]').split(',') if x)
+            return tuple(
+                int(x) if x.strip("- ").isdigit() else None
+                for x in tshape.strip("()[]").split(",")
+                if x
+            )
         if isinstance(default, RequiredAttr):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
@@ -167,8 +171,7 @@ class StrAttrsDict(object):
 
         if key in self.attrs:
             tshape = self.attrs[key]
-            return tuple(float(x.strip()) for x in
-                         tshape.strip('()[]').split(','))
+            return tuple(float(x.strip()) for x in tshape.strip("()[]").split(","))
         if isinstance(default, RequiredAttr):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
@@ -191,9 +194,9 @@ class StrAttrsDict(object):
         if key in self.attrs:
             value = self.attrs[key]
             seq = []
-            for tup in value.strip('()').split('),'):
-                tup = tup.strip('[]()')
-                els = [int(x.strip('( ')) for x in tup.split(',')]
+            for tup in value.strip("()").split("),"):
+                tup = tup.strip("[]()")
+                els = [int(x.strip("( ")) for x in tup.split(",")]
                 seq.append(tuple(els))
 
             return tuple(seq)
@@ -219,7 +222,7 @@ class StrAttrsDict(object):
         """
         if key in self.attrs:
             tshape = self.attrs[key]
-            return tuple(int(x.strip()) for x in tshape.strip('[]()').split(','))
+            return tuple(int(x.strip()) for x in tshape.strip("[]()").split(","))
         if isinstance(default, RequiredAttr):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
@@ -241,7 +244,7 @@ class StrAttrsDict(object):
         """
         if key in self.attrs:
             val = self.attrs[key]
-            return val.strip().lower() in ['true', '1', 't', 'y', 'yes']
+            return val.strip().lower() in ["true", "1", "t", "y", "yes"]
         if isinstance(default, RequiredAttr):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
@@ -254,11 +257,11 @@ def get_relay_op(op_name):
     op_name : str
         The Relay operator name.
     """
-    if '.' in op_name:
+    if "." in op_name:
         # explicit hierachical modules
         op = _op
         try:
-            for opn in op_name.split('.'):
+            for opn in op_name.split("."):
                 op = getattr(op, opn)
         except AttributeError:
             op = None
@@ -275,6 +278,7 @@ def get_relay_op(op_name):
 
 class ExprTable(object):
     """Table storing Relay expressions by names."""
+
     def __init__(self):
         self.exprs = {}
         self.params = {}
@@ -353,9 +357,17 @@ class AttrCvt(object):
         A custom function takes attribute, and return True/False.
         Raise RuntimeError if not bool(True) returned.
     """
-    def __init__(self, op_name, transforms=None,
-                 excludes=None, disables=None, ignores=None,
-                 extras=None, custom_check=None):
+
+    def __init__(
+        self,
+        op_name,
+        transforms=None,
+        excludes=None,
+        disables=None,
+        ignores=None,
+        extras=None,
+        custom_check=None,
+    ):
         self._op_name = op_name
         self._transforms = transforms if transforms else {}
         self._excludes = excludes if excludes else []
@@ -365,13 +377,13 @@ class AttrCvt(object):
         self._custom_check = custom_check
 
     def __call__(self, inputs, attrs, *args):
-        self._ignores.append('_output_shapes')
-        self._ignores.append('_input_shapes')
-        self._ignores.append('T')
-        self._ignores.append('use_cudnn_on_gpu')
-        self._ignores.append('_node_name')
-        self._ignores.append('is_training')
-        self._ignores.append('_target_layout')
+        self._ignores.append("_output_shapes")
+        self._ignores.append("_input_shapes")
+        self._ignores.append("T")
+        self._ignores.append("use_cudnn_on_gpu")
+        self._ignores.append("_node_name")
+        self._ignores.append("is_training")
+        self._ignores.append("_target_layout")
 
         # apply custom check
         if self._custom_check:
@@ -386,18 +398,19 @@ class AttrCvt(object):
             op_name = self._op_name(attrs)
 
         # ignore 'tvm_custom' always
-        self._ignores.append('tvm_custom')
+        self._ignores.append("tvm_custom")
 
         # convert attributes
         new_attrs = {}
         for k in attrs.keys():
             if k in self._excludes:
-                raise NotImplementedError('Attribute %s in operator %s is not' +
-                                          ' supported.', k, op_name)
+                raise NotImplementedError(
+                    "Attribute %s in operator %s is not" + " supported.", k, op_name
+                )
             if k in self._disables:
                 logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
             elif k in self._ignores:
-                if k != 'tvm_custom':
+                if k != "tvm_custom":
                     logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
             elif k in self._transforms:
                 new_name, defaults, transform = self._parse_default(self._transforms[k])
@@ -436,7 +449,7 @@ class AttrCvt(object):
     def _parse_bool(self, value):
         """Helper function to parse default boolean values."""
         if isinstance(value, str):
-            return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
+            return value.strip().lower() in ["true", "1", "t", "y", "yes"]
         return bool(value)
 
     def _required_attr(self, attr, key):
@@ -448,7 +461,7 @@ class AttrCvt(object):
 
 
 def get_name(node):
-    name = ''
+    name = ""
     if hasattr(node, "name_hint"):
         name = node.name_hint
     return name
@@ -471,6 +484,7 @@ def infer_type(node, mod=None):
 
     return ret
 
+
 def infer_channels(inputs, transpose=False):
     """A hack for getting 'channels' or 'units' since caffe2 does not provide
     these attributes. We check the shape of weights provided to get the number.
@@ -485,7 +499,7 @@ def infer_shape(inputs, mod=None):
     """A method to get the output type of an intermediate node in the graph."""
     out_type = infer_type(inputs, mod=mod)
     checked_type = out_type.checked_type
-    if hasattr(checked_type, 'shape'):
+    if hasattr(checked_type, "shape"):
         # Regular operator that outputs tensors
         return get_const_tuple(checked_type.shape)
     # The return type is not a tensor, for example List
@@ -498,12 +512,14 @@ def infer_value(input_val, params, mod=None):
     whose output shape depends on the value of a tensor.
     """
     # Check that all free variables have associated parameters.
-    assert all(var.name_hint in params.keys() for var in analysis.free_vars(
-        input_val)), "All inputs to infer must be available in params."
+    assert all(
+        var.name_hint in params.keys() for var in analysis.free_vars(input_val)
+    ), "All inputs to infer must be available in params."
     try:
         # TODO(kevinthesun): Use VM for all cases.
         # pylint: disable=import-outside-toplevel
         from tvm.contrib import graph_runtime
+
         func = _function.Function(analysis.free_vars(input_val), input_val)
         with tvm.transform.PassContext(opt_level=0):
             graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
@@ -519,7 +535,7 @@ def infer_value(input_val, params, mod=None):
             mod = IRModule.from_expr(input_val)
         exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
         inputs = []
-        for param in mod['main'].params:
+        for param in mod["main"].params:
             inputs.append(params[param.name_hint])
         result = exc.evaluate()(*inputs)
         return result
@@ -539,9 +555,7 @@ def infer_value_simulated(input_val, params):
             fp_dtype = free_param.type_annotation.dtype
             fp_shape = [s.value for s in free_param.type_annotation.shape]
             fake_params.append(free_param)
-            params[free_param.name_hint] = tvm.nd.array(
-                np.random.rand(*fp_shape).astype(fp_dtype)
-            )
+            params[free_param.name_hint] = tvm.nd.array(np.random.rand(*fp_shape).astype(fp_dtype))
     # Now infer the value.
     output_value = infer_value(input_val, params)
     # Clean fake params out of param dictionary.
@@ -550,10 +564,7 @@ def infer_value_simulated(input_val, params):
     return output_value
 
 
-def new_var(name_hint,
-            type_annotation=None,
-            shape=None,
-            dtype="float32"):
+def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"):
     return _expr.var(name_hint, type_annotation, shape, dtype)
 
 
@@ -565,10 +576,11 @@ class Renamer(object):
     new_name : str
         The new name for the operator
     """
+
     def __init__(self, new_name):
         self._new_name = new_name
 
     def __call__(self, inputs, attrs, *args):
-        if 'tvm_custom' in attrs:
-            attrs.pop('tvm_custom')
+        if "tvm_custom" in attrs:
+            attrs.pop("tvm_custom")
         return get_relay_op(self._new_name)(*inputs, **attrs)
index 65f1c2a..e510d6a 100644 (file)
@@ -31,7 +31,7 @@ from ..._ffi import base as _base
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
 
-__all__ = ['from_coreml']
+__all__ = ["from_coreml"]
 
 
 def _NeuralNetworkImageScaler(op, inexpr, etab):
@@ -39,36 +39,44 @@ def _NeuralNetworkImageScaler(op, inexpr, etab):
     # this changes the symbol
     biases = np.array([op.blueBias, op.greenBias, op.redBias]).reshape([3, 1, 1])
     bias = etab.new_const(biases)
-    ret = _op.multiply(inexpr, _expr.const(op.channelScale, dtype='float32'))
+    ret = _op.multiply(inexpr, _expr.const(op.channelScale, dtype="float32"))
     ret = _op.add(ret, bias)
     return ret
 
 
 def _NeuralNetworkMeanImage(op, inexpr, etab):
     # this changes the symbol
-    ret = _op.subtract(inexpr, _expr.const(op.meanImage, dtype='float32'))
+    ret = _op.subtract(inexpr, _expr.const(op.meanImage, dtype="float32"))
     return ret
 
 
 def _ConvolutionLayerParams(op, inexpr, etab):
     """Convolution layer params."""
     if op.isDeconvolution:
-        weights = etab.new_const(np.array(list(op.weights.floatValue)).reshape(
-            tuple([op.kernelChannels, op.outputChannels] + list(op.kernelSize))))
+        weights = etab.new_const(
+            np.array(list(op.weights.floatValue)).reshape(
+                tuple([op.kernelChannels, op.outputChannels] + list(op.kernelSize))
+            )
+        )
     else:
-        weights = etab.new_const(np.array(list(op.weights.floatValue)).reshape(
-            tuple([op.outputChannels, op.kernelChannels] + list(op.kernelSize))))
+        weights = etab.new_const(
+            np.array(list(op.weights.floatValue)).reshape(
+                tuple([op.outputChannels, op.kernelChannels] + list(op.kernelSize))
+            )
+        )
     dilation = list(op.dilationFactor)
     if not dilation:
         dilation = [1, 1]
     N, C, H, W = _infer_shape(inexpr)
-    params = {'channels':op.outputChannels,
-              'kernel_size':list(op.kernelSize),
-              'strides':list(op.stride),
-              'dilation': dilation,
-              'groups':op.nGroups}
-
-    if op.WhichOneof('ConvolutionPaddingType') == 'valid':
+    params = {
+        "channels": op.outputChannels,
+        "kernel_size": list(op.kernelSize),
+        "strides": list(op.stride),
+        "dilation": dilation,
+        "groups": op.nGroups,
+    }
+
+    if op.WhichOneof("ConvolutionPaddingType") == "valid":
         valid = op.valid
         if valid.paddingAmounts.borderAmounts:
             assert len(valid.paddingAmounts.borderAmounts) == 2
@@ -77,15 +85,16 @@ def _ConvolutionLayerParams(op, inexpr, etab):
             pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
             pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
             if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
-                params['padding'] = (pad_t, pad_l, pad_b, pad_r)
-    elif op.WhichOneof('ConvolutionPaddingType') == 'same':
-        assert op.same.asymmetryMode == 0, "Only support BOTTOM_RIGHT_HEAVY mode, " \
-                                           "which is used by tf/caffe and so on"
-        kernel = params['kernel_size']
-        strides = params['strides']
+                params["padding"] = (pad_t, pad_l, pad_b, pad_r)
+    elif op.WhichOneof("ConvolutionPaddingType") == "same":
+        assert op.same.asymmetryMode == 0, (
+            "Only support BOTTOM_RIGHT_HEAVY mode, " "which is used by tf/caffe and so on"
+        )
+        kernel = params["kernel_size"]
+        strides = params["strides"]
         pad_t, pad_b = get_pad_value(H, kernel[0], strides[0])
         pad_l, pad_r = get_pad_value(W, kernel[1], strides[1])
-        params['padding'] = (pad_t, pad_l, pad_b, pad_r)
+        params["padding"] = (pad_t, pad_l, pad_b, pad_r)
     else:
         raise NotImplementedError("Valid/Same convolution padding implemented")
 
@@ -105,78 +114,89 @@ def _BatchnormLayerParams(op, inexpr, etab):
     # this changes the symbol
     if op.instanceNormalization:
         raise tvm.error.OpNotImplemented(
-            'Operator "instance normalization" is not supported in frontend CoreML.')
-    params = {'gamma':etab.new_const(list(op.gamma.floatValue)),
-              'beta':etab.new_const(list(op.beta.floatValue)),
-              'moving_mean':etab.new_const(list(op.mean.floatValue)),
-              'moving_var': etab.new_const(list(op.variance.floatValue)),
-              'epsilon': op.epsilon}
+            'Operator "instance normalization" is not supported in frontend CoreML.'
+        )
+    params = {
+        "gamma": etab.new_const(list(op.gamma.floatValue)),
+        "beta": etab.new_const(list(op.beta.floatValue)),
+        "moving_mean": etab.new_const(list(op.mean.floatValue)),
+        "moving_var": etab.new_const(list(op.variance.floatValue)),
+        "epsilon": op.epsilon,
+    }
     result, moving_mean, moving_var = _op.nn.batch_norm(data=inexpr, **params)
     return result
 
 
 def _ActivationParams(op, inexpr, etab):
     """Get activation parameters"""
-    whichActivation = op.WhichOneof('NonlinearityType')
+    whichActivation = op.WhichOneof("NonlinearityType")
     par = getattr(op, whichActivation)
-    if whichActivation == 'linear':
-        alpha = _expr.const(par.alpha, dtype='float32')
-        beta = _expr.const(par.beta, dtype='float32')
+    if whichActivation == "linear":
+        alpha = _expr.const(par.alpha, dtype="float32")
+        beta = _expr.const(par.beta, dtype="float32")
         return _op.add(_op.multiply(inexpr, alpha), beta)
-    if whichActivation == 'ReLU':
+    if whichActivation == "ReLU":
         return _op.nn.relu(inexpr)
-    if whichActivation == 'leakyReLU':
-        _op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
-    elif whichActivation == 'thresholdedReLU':
-        alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype='float32'))
-        return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type('float32'))
-    if whichActivation == 'PReLU':
-        return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
-    if whichActivation == 'tanh':
+    if whichActivation == "leakyReLU":
+        _op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype="float32"))
+    elif whichActivation == "thresholdedReLU":
+        alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype="float32"))
+        return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type("float32"))
+    if whichActivation == "PReLU":
+        return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype="float32"))
+    if whichActivation == "tanh":
         return _op.tanh(inexpr)
-    if whichActivation == 'scaledTanh':
-        alpha = _expr.const(par.alpha, dtype='float32')
-        beta = _expr.const(par.beta, dtype='float32')
+    if whichActivation == "scaledTanh":
+        alpha = _expr.const(par.alpha, dtype="float32")
+        beta = _expr.const(par.beta, dtype="float32")
         return _op.multiply(_op.tanh(_op.multiply(inexpr, beta)), alpha)
-    if whichActivation == 'sigmoid':
+    if whichActivation == "sigmoid":
         return _op.sigmoid(inexpr)
-    if whichActivation == 'sigmoidHard':
-        alpha = _expr.const(par.alpha, dtype='float32')
-        beta = _expr.const(par.beta, dtype='float32')
+    if whichActivation == "sigmoidHard":
+        alpha = _expr.const(par.alpha, dtype="float32")
+        beta = _expr.const(par.beta, dtype="float32")
         transformX = (alpha * inexpr) + beta
-        return _op.clip(transformX, a_min=0., a_max=1.)
-    if whichActivation == 'ELU':
-        return _op.multiply(_op.add(_op.exp(inexpr), _expr.const(-1, dtype='float32')),
-                            _expr.const(par.alpha, dtype='float32'))
-    if whichActivation == 'softsign':
-        return inexpr / (_expr.const(1, dtype='float32') + (
-            op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr))))
-    if whichActivation == 'softplus':
-        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype='float32')))
-    if whichActivation == 'parametricSoftplus':
+        return _op.clip(transformX, a_min=0.0, a_max=1.0)
+    if whichActivation == "ELU":
+        return _op.multiply(
+            _op.add(_op.exp(inexpr), _expr.const(-1, dtype="float32")),
+            _expr.const(par.alpha, dtype="float32"),
+        )
+    if whichActivation == "softsign":
+        return inexpr / (
+            _expr.const(1, dtype="float32")
+            + (op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr)))
+        )
+    if whichActivation == "softplus":
+        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype="float32")))
+    if whichActivation == "parametricSoftplus":
         alpha = list(par.alpha.floatValue)
         beta = list(par.alpha.floatValue)
         if len(alpha) == 1:
-            return _op.multiply(_op.log(_op.add(_op.exp(inexpr),
-                                                _expr.const(beta[0], dtype='float32'))),
-                                _expr.const(alpha[0], dtype='float32'))
+            return _op.multiply(
+                _op.log(_op.add(_op.exp(inexpr), _expr.const(beta[0], dtype="float32"))),
+                _expr.const(alpha[0], dtype="float32"),
+            )
         alpha = np.array(alpha).reshape((len(alpha), 1, 1))
         beta = np.array(beta).reshape((len(beta), 1, 1))
         alpha_expr = etab.new_const(alpha)
         beta_expr = etab.new_const(beta)
         return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr)
     raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported in frontend CoreML.'.format(whichActivation))
+        "Operator {} is not supported in frontend CoreML.".format(whichActivation)
+    )
 
 
 def _ScaleLayerParams(op, inexpr, etab):
     """Scale layer params."""
-    scale = etab.new_const(np.array(list(op.scale.floatValue)).reshape(
-        tuple(list(op.shapeScale) + [1, 1])))
+    scale = etab.new_const(
+        np.array(list(op.scale.floatValue)).reshape(tuple(list(op.shapeScale) + [1, 1]))
+    )
     ret = _op.multiply(inexpr, scale)
     if op.hasBias:
-        bias = etab.new_const(np.array(list(op.bias.floatValue)).reshape(
-            tuple(list(op.shapeBias) + [1, 1])))
+        bias = etab.new_const(
+            np.array(list(op.bias.floatValue)).reshape(tuple(list(op.shapeBias) + [1, 1]))
+        )
         ret = _op.add(ret, bias)
     return ret
 
@@ -189,12 +209,12 @@ def _PoolingLayerParams(op, inexpr, etab):
         if op.type == 1:
             return _op.nn.global_avg_pool2d(inexpr)
         raise tvm.error.OpNotImplemented(
-            'Only Max and Average Pooling are supported in frontend CoreML.')
+            "Only Max and Average Pooling are supported in frontend CoreML."
+        )
 
-    params = {'pool_size':list(op.kernelSize),
-              'strides':list(op.stride)}
+    params = {"pool_size": list(op.kernelSize), "strides": list(op.stride)}
 
-    if op.WhichOneof('PoolingPaddingType') == 'valid':
+    if op.WhichOneof("PoolingPaddingType") == "valid":
         valid = op.valid
         if valid.paddingAmounts.borderAmounts:
             assert len(valid.paddingAmounts.borderAmounts) == 2
@@ -203,24 +223,23 @@ def _PoolingLayerParams(op, inexpr, etab):
             pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
             pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
             if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
-                params['padding'] = [pad_t, pad_l, pad_b, pad_r]
-    elif op.WhichOneof('PoolingPaddingType') == 'includeLastPixel':
+                params["padding"] = [pad_t, pad_l, pad_b, pad_r]
+    elif op.WhichOneof("PoolingPaddingType") == "includeLastPixel":
         # I don't know if this is correct
         valid = op.includeLastPixel
         padding = list(valid.paddingAmounts)
-        params['padding'] = padding
-        params['ceil_mode'] = True
+        params["padding"] = padding
+        params["ceil_mode"] = True
     else:
-        msg = 'PoolingPaddingType {} is not supported in operator Pooling.'
-        op_name = op.WhichOneof('PoolingPaddingType')
+        msg = "PoolingPaddingType {} is not supported in operator Pooling."
+        op_name = op.WhichOneof("PoolingPaddingType")
         raise tvm.error.OpAttributeUnImplemented(msg.format(op_name))
 
     if op.type == 0:
         return _op.nn.max_pool2d(inexpr, **params)
     if op.type == 1:
         return _op.nn.avg_pool2d(inexpr, **params)
-    raise tvm.error.OpNotImplemented(
-        'Only Max and Average Pooling are supported in CoreML.')
+    raise tvm.error.OpNotImplemented("Only Max and Average Pooling are supported in CoreML.")
 
 
 def _SoftmaxLayerParams(op, inexpr, etab):
@@ -228,8 +247,9 @@ def _SoftmaxLayerParams(op, inexpr, etab):
 
 
 def _InnerProductLayerParams(op, inexpr, etab):
-    weights = etab.new_const(np.array(op.weights.floatValue).reshape(
-        (op.outputChannels, op.inputChannels)))
+    weights = etab.new_const(
+        np.array(op.weights.floatValue).reshape((op.outputChannels, op.inputChannels))
+    )
     out = _op.nn.dense(data=inexpr, weight=weights, units=op.outputChannels)
     if op.hasBias:
         bias = etab.new_const(np.array(op.bias.floatValue))
@@ -244,7 +264,7 @@ def _AddLayerParams(op, inexpr, etab):
     for i in range(1, len(inexpr)):
         ret = _op.add(ret, inexpr[i])
     if op.alpha > 0:
-        ret = _op.add(ret, _expr.const(op.alpha, dtype='float32'))
+        ret = _op.add(ret, _expr.const(op.alpha, dtype="float32"))
     return ret
 
 
@@ -255,7 +275,7 @@ def _MultiplyLayerParams(op, inexpr, etab):
     for i in range(1, len(inexpr)):
         ret = _op.multiply(ret, inexpr[i])
     if op.alpha != 1:
-        ret = _op.multiply(ret, _expr.const(op.alpha, dtype='float32'))
+        ret = _op.multiply(ret, _expr.const(op.alpha, dtype="float32"))
     return ret
 
 
@@ -264,7 +284,8 @@ def _ConcatLayerParams(op, inexpr, etab):
         inexpr = [inexpr]
     if op.sequenceConcat:
         raise tvm.error.OpNotImplemented(
-            'Operator Sequence Concat is not supported in frontend CoreML.')
+            "Operator Sequence Concat is not supported in frontend CoreML."
+        )
     ret = _op.concatenate(inexpr, axis=1)
     return ret
 
@@ -277,21 +298,18 @@ def _FlattenLayerParams(op, inexpr, etab):
 
 def _PaddingLayerParams(op, inexpr, etab):
     """Padding layer params."""
-    if op.WhichOneof('PaddingType') == 'constant':
+    if op.WhichOneof("PaddingType") == "constant":
         constant = op.constant
         if constant.value != 0:
             raise tvm.error.OpAttributeUnImplemented(
-                '{} is not supported in operator Padding.'.format(constant.value))
+                "{} is not supported in operator Padding.".format(constant.value)
+            )
         pad_t = op.paddingAmounts.borderAmounts[0].startEdgeSize
         pad_l = op.paddingAmounts.borderAmounts[1].startEdgeSize
         pad_b = op.paddingAmounts.borderAmounts[0].endEdgeSize
         pad_r = op.paddingAmounts.borderAmounts[1].endEdgeSize
-        return _op.nn.pad(data=inexpr, pad_width=((0, 0),
-                                                  (0, 0),
-                                                  (pad_t, pad_b),
-                                                  (pad_l, pad_r)))
-    raise tvm.error.OpNotImplemented(
-        'Non-constant padding is not supported in frontend CoreML.')
+        return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
+    raise tvm.error.OpNotImplemented("Non-constant padding is not supported in frontend CoreML.")
 
 
 def _PermuteLayerParams(op, inexpr, etab):
@@ -301,11 +319,11 @@ def _PermuteLayerParams(op, inexpr, etab):
 
 def _UpsampleLayerParams(op, inexpr, etab):
     if op.scalingFactor[0] != op.scalingFactor[1]:
-        raise tvm.error.OpAttributeUnimplemented(
-            'Upsample height and width must be equal.')
-    interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
-    return _op.nn.upsampling(inexpr, scale_h=op.scalingFactor[0],
-                             scale_w=op.scalingFactor[1], method=interpolationMode)
+        raise tvm.error.OpAttributeUnimplemented("Upsample height and width must be equal.")
+    interpolationMode = "nearest_neighbor" if op.mode == 0 else "bilinear"
+    return _op.nn.upsampling(
+        inexpr, scale_h=op.scalingFactor[0], scale_w=op.scalingFactor[1], method=interpolationMode
+    )
 
 
 def _L2NormalizeLayerParams(op, inexpr, etab):
@@ -314,11 +332,11 @@ def _L2NormalizeLayerParams(op, inexpr, etab):
 
 def _LRNLayerParams(op, inexpr, etab):
     par = {}
-    par['size'] = op.localSize
-    par['bias'] = op.k
-    par['alpha'] = op.alpha
-    par['beta'] = op.beta
-    par['axis'] = 1 # default layout is nchw
+    par["size"] = op.localSize
+    par["bias"] = op.k
+    par["alpha"] = op.alpha
+    par["beta"] = op.beta
+    par["axis"] = 1  # default layout is nchw
     return _op.nn.lrn(data=inexpr, **par)
 
 
@@ -329,7 +347,7 @@ def _AverageLayerParams(op, inexpr, etab):
     _sum = inexpr[0]
     for i in range(1, count):
         _sum = _op.add(_sum, inexpr[i])
-    return _sum / _expr.const(count, dtype='float32')
+    return _sum / _expr.const(count, dtype="float32")
 
 
 def _MaxLayerParams(op, inexpr, etab):
@@ -373,7 +391,7 @@ def _UnaryFunctionLayerParams(op, inexpr, etab):
         alpha = _expr.const(op.alpha)
         return _op.maximum(inexpr, alpha)
     else:
-        msg = 'Unary Op type value {} is not supported in frontend CoreML.'
+        msg = "Unary Op type value {} is not supported in frontend CoreML."
         raise tvm.error.OpAttributeUnImplemented(msg.format(op_type))
 
 
@@ -390,7 +408,7 @@ def _ReduceLayerParams(op, inexpr, etab):
     elif axis == op.W:
         axis = -1
     else:
-        msg = 'Reduce axis value {} is not supported in frontend CoreML.'
+        msg = "Reduce axis value {} is not supported in frontend CoreML."
         raise tvm.error.OpAttributeUnImplemented(msg.format(axis))
 
     mode = op.mode
@@ -407,7 +425,7 @@ def _ReduceLayerParams(op, inexpr, etab):
     elif mode == op.ARGMAX:
         return _op.argmax(inexpr, axis=axis, keepdims=True)
     else:
-        msg = 'Reduce mode value {} is not supported in frontend CoreML.'
+        msg = "Reduce mode value {} is not supported in frontend CoreML."
         raise tvm.error.OpAttributeUnImplemented(msg.format(mode))
 
 
@@ -420,31 +438,31 @@ def _SplitLayerParams(op, inexpr, etab):
 
 
 _convert_map = {
-    'NeuralNetworkMeanImage': _NeuralNetworkMeanImage,
-    'NeuralNetworkImageScaler': _NeuralNetworkImageScaler,
-    'ConvolutionLayerParams': _ConvolutionLayerParams,
-    'BatchnormLayerParams': _BatchnormLayerParams,
-    'ActivationParams': _ActivationParams,
-    'ScaleLayerParams': _ScaleLayerParams,
-    'PoolingLayerParams': _PoolingLayerParams,
-    'SoftmaxLayerParams': _SoftmaxLayerParams,
-    'InnerProductLayerParams': _InnerProductLayerParams,
-    'AddLayerParams': _AddLayerParams,
-    'MultiplyLayerParams': _MultiplyLayerParams,
-    'FlattenLayerParams': _FlattenLayerParams,
-    'ConcatLayerParams': _ConcatLayerParams,
-    'PaddingLayerParams': _PaddingLayerParams,
-    'PermuteLayerParams': _PermuteLayerParams,
-    'UpsampleLayerParams': _UpsampleLayerParams,
-    'L2NormalizeLayerParams': _L2NormalizeLayerParams,
-    'LRNLayerParams': _LRNLayerParams,
-    'AverageLayerParams': _AverageLayerParams,
-    'MaxLayerParams': _MaxLayerParams,
-    'MinLayerParams': _MinLayerParams,
-    'UnaryFunctionLayerParams': _UnaryFunctionLayerParams,
-    'ReduceLayerParams': _ReduceLayerParams,
-    'ReshapeLayerParams': _ReshapeLayerParams,
-    'SplitLayerParams': _SplitLayerParams,
+    "NeuralNetworkMeanImage": _NeuralNetworkMeanImage,
+    "NeuralNetworkImageScaler": _NeuralNetworkImageScaler,
+    "ConvolutionLayerParams": _ConvolutionLayerParams,
+    "BatchnormLayerParams": _BatchnormLayerParams,
+    "ActivationParams": _ActivationParams,
+    "ScaleLayerParams": _ScaleLayerParams,
+    "PoolingLayerParams": _PoolingLayerParams,
+    "SoftmaxLayerParams": _SoftmaxLayerParams,
+    "InnerProductLayerParams": _InnerProductLayerParams,
+    "AddLayerParams": _AddLayerParams,
+    "MultiplyLayerParams": _MultiplyLayerParams,
+    "FlattenLayerParams": _FlattenLayerParams,
+    "ConcatLayerParams": _ConcatLayerParams,
+    "PaddingLayerParams": _PaddingLayerParams,
+    "PermuteLayerParams": _PermuteLayerParams,
+    "UpsampleLayerParams": _UpsampleLayerParams,
+    "L2NormalizeLayerParams": _L2NormalizeLayerParams,
+    "LRNLayerParams": _LRNLayerParams,
+    "AverageLayerParams": _AverageLayerParams,
+    "MaxLayerParams": _MaxLayerParams,
+    "MinLayerParams": _MinLayerParams,
+    "UnaryFunctionLayerParams": _UnaryFunctionLayerParams,
+    "ReduceLayerParams": _ReduceLayerParams,
+    "ReshapeLayerParams": _ReshapeLayerParams,
+    "SplitLayerParams": _SplitLayerParams,
 }
 
 # SAME padding: https://www.tensorflow.org/api_guides/python/nn
@@ -493,7 +511,8 @@ def coreml_op_to_relay(op, inname, outnames, etab):
     classname = type(op).__name__
     if classname not in _convert_map:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported in frontend CoreML.'.format(classname))
+            "Operator {} is not supported in frontend CoreML.".format(classname)
+        )
     if isinstance(inname, _base.string_types):
         insym = etab.get_expr(inname)
     else:
@@ -533,12 +552,12 @@ def from_coreml(model, shape=None):
     try:
         import coremltools as cm
     except ImportError:
-        raise ImportError('The coremltools package must be installed')
+        raise ImportError("The coremltools package must be installed")
 
     assert isinstance(model, cm.models.MLModel)
     spec = model.get_spec()
-    modeltype = spec.WhichOneof('Type')
-    assert modeltype in ['neuralNetworkClassifier', 'neuralNetwork', 'neuralNetworkRegressor']
+    modeltype = spec.WhichOneof("Type")
+    assert modeltype in ["neuralNetworkClassifier", "neuralNetwork", "neuralNetworkRegressor"]
     cc = getattr(spec, modeltype)
 
     etab = ExprTable()
@@ -547,39 +566,41 @@ def from_coreml(model, shape=None):
         etab.set_expr(i.name, _expr.var(i.name, shape=input_shape))
 
     for pp in cc.preprocessing:
-        whichpp = pp.WhichOneof('preprocessor')
+        whichpp = pp.WhichOneof("preprocessor")
         ppmethod = getattr(pp, whichpp)
-        if whichpp == 'scaler':
+        if whichpp == "scaler":
             # Be careful we maybe only preprocess one input when we have multi inputs
             # which is stored in pp.featureName. See unit testing verify_image_scaler
             # in test_forward.py for CoreML.
             for i in spec.description.input:
                 # we have multi inputs
                 if len(spec.description.input) > 1:
-                    assert pp.featureName != ''
+                    assert pp.featureName != ""
                     if i.name == pp.featureName:
                         coreml_op_to_relay(ppmethod, i.name, i.name, etab)
                 else:
-                    assert pp.featureName == ''
+                    assert pp.featureName == ""
                     coreml_op_to_relay(ppmethod, i.name, i.name, etab)
         else:
             coreml_op_to_relay(ppmethod, pp.featureName, pp.featureName, etab)
 
     for l in cc.layers:
-        layertype = l.WhichOneof('layer')
+        layertype = l.WhichOneof("layer")
         layerop = getattr(l, layertype)
         if len(l.input) == 1:
             coreml_op_to_relay(layerop, l.input[0], l.output, etab)
         else:
             coreml_op_to_relay(layerop, list(l.input), l.output, etab)
 
-    outexpr = [etab.get_expr(o.name) if o.name in etab.exprs else _expr.var(o.name)
-               for o in spec.description.output]
+    outexpr = [
+        etab.get_expr(o.name) if o.name in etab.exprs else _expr.var(o.name)
+        for o in spec.description.output
+    ]
 
     # check there are multiple outputs in the model and all are there in etab
     multi_out = all([bool(o.name in etab.exprs) for o in spec.description.output])
     outexpr = _expr.Tuple(outexpr) if multi_out else outexpr[0]
 
     func = _function.Function(analysis.free_vars(outexpr), outexpr)
-    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
+    params = {k: _nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
     return IRModule.from_expr(func), params
index 62a3207..87e5559 100644 (file)
@@ -29,20 +29,24 @@ from .. import expr as _expr
 from .. import function as _function
 from .common import get_relay_op, new_var
 
-__all__ = ['from_darknet']
+__all__ = ["from_darknet"]
 
-def _darknet_not_support(attr, op='relay'):
+
+def _darknet_not_support(attr, op="relay"):
     """Raise error if any operation is not supported."""
     err = "{} is not supported in {}.".format(attr, op)
     raise NotImplementedError(err)
 
+
 def _get_params_prefix(opname, layer_num):
     """Makes the params prefix name from opname and layer number."""
     return str(opname) + str(layer_num)
 
+
 def _get_params_name(prefix, item):
     """Makes the params name for the k,v pair."""
-    return prefix + '_'+ item
+    return prefix + "_" + item
+
 
 def _get_param_var(params, prefix, item):
     name = _get_params_name(prefix, item)
@@ -50,228 +54,243 @@ def _get_param_var(params, prefix, item):
         raise AttributeError("{} not found in params dict.".format(name))
     return new_var(name, shape=params[name].shape, dtype=params[name].dtype)
 
+
 def _darknet_maxpooling(inputs, params, attrs, prefix):
     """Process the max pool 2d operation."""
     new_attrs = {}
-    kernel = attrs.get('kernel')
-    strides = attrs.get('stride', 1)
-    pads = attrs.get('pad', 1)
-    new_attrs['pool_size'] = (kernel, kernel)
-    new_attrs['strides'] = (strides, strides)
-    new_attrs['padding'] = (pads, pads)
-    extra_pad_size = attrs.get('extra_pad_size', 0)
+    kernel = attrs.get("kernel")
+    strides = attrs.get("stride", 1)
+    pads = attrs.get("pad", 1)
+    new_attrs["pool_size"] = (kernel, kernel)
+    new_attrs["strides"] = (strides, strides)
+    new_attrs["padding"] = (pads, pads)
+    extra_pad_size = attrs.get("extra_pad_size", 0)
     if extra_pad_size:
         pad_width = ((0, 0), (0, 0), (0, extra_pad_size), (0, extra_pad_size))
-        inputs = [get_relay_op('pad')(*inputs,
-                                      pad_width=pad_width,
-                                      pad_value=np.finfo(np.float32).min)]
-    return get_relay_op('max_pool2d')(*inputs, **new_attrs)
+        inputs = [
+            get_relay_op("pad")(*inputs, pad_width=pad_width, pad_value=np.finfo(np.float32).min)
+        ]
+    return get_relay_op("max_pool2d")(*inputs, **new_attrs)
+
 
 def _darknet_avgpooling(inputs, params, attrs, prefix):
     """Process the average pool 2d operation."""
     new_attrs = {}
-    kernel = attrs.get('kernel')
-    strides = attrs.get('stride', 1)
-    pads = attrs.get('pad', 0)
+    kernel = attrs.get("kernel")
+    strides = attrs.get("stride", 1)
+    pads = attrs.get("pad", 0)
+
+    new_attrs["pool_size"] = (kernel, kernel)
+    new_attrs["strides"] = (strides, strides)
+    new_attrs["padding"] = (pads, pads)
+    return get_relay_op("avg_pool2d")(*inputs, **new_attrs)
 
-    new_attrs['pool_size'] = (kernel, kernel)
-    new_attrs['strides'] = (strides, strides)
-    new_attrs['padding'] = (pads, pads)
-    return get_relay_op('avg_pool2d')(*inputs, **new_attrs)
 
 def _darknet_conv2d(inputs, params, attrs, prefix):
     """Process the convolution 2d operation."""
     new_attrs = {}
-    kernel = attrs.get('kernel')
-    strides = attrs.get('stride', 1)
-    pads = attrs.get('pad', 0)
+    kernel = attrs.get("kernel")
+    strides = attrs.get("stride", 1)
+    pads = attrs.get("pad", 0)
 
-    new_attrs['channels'] = attrs.get('num_filter')
-    new_attrs['kernel_size'] = (kernel, kernel)
-    new_attrs['strides'] = (strides, strides)
-    new_attrs['padding'] = (pads, pads)
-    new_attrs['dilation'] = attrs.get('dilate', (1, 1))
-    new_attrs['groups'] = attrs.get('num_group', 1)
+    new_attrs["channels"] = attrs.get("num_filter")
+    new_attrs["kernel_size"] = (kernel, kernel)
+    new_attrs["strides"] = (strides, strides)
+    new_attrs["padding"] = (pads, pads)
+    new_attrs["dilation"] = attrs.get("dilate", (1, 1))
+    new_attrs["groups"] = attrs.get("num_group", 1)
 
-    weight = _get_param_var(params, prefix, 'weight')
-    out = get_relay_op('conv2d')(*inputs, weight=weight, **new_attrs)
+    weight = _get_param_var(params, prefix, "weight")
+    out = get_relay_op("conv2d")(*inputs, weight=weight, **new_attrs)
 
-    use_bias = not attrs.get('use_batchNorm', False)
+    use_bias = not attrs.get("use_batchNorm", False)
     if use_bias:
         new_attrs = {}
-        new_attrs['axis'] = 1
-        bias = _get_param_var(params, prefix, 'bias')
-        out = get_relay_op('bias_add')(out, bias=bias, **new_attrs)
+        new_attrs["axis"] = 1
+        bias = _get_param_var(params, prefix, "bias")
+        out = get_relay_op("bias_add")(out, bias=bias, **new_attrs)
     else:
         new_attrs = {}
-        new_attrs['epsilon'] = 0.000001
-        gamma = _get_param_var(params, prefix, 'gamma')
-        beta = _get_param_var(params, prefix, 'beta')
-        moving_mean = _get_param_var(params, prefix, 'moving_mean')
-        moving_var = _get_param_var(params, prefix, 'moving_var')
-        out = get_relay_op('batch_norm')(out, gamma, beta, moving_mean, moving_var, **new_attrs)
-
-    if 'activation' in attrs:
+        new_attrs["epsilon"] = 0.000001
+        gamma = _get_param_var(params, prefix, "gamma")
+        beta = _get_param_var(params, prefix, "beta")
+        moving_mean = _get_param_var(params, prefix, "moving_mean")
+        moving_var = _get_param_var(params, prefix, "moving_var")
+        out = get_relay_op("batch_norm")(out, gamma, beta, moving_mean, moving_var, **new_attrs)
+
+    if "activation" in attrs:
         new_attrs = {}
-        new_attrs['activation'] = attrs['activation']
-        new_attrs['slope'] = 0.1
+        new_attrs["activation"] = attrs["activation"]
+        new_attrs["slope"] = 0.1
         out = _darknet_activations(out, None, new_attrs)
     return out
 
+
 def _darknet_shortcut(inputs, params, attrs, prefix):
     """Process the shortcut operation."""
     input_0 = inputs[0]
     input_1 = inputs[1]
 
-    input_0_channel = int(attrs['out_channel'])
-    input_1_channel = int(attrs['add_out_channel'])
-    input_0_size = int(attrs['out_size'])
-    input_1_size = int(attrs['add_out_size'])
+    input_0_channel = int(attrs["out_channel"])
+    input_1_channel = int(attrs["add_out_channel"])
+    input_0_size = int(attrs["out_size"])
+    input_1_size = int(attrs["add_out_size"])
 
     if input_0_size > input_1_size:
-        scale = int(input_0_size/input_1_size)
-        input_1 = get_relay_op('upsampling')(input_1, scale_h=scale, scale_w=scale)
+        scale = int(input_0_size / input_1_size)
+        input_1 = get_relay_op("upsampling")(input_1, scale_h=scale, scale_w=scale)
 
     elif input_0_size < input_1_size:
-        stride = int(input_1_size/input_0_size)
-        input_1 = get_relay_op('avg_pool2d')(input_1,
-                                             pool_size=(1, 1),
-                                             strides=(stride, stride),
-                                             padding=(0, 0))
+        stride = int(input_1_size / input_0_size)
+        input_1 = get_relay_op("avg_pool2d")(
+            input_1, pool_size=(1, 1), strides=(stride, stride), padding=(0, 0)
+        )
 
     if input_0_channel != input_1_channel:
         pad_channel = input_0_channel - input_1_channel
-        input_1 = get_relay_op('pad')(input_1,
-                                      pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)),
-                                      pad_value=0.)
+        input_1 = get_relay_op("pad")(
+            input_1, pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)), pad_value=0.0
+        )
     sym = input_0 + input_1
-    if 'activation' in attrs:
+    if "activation" in attrs:
         new_attrs = {}
-        new_attrs['activation'] = attrs['activation']
+        new_attrs["activation"] = attrs["activation"]
         sym = _darknet_activations(sym, None, new_attrs)
     return sym
 
+
 def _darknet_dense(inputs, params, attrs, prefix):
     """Process the dense operation."""
     new_attrs = {}
-    new_attrs['units'] = attrs.get('num_hidden')
+    new_attrs["units"] = attrs.get("num_hidden")
     data = inputs[0]
 
-    if attrs.get('use_flatten', False) is True:
-        data = get_relay_op('batch_flatten')(data)
+    if attrs.get("use_flatten", False) is True:
+        data = get_relay_op("batch_flatten")(data)
 
-    weight = _get_param_var(params, prefix, 'weight')
-    data = get_relay_op('dense')(data, weight, **new_attrs)
+    weight = _get_param_var(params, prefix, "weight")
+    data = get_relay_op("dense")(data, weight, **new_attrs)
 
-    use_bias = attrs.get('use_bias', False)
+    use_bias = attrs.get("use_bias", False)
     if use_bias:
-        bias = _get_param_var(params, prefix, 'bias')
-        data = get_relay_op('bias_add')(data, bias, axis=1)
+        bias = _get_param_var(params, prefix, "bias")
+        data = get_relay_op("bias_add")(data, bias, axis=1)
 
-    if 'use_batchNorm' in attrs:
+    if "use_batchNorm" in attrs:
         new_attrs = {}
-        new_attrs['epsilon'] = 0.000001
-        gamma = _get_param_var(params, prefix, 'gamma')
-        beta = _get_param_var(params, prefix, 'beta')
-        moving_mean = _get_param_var(params, prefix, 'moving_mean')
-        moving_var = _get_param_var(params, prefix, 'moving_var')
-        data = get_relay_op('batch_norm')(data, gamma, beta, moving_mean, moving_var, **new_attrs)
-    if 'activation' in attrs:
+        new_attrs["epsilon"] = 0.000001
+        gamma = _get_param_var(params, prefix, "gamma")
+        beta = _get_param_var(params, prefix, "beta")
+        moving_mean = _get_param_var(params, prefix, "moving_mean")
+        moving_var = _get_param_var(params, prefix, "moving_var")
+        data = get_relay_op("batch_norm")(data, gamma, beta, moving_mean, moving_var, **new_attrs)
+    if "activation" in attrs:
         new_attrs = {}
-        new_attrs['activation'] = attrs['activation']
+        new_attrs["activation"] = attrs["activation"]
         data = _darknet_activations(data, None, new_attrs)
     return data
 
+
 def _darknet_dropout(inputs, params, attrs, prefix):
     """Process the dropout operation, its a blank operation."""
     new_attrs = {}
-    new_attrs['rate'] = attrs.get('p', 0.5)
-    return get_relay_op('dropout')(*inputs, **new_attrs)
+    new_attrs["rate"] = attrs.get("p", 0.5)
+    return get_relay_op("dropout")(*inputs, **new_attrs)
+
 
 def _darknet_reshape(inputs, params, attrs, prefix):
     """Process the reshape operation."""
     new_attrs = {}
-    new_attrs['shape'] = attrs.get('shape')
-    return get_relay_op('reshape')(*inputs, **new_attrs)
+    new_attrs["shape"] = attrs.get("shape")
+    return get_relay_op("reshape")(*inputs, **new_attrs)
+
 
 def _darknet_upsampling(inputs, params, attrs, prefix):
     """Process the upsampling operation."""
     new_attrs = {}
-    new_attrs['scale_h'] = attrs.get('scale', 1)
-    new_attrs['scale_w'] = attrs.get('scale', 1)
-    return get_relay_op('upsampling')(*inputs, **new_attrs)
+    new_attrs["scale_h"] = attrs.get("scale", 1)
+    new_attrs["scale_w"] = attrs.get("scale", 1)
+    return get_relay_op("upsampling")(*inputs, **new_attrs)
+
 
 def _darknet_l2normalize(inputs, params, attrs, prefix):
     """Process the l2 normalization operation."""
     new_attrs = {}
-    new_attrs['eps'] = attrs.get('eps', 0.0)
-    new_attrs['axis'] = [attrs.get('axis', 1)]
-    return get_relay_op('l2_normalize')(*inputs, **new_attrs)
+    new_attrs["eps"] = attrs.get("eps", 0.0)
+    new_attrs["axis"] = [attrs.get("axis", 1)]
+    return get_relay_op("l2_normalize")(*inputs, **new_attrs)
+
 
 def _darknet_softmax_output(inputs, params, attrs, prefix):
     """Process the softmax operation."""
-    temperature = attrs.get('temperature', 1)
+    temperature = attrs.get("temperature", 1)
     data = inputs[0]
     if temperature != 1:
         data = data / _expr.const(float(temperature))
 
-    if attrs.get('use_flatten', False) is True:
-        data = get_relay_op('batch_flatten')(data)
+    if attrs.get("use_flatten", False) is True:
+        data = get_relay_op("batch_flatten")(data)
 
     new_attrs = {}
-    if attrs.get('multi_output', False):
-        new_attrs['axis'] = 1
-    return get_relay_op('softmax')(data, **new_attrs)
+    if attrs.get("multi_output", False):
+        new_attrs["axis"] = 1
+    return get_relay_op("softmax")(data, **new_attrs)
+
 
 def _darknet_route(inputs, params, attrs, prefix):
     """Process the route operation, which is equivalent to concat."""
-    new_attrs = {'axis': attrs.get('dim', 1)}
-    return get_relay_op('concatenate')((inputs[0], inputs[1]), **new_attrs)
+    new_attrs = {"axis": attrs.get("dim", 1)}
+    return get_relay_op("concatenate")((inputs[0], inputs[1]), **new_attrs)
+
 
 def _darknet_reorg(inputs, params, attrs, prefix):
     """Process the reorg operation."""
     new_attrs = {}
-    if 'stride' in attrs:
-        new_attrs = {'stride': attrs.get('stride', 1)}
-    return get_relay_op('yolo_reorg')(*inputs, **new_attrs)
+    if "stride" in attrs:
+        new_attrs = {"stride": attrs.get("stride", 1)}
+    return get_relay_op("yolo_reorg")(*inputs, **new_attrs)
+
 
 def _darknet_region(inputs, params, attrs, prefix):
     """Process the region operation."""
-    num = attrs.get('n', 1)
-    classes = attrs.get('classes', 1)
-    coords = attrs.get('coords', 0)
-    background = attrs.get('background', 0)
-    softmax = attrs.get('softmax', True)
-    input_shape = attrs.get('shape')
+    num = attrs.get("n", 1)
+    classes = attrs.get("classes", 1)
+    coords = attrs.get("coords", 0)
+    background = attrs.get("background", 0)
+    softmax = attrs.get("softmax", True)
+    input_shape = attrs.get("shape")
 
     split_size = classes + coords + 1
     intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3])
-    data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape)
+    data_block = get_relay_op("reshape")(inputs[0], newshape=intermediate_shape)
     split_indices = (2, 4, 5)
-    split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2)
-    split_res0 = get_relay_op('sigmoid')(split_res[0])
-    split_res2 = split_res[2] if background else get_relay_op('sigmoid')(split_res[2])
-    split_res3 = get_relay_op('softmax')(split_res[3], axis=2) if softmax else split_res[3]
-    out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2, split_res3), axis=2)
-    return get_relay_op('reshape')(out, newshape=input_shape)
+    split_res = get_relay_op("split")(data_block, indices_or_sections=split_indices, axis=2)
+    split_res0 = get_relay_op("sigmoid")(split_res[0])
+    split_res2 = split_res[2] if background else get_relay_op("sigmoid")(split_res[2])
+    split_res3 = get_relay_op("softmax")(split_res[3], axis=2) if softmax else split_res[3]
+    out = get_relay_op("concatenate")((split_res0, split_res[1], split_res2, split_res3), axis=2)
+    return get_relay_op("reshape")(out, newshape=input_shape)
+
 
 def _darknet_yolo(inputs, params, attrs, prefix):
     """Process the yolo operation."""
-    num = attrs.get('n', 1)
-    classes = attrs.get('classes', 1)
-    input_shape = attrs.get('shape')
+    num = attrs.get("n", 1)
+    classes = attrs.get("classes", 1)
+    input_shape = attrs.get("shape")
     split_size = classes + 5
     intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3])
-    data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape)
+    data_block = get_relay_op("reshape")(inputs[0], newshape=intermediate_shape)
     split_indices = (2, 4)
-    split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2)
-    split_res0 = get_relay_op('sigmoid')(split_res[0])
-    split_res2 = get_relay_op('sigmoid')(split_res[2])
-    out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2), axis=2)
-    return get_relay_op('reshape')(out, newshape=input_shape)
+    split_res = get_relay_op("split")(data_block, indices_or_sections=split_indices, axis=2)
+    split_res0 = get_relay_op("sigmoid")(split_res[0])
+    split_res2 = get_relay_op("sigmoid")(split_res[2])
+    out = get_relay_op("concatenate")((split_res0, split_res[1], split_res2), axis=2)
+    return get_relay_op("reshape")(out, newshape=input_shape)
+
 
 class ACTIVATION(object):
     """Darknet ACTIVATION Class constant."""
+
     LOGISTIC = 0
     RELU = 1
     RELIE = 2
@@ -286,25 +305,26 @@ class ACTIVATION(object):
     HARDTAN = 11
     LHTAN = 12
 
+
 def _darknet_activations(inputs, params, attrs):
     """Process the activation function."""
-    act = attrs.get('activation')
+    act = attrs.get("activation")
     data = inputs[0] if isinstance(inputs, _expr.TupleWrapper) else inputs
 
     def _const(val):
         return _expr.const(val)
 
     def _relu(data):
-        return get_relay_op('relu')(data)
+        return get_relay_op("relu")(data)
 
     def _exp(data):
-        return get_relay_op('exp')(data)
+        return get_relay_op("exp")(data)
 
     def _tanh(data):
-        return get_relay_op('tanh')(data)
+        return get_relay_op("tanh")(data)
 
     def _sigmoid(data):
-        return get_relay_op('sigmoid')(data)
+        return get_relay_op("sigmoid")(data)
 
     def _elu(data):
         alpha = _const(-1.0)
@@ -312,8 +332,8 @@ def _darknet_activations(inputs, params, attrs):
 
     def _leaky_relu(data, slope):
         new_attrs = {}
-        new_attrs['alpha'] = slope
-        return get_relay_op('leaky_relu')(data, **new_attrs)
+        new_attrs["alpha"] = slope
+        return get_relay_op("leaky_relu")(data, **new_attrs)
 
     if ACTIVATION.LOGISTIC == act:
         data = _sigmoid(data)
@@ -324,15 +344,17 @@ def _darknet_activations(inputs, params, attrs):
     elif ACTIVATION.LINEAR == act:
         return data
     elif ACTIVATION.LEAKY == act:
-        data = _leaky_relu(data, attrs.get('slope', 0.1))
+        data = _leaky_relu(data, attrs.get("slope", 0.1))
     elif ACTIVATION.ELU == act:
         data = _elu(data)
     else:
-        _darknet_not_support('act: ' + attrs)
+        _darknet_not_support("act: " + attrs)
     return data
 
+
 class LAYERTYPE(Enum):
     """Darknet LAYERTYPE Class constant."""
+
     CONVOLUTIONAL = 0
     DECONVOLUTIONAL = 1
     CONNECTED = 2
@@ -363,37 +385,39 @@ class LAYERTYPE(Enum):
     L2NORM = 27
     BLANK = 28
 
+
 _DARKNET_CONVERT_MAP = {
-    LAYERTYPE.CONVOLUTIONAL   : _darknet_conv2d,
-    LAYERTYPE.CONNECTED       : _darknet_dense,
-    LAYERTYPE.MAXPOOL         : _darknet_maxpooling,
-    LAYERTYPE.SOFTMAX         : _darknet_softmax_output,
-    LAYERTYPE.DROPOUT         : _darknet_dropout,
-    LAYERTYPE.AVGPOOL         : _darknet_avgpooling,
-    LAYERTYPE.ROUTE           : _darknet_route,
-    LAYERTYPE.REORG           : _darknet_reorg,
-    LAYERTYPE.REGION          : _darknet_region,
-    LAYERTYPE.SHORTCUT        : _darknet_shortcut,
-    LAYERTYPE.UPSAMPLE        : _darknet_upsampling,
-    LAYERTYPE.L2NORM          : _darknet_l2normalize,
-    LAYERTYPE.YOLO            : _darknet_yolo,
-    LAYERTYPE.DECONVOLUTIONAL : _darknet_not_support,
-    LAYERTYPE.BATCHNORM       : _darknet_not_support,
-    LAYERTYPE.DETECTION       : _darknet_not_support,
-    LAYERTYPE.CROP            : _darknet_not_support,
-    LAYERTYPE.COST            : _darknet_not_support,
-    LAYERTYPE.NORMALIZATION   : _darknet_not_support,
-    LAYERTYPE.LOCAL           : _darknet_not_support,
-    LAYERTYPE.ACTIVE          : _darknet_not_support,
-    LAYERTYPE.RNN             : _darknet_not_support,
-    LAYERTYPE.GRU             : _darknet_not_support,
-    LAYERTYPE.LSTM            : _darknet_not_support,
-    LAYERTYPE.CRNN            : _darknet_not_support,
-    LAYERTYPE.NETWORK         : _darknet_not_support,
-    LAYERTYPE.XNOR            : _darknet_not_support,
-    LAYERTYPE.BLANK           : _darknet_not_support,
+    LAYERTYPE.CONVOLUTIONAL: _darknet_conv2d,
+    LAYERTYPE.CONNECTED: _darknet_dense,
+    LAYERTYPE.MAXPOOL: _darknet_maxpooling,
+    LAYERTYPE.SOFTMAX: _darknet_softmax_output,
+    LAYERTYPE.DROPOUT: _darknet_dropout,
+    LAYERTYPE.AVGPOOL: _darknet_avgpooling,
+    LAYERTYPE.ROUTE: _darknet_route,
+    LAYERTYPE.REORG: _darknet_reorg,
+    LAYERTYPE.REGION: _darknet_region,
+    LAYERTYPE.SHORTCUT: _darknet_shortcut,
+    LAYERTYPE.UPSAMPLE: _darknet_upsampling,
+    LAYERTYPE.L2NORM: _darknet_l2normalize,
+    LAYERTYPE.YOLO: _darknet_yolo,
+    LAYERTYPE.DECONVOLUTIONAL: _darknet_not_support,
+    LAYERTYPE.BATCHNORM: _darknet_not_support,
+    LAYERTYPE.DETECTION: _darknet_not_support,
+    LAYERTYPE.CROP: _darknet_not_support,
+    LAYERTYPE.COST: _darknet_not_support,
+    LAYERTYPE.NORMALIZATION: _darknet_not_support,
+    LAYERTYPE.LOCAL: _darknet_not_support,
+    LAYERTYPE.ACTIVE: _darknet_not_support,
+    LAYERTYPE.RNN: _darknet_not_support,
+    LAYERTYPE.GRU: _darknet_not_support,
+    LAYERTYPE.LSTM: _darknet_not_support,
+    LAYERTYPE.CRNN: _darknet_not_support,
+    LAYERTYPE.NETWORK: _darknet_not_support,
+    LAYERTYPE.XNOR: _darknet_not_support,
+    LAYERTYPE.BLANK: _darknet_not_support,
 }
 
+
 def _darknet_convert_symbol(op_name, inputs, params, attrs, params_prefix):
     """Convert from darknet op to relay op.
     Parameters
@@ -417,20 +441,21 @@ def _darknet_convert_symbol(op_name, inputs, params, attrs, params_prefix):
     if op_name in _DARKNET_CONVERT_MAP:
         sym = _DARKNET_CONVERT_MAP[op_name](inputs, params, attrs, params_prefix)
     else:
-        _darknet_not_support('Operator type ' + str(op_name))
+        _darknet_not_support("Operator type " + str(op_name))
     return sym
 
+
 def _as_list(arr):
     """Force being a list, ignore if already is."""
     if isinstance(arr, list):
         return arr
     return [arr]
 
+
 class GraphProto(object):
-    """A helper class for handling relay functions from darknet model.
-    """
+    """A helper class for handling relay functions from darknet model."""
 
-    def __init__(self, net, shape, dtype='float32'):
+    def __init__(self, net, shape, dtype="float32"):
         self._net = net
         self._shape = shape
         self._dtype = dtype
@@ -438,11 +463,11 @@ class GraphProto(object):
         self._tvmparams = {}
         self._outs = []
         self._state_ctr = {}
-        self._state_ctr['rnn'] = 0
-        self._state_ctr['crnn'] = 0
-        self._state_ctr['lstm'] = 0
-        self._state_ctr['cell_state'] = 0
-        self._state_ctr['gru'] = 0
+        self._state_ctr["rnn"] = 0
+        self._state_ctr["crnn"] = 0
+        self._state_ctr["lstm"] = 0
+        self._state_ctr["cell_state"] = 0
+        self._state_ctr["gru"] = 0
 
     def _read_memory_buffer(self, shape, data, dtype=None):
         if dtype is None:
@@ -467,17 +492,17 @@ class GraphProto(object):
         shape = (layer.n, layer.c // layer.groups, layer.size, layer.size)
         weights = self._read_memory_buffer(shape, layer.weights)
 
-        biases = self._read_memory_buffer((layer.n, ), layer.biases)
+        biases = self._read_memory_buffer((layer.n,), layer.biases)
 
-        k = _get_params_name(opname, 'weight')
+        k = _get_params_name(opname, "weight")
         params[k] = tvm.nd.array(weights)
 
         if layer.batch_normalize == 1 and layer.dontloadscales != 1:
             params.update(self._get_batchnorm_weights(layer, opname, layer.n))
-            k = _get_params_name(opname, 'beta')
+            k = _get_params_name(opname, "beta")
             params[k] = tvm.nd.array(biases)
         else:
-            k = _get_params_name(opname, 'bias')
+            k = _get_params_name(opname, "bias")
             params[k] = tvm.nd.array(biases)
         return params
 
@@ -488,63 +513,73 @@ class GraphProto(object):
             return None
 
         weights = self._read_memory_buffer((layer.outputs, layer.inputs), layer.weights)
-        biases = self._read_memory_buffer((layer.outputs, ), layer.biases)
+        biases = self._read_memory_buffer((layer.outputs,), layer.biases)
 
         params = {}
-        k = _get_params_name(opname, 'weight')
+        k = _get_params_name(opname, "weight")
         params[k] = tvm.nd.array(weights)
 
         if layer.batch_normalize == 1 and layer.dontloadscales != 1:
             params.update(self._get_batchnorm_weights(layer, opname, layer.outputs))
-            k = _get_params_name(opname, 'beta')
+            k = _get_params_name(opname, "beta")
             params[k] = tvm.nd.array(biases)
         else:
-            k = _get_params_name(opname, 'bias')
+            k = _get_params_name(opname, "bias")
             params[k] = tvm.nd.array(biases)
         return params
 
     def _get_region_weights(self, layer, opname):
         """Parse the biases for region layer."""
-        biases = self._read_memory_buffer((layer.n*2, ), layer.biases)
-        attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w,
-                               layer.classes, layer.coords, layer.background],
-                              dtype=np.int32)
+        biases = self._read_memory_buffer((layer.n * 2,), layer.biases)
+        attributes = np.array(
+            [
+                layer.n,
+                layer.out_c,
+                layer.out_h,
+                layer.out_w,
+                layer.classes,
+                layer.coords,
+                layer.background,
+            ],
+            dtype=np.int32,
+        )
         params = {}
-        k = _get_params_name(opname, 'bias')
+        k = _get_params_name(opname, "bias")
         params[k] = tvm.nd.array(biases)
-        k = _get_params_name(opname, 'attr')
+        k = _get_params_name(opname, "attr")
         params[k] = tvm.nd.array(attributes)
         return params
 
     def _get_yolo_weights(self, layer, opname):
         """Parse the biases and mask for yolo layer."""
-        biases = self._read_memory_buffer((layer.total*2, ), layer.biases)
-        mask = self._read_memory_buffer((layer.n, ), layer.mask, dtype='int32')
-        attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w,
-                               layer.classes, layer.total],
-                              dtype=np.int32)
+        biases = self._read_memory_buffer((layer.total * 2,), layer.biases)
+        mask = self._read_memory_buffer((layer.n,), layer.mask, dtype="int32")
+        attributes = np.array(
+            [layer.n, layer.out_c, layer.out_h, layer.out_w, layer.classes, layer.total],
+            dtype=np.int32,
+        )
         params = {}
-        k = _get_params_name(opname, 'bias')
+        k = _get_params_name(opname, "bias")
         params[k] = tvm.nd.array(biases)
-        k = _get_params_name(opname, 'mask')
+        k = _get_params_name(opname, "mask")
         params[k] = tvm.nd.array(mask)
-        k = _get_params_name(opname, 'attr')
+        k = _get_params_name(opname, "attr")
         params[k] = tvm.nd.array(attributes)
         return params
 
     def _get_batchnorm_weights(self, layer, opname, size):
         """Parse the weights for batchnorm, which includes, scales, moving mean
         and moving variances."""
-        scales = self._read_memory_buffer((size, ), layer.scales)
-        rolling_mean = self._read_memory_buffer((size, ), layer.rolling_mean)
-        rolling_variance = self._read_memory_buffer((size, ), layer.rolling_variance)
+        scales = self._read_memory_buffer((size,), layer.scales)
+        rolling_mean = self._read_memory_buffer((size,), layer.rolling_mean)
+        rolling_variance = self._read_memory_buffer((size,), layer.rolling_variance)
 
         params = {}
-        k = _get_params_name(opname, 'moving_mean')
+        k = _get_params_name(opname, "moving_mean")
         params[k] = tvm.nd.array(rolling_mean)
-        k = _get_params_name(opname, 'moving_var')
+        k = _get_params_name(opname, "moving_var")
         params[k] = tvm.nd.array(rolling_variance)
-        k = _get_params_name(opname, 'gamma')
+        k = _get_params_name(opname, "gamma")
         params[k] = tvm.nd.array(scales)
         return params
 
@@ -554,73 +589,75 @@ class GraphProto(object):
         use_flatten = True
         layer_type = LAYERTYPE(layer.type)
         if LAYERTYPE.CONVOLUTIONAL == layer_type:
-            attr.update({'pad' : layer.pad})
-            attr.update({'num_group' : layer.groups})
-            attr.update({'num_filter' : layer.n})
-            attr.update({'stride' : layer.stride})
-            attr.update({'kernel' : layer.size})
-            attr.update({'activation' : (layer.activation)})
+            attr.update({"pad": layer.pad})
+            attr.update({"num_group": layer.groups})
+            attr.update({"num_filter": layer.n})
+            attr.update({"stride": layer.stride})
+            attr.update({"kernel": layer.size})
+            attr.update({"activation": (layer.activation)})
 
             if layer.nbiases == 0:
-                attr.update({'use_bias' : False})
+                attr.update({"use_bias": False})
             else:
-                attr.update({'use_bias' : True})
+                attr.update({"use_bias": True})
 
             if layer.batch_normalize == 1 and layer.dontloadscales != 1:
-                attr.update({'use_batchNorm' : True})
-                attr.update({'use_scales' : True})
+                attr.update({"use_batchNorm": True})
+                attr.update({"use_scales": True})
 
         elif LAYERTYPE.CONNECTED == layer_type:
-            attr.update({'num_hidden' : layer.outputs})
-            attr.update({'activation' : (layer.activation)})
+            attr.update({"num_hidden": layer.outputs})
+            attr.update({"activation": (layer.activation)})
             if layer_num != 0:
                 layer_prev = self._net.layers[layer_num - 1]
-                if (layer_prev.out_h == layer.h and
-                        layer_prev.out_w == layer.w and
-                        layer_prev.out_c == layer.c):
+                if (
+                    layer_prev.out_h == layer.h
+                    and layer_prev.out_w == layer.w
+                    and layer_prev.out_c == layer.c
+                ):
                     use_flatten = False
-            attr.update({'use_flatten' : use_flatten})
-            attr.update({'use_bias' : True})
+            attr.update({"use_flatten": use_flatten})
+            attr.update({"use_bias": True})
             if layer.batch_normalize == 1 and layer.dontloadscales != 1:
-                attr.update({'use_batchNorm' : True})
-                attr.update({'use_scales' : True})
-                attr.update({'use_bias' : False})
+                attr.update({"use_batchNorm": True})
+                attr.update({"use_scales": True})
+                attr.update({"use_bias": False})
 
         elif LAYERTYPE.MAXPOOL == layer_type:
-            attr.update({'pad' : layer.pad})
-            attr.update({'stride' : layer.stride})
-            attr.update({'kernel' : layer.size})
-            max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1
+            attr.update({"pad": layer.pad})
+            attr.update({"stride": layer.stride})
+            attr.update({"kernel": layer.size})
+            max_output = (layer.w - layer.size + 2 * layer.pad) / float(layer.stride) + 1
             if max_output < layer.out_w:
-                extra_pad = (layer.out_w - max_output)*layer.stride
-                attr.update({'extra_pad_size' : int(extra_pad)})
+                extra_pad = (layer.out_w - max_output) * layer.stride
+                attr.update({"extra_pad_size": int(extra_pad)})
         elif LAYERTYPE.AVGPOOL == layer_type:
-            attr.update({'pad' : layer.pad})
+            attr.update({"pad": layer.pad})
             if layer.stride == 0:
-                attr.update({'stride' : 1})
+                attr.update({"stride": 1})
             else:
-                attr.update({'stride' : layer.stride})
+                attr.update({"stride": layer.stride})
             if layer.size == 0 and layer.h == layer.w:
-                attr.update({'kernel' : layer.h})
+                attr.update({"kernel": layer.h})
             else:
-                attr.update({'kernel' : layer.size})
+                attr.update({"kernel": layer.size})
 
         elif LAYERTYPE.DROPOUT == layer_type:
-            attr.update({'p' : layer.probability})
+            attr.update({"p": layer.probability})
 
         elif LAYERTYPE.SOFTMAX == layer_type:
-            attr.update({'axis' : 1})
-            attr.update({'use_flatten' : True})
+            attr.update({"axis": 1})
+            attr.update({"use_flatten": True})
             if layer.temperature:
-                attr.update({'temperature' : str(layer.temperature)})
+                attr.update({"temperature": str(layer.temperature)})
 
         elif LAYERTYPE.SHORTCUT == layer_type:
             add_layer = self._net.layers[layer.index]
-            attr.update({'activation' : layer.activation})
-            attr.update({'out_channel' : layer.out_c})
-            attr.update({'out_size' : layer.out_h})
-            attr.update({'add_out_channel' : add_layer.out_c})
-            attr.update({'add_out_size' : add_layer.out_h})
+            attr.update({"activation": layer.activation})
+            attr.update({"out_channel": layer.out_c})
+            attr.update({"out_size": layer.out_h})
+            attr.update({"add_out_channel": add_layer.out_c})
+            attr.update({"add_out_size": add_layer.out_h})
 
         elif LAYERTYPE.ROUTE == layer_type:
             pass
@@ -629,23 +666,23 @@ class GraphProto(object):
             pass
 
         elif LAYERTYPE.REORG == layer_type:
-            attr.update({'stride' : layer.stride})
+            attr.update({"stride": layer.stride})
 
         elif LAYERTYPE.REGION == layer_type:
-            attr.update({'n' : layer.n})
-            attr.update({'classes' : layer.classes})
-            attr.update({'coords' : layer.coords})
-            attr.update({'background' : layer.background})
-            attr.update({'softmax' : layer.softmax})
-            attr.update({'shape' : (-1, layer.c, layer.h, layer.w)})
+            attr.update({"n": layer.n})
+            attr.update({"classes": layer.classes})
+            attr.update({"coords": layer.coords})
+            attr.update({"background": layer.background})
+            attr.update({"softmax": layer.softmax})
+            attr.update({"shape": (-1, layer.c, layer.h, layer.w)})
 
         elif LAYERTYPE.YOLO == layer_type:
-            attr.update({'n' : layer.n})
-            attr.update({'classes' : layer.classes})
-            attr.update({'shape' : (-1, layer.c, layer.h, layer.w)})
+            attr.update({"n": layer.n})
+            attr.update({"classes": layer.classes})
+            attr.update({"shape": (-1, layer.c, layer.h, layer.w)})
 
         elif LAYERTYPE.UPSAMPLE == layer_type:
-            attr.update({'scale' : layer.stride})
+            attr.update({"scale": layer.stride})
 
         elif LAYERTYPE.L2NORM == layer_type:
             pass
@@ -673,7 +710,7 @@ class GraphProto(object):
     def _preproc_layer(self, layer, layer_num):
         """To preprocess each darknet layer, some layer doesnt need processing."""
         if layer_num == 0:
-            name = 'data'
+            name = "data"
             sym = new_var(name, shape=self._shape, dtype=self._dtype)
         else:
             sym = self._sym_array[layer_num - 1]
@@ -704,7 +741,7 @@ class GraphProto(object):
         """Returs the layer name."""
         return LAYERTYPE(layer.type)
 
-    def _new_rnn_state_var(self, state=None, name='rnn'):
+    def _new_rnn_state_var(self, state=None, name="rnn"):
         """Returs a symbol for state"""
         sym_name = name + "%d_state" % self._state_ctr[name]
         self._state_ctr[name] += 1
@@ -734,10 +771,10 @@ class GraphProto(object):
 
         layer_type = LAYERTYPE(layer.type)
         if LAYERTYPE.RNN == layer_type:
-            attr.update({'n' : layer.n})
-            attr.update({'batch' : layer.batch})
-            attr.update({'num_hidden' : str(layer.outputs)})
-            state = self._get_rnn_state_buffer(layer, 'rnn')
+            attr.update({"n": layer.n})
+            attr.update({"batch": layer.batch})
+            attr.update({"num_hidden": str(layer.outputs)})
+            state = self._get_rnn_state_buffer(layer, "rnn")
             for _ in range(layer.steps):
                 input_layer = layer.input_layer
                 prefix = "_input_" + str(layer_num)
@@ -761,40 +798,40 @@ class GraphProto(object):
     def _make_outlist(self, sym, op_name, layer, layer_num):
         layer_type = LAYERTYPE(layer.type)
         if layer_type == LAYERTYPE.REGION:
-            #Add attributes
-            k = _get_params_name(op_name, 'attr')
+            # Add attributes
+            k = _get_params_name(op_name, "attr")
             dshape = self._tvmparams[k].shape
             dtype = self._tvmparams[k].dtype
             self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype))
 
-            #Add bias
-            k = _get_params_name(op_name, 'bias')
+            # Add bias
+            k = _get_params_name(op_name, "bias")
             dshape = self._tvmparams[k].shape
             dtype = self._tvmparams[k].dtype
             self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype))
-            if layer_num != self._net.n-1:
+            if layer_num != self._net.n - 1:
                 self._outs.insert(0, sym)
 
         elif layer_type == LAYERTYPE.YOLO:
-            #Add attributes
-            k = _get_params_name(op_name, 'attr')
+            # Add attributes
+            k = _get_params_name(op_name, "attr")
             dshape = self._tvmparams[k].shape
             dtype = self._tvmparams[k].dtype
             self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype))
 
-            #Add bias
-            k = _get_params_name(op_name, 'bias')
+            # Add bias
+            k = _get_params_name(op_name, "bias")
             dshape = self._tvmparams[k].shape
             dtype = self._tvmparams[k].dtype
             self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype))
 
-            #Add mask
-            k = _get_params_name(op_name, 'mask')
+            # Add mask
+            k = _get_params_name(op_name, "mask")
             dshape = self._tvmparams[k].shape
             dtype = self._tvmparams[k].dtype
             self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype))
 
-            if layer_num != self._net.n-1:
+            if layer_num != self._net.n - 1:
                 self._outs.insert(0, sym)
 
     def from_darknet(self):
@@ -825,9 +862,8 @@ class GraphProto(object):
         sym = _function.Function(analysis.free_vars(outputs), outputs)
         return IRModule.from_expr(sym), self._tvmparams
 
-def from_darknet(net,
-                 shape=None,
-                 dtype="float32"):
+
+def from_darknet(net, shape=None, dtype="float32"):
     """Convert from Darknet's model into compatible relay Function.
 
     Parameters
index d8bff8c..b5085d3 100644 (file)
@@ -28,12 +28,12 @@ from .. import op as _op
 from ... import nd as _nd
 from .common import ExprTable, new_var
 
-__all__ = ['from_keras']
+__all__ = ["from_keras"]
 
 
 def _check_data_format(keras_layer):
-    if hasattr(keras_layer, ('data_format')):
-        if keras_layer.data_format != 'channels_last':
+    if hasattr(keras_layer, ("data_format")):
+        if keras_layer.data_format != "channels_last":
             raise ValueError("Keras frontend currently supports data_format = channels_last only.")
 
 
@@ -47,8 +47,9 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
 
 def _get_elu(inexpr, alpha):
     """A helper method for elu."""
-    return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \
-        _op.exp(inexpr)) + _op.nn.relu(inexpr)
+    return _op.negative(alpha) * _op.nn.relu(
+        _expr.const(1.0, dtype="float32") - _op.exp(inexpr)
+    ) + _op.nn.relu(inexpr)
 
 
 def _as_list(arr):
@@ -71,140 +72,155 @@ def _convert_activation(inexpr, keras_layer, etab):
             act_type = keras_layer.activation.func_name
         else:
             act_type = keras_layer.activation.__name__
-    if act_type == 'linear':
+    if act_type == "linear":
         if isinstance(keras_layer, str):
             return inexpr
-        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
-        beta = keras_layer.beta if hasattr(keras_layer, 'beta') else 0.
-        alpha = _expr.const(alpha, dtype='float32')
-        beta = _expr.const(beta, dtype='float32')
+        alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
+        beta = keras_layer.beta if hasattr(keras_layer, "beta") else 0.0
+        alpha = _expr.const(alpha, dtype="float32")
+        beta = _expr.const(beta, dtype="float32")
         return _op.add(_op.multiply(inexpr, alpha), beta)
-    if act_type == 'softmax':
-        axis = 1 if etab.data_layout == 'NCHW' else -1
+    if act_type == "softmax":
+        axis = 1 if etab.data_layout == "NCHW" else -1
         return _op.nn.softmax(inexpr, axis)
-    if act_type == 'sigmoid':
+    if act_type == "sigmoid":
         return _op.sigmoid(inexpr)
-    if act_type == 'tanh':
+    if act_type == "tanh":
         return _op.tanh(inexpr)
-    if act_type == 'relu':
+    if act_type == "relu":
         return _op.nn.relu(inexpr)
-    if act_type == 'softplus':
-        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
-    if act_type == 'elu':
-        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
-        alpha = _expr.const(alpha, dtype='float32')
+    if act_type == "softplus":
+        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1.0, dtype="float32")))
+    if act_type == "elu":
+        alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
+        alpha = _expr.const(alpha, dtype="float32")
         return _get_elu(inexpr, alpha)
-    if act_type == 'selu':
+    if act_type == "selu":
         # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
-        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
+        alpha = (
+            keras_layer.alpha
+            if hasattr(keras_layer, "alpha")
             else 1.6732632423543772848170429916717
-        gamma = keras_layer.gamma if hasattr(keras_layer, 'gamma') \
+        )
+        gamma = (
+            keras_layer.gamma
+            if hasattr(keras_layer, "gamma")
             else 1.0507009873554804934193349852946
-        alpha = _expr.const(alpha, dtype='float32')
-        gamma = _expr.const(gamma, dtype='float32')
+        )
+        alpha = _expr.const(alpha, dtype="float32")
+        gamma = _expr.const(gamma, dtype="float32")
         return gamma * _get_elu(inexpr, alpha)
-    if act_type == 'relu6':
-        return _op.clip(inexpr, a_min=0., a_max=6.)
-    if act_type == 'softsign':
-        return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
-    if act_type == 'hard_sigmoid':
-        x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
-        return _op.clip(x, a_min=0., a_max=1.)
+    if act_type == "relu6":
+        return _op.clip(inexpr, a_min=0.0, a_max=6.0)
+    if act_type == "softsign":
+        return inexpr / (_expr.const(1.0, dtype="float32") + _op.abs(inexpr))
+    if act_type == "hard_sigmoid":
+        x = (_expr.const(0.2, dtype="float32") * inexpr) + _expr.const(0.5, dtype="float32")
+        return _op.clip(x, a_min=0.0, a_max=1.0)
 
     raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported in frontend Keras.'.format(act_type))
+        "Operator {} is not supported in frontend Keras.".format(act_type)
+    )
 
 
 def _convert_advanced_activation(inexpr, keras_layer, etab):
     act_type = type(keras_layer).__name__
 
-    if act_type == 'Softmax':
+    if act_type == "Softmax":
         axis = keras_layer.axis
         dims = len(keras_layer.input_shape)
         if isinstance(axis, list):
             raise tvm.error.OpAttributeUnImplemented(
-                'Softmax with axes {} is not supported.'.format(axis))
-        if etab.data_layout == 'NCHW':
+                "Softmax with axes {} is not supported.".format(axis)
+            )
+        if etab.data_layout == "NCHW":
             if axis == -1:
                 axis = 1
             else:
                 axis = axis + 1 if axis < dims - 1 else 1
         return _op.nn.softmax(inexpr, axis=axis)
-    if act_type == 'ReLU':
-        threshold = _expr.const(keras_layer.threshold, dtype='float32')
+    if act_type == "ReLU":
+        threshold = _expr.const(keras_layer.threshold, dtype="float32")
         if keras_layer.max_value and float(keras_layer.threshold) == 0:
             # f(x) = max_value, for x >= max_value
             # f(x) = x,         for threshold <= x < max_value
-            return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
-        if keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'):
+            return _op.clip(inexpr, a_min=0.0, a_max=float(keras_layer.max_value))
+        if keras_layer.max_value and _op.greater(threshold, inexpr).astype("float32"):
             # f(x) = negative_slope * (inexpr - threshold)
-            negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32')
+            negative_slope = _expr.const(keras_layer.negative_slope, dtype="float32")
             return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
         return _op.nn.relu(inexpr)
-    if act_type == 'LeakyReLU':
+    if act_type == "LeakyReLU":
         return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
-    if act_type == 'ELU':
-        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
-        alpha = _expr.const(alpha, dtype='float32')
+    if act_type == "ELU":
+        alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
+        alpha = _expr.const(alpha, dtype="float32")
         return _get_elu(inexpr, alpha)
-    if act_type == 'PReLU':
-        assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
+    if act_type == "PReLU":
+        assert hasattr(keras_layer, "alpha"), "alpha required for PReLU."
         _check_data_format(keras_layer)
         size = len(keras_layer.alpha.shape)
-        if etab.data_layout == 'NCHW':
-            alpha = etab.new_const(keras_layer.get_weights()[0]
-                                   .transpose(np.roll(range(size), 1)))
+        if etab.data_layout == "NCHW":
+            alpha = etab.new_const(keras_layer.get_weights()[0].transpose(np.roll(range(size), 1)))
         else:
             alpha = etab.new_const(keras_layer.get_weights()[0])
         return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
-    if act_type == 'ThresholdedReLU':
-        theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
-        return _op.multiply(inexpr, _op.greater(inexpr, \
-            _expr.const(theta, dtype='float32')).astype('float32'))
+    if act_type == "ThresholdedReLU":
+        theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0
+        return _op.multiply(
+            inexpr, _op.greater(inexpr, _expr.const(theta, dtype="float32")).astype("float32")
+        )
 
     raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported in frontend Keras.'.format(act_type))
+        "Operator {} is not supported in frontend Keras.".format(act_type)
+    )
 
 
 def _convert_merge(inexpr, keras_layer, _):
     merge_type = type(keras_layer).__name__
     ret = inexpr[0]
-    if merge_type == 'Dot':
+    if merge_type == "Dot":
         axes = keras_layer.axes
         if isinstance(keras_layer.axes, int):
             axes = [keras_layer.axes, keras_layer.axes]
         if isinstance(axes, list):
             if len(axes) != 2:
                 raise tvm.error.OpAttributeUnImplemented(
-                    'Dot with axes {} is not supported.'.format(keras_layer.axes))
+                    "Dot with axes {} is not supported.".format(keras_layer.axes)
+                )
             for i, axis in enumerate(axes):
                 if axis not in [1, 2]:
                     raise tvm.error.OpAttributeUnImplemented(
-                        'Dot with axes {} is not supported.'.format(keras_layer.axes))
+                        "Dot with axes {} is not supported.".format(keras_layer.axes)
+                    )
                 if axes[i] == 2:
                     inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
         else:
             raise tvm.error.OpAttributeUnImplemented(
-                'Dot with axes {} is not supported.'.format(keras_layer.axes))
+                "Dot with axes {} is not supported.".format(keras_layer.axes)
+            )
         ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
         ret = _op.transpose(ret_dot, axes=[0, 2, 1])
-    elif merge_type == 'Subtract':
+    elif merge_type == "Subtract":
         assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
         ret = _op.subtract(ret, inexpr[1])
-    elif merge_type in ['Add', 'Multiply', 'Minimum', 'Maximum']:
-        op_map = {'Add': _op.add,
-                  'Multiply': _op.multiply,
-                  'Minimum': _op.minimum,
-                  'Maximum': _op.maximum}
+    elif merge_type in ["Add", "Multiply", "Minimum", "Maximum"]:
+        op_map = {
+            "Add": _op.add,
+            "Multiply": _op.multiply,
+            "Minimum": _op.minimum,
+            "Maximum": _op.maximum,
+        }
         for i in range(1, len(inexpr)):
             ret = op_map[merge_type](ret, inexpr[i])
-    elif merge_type == 'Average':
+    elif merge_type == "Average":
         for i in range(1, len(inexpr)):
             ret = _op.add(ret, inexpr[i])
-        ret = ret / _expr.const(len(inexpr), dtype='float32')
+        ret = ret / _expr.const(len(inexpr), dtype="float32")
     else:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported in frontend Keras.'.format(merge_type))
+            "Operator {} is not supported in frontend Keras.".format(merge_type)
+        )
     return ret
 
 
@@ -216,14 +232,15 @@ def _convert_embedding(inexpr, keras_layer, etab):
     indices = inexpr
     weightList = keras_layer.get_weights()
     weight = etab.new_const(weightList[0])
-    out = _op.take(weight, indices.astype('int32'), axis=0)
+    out = _op.take(weight, indices.astype("int32"), axis=0)
 
     return out
 
+
 def _convert_dense(inexpr, keras_layer, etab):
     weightList = keras_layer.get_weights()
     weight = etab.new_const(weightList[0].transpose([1, 0]))
-    params = {'weight': weight, 'units': weightList[0].shape[1]}
+    params = {"weight": weight, "units": weightList[0].shape[1]}
     input_shape = keras_layer.input_shape
     input_dim = len(input_shape)
     # In case of RNN dense, input shape will be (1, 1, n)
@@ -231,7 +248,8 @@ def _convert_dense(inexpr, keras_layer, etab):
         input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
         if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
             raise tvm.error.OpAttributeInvalid(
-                'Input shape {} is not valid for operator Dense.'.format(input_shape))
+                "Input shape {} is not valid for operator Dense.".format(input_shape)
+            )
         inexpr = _op.squeeze(inexpr, axis=0)
     out = _op.nn.dense(data=inexpr, **params)
     if keras_layer.use_bias:
@@ -242,7 +260,7 @@ def _convert_dense(inexpr, keras_layer, etab):
         act_type = keras_layer.activation.func_name
     else:
         act_type = keras_layer.activation.__name__
-    if act_type != 'linear':
+    if act_type != "linear":
         out = _convert_activation(out, act_type, etab)
     if input_dim > 2:
         out = _op.expand_dims(out, axis=0)
@@ -251,27 +269,27 @@ def _convert_dense(inexpr, keras_layer, etab):
 
 def _convert_convolution(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
-    is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose'
-    is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D'
+    is_deconv = type(keras_layer).__name__ == "Conv2DTranspose"
+    is_depthconv = type(keras_layer).__name__ == "DepthwiseConv2D"
     weightList = keras_layer.get_weights()
     weight = weightList[0]
-    if etab.data_layout == 'NHWC':
+    if etab.data_layout == "NHWC":
         if is_depthconv:
-            kernel_layout = 'HWOI'
+            kernel_layout = "HWOI"
         else:
-            kernel_layout = 'HWIO'
+            kernel_layout = "HWIO"
     else:
-        kernel_layout = 'OIHW'
+        kernel_layout = "OIHW"
 
     if is_deconv:
         kernel_h, kernel_w, n_filters, in_channels = weight.shape
-        if kernel_layout == 'OIHW':
+        if kernel_layout == "OIHW":
             weight = weight.transpose([3, 2, 0, 1])
     elif is_depthconv:
         kernel_h, kernel_w, in_channels, depth_mult = weight.shape
-        if kernel_layout == 'OIHW':
+        if kernel_layout == "OIHW":
             weight = weight.transpose([2, 3, 0, 1])
-    elif etab.data_layout == 'NCHW':
+    elif etab.data_layout == "NCHW":
         kernel_h, kernel_w, in_channels, n_filters = weight.shape
         weight = weight.transpose([3, 2, 0, 1])
     else:
@@ -283,30 +301,31 @@ def _convert_convolution(inexpr, keras_layer, etab):
     dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1
     dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1
     stride_h, stride_w = keras_layer.strides
-    params = {'weight': etab.new_const(weight),
-              'kernel_size': [kernel_h, kernel_w],
-              'strides': [stride_h, stride_w],
-              'dilation': dilation,
-              'padding': [0, 0],
-              'data_layout': etab.data_layout,
-              'kernel_layout': kernel_layout}
+    params = {
+        "weight": etab.new_const(weight),
+        "kernel_size": [kernel_h, kernel_w],
+        "strides": [stride_h, stride_w],
+        "dilation": dilation,
+        "padding": [0, 0],
+        "data_layout": etab.data_layout,
+        "kernel_layout": kernel_layout,
+    }
     if is_depthconv:
-        params['channels'] = in_channels * depth_mult
-        params['groups'] = in_channels
+        params["channels"] = in_channels * depth_mult
+        params["groups"] = in_channels
     else:
-        params['channels'] = n_filters
-    if keras_layer.padding == 'valid':
+        params["channels"] = n_filters
+    if keras_layer.padding == "valid":
         pass
     # we insert a separate pad operator
-    elif keras_layer.padding == 'same':
+    elif keras_layer.padding == "same":
         in_h = keras_layer.input_shape[1]
         in_w = keras_layer.input_shape[2]
         pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
         pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
-        params['padding'] = (pad_t, pad_l, pad_b, pad_r)
+        params["padding"] = (pad_t, pad_l, pad_b, pad_r)
     else:
-        msg = 'Padding with {} is not supported for operator Convolution ' \
-              'in frontend Keras.'
+        msg = "Padding with {} is not supported for operator Convolution " "in frontend Keras."
         raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
     if is_deconv:
         out = _op.nn.conv2d_transpose(data=inexpr, **params)
@@ -315,7 +334,7 @@ def _convert_convolution(inexpr, keras_layer, etab):
 
     if keras_layer.use_bias:
         bias = etab.new_const(weightList[1])
-        if etab.data_layout == 'NCHW':
+        if etab.data_layout == "NCHW":
             out = _op.nn.bias_add(out, bias)
         else:
             out = _op.nn.bias_add(out, bias, axis=-1)
@@ -324,28 +343,31 @@ def _convert_convolution(inexpr, keras_layer, etab):
         act_type = keras_layer.activation.func_name
     else:
         act_type = keras_layer.activation.__name__
-    if act_type != 'linear':
+    if act_type != "linear":
         out = _convert_activation(out, act_type, etab)
     return out
 
+
 def _convert_convolution3d(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     weightList = keras_layer.get_weights()
     weight = weightList[0]
 
-    if etab.data_layout == 'NDHWC':
-        kernel_layout = 'DHWIO'
+    if etab.data_layout == "NDHWC":
+        kernel_layout = "DHWIO"
     else:
-        kernel_layout = 'OIDHW'
-        msg = 'Kernel layout with {} is not supported for operator Convolution3D ' \
-              'in frontend Keras.'
+        kernel_layout = "OIDHW"
+        msg = (
+            "Kernel layout with {} is not supported for operator Convolution3D "
+            "in frontend Keras."
+        )
         raise tvm.error.OpAttributeUnImplemented(msg.format(etab.data_layout))
 
-    is_deconv = type(keras_layer).__name__ == 'Conv3DTranspose'
+    is_deconv = type(keras_layer).__name__ == "Conv3DTranspose"
 
     if is_deconv:
         kernel_d, kernel_h, kernel_w, n_filters, _ = weight.shape
-        if kernel_layout == 'OIDHW':
+        if kernel_layout == "OIDHW":
             weight = weight.transpose([4, 3, 2, 0, 1])
     else:
         kernel_d, kernel_h, kernel_w, _, n_filters = weight.shape
@@ -360,29 +382,30 @@ def _convert_convolution3d(inexpr, keras_layer, etab):
     dilated_kernel_h = (kernel_h - 1) * dilation[1] + 1
     dilated_kernel_w = (kernel_w - 1) * dilation[2] + 1
     stride_d, stride_h, stride_w = keras_layer.strides
-    params = {'weight': etab.new_const(weight),
-              'kernel_size': [kernel_d, kernel_h, kernel_w],
-              'strides': [stride_d, stride_h, stride_w],
-              'dilation': dilation,
-              'padding': [0, 0, 0],
-              'data_layout': etab.data_layout,
-              'kernel_layout': kernel_layout}
-    params['channels'] = n_filters
-
-    if keras_layer.padding == 'valid':
+    params = {
+        "weight": etab.new_const(weight),
+        "kernel_size": [kernel_d, kernel_h, kernel_w],
+        "strides": [stride_d, stride_h, stride_w],
+        "dilation": dilation,
+        "padding": [0, 0, 0],
+        "data_layout": etab.data_layout,
+        "kernel_layout": kernel_layout,
+    }
+    params["channels"] = n_filters
+
+    if keras_layer.padding == "valid":
         pass
     # calculate the padding values
-    elif keras_layer.padding == 'same':
+    elif keras_layer.padding == "same":
         in_d = keras_layer.input_shape[1]
         in_h = keras_layer.input_shape[2]
         in_w = keras_layer.input_shape[3]
         pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
         pad_h = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
         pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
-        params['padding'] = [pad_d[0], pad_h[0], pad_w[0], pad_d[1], pad_h[1], pad_w[1]]
+        params["padding"] = [pad_d[0], pad_h[0], pad_w[0], pad_d[1], pad_h[1], pad_w[1]]
     else:
-        msg = 'Padding with {} is not supported for operator Convolution3D ' \
-              'in frontend Keras.'
+        msg = "Padding with {} is not supported for operator Convolution3D " "in frontend Keras."
         raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
     if is_deconv:
         out = _op.nn.conv3d_transpose(data=inexpr, **params)
@@ -399,66 +422,73 @@ def _convert_convolution3d(inexpr, keras_layer, etab):
         act_type = keras_layer.activation.func_name
     else:
         act_type = keras_layer.activation.__name__
-    if act_type != 'linear':
+    if act_type != "linear":
         out = _convert_activation(out, act_type, etab)
 
     return out
 
+
 def _convert_separable_convolution(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
-    if etab.data_layout == 'NHWC':
-        kernel_layout = 'HWOI'
+    if etab.data_layout == "NHWC":
+        kernel_layout = "HWOI"
     else:
-        kernel_layout = 'OIHW'
+        kernel_layout = "OIHW"
     weightList = keras_layer.get_weights()
     # depthwise conv
     kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
     stride_h, stride_w = keras_layer.strides
-    if kernel_layout == 'OIHW':
+    if kernel_layout == "OIHW":
         weight0 = weightList[0].transpose([2, 3, 0, 1])
     else:
         weight0 = weightList[0]
-    params0 = {'weight': etab.new_const(weight0),
-               'channels': in_channels * depth_mult,
-               'groups': in_channels,
-               'kernel_size': [kernel_h, kernel_w],
-               'strides': [stride_h, stride_w],
-               'dilation': [1, 1],
-               'padding': [0, 0],
-               'data_layout': etab.data_layout,
-               'kernel_layout': kernel_layout}
-    if keras_layer.padding == 'valid':
+    params0 = {
+        "weight": etab.new_const(weight0),
+        "channels": in_channels * depth_mult,
+        "groups": in_channels,
+        "kernel_size": [kernel_h, kernel_w],
+        "strides": [stride_h, stride_w],
+        "dilation": [1, 1],
+        "padding": [0, 0],
+        "data_layout": etab.data_layout,
+        "kernel_layout": kernel_layout,
+    }
+    if keras_layer.padding == "valid":
         pass
     # we insert a separate pad operator
-    elif keras_layer.padding == 'same':
+    elif keras_layer.padding == "same":
         in_h = keras_layer.input_shape[1]
         in_w = keras_layer.input_shape[2]
         pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h)
         pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w)
-        params0['padding'] = (pad_t, pad_l, pad_b, pad_r)
+        params0["padding"] = (pad_t, pad_l, pad_b, pad_r)
     else:
-        msg = 'Padding with {} is not supported for operator Separable ' \
-              'Convolution in frontend Keras.'
+        msg = (
+            "Padding with {} is not supported for operator Separable "
+            "Convolution in frontend Keras."
+        )
         raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
     depthconv = _op.nn.conv2d(data=inexpr, **params0)
     # pointwise conv
-    if kernel_layout == 'OIHW':
+    if kernel_layout == "OIHW":
         weight1 = weightList[1].transpose([3, 2, 0, 1])
     else:
         weight1 = weightList[1]
         kernel_layout = "HWIO"
-    params1 = {'weight': etab.new_const(weight1),
-               'channels': weightList[1].shape[3],
-               'groups': 1,
-               'kernel_size': [1, 1],
-               'strides': [1, 1],
-               'dilation': [1, 1],
-               'data_layout': etab.data_layout,
-               'kernel_layout': kernel_layout}
+    params1 = {
+        "weight": etab.new_const(weight1),
+        "channels": weightList[1].shape[3],
+        "groups": 1,
+        "kernel_size": [1, 1],
+        "strides": [1, 1],
+        "dilation": [1, 1],
+        "data_layout": etab.data_layout,
+        "kernel_layout": kernel_layout,
+    }
     out = _op.nn.conv2d(data=depthconv, **params1)
     if keras_layer.use_bias:
         bias = etab.new_const(weightList[2])
-        if etab.data_layout == 'NCHW':
+        if etab.data_layout == "NCHW":
             out = _op.nn.bias_add(out, bias)
         else:
             out = _op.nn.bias_add(out, bias, axis=-1)
@@ -467,7 +497,7 @@ def _convert_separable_convolution(inexpr, keras_layer, etab):
         act_type = keras_layer.activation.func_name
     else:
         act_type = keras_layer.activation.__name__
-    if act_type != 'linear':
+    if act_type != "linear":
         out = _convert_activation(out, act_type, etab)
     return out
 
@@ -475,7 +505,7 @@ def _convert_separable_convolution(inexpr, keras_layer, etab):
 def _convert_flatten(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     # NCHW -> NHWC so that dense can be correctly converted
-    if etab.data_layout == 'NCHW':
+    if etab.data_layout == "NCHW":
         inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
     return _op.nn.batch_flatten(inexpr)
 
@@ -484,72 +514,83 @@ def _convert_pooling(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     pool_type = type(keras_layer).__name__
     # global pool in keras = global pool + flatten in relay
-    global_pool_params = {'layout': etab.data_layout}
-    if pool_type == 'GlobalMaxPooling2D':
+    global_pool_params = {"layout": etab.data_layout}
+    if pool_type == "GlobalMaxPooling2D":
         return _convert_flatten(
-            _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab)
-    if pool_type == 'GlobalAveragePooling2D':
+            _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab
+        )
+    if pool_type == "GlobalAveragePooling2D":
         return _convert_flatten(
-            _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab)
+            _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab
+        )
     pool_h, pool_w = keras_layer.pool_size
     stride_h, stride_w = keras_layer.strides
-    params = {'pool_size': [pool_h, pool_w],
-              'strides': [stride_h, stride_w],
-              'padding': [0, 0],
-              'layout': etab.data_layout}
-    if keras_layer.padding == 'valid':
+    params = {
+        "pool_size": [pool_h, pool_w],
+        "strides": [stride_h, stride_w],
+        "padding": [0, 0],
+        "layout": etab.data_layout,
+    }
+    if keras_layer.padding == "valid":
         pass
-    elif keras_layer.padding == 'same':
+    elif keras_layer.padding == "same":
         in_h = keras_layer.input_shape[1]
         in_w = keras_layer.input_shape[2]
         pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
         pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
-        params['padding'] = [pad_t, pad_l, pad_b, pad_r]
+        params["padding"] = [pad_t, pad_l, pad_b, pad_r]
     else:
         raise tvm.error.OpAttributeUnImplemented(
-            'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding))
-    if pool_type == 'MaxPooling2D':
+            "Padding with {} is not supported in operator Pooling.".format(keras_layer.padding)
+        )
+    if pool_type == "MaxPooling2D":
         return _op.nn.max_pool2d(inexpr, **params)
-    if pool_type == 'AveragePooling2D':
-        params['count_include_pad'] = False
+    if pool_type == "AveragePooling2D":
+        params["count_include_pad"] = False
         return _op.nn.avg_pool2d(inexpr, **params)
     raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported for frontend Keras.'.format(keras_layer))
+        "Operator {} is not supported for frontend Keras.".format(keras_layer)
+    )
+
 
 def _convert_pooling3d(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     pool_type = type(keras_layer).__name__
 
-    if pool_type not in ['MaxPooling3D', 'AveragePooling3D']:
+    if pool_type not in ["MaxPooling3D", "AveragePooling3D"]:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend Keras.'.format(keras_layer))
+            "Operator {} is not supported for frontend Keras.".format(keras_layer)
+        )
 
     pool_d1, pool_d2, pool_d3 = keras_layer.pool_size
     stride_d1, stride_d2, stride_d3 = keras_layer.strides
-    params = {'pool_size': [pool_d1, pool_d2, pool_d3],
-              'strides': [stride_d1, stride_d2, stride_d3],
-              'padding': [0, 0, 0],
-              'layout': etab.data_layout}
-
-    if keras_layer.padding == 'valid':
+    params = {
+        "pool_size": [pool_d1, pool_d2, pool_d3],
+        "strides": [stride_d1, stride_d2, stride_d3],
+        "padding": [0, 0, 0],
+        "layout": etab.data_layout,
+    }
+
+    if keras_layer.padding == "valid":
         pass
-    elif keras_layer.padding == 'same':
+    elif keras_layer.padding == "same":
         in_d1 = keras_layer.input_shape[1]
         in_d2 = keras_layer.input_shape[2]
         in_d3 = keras_layer.input_shape[3]
         pad_d1 = _get_pad_pair(in_d1, pool_d1, stride_d1)
         pad_d2 = _get_pad_pair(in_d2, pool_d2, stride_d2)
         pad_d3 = _get_pad_pair(in_d3, pool_d3, stride_d3)
-        params['padding'] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]]
+        params["padding"] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]]
     else:
         raise tvm.error.OpAttributeUnImplemented(
-            'Padding with {} is not supported in operator Pooling3D.'.format(keras_layer.padding))
+            "Padding with {} is not supported in operator Pooling3D.".format(keras_layer.padding)
+        )
 
     out = _op.transpose(inexpr, axes=(0, 4, 1, 2, 3))
-    params['layout'] = "NCDHW"
-    if pool_type == 'MaxPooling3D':
+    params["layout"] = "NCDHW"
+    if pool_type == "MaxPooling3D":
         out = _op.nn.max_pool3d(out, **params)
-    elif pool_type == 'AveragePooling3D':
+    elif pool_type == "AveragePooling3D":
         out = _op.nn.avg_pool3d(out, **params)
 
     return _op.transpose(out, axes=(0, 2, 3, 4, 1))
@@ -559,14 +600,15 @@ def _convert_global_pooling3d(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     pool_type = type(keras_layer).__name__
 
-    global_pool_params = {'layout': etab.data_layout}
-    if pool_type == 'GlobalMaxPooling3D':
+    global_pool_params = {"layout": etab.data_layout}
+    if pool_type == "GlobalMaxPooling3D":
         out = _op.nn.global_max_pool3d(inexpr, **global_pool_params)
-    elif pool_type == 'GlobalAveragePooling3D':
+    elif pool_type == "GlobalAveragePooling3D":
         out = _op.nn.global_avg_pool3d(inexpr, **global_pool_params)
     else:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend Keras.'.format(keras_layer))
+            "Operator {} is not supported for frontend Keras.".format(keras_layer)
+        )
 
     return _convert_flatten(out, keras_layer, etab)
 
@@ -575,27 +617,27 @@ def _convert_upsample(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     upsample_type = type(keras_layer).__name__
     params = {}
-    if upsample_type == 'UpSampling1D':
+    if upsample_type == "UpSampling1D":
         h = keras_layer.size
-        params['scale_h'] = h
-    elif upsample_type == 'UpSampling2D':
+        params["scale_h"] = h
+    elif upsample_type == "UpSampling2D":
         h, w = keras_layer.size
         if h != w:
-            raise tvm.error.OpAttributeInvalid(
-                'Height must equal width for operator Upsample.')
-        params['scale_h'] = h
-        params['scale_w'] = h
+            raise tvm.error.OpAttributeInvalid("Height must equal width for operator Upsample.")
+        params["scale_h"] = h
+        params["scale_w"] = h
 
-        if hasattr(keras_layer, 'interpolation'):
+        if hasattr(keras_layer, "interpolation"):
             interpolation = keras_layer.interpolation
-            if interpolation == 'nearest':
-                params['method'] = 'nearest_neighbor'
+            if interpolation == "nearest":
+                params["method"] = "nearest_neighbor"
             else:
-                params['method'] = 'bilinear'
+                params["method"] = "bilinear"
     else:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend Keras.'.format(upsample_type))
-    params['layout'] = etab.data_layout
+            "Operator {} is not supported for frontend Keras.".format(upsample_type)
+        )
+    params["layout"] = etab.data_layout
     out = _op.nn.upsampling(inexpr, **params)
     return out
 
@@ -604,10 +646,10 @@ def _convert_upsample3d(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     params = {}
     d, h, w = keras_layer.size
-    params['scale_d'] = d
-    params['scale_h'] = h
-    params['scale_w'] = w
-    params['layout'] = etab.data_layout
+    params["scale_d"] = d
+    params["scale_h"] = h
+    params["scale_w"] = w
+    params["layout"] = etab.data_layout
     out = _op.nn.upsampling3d(inexpr, **params)
     return out
 
@@ -615,47 +657,50 @@ def _convert_upsample3d(inexpr, keras_layer, etab):
 def _convert_cropping(inexpr, keras_layer, _):
     _check_data_format(keras_layer)
     crop_type = type(keras_layer).__name__
-    if crop_type == 'Cropping2D':
+    if crop_type == "Cropping2D":
         (_, in_h, in_w, _) = keras_layer.input_shape
         ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
     else:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend Keras.'.format(crop_type))
+            "Operator {} is not supported for frontend Keras.".format(crop_type)
+        )
     int32_max = np.iinfo(np.int32).max
-    return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
-        end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
+    return _op.strided_slice(
+        inexpr,
+        begin=[0, 0, crop_t, crop_l],
+        end=[int32_max, int32_max, in_h - crop_b, in_w - crop_r],
+    )
 
 
 def _convert_batchnorm(inexpr, keras_layer, etab):
-    if etab.data_layout == 'NCHW' or len(keras_layer.input_shape) < 4:
+    if etab.data_layout == "NCHW" or len(keras_layer.input_shape) < 4:
         axis = 1
     else:
         axis = 3
 
-    params = {'scale': False,
-              'center': False,
-              'epsilon': keras_layer.epsilon,
-              'axis': axis}
+    params = {"scale": False, "center": False, "epsilon": keras_layer.epsilon, "axis": axis}
     idx = 0
     if keras_layer.scale:
-        params['scale'] = True
+        params["scale"] = True
         gamma = keras_layer.get_weights()[idx]
-        params['gamma'] = etab.new_const(gamma)
+        params["gamma"] = etab.new_const(gamma)
         idx += 1
     if keras_layer.center:
-        params['center'] = True
+        params["center"] = True
         beta = keras_layer.get_weights()[idx]
-        params['beta'] = etab.new_const(beta)
+        params["beta"] = etab.new_const(beta)
         idx += 1
     moving_mean = keras_layer.get_weights()[idx]
     moving_var = keras_layer.get_weights()[idx + 1]
-    params['moving_mean'] = etab.new_const(moving_mean)
-    params['moving_var'] = etab.new_const(moving_var)
+    params["moving_mean"] = etab.new_const(moving_mean)
+    params["moving_var"] = etab.new_const(moving_var)
     # in case beta or gamma is not defined
-    params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
-                     'beta' not in params else params['beta']
-    params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
-                      'gamma' not in params else params['gamma']
+    params["beta"] = (
+        etab.new_const(np.zeros(moving_mean.shape)) if "beta" not in params else params["beta"]
+    )
+    params["gamma"] = (
+        etab.new_const(np.ones(moving_mean.shape)) if "gamma" not in params else params["gamma"]
+    )
     result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
     return result
 
@@ -665,7 +710,7 @@ def _convert_padding(inexpr, keras_layer, etab):
     padding_type = type(keras_layer).__name__
     padding = keras_layer.padding
     top = left = bottom = right = 0
-    if padding_type == 'ZeroPadding2D':
+    if padding_type == "ZeroPadding2D":
         if isinstance(padding, int):
             top = left = bottom = right = padding
         elif isinstance(padding, tuple):
@@ -676,20 +721,19 @@ def _convert_padding(inexpr, keras_layer, etab):
                 top, bottom = padding[0]
                 left, right = padding[1]
             else:
-                msg = 'Value {} in attribute "padding" of operator Padding ' \
-                      'is not valid.'
+                msg = 'Value {} in attribute "padding" of operator Padding ' "is not valid."
                 raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
         else:
-            msg = 'Value {} in attribute "padding" of operator Padding is ' \
-                  'not valid.'
+            msg = 'Value {} in attribute "padding" of operator Padding is ' "not valid."
             raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
     else:
-        msg = 'Operator {} is not supported in frontend Keras.'
+        msg = "Operator {} is not supported in frontend Keras."
         raise tvm.error.OpNotImplemented(msg.format(padding_type))
-    if etab.data_layout == 'NCHW':
+    if etab.data_layout == "NCHW":
         return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
     return _op.nn.pad(data=inexpr, pad_width=((0, 0), (top, bottom), (left, right), (0, 0)))
 
+
 def _convert_padding3d(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     padding = keras_layer.padding
@@ -704,26 +748,37 @@ def _convert_padding3d(inexpr, keras_layer, etab):
         h_pad = padding[1]
         w_pad = padding[2]
     else:
-        msg = 'Value {} in attribute "padding" of operator ZeroPadding3D is ' \
-              'not valid.'
+        msg = 'Value {} in attribute "padding" of operator ZeroPadding3D is ' "not valid."
         raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
 
-    if etab.data_layout == 'NCDHW':
-        out = _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0),
-                                                 (d_pad[0], d_pad[1]),
-                                                 (h_pad[0], h_pad[1]),
-                                                 (w_pad[0], w_pad[1])))
+    if etab.data_layout == "NCDHW":
+        out = _op.nn.pad(
+            data=inexpr,
+            pad_width=(
+                (0, 0),
+                (0, 0),
+                (d_pad[0], d_pad[1]),
+                (h_pad[0], h_pad[1]),
+                (w_pad[0], w_pad[1]),
+            ),
+        )
     else:
-        out = _op.nn.pad(data=inexpr, pad_width=((0, 0),
-                                                 (d_pad[0], d_pad[1]),
-                                                 (h_pad[0], h_pad[1]),
-                                                 (w_pad[0], w_pad[1]),
-                                                 (0, 0)))
+        out = _op.nn.pad(
+            data=inexpr,
+            pad_width=(
+                (0, 0),
+                (d_pad[0], d_pad[1]),
+                (h_pad[0], h_pad[1]),
+                (w_pad[0], w_pad[1]),
+                (0, 0),
+            ),
+        )
     return out
 
+
 def _convert_concat(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
-    if etab.data_layout == 'NHWC' or len(keras_layer.input_shape[0]) < 4:
+    if etab.data_layout == "NHWC" or len(keras_layer.input_shape[0]) < 4:
         axis = -1
     else:
         axis = 1
@@ -732,26 +787,28 @@ def _convert_concat(inexpr, keras_layer, etab):
 
 def _convert_reshape(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
-    inshape = keras_layer.input_shape # includes batch
-    tshape = keras_layer.target_shape # no batch
+    inshape = keras_layer.input_shape  # includes batch
+    tshape = keras_layer.target_shape  # no batch
     if len(inshape) == 3 and len(tshape) == 1:
         # (?, a, b) -> (-1, ab)
         shape = (-1, tshape[0])
     elif len(inshape) in [2, 3] and len(tshape) == 2:
         # (?, cc) -> (-1, c, c)
         # (?, a, b) -> (-1, c, c)
-        assert tshape[0] == tshape[1], \
-            "Only supports square target shapes, but got {}".format(tshape)
-        shape = (-1, ) + tshape
+        assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format(
+            tshape
+        )
+        shape = (-1,) + tshape
     else:
         # (?, h, w, c) -> (-1, c, H, W)
         # (?, h, w, c) -> (-1, c, hw)
         # (?, hw, c) -> (-1, c, h, w)
         ch = inshape[-1]
-        assert ch == tshape[-1], \
-            "Only supports last dimension in target shape being equal to " \
+        assert ch == tshape[-1], (
+            "Only supports last dimension in target shape being equal to "
             "the channel number of input tensor."
-        if etab.data_layout == 'NCHW':
+        )
+        if etab.data_layout == "NCHW":
             shape = (-1, ch) + tshape[:-1]
         else:
             shape = (-1,) + tshape[:-1] + (ch,)
@@ -761,7 +818,7 @@ def _convert_reshape(inexpr, keras_layer, etab):
 def _convert_lstm(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     if not isinstance(inexpr, list):
-        buf = np.zeros((1, keras_layer.units), 'float32')
+        buf = np.zeros((1, keras_layer.units), "float32")
         c_op = etab.new_const(buf)
         h_op = etab.new_const(buf)
         inexpr = [inexpr, h_op, c_op]
@@ -796,7 +853,7 @@ def _convert_lstm(inexpr, keras_layer, etab):
 def _convert_simple_rnn(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     if not isinstance(inexpr, list):
-        buf = np.zeros((1, keras_layer.units), 'float32')
+        buf = np.zeros((1, keras_layer.units), "float32")
         prev_op = etab.new_const(buf)
         inexpr = [inexpr, prev_op]
     in_data = inexpr[0]
@@ -820,7 +877,7 @@ def _convert_simple_rnn(inexpr, keras_layer, etab):
 def _convert_gru(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     if not isinstance(inexpr, list):
-        buf = np.zeros((1, keras_layer.units), 'float32')
+        buf = np.zeros((1, keras_layer.units), "float32")
         h_tm1 = etab.new_const(buf)
         inexpr = [inexpr, h_tm1]
     in_data = inexpr[0]
@@ -854,7 +911,7 @@ def _convert_gru(inexpr, keras_layer, etab):
     recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units)
     act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
     # previous and candidate state mixed by update gate
-    output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh
+    output = rec_act_z * h_tm1_op + (_expr.const(1.0, dtype="float32") - rec_act_z) * act_hh
     out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
     output = _op.reshape(output, newshape=out_shape)
     return [output, output]
@@ -870,45 +927,40 @@ def _convert_repeat_vector(inexpr, keras_layer, _):
     return out
 
 
-def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
+def _default_skip(inexpr, keras_layer, _):  # pylint: disable=unused-argument
     """Layers that can be skipped because they are train time only."""
     return inexpr
 
 
 _convert_map = {
-    'Dense'                    : _convert_dense,
-    'Activation'               : _convert_activation,
-    'Softmax'                  : _convert_advanced_activation,
-    'ReLU'                     : _convert_advanced_activation,
-    'LeakyReLU'                : _convert_advanced_activation,
-    'PReLU'                    : _convert_advanced_activation,
-    'ELU'                      : _convert_advanced_activation,
-    'ThresholdedReLU'          : _convert_advanced_activation,
-
-    'AveragePooling2D'         : _convert_pooling,
-    'MaxPooling2D'             : _convert_pooling,
-    'GlobalAveragePooling2D'   : _convert_pooling,
-    'GlobalMaxPooling2D'       : _convert_pooling,
-    'Conv2D'                   : _convert_convolution,
-    'Conv2DTranspose'          : _convert_convolution,
-    'DepthwiseConv2D'          : _convert_convolution,
-    'SeparableConv2D'          : _convert_separable_convolution,
-
-    'Flatten'                  : _convert_flatten,
-    'Reshape'                  : _convert_reshape,
-    'Concatenate'              : _convert_concat,
-    'BatchNormalization'       : _convert_batchnorm,
-
+    "Dense": _convert_dense,
+    "Activation": _convert_activation,
+    "Softmax": _convert_advanced_activation,
+    "ReLU": _convert_advanced_activation,
+    "LeakyReLU": _convert_advanced_activation,
+    "PReLU": _convert_advanced_activation,
+    "ELU": _convert_advanced_activation,
+    "ThresholdedReLU": _convert_advanced_activation,
+    "AveragePooling2D": _convert_pooling,
+    "MaxPooling2D": _convert_pooling,
+    "GlobalAveragePooling2D": _convert_pooling,
+    "GlobalMaxPooling2D": _convert_pooling,
+    "Conv2D": _convert_convolution,
+    "Conv2DTranspose": _convert_convolution,
+    "DepthwiseConv2D": _convert_convolution,
+    "SeparableConv2D": _convert_separable_convolution,
+    "Flatten": _convert_flatten,
+    "Reshape": _convert_reshape,
+    "Concatenate": _convert_concat,
+    "BatchNormalization": _convert_batchnorm,
     # Specific tf.Keras terminology for batch normalization
-    'BatchNormalizationV1'     : _convert_batchnorm,
-
-    'Add'                      : _convert_merge,
-    'Subtract'                 : _convert_merge,
-    'Multiply'                 : _convert_merge,
-    'ZeroPadding2D'            : _convert_padding,
-    'UpSampling2D'             : _convert_upsample,
-    'Cropping2D'               : _convert_cropping,
-
+    "BatchNormalizationV1": _convert_batchnorm,
+    "Add": _convert_merge,
+    "Subtract": _convert_merge,
+    "Multiply": _convert_merge,
+    "ZeroPadding2D": _convert_padding,
+    "UpSampling2D": _convert_upsample,
+    "Cropping2D": _convert_cropping,
     # 'ZeroPadding1D'          : _convert_padding,
     # 'AveragePooling1D'       : _convert_pooling,
     # 'MaxPooling1D'           : _convert_pooling,
@@ -917,38 +969,34 @@ _convert_map = {
     # 'Cropping1D'             : _convert_cropping,
     # 'UpSampling1D'           : _convert_upsample,
     # 'Conv1D'                 : _convert_convolution1d,
-
-    'Conv3D'                   : _convert_convolution3d,
-    'Conv3DTranspose'          : _convert_convolution3d,
+    "Conv3D": _convert_convolution3d,
+    "Conv3DTranspose": _convert_convolution3d,
     # 'SeparableConv3D'        : _convert_convolution3d,
-    'MaxPooling3D'             : _convert_pooling3d,
-    'AveragePooling3D'         : _convert_pooling3d,
-    'GlobalMaxPooling3D'       : _convert_global_pooling3d,
-    'GlobalAveragePooling3D'   : _convert_global_pooling3d,
-    'UpSampling3D'             : _convert_upsample3d,
-    'ZeroPadding3D'            : _convert_padding3d,
-
-    'SimpleRNN'                : _convert_simple_rnn,
-    'LSTM'                     : _convert_lstm,
-    'GRU'                      : _convert_gru,
+    "MaxPooling3D": _convert_pooling3d,
+    "AveragePooling3D": _convert_pooling3d,
+    "GlobalMaxPooling3D": _convert_global_pooling3d,
+    "GlobalAveragePooling3D": _convert_global_pooling3d,
+    "UpSampling3D": _convert_upsample3d,
+    "ZeroPadding3D": _convert_padding3d,
+    "SimpleRNN": _convert_simple_rnn,
+    "LSTM": _convert_lstm,
+    "GRU": _convert_gru,
     # 'Bidirectional'          : _convert_bidirectional,
     # 'TimeDistributed'        : _default_skip,
-
-    'Average'                  : _convert_merge,
-    'Minimum'                  : _convert_merge,
-    'Maximum'                  : _convert_merge,
-    'Dot'                      : _convert_merge,
-    'Permute'                  : _convert_permute,
-    'Embedding'                : _convert_embedding,
-    'RepeatVector'             : _convert_repeat_vector,
-
-    'InputLayer'               : _default_skip,
-    'Dropout'                  : _default_skip,
-    'AlphaDropout'             : _default_skip,
-    'SpatialDropout2D'         : _default_skip,
-    'SpatialDropout1D'         : _default_skip,
-    'GaussianDropout'          : _default_skip,
-    'GaussianNoise'            : _default_skip,
+    "Average": _convert_merge,
+    "Minimum": _convert_merge,
+    "Maximum": _convert_merge,
+    "Dot": _convert_merge,
+    "Permute": _convert_permute,
+    "Embedding": _convert_embedding,
+    "RepeatVector": _convert_repeat_vector,
+    "InputLayer": _default_skip,
+    "Dropout": _default_skip,
+    "AlphaDropout": _default_skip,
+    "SpatialDropout2D": _default_skip,
+    "SpatialDropout1D": _default_skip,
+    "GaussianDropout": _default_skip,
+    "GaussianNoise": _default_skip,
 }
 
 
@@ -960,8 +1008,9 @@ def _check_unsupported_layers(model):
             missing_ops.add(op_name)
 
     if missing_ops:
-        raise NotImplementedError( \
-            "The following operators are not implemented: {}".format(missing_ops))
+        raise NotImplementedError(
+            "The following operators are not implemented: {}".format(missing_ops)
+        )
 
 
 def keras_op_to_relay(inexpr, keras_layer, outname, etab):
@@ -984,7 +1033,8 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab):
     op_name = type(keras_layer).__name__
     if op_name not in _convert_map:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend Keras.'.format(op_name))
+            "Operator {} is not supported for frontend Keras.".format(op_name)
+        )
     outs = _convert_map[op_name](inexpr, keras_layer, etab)
     outs = _as_list(outs)
     for t_idx, out in enumerate(outs):
@@ -992,7 +1042,7 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab):
         etab.set_expr(name, out)
 
 
-def from_keras(model, shape=None, layout='NCHW'):
+def from_keras(model, shape=None, layout="NCHW"):
     """Convert keras model to relay Function.
 
     Parameters
@@ -1016,6 +1066,7 @@ def from_keras(model, shape=None, layout='NCHW'):
     params : dict of str to tvm.nd.NDArray
         The parameter dict to be used by Relay.
     """
+
     def _check_model_is_tf_keras():
         return type(model).__module__.startswith("tensorflow.python.keras")
 
@@ -1032,9 +1083,9 @@ def from_keras(model, shape=None, layout='NCHW'):
             import keras
         except ImportError:
             raise ImportError("Keras must be installed")
-        if keras.backend.backend() != 'tensorflow':
+        if keras.backend.backend() != "tensorflow":
             raise ValueError("Keras frontend currently supports tensorflow backend only.")
-        if keras.backend.image_data_format() != 'channels_last':
+        if keras.backend.image_data_format() != "channels_last":
             raise ValueError("Keras frontend currently supports data_format = channels_last only.")
         expected_model_class = keras.engine.training.Model
         input_layer_class = keras.engine.InputLayer
@@ -1051,26 +1102,33 @@ def from_keras(model, shape=None, layout='NCHW'):
 
     etab = ExprTable()
     # Set global data format.
-    assert layout in ['NCHW', 'NHWC', 'NDHWC'], "Layout must be one of 'NCHW', NHWC or NDHWC"
+    assert layout in ["NCHW", "NHWC", "NDHWC"], "Layout must be one of 'NCHW', NHWC or NDHWC"
     etab.data_layout = layout
     for keras_layer in model.layers:
         if isinstance(keras_layer, input_layer_class):
             _convert_input_layer(keras_layer)
         else:
-            inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
-                       else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
-                       else None
+            inbound_nodes = (
+                keras_layer.inbound_nodes
+                if hasattr(keras_layer, "inbound_nodes")
+                else keras_layer._inbound_nodes
+                if hasattr(keras_layer, "_inbound_nodes")
+                else None
+            )
             if inbound_nodes is None:
-                raise TypeError("Unknown layer type or unsupported Keras version : {}"
-                                .format(keras_layer))
+                raise TypeError(
+                    "Unknown layer type or unsupported Keras version : {}".format(keras_layer)
+                )
             for node_idx, node in enumerate(inbound_nodes):
                 # If some nodes in imported model are not relevant to the current model,
                 # skip such layers.
                 # - In Keras, model._network_nodes contains keys of all nodes relevant to the
                 #   current model;
                 # - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config
-                if not is_tf_keras and \
-                   not model._node_key(keras_layer, node_idx) in model._network_nodes:
+                if (
+                    not is_tf_keras
+                    and not model._node_key(keras_layer, node_idx) in model._network_nodes
+                ):
                     continue
                 inexpr = []
                 # Since Keras allows creating multiple layers from the same name instance,
@@ -1080,12 +1138,13 @@ def from_keras(model, shape=None, layout='NCHW'):
                 # they are named uniquely to input_1, input_2, input_3... by default.
                 # node_indices attribute removed in tensorflow 2.3, however iterate_inbound() can
                 # be used
-                if hasattr(node, 'node_indices'):
+                if hasattr(node, "node_indices"):
                     zip_node = zip(
                         _as_list(node.inbound_layers),
                         _as_list(node.node_indices),
                         _as_list(node.tensor_indices),
-                        _as_list(node.input_tensors))
+                        _as_list(node.input_tensors),
+                    )
                     node_attributes = zip_node
                 else:
                     node_attributes = node.iterate_inbound()
@@ -1094,18 +1153,20 @@ def from_keras(model, shape=None, layout='NCHW'):
                         expr_name = inbound_layer.name
                         _convert_input_layer(inbound_layer)
                     else:
-                        expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
+                        expr_name = inbound_layer.name + ":" + str(n_idx) + ":" + str(t_idx)
                     expr = etab.get_expr(expr_name)
                     inexpr.append(expr)
                 if len(inexpr) == 1:
                     inexpr = inexpr[0]
-                keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ':' + str(node_idx), etab)
+                keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ":" + str(node_idx), etab)
     # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2])
     # Get all output nodes in etab using the name made from above values.
     # The out exprs were added to etab in keras_op_to_relay using this name.
-    outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
-               for oc in model._output_coordinates]
+    outexpr = [
+        etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2]))
+        for oc in model._output_coordinates
+    ]
     outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
     func = _function.Function(analysis.free_vars(outexpr), outexpr)
-    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
+    params = {k: _nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
     return IRModule.from_expr(func), params
index faa62e1..712c025 100644 (file)
@@ -41,27 +41,26 @@ from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
 from .nnvm_common import _clip, _transpose, _upsampling
 from .nnvm_common import _elemwise_sum, _reshape
 from .nnvm_common import _warn_not_used
-from .mxnet_qnn_op_utils import quantize_mxnet_min_max, \
-                                quantize_conv_weights_bias_channel_mkldnn_from_var, \
-                                quantize_conv_bias_mkldnn_from_var, \
-                                get_conv_mkldnn_requantized_scale_outDtype, \
-                                dequantize_mxnet_min_max, \
-                                get_mkldnn_int8_scale, \
-                                get_mkldnn_uint8_scale, \
-                                get_mkldnn_requantize_scale_outDtype
-
-
-__all__ = ['from_mxnet']
-
-_activation_map = {
-    "sigmoid": _op.sigmoid,
-    "tanh"   : _op.tanh,
-    "relu"   : _op.nn.relu
-}
+from .mxnet_qnn_op_utils import (
+    quantize_mxnet_min_max,
+    quantize_conv_weights_bias_channel_mkldnn_from_var,
+    quantize_conv_bias_mkldnn_from_var,
+    get_conv_mkldnn_requantized_scale_outDtype,
+    dequantize_mxnet_min_max,
+    get_mkldnn_int8_scale,
+    get_mkldnn_uint8_scale,
+    get_mkldnn_requantize_scale_outDtype,
+)
+
+
+__all__ = ["from_mxnet"]
+
+_activation_map = {"sigmoid": _op.sigmoid, "tanh": _op.tanh, "relu": _op.nn.relu}
 
 
 def _mx_fully_connected(inputs, attrs):
-    import mxnet as mx #pylint: disable=import-outside-toplevel
+    import mxnet as mx  # pylint: disable=import-outside-toplevel
+
     units = attrs.get_int("num_hidden")
     use_bias = not attrs.get_bool("no_bias", False)
     try:
@@ -95,23 +94,26 @@ def _get_channel_axis(layout, op_name):
     if layout == "NDHWC":
         return 4
     raise tvm.error.OpAttributeInvalid(
-        'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name))
+        'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name)
+    )
 
 
 def _mx_activations(inputs, attrs):
     act_type = attrs.get_str("act_type")
     assert len(inputs) == 1
     if act_type == "softrelu":
+
         def _stable_softrelu(x):
             # log(1 + exp(-abs(x))) + relu(x)
             one = _expr.const(1, dtype="float32")
             exp_neg_abs_x = _op.exp(_op.negative(_op.abs(x)))
-            return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
-                           _op.nn.relu(x))
+            return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(x))
+
         return _stable_softrelu(inputs[0])
     if act_type not in _activation_map:
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend MXNet.'.format(act_type))
+            "Operator {} is not supported for frontend MXNet.".format(act_type)
+        )
     return _activation_map[act_type](inputs[0])
 
 
@@ -120,6 +122,7 @@ def _mx_compare(new_op, wrapper):
         expr = _infer_type(inputs[0])
         dtype = expr.checked_type.dtype
         return wrapper(new_op)(inputs, attrs).astype(dtype)
+
     return impl
 
 
@@ -132,8 +135,8 @@ def _mx_unravel_index(inputs, attrs):
 
 def _mx_swap_axis(inputs, attrs):
     assert len(inputs) == 1
-    dim1 = attrs.get_int('dim1')
-    dim2 = attrs.get_int('dim2')
+    dim1 = attrs.get_int("dim1")
+    dim2 = attrs.get_int("dim2")
     shape = _infer_type(inputs[0]).checked_type.shape
     axes = list(range(len(shape)))
     axes[dim1] = dim2
@@ -160,18 +163,20 @@ def _mx_conv(inputs, attrs):
         return _mx_conv1d(inputs, attrs)
     else:
         raise tvm.error.OpAttributeInvalid(
-            '1D, 2D or 3D kernels only are supported for operator Convolution')
+            "1D, 2D or 3D kernels only are supported for operator Convolution"
+        )
+
 
 def _mx_conv1d(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
     if len(kernel_size) != 1:
         raise tvm.error.OpAttributeInvalid(
-            'Non 1D or 2D kernels are not supported for operator Convolution')
+            "Non 1D or 2D kernels are not supported for operator Convolution"
+        )
     data_layout = attrs.get_str("layout", "NCW")
     # MXNet Conv1D only supports â€˜NCW’ layout for now.
     if data_layout != "NCW":
-        raise tvm.error.OpAttributeInvalid(
-            'Only "NCW" data layout is supported for 1D Convolution')
+        raise tvm.error.OpAttributeInvalid('Only "NCW" data layout is supported for 1D Convolution')
     data_layout = "NCHW"
     channel_axis = 1
     kernel_layout = "OIHW"
@@ -181,7 +186,7 @@ def _mx_conv1d(inputs, attrs):
     new_attrs["kernel_size"] = (1,) + kernel_size
     new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
     new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
-    new_attrs["dilation"] = (1,) +  attrs.get_int_tuple("dilate", (1,))
+    new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
     new_attrs["groups"] = attrs.get_int("num_group", 1)
     new_attrs["data_layout"] = data_layout
     new_attrs["kernel_layout"] = kernel_layout
@@ -214,12 +219,12 @@ def _get_mx_conv2d_attrs(attrs):
     new_attrs["kernel_layout"] = kernel_layout
     return new_attrs
 
+
 def _mx_conv2d(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
     data_layout = attrs.get_str("layout", "NCHW")
     if len(kernel_size) != 2:
-        raise tvm.error.OpAttributeInvalid(
-            'Only 2D kernels are supported for operator Convolution')
+        raise tvm.error.OpAttributeInvalid("Only 2D kernels are supported for operator Convolution")
 
     new_attrs = _get_mx_conv2d_attrs(attrs)
     channel_axis = _get_channel_axis(data_layout, "conv2d")
@@ -254,8 +259,7 @@ def _mx_conv3d(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
     data_layout = attrs.get_str("layout", "NCDHW")
     if len(kernel_size) != 3:
-        raise tvm.error.OpAttributeInvalid(
-            'Only 3D kernels are supported for operator Convolution')
+        raise tvm.error.OpAttributeInvalid("Only 3D kernels are supported for operator Convolution")
 
     new_attrs = _get_mx_conv3d_attrs(attrs)
     channel_axis = _get_channel_axis(data_layout, "conv3d")
@@ -277,17 +281,18 @@ def _mx_conv_transpose(inputs, attrs):
         return _mx_conv1d_transpose(inputs, attrs)
     else:
         raise tvm.error.OpAttributeInvalid(
-            '1D, 2D or 3D kernels only are supported for operator Convolution')
+            "1D, 2D or 3D kernels only are supported for operator Convolution"
+        )
 
 
 def _mx_conv1d_transpose(inputs, attrs):
     if "target_shape" in attrs.attrs:
         raise tvm.error.OpAttributeUnImplemented(
-            'Attribute "target_shape" is not supported for operator Conv2D-transpose.')
+            'Attribute "target_shape" is not supported for operator Conv2D-transpose.'
+        )
     data_layout = attrs.get_str("layout", "NCW")
     if data_layout != "NCW":
-        raise tvm.error.OpAttributeInvalid(
-            'Only "NCW" data layout is supported for 1D Convolution')
+        raise tvm.error.OpAttributeInvalid('Only "NCW" data layout is supported for 1D Convolution')
     channel_axis = 1
     kernel_layout = "OIW"
     new_attrs = {}
@@ -311,11 +316,13 @@ def _mx_conv1d_transpose(inputs, attrs):
 def _mx_conv2d_transpose(inputs, attrs):
     if "target_shape" in attrs.attrs:
         raise tvm.error.OpAttributeUnImplemented(
-            'Attribute "target_shape" is not supported for operator Conv2D-transpose.')
+            'Attribute "target_shape" is not supported for operator Conv2D-transpose.'
+        )
     kernel_size = attrs.get_int_tuple("kernel")
     if len(kernel_size) != 2:
         raise tvm.error.OpAttributeInvalid(
-            'Non-2D kernels are not supported for operator Conv2D-transpose.')
+            "Non-2D kernels are not supported for operator Conv2D-transpose."
+        )
     data_layout = attrs.get_str("layout", "NCHW")
     channel_axis = _get_channel_axis(data_layout, "conv2d_transpose")
 
@@ -346,11 +353,13 @@ def _mx_conv2d_transpose(inputs, attrs):
 def _mx_conv3d_transpose(inputs, attrs):
     if "target_shape" in attrs.attrs:
         raise tvm.error.OpAttributeUnImplemented(
-            'Attribute "target_shape" is not supported for operator Conv3D-transpose.')
+            'Attribute "target_shape" is not supported for operator Conv3D-transpose.'
+        )
     kernel_size = attrs.get_int_tuple("kernel")
     if len(kernel_size) != 3:
         raise tvm.error.OpAttributeInvalid(
-            'Non-3D kernels are not supported for operator Conv3D-transpose.')
+            "Non-3D kernels are not supported for operator Conv3D-transpose."
+        )
     data_layout = attrs.get_str("layout", "NCDHW")
     channel_axis = _get_channel_axis(data_layout, "conv3d_transpose")
 
@@ -385,13 +394,12 @@ def _mx_pooling(inputs, attrs):
     def _pool2d(new_op, is_avg):
         kernel_size = attrs.get_int_tuple("kernel")
         if len(kernel_size) != 2:
-            raise tvm.error.OpAttributeInvalid(
-                'Only 2D kernels are supported for operator Pool2D.')
+            raise tvm.error.OpAttributeInvalid("Only 2D kernels are supported for operator Pool2D.")
         new_attrs = {}
         new_attrs["pool_size"] = kernel_size
         new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1))
         new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0))
-        new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full")
+        new_attrs["ceil_mode"] = attrs.get_str("pooling_convention", "valid") == "full"
         if is_avg:
             new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
         return new_op(inputs[0], **new_attrs)
@@ -399,18 +407,17 @@ def _mx_pooling(inputs, attrs):
     def _pool3d(new_op, is_avg):
         kernel_size = attrs.get_int_tuple("kernel")
         if len(kernel_size) != 3:
-            raise tvm.error.OpAttributeInvalid(
-                'Only 3D kernels are supported for operator Pool3D.')
+            raise tvm.error.OpAttributeInvalid("Only 3D kernels are supported for operator Pool3D.")
         new_attrs = {}
         new_attrs["pool_size"] = kernel_size
         new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
         new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
-        new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full")
+        new_attrs["ceil_mode"] = attrs.get_str("pooling_convention", "valid") == "full"
         if is_avg:
             new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
         return new_op(inputs[0], **new_attrs)
 
-    #3D pooling
+    # 3D pooling
     if len(_infer_shape(inputs[0])) == 5:
         if pool_type == "max":
             if global_pool:
@@ -421,9 +428,11 @@ def _mx_pooling(inputs, attrs):
                 return _op.nn.global_avg_pool3d(inputs[0])
             return _pool3d(_op.nn.avg_pool3d, True)
         raise tvm.error.OpNotImplemented(
-            'Operator {} Pooling is not supported for frontend MXNet.' \
-                .format(pool_type.capitalize()))
-    #2D Pooling
+            "Operator {} Pooling is not supported for frontend MXNet.".format(
+                pool_type.capitalize()
+            )
+        )
+    # 2D Pooling
     if pool_type == "max":
         if global_pool:
             return _op.nn.global_max_pool2d(inputs[0])
@@ -433,8 +442,8 @@ def _mx_pooling(inputs, attrs):
             return _op.nn.global_avg_pool2d(inputs[0])
         return _pool2d(_op.nn.avg_pool2d, True)
     raise tvm.error.OpNotImplemented(
-        'Operator {} Pooling is not supported for frontend MXNet.' \
-            .format(pool_type.capitalize()))
+        "Operator {} Pooling is not supported for frontend MXNet.".format(pool_type.capitalize())
+    )
 
 
 def _mx_adaptive_avg_pooling(inputs, attrs):
@@ -447,14 +456,15 @@ def _mx_dropout(inputs, attrs):
     return _op.nn.dropout(inputs[0], rate=rate)
 
 
-def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument
+def _mx_BlockGrad(inputs, attrs):  # pylint: disable=unused-argument
     return inputs
 
 
 def _mx_batch_norm(inputs, attrs):
     if attrs.get_bool("output_mean_var", False):
         raise tvm.error.OpAttributeUnImplemented(
-            'Attribute "output_mean_var" is not supported for operator Batch Norm.')
+            'Attribute "output_mean_var" is not supported for operator Batch Norm.'
+        )
     if attrs.get_bool("use_global_stats", False):
         _warn_not_used("use_global_stats", "batch_norm")
     new_attrs = {}
@@ -477,7 +487,8 @@ def _mx_layer_norm(inputs, attrs):
     assert len(inputs) == 3
     if attrs.get_bool("output_mean_var", False):
         raise tvm.error.OpAttributeUnimplemented(
-            'Attribute "output_mean_var" is not supported for operator Layer Norm.')
+            'Attribute "output_mean_var" is not supported for operator Layer Norm.'
+        )
     new_attrs = {}
     new_attrs["axis"] = attrs.get_int("axis", -1)
     new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
@@ -486,25 +497,22 @@ def _mx_layer_norm(inputs, attrs):
 
 def _mx_slice(inputs, attrs):
     new_attrs = {}
-    begin = list(attrs.get_int_tuple('begin', None))
-    end = list(attrs.get_int_tuple('end', None))
-    stride = attrs.get_int_tuple('step', None)
+    begin = list(attrs.get_int_tuple("begin", None))
+    end = list(attrs.get_int_tuple("end", None))
+    stride = attrs.get_int_tuple("step", None)
     input_shape = _infer_type(inputs[0]).checked_type.shape
     if begin is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "begin" not found in operator Slice.')
+        raise tvm.error.OpAttributeRequired('Attribute "begin" not found in operator Slice.')
     if end is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "end" not found in operator Slice.')
+        raise tvm.error.OpAttributeRequired('Attribute "end" not found in operator Slice.')
     begin = (x if x is not None else 0 for x in begin)
     for i, ed in enumerate(end):
         if ed is None:
             end[i] = input_shape[i]
-    new_attrs = {'begin': list(begin),
-                 'end': list(end)}
+    new_attrs = {"begin": list(begin), "end": list(end)}
     if stride is not None:
         stride = (x if x is not None else 1 for x in stride)
-        new_attrs['strides'] = list(stride)
+        new_attrs["strides"] = list(stride)
     return _op.strided_slice(inputs[0], **new_attrs)
 
 
@@ -550,13 +558,12 @@ def _mx_slice_axis(inputs, attrs):
 def _mx_crop_like(inputs, attrs):
     if len(inputs) < 2:
         raise tvm.error.OpAttributeUnimplemented(
-            "Only support crop_like pattern for operator Crop.")
+            "Only support crop_like pattern for operator Crop."
+        )
     if attrs.get_bool("center_crop", False):
-        raise tvm.error.OpAttributeUnimplemented(
-            "Center crop is not supported in operator Crop.")
+        raise tvm.error.OpAttributeUnimplemented("Center crop is not supported in operator Crop.")
     if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0):
-        raise tvm.error.OpAttributeUnimplemented(
-            "Doesn't support h_w in operator Crop.")
+        raise tvm.error.OpAttributeUnimplemented("Doesn't support h_w in operator Crop.")
     offset = attrs.get_int_tuple("offset", (0, 0))
     new_attrs = {}
     if offset == (0, 0):
@@ -564,9 +571,13 @@ def _mx_crop_like(inputs, attrs):
         return _op.slice_like(*inputs, **new_attrs)
     expr = _infer_type(inputs[1])
     like_shape = expr.checked_type.shape
-    new_attrs['begin'] = [0, 0, offset[0], offset[1]]
-    new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
-                        offset[1]+like_shape[3]]
+    new_attrs["begin"] = [0, 0, offset[0], offset[1]]
+    new_attrs["end"] = [
+        like_shape[0],
+        like_shape[1],
+        offset[0] + like_shape[2],
+        offset[1] + like_shape[3],
+    ]
     return _op.strided_slice(inputs[0], **new_attrs)
 
 
@@ -611,27 +622,26 @@ def _mx_expand_dims(inputs, attrs):
     axis = attrs.get_int("axis")
     return _op.expand_dims(inputs[0], axis=axis)
 
+
 def _mx_pad(inputs, attrs):
-    pad_mode = attrs.get_str('mode', None)
+    pad_mode = attrs.get_str("mode", None)
     if pad_mode is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "mode" not found in operator pad.')
-    if pad_mode not in ['constant', 'edge', 'reflect']:
-        raise tvm.error.OpAttributeInvalid(
-            'Value ' + mode + ' in attribute "mode" is not valid')
-    pad_width = attrs.get_int_tuple('pad_width', None)
+        raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.')
+    if pad_mode not in ["constant", "edge", "reflect"]:
+        raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid')
+    pad_width = attrs.get_int_tuple("pad_width", None)
     if pad_width is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "pad_width" not found in operator pad.')
+        raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.')
     if None in pad_width:
         raise tvm.error.OpAttributeInvalid(
-            'Value None in attribute "pad_width" of operator Slice is not valid.')
-    constant_value = attrs.get_float('constant_value', 0.0)
+            'Value None in attribute "pad_width" of operator Slice is not valid.'
+        )
+    constant_value = attrs.get_float("constant_value", 0.0)
     padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))
-    return _op.nn.pad(data=inputs[0],
-                      pad_width=padding,
-                      pad_value=constant_value,
-                      pad_mode=pad_mode)
+    return _op.nn.pad(
+        data=inputs[0], pad_width=padding, pad_value=constant_value, pad_mode=pad_mode
+    )
+
 
 def _mx_leaky_relu(inputs, attrs):
     act_type = attrs.get_str("act_type", "leaky")
@@ -664,7 +674,8 @@ def _mx_leaky_relu(inputs, attrs):
         half_x = _op.multiply(inputs[0], half)
         return _op.multiply(half_x, erf_plus_one)
     raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported for frontend MXNet.'.format(act_type))
+        "Operator {} is not supported for frontend MXNet.".format(act_type)
+    )
 
 
 def _mx_make_power(power):
@@ -673,6 +684,7 @@ def _mx_make_power(power):
         scalar = _expr.const(power, dtype=None)
         # Note: int maps to "int32", float maps to "float32"
         return _op.power(inputs[0], scalar)
+
     return _impl
 
 
@@ -682,6 +694,7 @@ def _mx_make_exponent(base):
         assert len(inputs) == 1
         scalar = _op.exp(_expr.const(base, dtype="float32"))
         return _op.multiply(inputs[0], scalar)
+
     return _impl
 
 
@@ -691,6 +704,7 @@ def _mx_make_logarithm(base):
         assert len(inputs) == 1
         scalar = _op.log(_expr.const(base, dtype="float32"))
         return _op.divide(inputs[0], scalar)
+
     return _impl
 
 
@@ -700,6 +714,7 @@ def _mx_expm1():
         assert len(inputs) == 1
         one = _expr.const(1, dtype="float32")
         return _op.log(_op.subtract(inputs[0], one))
+
     return _impl
 
 
@@ -709,6 +724,7 @@ def _mx_log1p():
         assert len(inputs) == 1
         one = _expr.const(1, dtype="float32")
         return _op.log(_op.add(inputs[0], one))
+
     return _impl
 
 
@@ -726,10 +742,10 @@ def _mx_lrn(inputs, attrs):
 
 def _mx_multibox_prior(inputs, attrs):
     new_attrs = {}
-    new_attrs["sizes"] = attrs.get_float_tuple("sizes", (1.0, ))
+    new_attrs["sizes"] = attrs.get_float_tuple("sizes", (1.0,))
     new_attrs["steps"] = attrs.get_float_tuple("steps", (-1.0, -1.0))
     new_attrs["offsets"] = attrs.get_float_tuple("offsets", (0.5, 0.5))
-    new_attrs["ratios"] = attrs.get_float_tuple("ratios", (1.0, ))
+    new_attrs["ratios"] = attrs.get_float_tuple("ratios", (1.0,))
     new_attrs["clip"] = attrs.get_bool("clip", False)
     return _op.vision.multibox_prior(inputs[0], **new_attrs)
 
@@ -738,8 +754,7 @@ def _mx_multibox_detection(inputs, attrs):
     new_attrs0 = {}
     new_attrs0["clip"] = attrs.get_bool("clip", True)
     new_attrs0["threshold"] = attrs.get_float("threshold", 0.01)
-    new_attrs0["variances"] = attrs.get_float_tuple("variances", (0.1, 0.1,
-                                                                  0.2, 0.2))
+    new_attrs0["variances"] = attrs.get_float_tuple("variances", (0.1, 0.1, 0.2, 0.2))
 
     new_attrs1 = {}
     new_attrs1["return_indices"] = False
@@ -747,8 +762,7 @@ def _mx_multibox_detection(inputs, attrs):
     new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False)
     new_attrs1["top_k"] = attrs.get_int("nms_topk", -1)
 
-    ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
-                                            inputs[2], **new_attrs0)
+    ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1], inputs[2], **new_attrs0)
     return _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **new_attrs1)
 
 
@@ -758,8 +772,7 @@ def _mx_batch_dot(inputs, attrs):
     transpose_a = attrs.get_bool("transpose_a", False)
     transpose_b = attrs.get_bool("transpose_b", False)
     if transpose_a is True:
-        msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' \
-              'is not valid.'
+        msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' "is not valid."
         raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
     if transpose_b is False:
         b = _op.transpose(b, axes=[0, 2, 1])
@@ -770,7 +783,8 @@ def _mx_arange(inputs, attrs):
     assert len(inputs) == 0
     if attrs.get_int("repeat", 1) != 1:
         raise tvm.error.OpAttributeUnimplemented(
-            'Attribute "repeat" is not supported in operator arange.')
+            'Attribute "repeat" is not supported in operator arange.'
+        )
     dtype = attrs.get_str("dtype", "float32")
     stop = attrs.get_str("stop", "None")
     if stop == "None":
@@ -796,7 +810,8 @@ def _mx_contrib_arange_like(inputs, attrs):
     assert len(inputs) == 1
     if attrs.get_int("repeat", 1) != 1:
         raise tvm.error.OpAttributeUnimplemented(
-            'Attribute "repeat" is not supported in operator arange_like.')
+            'Attribute "repeat" is not supported in operator arange_like.'
+        )
     ty = _infer_type(inputs[0]).checked_type
     assert ty
     shape, dtype = get_const_tuple(ty.shape), ty.dtype
@@ -814,8 +829,8 @@ def _mx_contrib_arange_like(inputs, attrs):
         if not isinstance(n_elems, int):
             raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
         shape = (n_elems,)
-    start = attrs.get_float("start", 0.)
-    step = attrs.get_float("step", 1.)
+    start = attrs.get_float("start", 0.0)
+    step = attrs.get_float("step", 1.0)
     stop = start + step * n_elems
     new_attrs = {}
     new_attrs["start"] = _expr.const(start, dtype=dtype)
@@ -851,11 +866,13 @@ def _mx_take(inputs, attrs):
     axis = attrs.get_int("axis", 0)
     return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
 
+
 def _mx_gather_nd(inputs, attrs):
     assert len(inputs) == 2
     assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions"
     return _op.gather_nd(inputs[0], inputs[1])
 
+
 def _mx_reverse(inputs, attrs):
     assert len(inputs) == 1
     new_attrs = {}
@@ -886,6 +903,7 @@ def _mx_roi_align(inputs, attrs):
     new_attrs["layout"] = "NCHW"
     return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)
 
+
 def _mx_resize(inputs, attrs):
     scale_height = attrs.get_float("scale_height", None)
     scale_width = attrs.get_float("scale_width", None)
@@ -898,30 +916,32 @@ def _mx_resize(inputs, attrs):
     if scale_width is not None:
         width = (scale_width * shape[3]).astype("int32")
     size = (height, width)
-    return _op.image.resize(inputs[0], size,
-                            coordinate_transformation_mode="align_corners")
+    return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners")
+
 
 def _mx_amp_multicast(inputs, attrs):
     cast_narrow = attrs.get_bool("cast_narrow", False)
     dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
-    supported_dtypes = ['float16', 'float32']
-    assert all([x in supported_dtypes for x in dtypes]), \
-            "amp_multicast support is limited to float16 and float32 inputs only."
+    supported_dtypes = ["float16", "float32"]
+    assert all(
+        [x in supported_dtypes for x in dtypes]
+    ), "amp_multicast support is limited to float16 and float32 inputs only."
     has_float16 = any(x == "float16" for x in dtypes)
     has_float32 = any(x == "float32" for x in dtypes)
     dtype = dtypes[0]
     if cast_narrow and has_float16:
-        dtype = 'float16'
+        dtype = "float16"
     if not cast_narrow and has_float32:
-        dtype = 'float32'
+        dtype = "float32"
     return [_op.cast(x, dtype) for x in inputs]
 
+
 def _mx_grid_generator(inputs, attrs):
     transform_type = attrs.get_str("transform_type")
-    if transform_type == 'affine':
+    if transform_type == "affine":
         target_shape = attrs.get_int_tuple("target_shape")
         return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), target_shape)
-    if transform_type == 'warp':
+    if transform_type == "warp":
         checked_type = _infer_type(inputs[0]).checked_type
         batch, _, height, width = get_const_tuple(checked_type.shape)
         dtype = checked_type.dtype
@@ -933,8 +953,10 @@ def _mx_grid_generator(inputs, attrs):
         return grid + normalized_flow
     raise ValueError("unknown transform type" + transform_type)
 
+
 def _mx_bilinear_sampler(inputs, attrs):
-    return _op.image.grid_sample(inputs[0], inputs[1], 'bilinear', 'NCHW')
+    return _op.image.grid_sample(inputs[0], inputs[1], "bilinear", "NCHW")
+
 
 def _mx_roi_pooling(inputs, attrs):
     new_attrs = {}
@@ -960,46 +982,51 @@ def _mx_proposal(inputs, attrs):
 
 def _mx_box_nms(inputs, attrs):
     force_suppress = attrs.get_bool("force_suppress", False)
-    iou_thresh = attrs.get_float('overlap_thresh', 0.5)
-    top_k = attrs.get_int('topk', -1)
-    valid_thresh = attrs.get_float('valid_thresh', 0)
-    coord_start = attrs.get_int('coord_start', 2)
-    score_index = attrs.get_int('score_index', 1)
-    id_index = attrs.get_int('id_index', -1)
-    in_format = attrs.get_str('in_format', 'corner')
-    out_format = attrs.get_str('out_format', 'corner')
-    if in_format != 'corner':
+    iou_thresh = attrs.get_float("overlap_thresh", 0.5)
+    top_k = attrs.get_int("topk", -1)
+    valid_thresh = attrs.get_float("valid_thresh", 0)
+    coord_start = attrs.get_int("coord_start", 2)
+    score_index = attrs.get_int("score_index", 1)
+    id_index = attrs.get_int("id_index", -1)
+    in_format = attrs.get_str("in_format", "corner")
+    out_format = attrs.get_str("out_format", "corner")
+    if in_format != "corner":
         raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "in_format" must equal "corner" for operator box_nms.')
-    if out_format != 'corner':
+            'Value of attribute "in_format" must equal "corner" for operator box_nms.'
+        )
+    if out_format != "corner":
         raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "out_format" must equal "corner" for operator box_nms.')
-
-    ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh,
-                                      id_index=id_index, score_index=score_index)
-    nms_out = _op.vision.non_max_suppression(ret[1],
-                                             ret[0],
-                                             ret[2],
-                                             iou_threshold=iou_thresh,
-                                             force_suppress=force_suppress,
-                                             top_k=top_k,
-                                             coord_start=coord_start,
-                                             score_index=score_index,
-                                             id_index=id_index,
-                                             return_indices=False,
-                                             invalid_to_bottom=True)
+            'Value of attribute "out_format" must equal "corner" for operator box_nms.'
+        )
+
+    ret = _op.vision.get_valid_counts(
+        inputs[0], score_threshold=valid_thresh, id_index=id_index, score_index=score_index
+    )
+    nms_out = _op.vision.non_max_suppression(
+        ret[1],
+        ret[0],
+        ret[2],
+        iou_threshold=iou_thresh,
+        force_suppress=force_suppress,
+        top_k=top_k,
+        coord_start=coord_start,
+        score_index=score_index,
+        id_index=id_index,
+        return_indices=False,
+        invalid_to_bottom=True,
+    )
     return nms_out
 
 
 def _mx_box_decode(inputs, attrs):
-    std0 = relay.const(attrs.get_float('std0', 1), "float32")
-    std1 = relay.const(attrs.get_float('std1', 1), "float32")
-    std2 = relay.const(attrs.get_float('std2', 1), "float32")
-    std3 = relay.const(attrs.get_float('std3', 1), "float32")
-    clip = attrs.get_float('clip', -1)
-    in_format = attrs.get_str('format', 'corner')
-
-    anchors = inputs[1] # (1, N, 4) encoded in corner or center
+    std0 = relay.const(attrs.get_float("std0", 1), "float32")
+    std1 = relay.const(attrs.get_float("std1", 1), "float32")
+    std2 = relay.const(attrs.get_float("std2", 1), "float32")
+    std3 = relay.const(attrs.get_float("std3", 1), "float32")
+    clip = attrs.get_float("clip", -1)
+    in_format = attrs.get_str("format", "corner")
+
+    anchors = inputs[1]  # (1, N, 4) encoded in corner or center
     a = _op.split(anchors, indices_or_sections=4, axis=-1)
     # Convert to format "center".
     if in_format == "corner":
@@ -1009,7 +1036,7 @@ def _mx_box_decode(inputs, attrs):
         a_y = a[1] + a_height * relay.const(0.5, "float32")
     else:
         a_x, a_y, a_width, a_height = a
-    data = inputs[0] # (B, N, 4) predicted bbox offset
+    data = inputs[0]  # (B, N, 4) predicted bbox offset
     p = _op.split(data, indices_or_sections=4, axis=-1)
     ox = p[0] * std0 * a_width + a_x
     oy = p[1] * std1 * a_height + a_y
@@ -1029,12 +1056,13 @@ def _mx_box_decode(inputs, attrs):
 
 def _mx_l2_normalize(inputs, attrs):
     new_attrs = {}
-    mode = attrs.get_str('mode', 'instance')
-    if mode != 'channel':
+    mode = attrs.get_str("mode", "instance")
+    if mode != "channel":
         raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "mode" must equal "channel" for operator l2_normalize.')
-    new_attrs['eps'] = attrs.get_float('eps', 1e-10)
-    new_attrs['axis'] = [1]
+            'Value of attribute "mode" must equal "channel" for operator l2_normalize.'
+        )
+    new_attrs["eps"] = attrs.get_float("eps", 1e-10)
+    new_attrs["axis"] = [1]
     return _op.nn.l2_normalize(inputs[0], **new_attrs)
 
 
@@ -1053,7 +1081,7 @@ def _mx_hard_sigmoid(inputs, attrs):
 
 
 def _mx_reciprocal(inputs, attrs):
-    return _expr.const(1.0) /inputs[0]
+    return _expr.const(1.0) / inputs[0]
 
 
 def _mx_shape_array(inputs, attrs):
@@ -1066,7 +1094,7 @@ def _mx_shape_array(inputs, attrs):
         raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin")
     if attrs.get_int("rhs_end", None) is not None:
         raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end")
-    return _op.shape_of(inputs[0], dtype='int64')
+    return _op.shape_of(inputs[0], dtype="int64")
 
 
 def _mx_full(inputs, attrs):
@@ -1106,16 +1134,18 @@ def _mx_broadcast_axis(inputs, attrs):
 def _mx_embedding(inputs, _):
     assert len(inputs) == 2
     indices, weight = inputs
-    return _op.take(weight, indices.astype('int32'), axis=0)
+    return _op.take(weight, indices.astype("int32"), axis=0)
 
 
 def _mx_smooth_l1(inputs, attrs):
     scalar = attrs.get_float("scalar", 1.0)
     scalar_sq = scalar * scalar
-    mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype='float32'))
-    return _op.where(mask,
-                     _expr.const(scalar_sq / 2.0, dtype='float32') * inputs[0] * inputs[0],
-                     _op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))
+    mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype="float32"))
+    return _op.where(
+        mask,
+        _expr.const(scalar_sq / 2.0, dtype="float32") * inputs[0] * inputs[0],
+        _op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq),
+    )
 
 
 def _mx_deformable_convolution(inputs, attrs):
@@ -1154,7 +1184,8 @@ def _mx_topk(inputs, attrs):
     ret_type = attrs.get_str("ret_typ", "indices")
     if ret_type == "mask":
         raise tvm.error.OpAttributeUnimplemented(
-            "Attribute ret_type=mask is not supported in topk operator")
+            "Attribute ret_type=mask is not supported in topk operator"
+        )
     new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
     new_attrs["dtype"] = attrs.get_str("dtype", "float32")
     return _op.topk(inputs[0], **new_attrs)
@@ -1163,9 +1194,9 @@ def _mx_topk(inputs, attrs):
 def _mx_sequence_mask(inputs, attrs):
     assert len(inputs) == 1 or len(inputs) == 2
     new_attrs = {}
-    use_sequence_length = attrs.get_bool('use_sequence_length', False)
-    new_attrs['mask_value'] = attrs.get_float('value', 0.0)
-    new_attrs['axis'] = attrs.get_int('axis', 0)
+    use_sequence_length = attrs.get_bool("use_sequence_length", False)
+    new_attrs["mask_value"] = attrs.get_float("value", 0.0)
+    new_attrs["axis"] = attrs.get_int("axis", 0)
     if use_sequence_length:
         return _op.sequence_mask(*inputs, **new_attrs)
     else:
@@ -1175,7 +1206,7 @@ def _mx_sequence_mask(inputs, attrs):
 def _mx_contrib_div_sqrt_dim(inputs, _):
     assert len(inputs) == 1
     ndim = len(_infer_type(inputs[0]).checked_type.shape)
-    dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
+    dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim - 1, dtype="int32"))
     dtype = _infer_type(inputs[0]).checked_type.dtype
     sqrt_dim = _op.sqrt(dim.astype(dtype))
     out = inputs[0] / sqrt_dim
@@ -1224,15 +1255,16 @@ def _mx_rnn_layer(inputs, attrs):
     mode = attrs.get_str("mode")
     output_states = attrs.get_bool("state_outputs", False)
     if mode.startswith("rnn"):
-        mode, activation = mode.split('_')
+        mode, activation = mode.split("_")
     assert mode in ["rnn", "gru", "lstm"]
     bidirectional = attrs.get_bool("bidirectional", False)
     direct = 2 if bidirectional else 1
     layout = attrs.get_str("layout", "TNC")
     if layout != "TNC":
         raise tvm.error.OpAttributeUnimplemented(
-            "RNN with layout other than TNC is not supported yet")
-    num_states = 2 if mode == 'lstm' else 1
+            "RNN with layout other than TNC is not supported yet"
+        )
+    num_states = 2 if mode == "lstm" else 1
     assert len(inputs) == num_states + 2
 
     seq_data = inputs[0]
@@ -1268,22 +1300,35 @@ def _mx_rnn_layer(inputs, attrs):
     back_bias = []
     back_states = []
     for i in range(num_layers):
-        weights.append([concat_weight[i*2*direct].args[0],
-                        concat_weight[i*2*direct + 1].args[0]])
-        bias.append([concat_weight[(num_layers+i)*2*direct].args[0],
-                     concat_weight[(num_layers+i)*2*direct + 1].args[0]])
+        weights.append(
+            [concat_weight[i * 2 * direct].args[0], concat_weight[i * 2 * direct + 1].args[0]]
+        )
+        bias.append(
+            [
+                concat_weight[(num_layers + i) * 2 * direct].args[0],
+                concat_weight[(num_layers + i) * 2 * direct + 1].args[0],
+            ]
+        )
         s = []
         for state in init_states:
-            s.append(_op.take(state, _expr.const(i*direct, "int32"), axis=0))
+            s.append(_op.take(state, _expr.const(i * direct, "int32"), axis=0))
         states.append(s)
         if bidirectional:
-            back_weights.append([concat_weight[i*2*direct + 2].args[0],
-                                 concat_weight[i*2*direct + 3].args[0]])
-            back_bias.append([concat_weight[(num_layers+i)*2*direct + 2].args[0],
-                              concat_weight[(num_layers+i)*2*direct + 3].args[0]])
+            back_weights.append(
+                [
+                    concat_weight[i * 2 * direct + 2].args[0],
+                    concat_weight[i * 2 * direct + 3].args[0],
+                ]
+            )
+            back_bias.append(
+                [
+                    concat_weight[(num_layers + i) * 2 * direct + 2].args[0],
+                    concat_weight[(num_layers + i) * 2 * direct + 3].args[0],
+                ]
+            )
             s = []
             for state in init_states:
-                s.append(_op.take(state, _expr.const(i*direct+1, "int32"), axis=0))
+                s.append(_op.take(state, _expr.const(i * direct + 1, "int32"), axis=0))
             back_states.append(s)
 
     xs = [_op.take(seq_data, _expr.const(t, "int32"), axis=0) for t in range(seq_len)]
@@ -1295,7 +1340,7 @@ def _mx_rnn_layer(inputs, attrs):
                 out, new_states = _rnn_cell(x, states[l], *weights[l], *bias[l], activation)
             elif mode == "gru":
                 out, new_states = _gru_cell(x, states[l], *weights[l], *bias[l])
-            else: # mode == "lstm"
+            else:  # mode == "lstm"
                 out, new_states = _lstm_cell(x, states[l], *weights[l], *bias[l])
             states[l] = new_states
             outputs.append(out)
@@ -1303,13 +1348,12 @@ def _mx_rnn_layer(inputs, attrs):
             for x in reversed(xs):
                 if mode == "rnn":
                     out, new_states = _rnn_cell(
-                        x, back_states[l], *back_weights[l], *back_bias[l], activation)
+                        x, back_states[l], *back_weights[l], *back_bias[l], activation
+                    )
                 elif mode == "gru":
-                    out, new_states = _gru_cell(
-                        x, back_states[l], *back_weights[l], *back_bias[l])
-                else: # mode == "lstm"
-                    out, new_states = _lstm_cell(
-                        x, back_states[l], *back_weights[l], *back_bias[l])
+                    out, new_states = _gru_cell(x, back_states[l], *back_weights[l], *back_bias[l])
+                else:  # mode == "lstm"
+                    out, new_states = _lstm_cell(x, back_states[l], *back_weights[l], *back_bias[l])
                 back_states[l] = new_states
                 back_outputs.append(out)
             back_outputs.reverse()
@@ -1331,12 +1375,13 @@ def _mx_rnn_layer(inputs, attrs):
             ret.append(_op.stack(inputs, axis=0))
     return ret
 
+
 def _mx_one_hot(inputs, attrs):
-    indices = inputs[0].astype('int32')
-    depth = attrs.get_int('depth', 0)
-    dtype = attrs.get_str('dtype', 'int32')
-    on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype)
-    off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype)
+    indices = inputs[0].astype("int32")
+    depth = attrs.get_int("depth", 0)
+    dtype = attrs.get_str("dtype", "int32")
+    on_value = tvm.relay.const(attrs.get_float("on_value", 1.0), dtype)
+    off_value = tvm.relay.const(attrs.get_float("off_value", 0.0), dtype)
     return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
 
 
@@ -1369,7 +1414,7 @@ def _mx_correlation(inputs, attrs):
 
 def _mx_contrib_fifo_buffer(inputs, attrs):
     new_attrs = {}
-    new_attrs['axis'] = attrs.get_int('axis')
+    new_attrs["axis"] = attrs.get_int("axis")
     return _op.nn.fifo_buffer(*inputs, **new_attrs)
 
 
@@ -1385,7 +1430,7 @@ def _mx_contrib_interleaved_matmul_selfatt_qk(inputs, attrs):
     """
     assert len(inputs) == 1
     qkv = inputs[0]
-    num_heads = attrs.get_int('heads')
+    num_heads = attrs.get_int("heads")
     qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
     q_proj = _op.take(qkv, _expr.const(0, "int32"), axis=3)
     q_proj = _op.transpose(q_proj, axes=[1, 2, 0, 3])
@@ -1493,8 +1538,9 @@ def _qnn_contrib_concat(inputs, attrs):
     else:
         # Get all dtypes. Find input and output scales, call concatenate.
         dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs]
-        assert all([x == 'uint8' for x in dtypes]), \
-                "Current support is limited to uint8 inputs only."
+        assert all(
+            [x == "uint8" for x in dtypes]
+        ), "Current support is limited to uint8 inputs only."
         new_min = min(mins)
         new_max = max(maxs)
         assert new_min == 0
@@ -1505,37 +1551,41 @@ def _qnn_contrib_concat(inputs, attrs):
         input_zeros = [0] * len(input_scales)
         output_zero = 0
 
-        input_scales_expr = [relay.const(x, 'float32') for x in input_scales]
-        input_zeros_expr = [relay.const(x, 'int32') for x in input_zeros]
+        input_scales_expr = [relay.const(x, "float32") for x in input_scales]
+        input_zeros_expr = [relay.const(x, "int32") for x in input_zeros]
 
-        output_scale_expr = relay.const(output_scale, 'float32')
-        output_zero_expr = relay.const(output_zero, 'int32')
+        output_scale_expr = relay.const(output_scale, "float32")
+        output_zero_expr = relay.const(output_zero, "int32")
 
-        res = relay.qnn.op.concatenate(input_exprs, input_scales_expr, input_zeros_expr,
-                                       output_scale_expr, output_zero_expr, axis=axis)
+        res = relay.qnn.op.concatenate(
+            input_exprs,
+            input_scales_expr,
+            input_zeros_expr,
+            output_scale_expr,
+            output_zero_expr,
+            axis=axis,
+        )
         return res, new_min, new_max
 
 
 def _qnn_quantize(inputs, attrs):
-    out_dtype = 'int8'
-    out_type = attrs.get_str('out_type')
-    if out_type == 'auto':
-        if attrs.has_attr('min_calib_range') and attrs.has_attr('max_calib_range'):
-            if attrs.get_float('min_calib_range') >= 0:
-                out_dtype = 'uint8'
+    out_dtype = "int8"
+    out_type = attrs.get_str("out_type")
+    if out_type == "auto":
+        if attrs.has_attr("min_calib_range") and attrs.has_attr("max_calib_range"):
+            if attrs.get_float("min_calib_range") >= 0:
+                out_dtype = "uint8"
             else:
-                out_dtype = 'int8'
+                out_dtype = "int8"
     else:
         out_dtype = out_type
-    if out_dtype not in {'int8', 'uint8'}:
-        raise ValueError('Unsupported out_dtype: %s' % out_dtype)
-    min_calib_range = attrs.get_float('min_calib_range', 0.0)
-    max_calib_range = attrs.get_float('max_calib_range', 0.0)
-    quantized_output, _, _ = \
-        quantize_mxnet_min_max(inputs[0],
-                               min_range=min_calib_range,
-                               max_range=max_calib_range,
-                               out_dtype=out_dtype)
+    if out_dtype not in {"int8", "uint8"}:
+        raise ValueError("Unsupported out_dtype: %s" % out_dtype)
+    min_calib_range = attrs.get_float("min_calib_range", 0.0)
+    max_calib_range = attrs.get_float("max_calib_range", 0.0)
+    quantized_output, _, _ = quantize_mxnet_min_max(
+        inputs[0], min_range=min_calib_range, max_range=max_calib_range, out_dtype=out_dtype
+    )
     return quantized_output, min_calib_range, max_calib_range
 
 
@@ -1550,18 +1600,17 @@ def _qnn_contrib_quantized_fifo_buffer(inputs, attrs, params):
     params[buffer_name] = _nd.array(np.zeros(buffer_shape).astype(data_dtype))
     new_buffer = relay.var(buffer_name, relay.TensorType(buffer_shape, data_dtype))
     inputs[1] = new_buffer
-    res = _op.nn.fifo_buffer(data=data, buffer=new_buffer, axis=attrs.get_int('axis'))
+    res = _op.nn.fifo_buffer(data=data, buffer=new_buffer, axis=attrs.get_int("axis"))
     return res, min_calib_range, max_calib_range
 
 
 def _get_subgraph_op(subgraphs, op_name):
-    assert len(subgraphs) == 1, \
-        "Subgraph should have 1 node but has {}".format(len(subgraphs))
+    assert len(subgraphs) == 1, "Subgraph should have 1 node but has {}".format(len(subgraphs))
     subgraph = subgraphs[0]
-    nodes = subgraph['nodes']
+    nodes = subgraph["nodes"]
     assert nodes is not None
     for node in nodes:
-        if node['op'] == op_name:
+        if node["op"] == op_name:
             return node
     raise ValueError("Op {} was not found in the subgraph".format(op_name))
 
@@ -1569,39 +1618,41 @@ def _get_subgraph_op(subgraphs, op_name):
 def _qnn_conv(inputs, attrs, subgraphs, params):
     def _has_fused_activation(_attrs, _supported_activations):
         has_fused_activation = False
-        if attrs.get_bool('with_act', False) or attrs.get_bool('with_postsum_act', False):
-            subgraph_activation_attrs = _get_subgraph_op(subgraphs, 'Activation')['attrs']
-            act_type = subgraph_activation_attrs['act_type']
+        if attrs.get_bool("with_act", False) or attrs.get_bool("with_postsum_act", False):
+            subgraph_activation_attrs = _get_subgraph_op(subgraphs, "Activation")["attrs"]
+            act_type = subgraph_activation_attrs["act_type"]
             if act_type not in _supported_activations:
-                raise ValueError('Fused activation {} is not supported at '
-                                 'this time'.format(act_type))
+                raise ValueError(
+                    "Fused activation {} is not supported at " "this time".format(act_type)
+                )
             has_fused_activation = True
         return has_fused_activation
 
-    def _get_data_scale_and_zp(_data, _inputs,
-                               _data_min_idx, _data_max_idx):
+    def _get_data_scale_and_zp(_data, _inputs, _data_min_idx, _data_max_idx):
         """ Finds the Qnn params for the data expr. """
         data_min = _inputs[_data_min_idx]
         data_max = _inputs[_data_max_idx]
         assert data_min <= data_max
         data_dtype = _infer_type(_data).checked_type.dtype
-        assert data_dtype in {'int8', 'uint8'}
+        assert data_dtype in {"int8", "uint8"}
         if data_min < 0.0:
-            assert data_dtype == 'int8', \
-                "Expect int8 when data_min < 0.0, consider quantize model with int8."
-        _data_scale = get_mkldnn_uint8_scale(data_min, data_max)\
-            if data_dtype == 'uint8' \
+            assert (
+                data_dtype == "int8"
+            ), "Expect int8 when data_min < 0.0, consider quantize model with int8."
+        _data_scale = (
+            get_mkldnn_uint8_scale(data_min, data_max)
+            if data_dtype == "uint8"
             else get_mkldnn_int8_scale(data_min, data_max)
+        )
         _data_zero_point = 0
         return _data_scale, _data_zero_point
 
-    def _get_bn_alpha_coeff(_bn_gamma_idx, _bn_beta_idx,
-                            _bn_running_mean_idx, _bn_running_var_idx):
+    def _get_bn_alpha_coeff(_bn_gamma_idx, _bn_beta_idx, _bn_running_mean_idx, _bn_running_var_idx):
         """ Extract the BN coeff. These will be use later for BN folding into convolution. """
         # Extract relevant attrs from bn.
-        bn_attrs = _get_subgraph_op(subgraphs, 'BatchNorm')['attrs']
-        bn_epsilon_param = float(bn_attrs['eps'])
-        bn_scale_param = bn_attrs['fix_gamma'] == "False"
+        bn_attrs = _get_subgraph_op(subgraphs, "BatchNorm")["attrs"]
+        bn_epsilon_param = float(bn_attrs["eps"])
+        bn_scale_param = bn_attrs["fix_gamma"] == "False"
         bn_center_param = True
 
         # Extract the relevant relay expressions.
@@ -1628,7 +1679,7 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
         """ Fold BN into kernel and bias. Get new kernel and bias. """
         _kernel = inputs[1]
         if _bn_scale:
-            assert attrs.get_bool('with_bn', False)
+            assert attrs.get_bool("with_bn", False)
             # Weights are on OIHW, and _bn_scale is in O.
             exp_bn_scale = relay.expand_dims(_bn_scale, axis=1, num_newaxis=3)
             _kernel = relay.multiply(exp_bn_scale, _kernel)
@@ -1655,41 +1706,47 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
         np_bias = None
         if _bias is not None:
             np_bias = _infer_value(_bias, params).asnumpy()
-        return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel,
-                                                                  np_bias,
-                                                                  kernel_channel_min,
-                                                                  kernel_channel_max,
-                                                                  _data_scale)
-
-    def _get_qnn_conv2d(_data, _kernel, _data_zero_point,
-                        _kernel_zero_point, _data_scale,
-                        _kernel_vector_scale, _conv2d_attrs):
+        return quantize_conv_weights_bias_channel_mkldnn_from_var(
+            _kernel, np_bias, kernel_channel_min, kernel_channel_max, _data_scale
+        )
+
+    def _get_qnn_conv2d(
+        _data,
+        _kernel,
+        _data_zero_point,
+        _kernel_zero_point,
+        _data_scale,
+        _kernel_vector_scale,
+        _conv2d_attrs,
+    ):
         return relay.qnn.op.conv2d(
             _data,
             _kernel,
-            input_zero_point=relay.const(_data_zero_point, 'int32'),
-            kernel_zero_point=relay.const(_kernel_zero_point, 'int32'),
-            input_scale=relay.const(_data_scale, 'float32'),
+            input_zero_point=relay.const(_data_zero_point, "int32"),
+            kernel_zero_point=relay.const(_kernel_zero_point, "int32"),
+            input_scale=relay.const(_data_scale, "float32"),
             kernel_scale=relay.const(_kernel_vector_scale),
-            channels=_conv2d_attrs['channels'],
-            groups=_conv2d_attrs['groups'],
-            kernel_size=_conv2d_attrs['kernel_size'],
-            strides=_conv2d_attrs['strides'],
-            dilation=_conv2d_attrs['dilation'],
-            padding=_conv2d_attrs['padding'],
-            data_layout=_conv2d_attrs['data_layout'],
-            kernel_layout=_conv2d_attrs['kernel_layout'])
+            channels=_conv2d_attrs["channels"],
+            groups=_conv2d_attrs["groups"],
+            kernel_size=_conv2d_attrs["kernel_size"],
+            strides=_conv2d_attrs["strides"],
+            dilation=_conv2d_attrs["dilation"],
+            padding=_conv2d_attrs["padding"],
+            data_layout=_conv2d_attrs["data_layout"],
+            kernel_layout=_conv2d_attrs["kernel_layout"],
+        )
 
     def _get_requantized_op(_res, _input_scale, _output_scale, _out_dtype):
         # Requantize to get the output back
         return relay.qnn.op.requantize(
             _res,
             input_scale=relay.const(_input_scale),
-            input_zero_point=relay.const(0, 'int32'),
-            output_scale=relay.const(_output_scale, 'float32'),
-            output_zero_point=relay.const(0, 'int32'),
+            input_zero_point=relay.const(0, "int32"),
+            output_scale=relay.const(_output_scale, "float32"),
+            output_zero_point=relay.const(0, "int32"),
             axis=1,
-            out_dtype=_out_dtype)
+            out_dtype=_out_dtype,
+        )
 
     def _get_sum(_res, _output_scale, out_dtype):
         """ Handles sum of the second quantized tensor. """
@@ -1699,53 +1756,61 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
         #   2) Call normal add
         #   3) Depending on final out_dtype, clip and cast (basically requantize).
 
-        _output_scale = relay.const(_output_scale, 'float32')
+        _output_scale = relay.const(_output_scale, "float32")
         data_sum = inputs[-5]
         data_sum_min = inputs[-2]
         data_sum_max = inputs[-1]
 
         data_sum_dtype = _infer_type(data_sum).checked_type.dtype
-        data_sum_scale = \
-            get_mkldnn_uint8_scale(data_sum_min, data_sum_max) if data_sum_dtype == 'uint8' \
+        data_sum_scale = (
+            get_mkldnn_uint8_scale(data_sum_min, data_sum_max)
+            if data_sum_dtype == "uint8"
             else get_mkldnn_int8_scale(data_sum_min, data_sum_max)
-        data_sum_scale = relay.const(data_sum_scale, 'float32')
-        zero_point = relay.const(0, 'int32')
+        )
+        data_sum_scale = relay.const(data_sum_scale, "float32")
+        zero_point = relay.const(0, "int32")
 
         # Save one requantize if the previous expr already has a requantize node. This also improves
         # little bit with accuracy.
         if isinstance(data_sum, _expr.Call) and data_sum.op.name == "qnn.requantize":
             prev_input, prev_scale, prev_zero_point = data_sum.args[0:3]
             prev_axis = data_sum.attrs.axis
-            data_sum = relay.qnn.op.requantize(prev_input,
-                                               input_scale=prev_scale,
-                                               input_zero_point=prev_zero_point,
-                                               output_scale=_output_scale,
-                                               output_zero_point=zero_point,
-                                               axis=prev_axis,
-                                               out_dtype='int32')
+            data_sum = relay.qnn.op.requantize(
+                prev_input,
+                input_scale=prev_scale,
+                input_zero_point=prev_zero_point,
+                output_scale=_output_scale,
+                output_zero_point=zero_point,
+                axis=prev_axis,
+                out_dtype="int32",
+            )
         else:
-            data_sum = relay.qnn.op.requantize(data_sum,
-                                               input_scale=data_sum_scale,
-                                               input_zero_point=zero_point,
-                                               output_scale=_output_scale,
-                                               output_zero_point=zero_point,
-                                               out_dtype='int32')
+            data_sum = relay.qnn.op.requantize(
+                data_sum,
+                input_scale=data_sum_scale,
+                input_zero_point=zero_point,
+                output_scale=_output_scale,
+                output_zero_point=zero_point,
+                out_dtype="int32",
+            )
 
         # 2) Add two int32 tensors.
         _res = relay.add(_res, data_sum)
 
         # 3) Clip/cast to change the out dtype.
-        _res = relay.clip(_res,
-                          a_min=float(tvm.tir.op.min_value(out_dtype).value),
-                          a_max=float(tvm.tir.op.max_value(out_dtype).value))
+        _res = relay.clip(
+            _res,
+            a_min=float(tvm.tir.op.min_value(out_dtype).value),
+            a_max=float(tvm.tir.op.max_value(out_dtype).value),
+        )
         _res = relay.cast(_res, out_dtype)
         return _res
 
     def _parse():
         assert len(subgraphs) == 1
-        subgraph_conv_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, 'Convolution')['attrs'])
+        subgraph_conv_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, "Convolution")["attrs"])
 
-        is_quantized = attrs.get_bool('quantized', False)
+        is_quantized = attrs.get_bool("quantized", False)
         if is_quantized:
             # The MKLDNN has a quantized convolution subgraph. There are many different arguments
             # that are taken into account to parse the subgraph.
@@ -1767,8 +1832,8 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
             #   6) Handle sum of quantized tensors if needed. Or just Requantize.
 
             has_bias = not subgraph_conv_attrs.get_bool("no_bias", False)
-            has_sum = attrs.get_bool('with_sum', False)
-            has_bn = attrs.get_bool('with_bn', False)
+            has_sum = attrs.get_bool("with_sum", False)
+            has_bn = attrs.get_bool("with_bn", False)
 
             ###############################################
             #   1) Get the input data scale and zero point.
@@ -1782,9 +1847,9 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
                 data_max_idx = -3
 
             data = inputs[0]
-            data_scale, data_zero_point = \
-                _get_data_scale_and_zp(data, inputs, data_min_idx, data_max_idx)
-
+            data_scale, data_zero_point = _get_data_scale_and_zp(
+                data, inputs, data_min_idx, data_max_idx
+            )
 
             #############################
             #   2) Extract the BN params.
@@ -1802,10 +1867,9 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
                 bn_running_mean_idx = bn_start_idx + 2
                 bn_running_var_idx = bn_start_idx + 3
 
-                bn_scale, bn_shift = _get_bn_alpha_coeff(bn_gamma_idx,
-                                                         bn_beta_idx,
-                                                         bn_running_mean_idx,
-                                                         bn_running_var_idx)
+                bn_scale, bn_shift = _get_bn_alpha_coeff(
+                    bn_gamma_idx, bn_beta_idx, bn_running_mean_idx, bn_running_var_idx
+                )
 
             ########################################
             #   3) Fold the BN into kernel and bias.
@@ -1815,15 +1879,23 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
             #######################################################################
             #   4) Fold BN params into kernel. Get quantized kernel and QNN params.
             #######################################################################
-            kernel, kernel_vector_scale, kernel_zero_point = _get_quantized_kernel(kernel, bias,
-                                                                                   data_scale)
+            kernel, kernel_vector_scale, kernel_zero_point = _get_quantized_kernel(
+                kernel, bias, data_scale
+            )
 
             ##########################
             #   5) Call QNN conv2d op.
             ##########################
             conv2d_attrs = _get_mx_conv2d_attrs(subgraph_conv_attrs)
-            res = _get_qnn_conv2d(data, kernel, data_zero_point, kernel_zero_point, data_scale,
-                                  kernel_vector_scale, conv2d_attrs)
+            res = _get_qnn_conv2d(
+                data,
+                kernel,
+                data_zero_point,
+                kernel_zero_point,
+                data_scale,
+                kernel_vector_scale,
+                conv2d_attrs,
+            )
 
             ###############################################
             #   6) Fold BN params into bias. Call bias_add.
@@ -1836,19 +1908,20 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
             #####################################################################
             #   7) Handle sum of quantized tensors if needed. Or just Requantize.
             #####################################################################
-            min_output_range = attrs.get_float('min_calib_range')
-            max_output_range = attrs.get_float('max_calib_range')
-            output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(min_output_range,
-                                                                                 max_output_range)
+            min_output_range = attrs.get_float("min_calib_range")
+            max_output_range = attrs.get_float("max_calib_range")
+            output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(
+                min_output_range, max_output_range
+            )
 
             # QNN conv2d output scale is product of data_scale and kernel_vector_scale
             input_scale = data_scale * kernel_vector_scale
-            if attrs.get_bool('with_sum', False):
+            if attrs.get_bool("with_sum", False):
                 # There is a second tensor that has to be added to the QNN conv2d output. Therefore,
                 # the QNN conv2d is first requantized to output scale with int32 precision. The
                 # second tensor will also be requantized to output scale with int32 precision,
                 # followed by an add operator.
-                res = _get_requantized_op(res, input_scale, output_scale, 'int32')
+                res = _get_requantized_op(res, input_scale, output_scale, "int32")
                 res = _get_sum(res, output_scale, out_dtype)
             else:
                 # Get the requantized conv output
@@ -1857,7 +1930,7 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
             return res, min_output_range, max_output_range
         else:
             res = _mx_conv(inputs, subgraph_conv_attrs)
-            has_fused_relu = _has_fused_activation(attrs, ['relu'])
+            has_fused_relu = _has_fused_activation(attrs, ["relu"])
             if has_fused_relu:
                 res = _op.nn.relu(res)
             return res
@@ -1866,7 +1939,7 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
 
 
 def _qnn_flatten(inputs, attrs):
-    #pylint: disable=unused-argument
+    # pylint: disable=unused-argument
     data = inputs[0]
     output_min = inputs[1]
     output_max = inputs[2]
@@ -1875,7 +1948,7 @@ def _qnn_flatten(inputs, attrs):
 
 
 def _qnn_dequantize(inputs, attrs):
-    #pylint: disable=unused-argument
+    # pylint: disable=unused-argument
     data = inputs[0]
     input_min = inputs[1]
     input_max = inputs[2]
@@ -1901,10 +1974,10 @@ def _qnn_pooling(inputs, attrs):
     data = inputs[0]
     data_dtype = _infer_type(data).checked_type.dtype
     pool_type = attrs.get_str("pool_type")
-    if data_dtype in ('int8', 'uint8') and pool_type != 'max':
-        data = _op.cast(data, 'int32')
+    if data_dtype in ("int8", "uint8") and pool_type != "max":
+        data = _op.cast(data, "int32")
     res = _mx_pooling([data, input_min, input_max], attrs)
-    if data_dtype in ('int8', 'uint8') and pool_type != 'max':
+    if data_dtype in ("int8", "uint8") and pool_type != "max":
         res = _op.cast(res, data_dtype)
     return res, input_min, input_max
 
@@ -1917,37 +1990,41 @@ def _qnn_batch_norm(inputs, attrs):
     data_min_idx, data_max_idx = (-2, -1)
     data_min, data_max = inputs[data_min_idx], inputs[data_max_idx]
     data_dtype = _infer_type(data).checked_type.dtype
-    data_scale = get_mkldnn_uint8_scale(data_min, data_max) if data_dtype == 'uint8' \
+    data_scale = (
+        get_mkldnn_uint8_scale(data_min, data_max)
+        if data_dtype == "uint8"
         else get_mkldnn_int8_scale(data_min, data_max)
+    )
     data_zp = 0
-    data = relay.qnn.op.dequantize(data,
-                                   relay.const(data_scale, 'float32'),
-                                   relay.const(data_zp, 'int32'))
+    data = relay.qnn.op.dequantize(
+        data, relay.const(data_scale, "float32"), relay.const(data_zp, "int32")
+    )
 
     # Run BN. The last 4 inputs are same as before.
     new_inputs = [data, *inputs[1:5]]
     res = _mx_batch_norm(new_inputs, attrs)
 
     # Quantize the result
-    min_output_range = attrs.get_float('min_calib_range')
-    max_output_range = attrs.get_float('max_calib_range')
-    output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(min_output_range,
-                                                                         max_output_range)
-    res = relay.qnn.op.quantize(res[0],
-                                relay.const(output_scale, 'float32'),
-                                relay.const(0, 'int32'),
-                                out_dtype=out_dtype)
+    min_output_range = attrs.get_float("min_calib_range")
+    max_output_range = attrs.get_float("max_calib_range")
+    output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(
+        min_output_range, max_output_range
+    )
+    res = relay.qnn.op.quantize(
+        res[0], relay.const(output_scale, "float32"), relay.const(0, "int32"), out_dtype=out_dtype
+    )
     return res, min_output_range, max_output_range
 
 
 def _qnn_fully_connected(inputs, attrs, subgraphs, params):
-
     def _get_input_scale_zp(_data_dtype, _inputs, _has_bias):
         data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3)
         data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx]
-        _data_scale = get_mkldnn_uint8_scale(data_min, data_max) \
-            if _data_dtype == 'uint8' \
+        _data_scale = (
+            get_mkldnn_uint8_scale(data_min, data_max)
+            if _data_dtype == "uint8"
             else get_mkldnn_int8_scale(data_min, data_max)
+        )
         _data_zp = 0
         return _data_scale, _data_zp
 
@@ -1955,8 +2032,9 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         kernel_dtype = _infer_type(_kernel).checked_type.dtype
 
         if kernel_dtype != "int8":
-            raise tvm.error.OpNotImplemented(\
-                "Tensor wise quantized expects weights in int8 data type")
+            raise tvm.error.OpNotImplemented(
+                "Tensor wise quantized expects weights in int8 data type"
+            )
 
         if isinstance(_kernel, tvm.relay.Call) and _kernel.op.name == "qnn.quantize":
             _kernel_scale = _kernel.args[1].data.asnumpy()
@@ -1968,37 +2046,38 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         kernel_min = params[kernel_min_name].asnumpy()[0]
         kernel_max_name = _get_name(_inputs[kernel_max_idx])
         kernel_max = params[kernel_max_name].asnumpy()[0]
-        _kernel_scale = get_mkldnn_uint8_scale(kernel_min, kernel_max) \
-            if kernel_dtype == 'uint8' \
+        _kernel_scale = (
+            get_mkldnn_uint8_scale(kernel_min, kernel_max)
+            if kernel_dtype == "uint8"
             else get_mkldnn_int8_scale(kernel_min, kernel_max)
+        )
         _kernel_zp = 0
         return _kernel_scale, _kernel_zp
 
     def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale):
         kernel_dtype = _infer_type(_kernel).checked_type.dtype
         if kernel_dtype != "float32":
-            raise tvm.error.OpNotImplemented(\
-                "Channel wise quantized expects weights in float32 data type")
+            raise tvm.error.OpNotImplemented(
+                "Channel wise quantized expects weights in float32 data type"
+            )
 
         # Get the FP32 values, calculate min/max and then channel quantize them
         np_kernel = _infer_value(_kernel, params).asnumpy()
-        kernel_channel_min = np.amin(np_kernel, axis=(1, ))
-        kernel_channel_max = np.amax(np_kernel, axis=(1, ))
+        kernel_channel_min = np.amin(np_kernel, axis=(1,))
+        kernel_channel_max = np.amax(np_kernel, axis=(1,))
 
         np_bias = None
         if _bias is not None:
             np_bias = _infer_value(_bias, params).asnumpy()
-        return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel,
-                                                                  np_bias,
-                                                                  kernel_channel_min,
-                                                                  kernel_channel_max,
-                                                                  _data_scale)
+        return quantize_conv_weights_bias_channel_mkldnn_from_var(
+            _kernel, np_bias, kernel_channel_min, kernel_channel_max, _data_scale
+        )
 
     def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
         _bias = _inputs[2]
         if isinstance(_bias, tvm.relay.Call) and _bias.op.name == "qnn.quantize":
             _bias_scale = _bias.args[1].data.asnumpy()
-            _bias_requantize_scale = _bias_scale/(_data_scale * _kernel_scale)
+            _bias_requantize_scale = _bias_scale / (_data_scale * _kernel_scale)
             _bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32")
             return _bias_requantize_scale
 
@@ -2007,13 +2086,13 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         bias_max_name = _get_name(_inputs[8])
         bias_max = params[bias_max_name].asnumpy()[0]
         bias_scale = get_mkldnn_int8_scale(bias_min, bias_max)
-        _bias_requantize_scale = bias_scale/(_data_scale * _kernel_scale)
+        _bias_requantize_scale = bias_scale / (_data_scale * _kernel_scale)
         _bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32")
         return _bias_requantize_scale
 
-    is_quantized = attrs.get_bool('quantized', False)
-    with_relu = attrs.get_bool('with_relu', False)
-    subgraph_dense_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, "FullyConnected")['attrs'])
+    is_quantized = attrs.get_bool("quantized", False)
+    with_relu = attrs.get_bool("with_relu", False)
+    subgraph_dense_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, "FullyConnected")["attrs"])
     if not is_quantized:
         res = _mx_fully_connected(inputs, subgraph_dense_attrs)
         if with_relu:
@@ -2023,8 +2102,8 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         has_bias = not subgraph_dense_attrs.get_bool("no_bias", False)
         units = subgraph_dense_attrs.get_int("num_hidden")
         is_flatten = subgraph_dense_attrs.get_bool("flatten", True)
-        enable_float_output = attrs.get_bool('enable_float_output', False)
-        is_channel_quantized = attrs.get_bool('channel_wise_quantize', False)
+        enable_float_output = attrs.get_bool("enable_float_output", False)
+        is_channel_quantized = attrs.get_bool("channel_wise_quantize", False)
 
         ########################
         # Get data, kernel, bias
@@ -2053,23 +2132,26 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         # Get weight scale and zero point
         #################################
         if is_channel_quantized:
-            kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(kernel,
-                                                                                     bias,
-                                                                                     data_scale)
+            kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(
+                kernel, bias, data_scale
+            )
         else:
-            kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(kernel, inputs,
-                                                                            has_bias)
+            kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(
+                kernel, inputs, has_bias
+            )
 
         ################
         # Call QNN dense
         ################
-        res = relay.qnn.op.dense(data,
-                                 kernel,
-                                 input_zero_point=relay.const(data_zp, 'int32'),
-                                 kernel_zero_point=relay.const(kernel_zp, 'int32'),
-                                 input_scale=relay.const(data_scale, 'float32'),
-                                 kernel_scale=relay.const(kernel_scale, 'float32'),
-                                 units=units)
+        res = relay.qnn.op.dense(
+            data,
+            kernel,
+            input_zero_point=relay.const(data_zp, "int32"),
+            kernel_zero_point=relay.const(kernel_zp, "int32"),
+            input_scale=relay.const(data_scale, "float32"),
+            kernel_scale=relay.const(kernel_scale, "float32"),
+            units=units,
+        )
 
         #################
         # Handle bias add
@@ -2081,15 +2163,17 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
                 res = _op.nn.bias_add(res, int32_bias, axis=-1)
             else:
                 bias_data = inputs[2]
-                bias_requantize_scale = \
-                    _get_bias_requantize_scale(inputs, data_scale, kernel_scale)
-                multiplied_bias = \
-                    _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
+                bias_requantize_scale = _get_bias_requantize_scale(inputs, data_scale, kernel_scale)
+                multiplied_bias = _op.multiply(
+                    _op.cast(bias_data, "float32"), bias_requantize_scale
+                )
                 rounded_bias = _op.round(multiplied_bias)
-                clipped_bias = _op.clip(rounded_bias,
-                                        a_min=tvm.tir.op.min_value('int32').value,
-                                        a_max=tvm.tir.op.max_value('int32').value)
-                requantized_bias = _op.cast(clipped_bias, 'int32')
+                clipped_bias = _op.clip(
+                    rounded_bias,
+                    a_min=tvm.tir.op.min_value("int32").value,
+                    a_max=tvm.tir.op.max_value("int32").value,
+                )
+                requantized_bias = _op.cast(clipped_bias, "int32")
                 res = _op.nn.bias_add(res, requantized_bias, axis=-1)
 
         ##############################################
@@ -2097,35 +2181,35 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
         ##############################################
         if enable_float_output:
             output_scale = np.float32(data_scale * kernel_scale)
-            res = relay.qnn.op.dequantize(res,
-                                          relay.const(output_scale),
-                                          input_zero_point=relay.const(0, 'int32'),
-                                          axis=1)
+            res = relay.qnn.op.dequantize(
+                res, relay.const(output_scale), input_zero_point=relay.const(0, "int32"), axis=1
+            )
             if with_relu:
                 res = _op.nn.relu(res)
         else:
 
             if is_channel_quantized:
-                raise tvm.error.OpNotImplemented(\
-                    "Channel wise quantized dense with non float output is not supported yet")
-            out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
+                raise tvm.error.OpNotImplemented(
+                    "Channel wise quantized dense with non float output is not supported yet"
+                )
+            out_dtype = "uint8" if attrs.get_bool("with_relu", False) else "int8"
             input_scale = np.float32(data_scale * kernel_scale)
-            min_output_range = attrs.get_float('min_calib_range')
-            max_output_range = attrs.get_float('max_calib_range')
-            output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range,
-                                                                max_output_range,
-                                                                out_dtype)
+            min_output_range = attrs.get_float("min_calib_range")
+            max_output_range = attrs.get_float("max_calib_range")
+            output_scale = get_mkldnn_requantize_scale_outDtype(
+                min_output_range, max_output_range, out_dtype
+            )
             res = relay.qnn.op.requantize(
                 res,
-                input_scale=relay.const(input_scale, 'float32'),
-                input_zero_point=relay.const(0, 'int32'),
-                output_scale=relay.const(output_scale, 'float32'),
-                output_zero_point=relay.const(0, 'int32'),
-                out_dtype=out_dtype)
+                input_scale=relay.const(input_scale, "float32"),
+                input_zero_point=relay.const(0, "int32"),
+                output_scale=relay.const(output_scale, "float32"),
+                output_zero_point=relay.const(0, "int32"),
+                out_dtype=out_dtype,
+            )
             if with_relu:
                 res = _op.nn.relu(res)
 
-
         ##############################
         # Handle for shape of data > 2
         ##############################
@@ -2138,6 +2222,7 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
             return res
         return res, min_output_range, max_output_range
 
+
 def _mx_broadcast_to(inputs, attrs):
     data = inputs[0]
     tgt_shape = attrs.get_int_tuple("shape", [])
@@ -2161,6 +2246,7 @@ def _mx_broadcast_logical(logical_op):
         rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1]
 
         return _op.cast(logical_op(lhs, rhs), lhs_type)
+
     return impl
 
 
@@ -2172,27 +2258,24 @@ def _mx_npi_transpose(inputs, attrs):
 
 
 def _mx_npi_pad(inputs, attrs):
-    pad_mode = attrs.get_str('mode', None)
+    pad_mode = attrs.get_str("mode", None)
     if pad_mode is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "mode" not found in operator pad.')
-    if pad_mode not in ['constant', 'edge', 'reflect']:
-        raise tvm.error.OpAttributeInvalid(
-            'Value ' + mode + ' in attribute "mode" is not valid')
-    pad_width = attrs.get_int_tuple('pad_width', None)
+        raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.')
+    if pad_mode not in ["constant", "edge", "reflect"]:
+        raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid')
+    pad_width = attrs.get_int_tuple("pad_width", None)
     if pad_width is None:
-        raise tvm.error.OpAttributeRequired(
-            'Attribute "pad_width" not found in operator pad.')
+        raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.')
     if None in pad_width:
         raise tvm.error.OpAttributeInvalid(
-            'Value None in attribute "pad_width" of operator Slice is not valid.')
-    constant_values = attrs.get_float('constant_values', 0.0)
+            'Value None in attribute "pad_width" of operator Slice is not valid.'
+        )
+    constant_values = attrs.get_float("constant_values", 0.0)
     padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))
 
-    return _op.nn.pad(data=inputs[0],
-                      pad_width=padding,
-                      pad_value=constant_values,
-                      pad_mode=pad_mode)
+    return _op.nn.pad(
+        data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode
+    )
 
 
 def _mx_npi_concatenate(inputs, attrs):
@@ -2220,8 +2303,7 @@ def _mx_npx_reshape(inputs, attrs):
         elif num == -6:
             new_shape_list.append(-4)
         else:
-            raise tvm.error.OpAttributeInvalid(
-                'Shape dimension %d is not supported' % num)
+            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
     shape = tuple(new_shape_list)
     if reverse:
         return _op.reverse_reshape(inputs[0], newshape=shape)
@@ -2279,196 +2361,196 @@ _identity_list = [
 ]
 
 _convert_map = {
-    "_copy"                  : _rename(_op.copy),
-    "relu"                   : _rename(_op.nn.relu),
-    "broadcast_add"          : _rename(_op.add),
-    "broadcast_plus"         : _rename(_op.add),
-    "broadcast_sub"          : _rename(_op.subtract),
-    "broadcast_minus"        : _rename(_op.subtract),
-    "broadcast_mul"          : _rename(_op.multiply),
-    "broadcast_div"          : _rename(_op.divide),
-    "broadcast_mod"          : _rename(_op.mod),
-    "broadcast_maximum"      : _rename(_op.maximum),
-    "broadcast_minimum"      : _rename(_op.minimum),
-    "broadcast_power"        : _rename(_op.power),
-    "arccos"                 : _rename(_op.acos),
-    "arcsin"                 : _rename(_op.asin),
-    "arctan"                 : _rename(_op.atan),
-    "arccosh"                : _rename(_op.acosh),
-    "arcsinh"                : _rename(_op.asinh),
-    "arctanh"                : _rename(_op.atanh),
-    "broadcast_equal"        : _mx_compare(_op.equal, _rename),
-    "broadcast_not_equal"    : _mx_compare(_op.not_equal, _rename),
-    "broadcast_greater"      : _mx_compare(_op.greater, _rename),
+    "_copy": _rename(_op.copy),
+    "relu": _rename(_op.nn.relu),
+    "broadcast_add": _rename(_op.add),
+    "broadcast_plus": _rename(_op.add),
+    "broadcast_sub": _rename(_op.subtract),
+    "broadcast_minus": _rename(_op.subtract),
+    "broadcast_mul": _rename(_op.multiply),
+    "broadcast_div": _rename(_op.divide),
+    "broadcast_mod": _rename(_op.mod),
+    "broadcast_maximum": _rename(_op.maximum),
+    "broadcast_minimum": _rename(_op.minimum),
+    "broadcast_power": _rename(_op.power),
+    "arccos": _rename(_op.acos),
+    "arcsin": _rename(_op.asin),
+    "arctan": _rename(_op.atan),
+    "arccosh": _rename(_op.acosh),
+    "arcsinh": _rename(_op.asinh),
+    "arctanh": _rename(_op.atanh),
+    "broadcast_equal": _mx_compare(_op.equal, _rename),
+    "broadcast_not_equal": _mx_compare(_op.not_equal, _rename),
+    "broadcast_greater": _mx_compare(_op.greater, _rename),
     "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
-    "broadcast_lesser"       : _mx_compare(_op.less, _rename),
-    "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
-    "broadcast_logical_or"   : _mx_broadcast_logical(_op.logical_or),
-    "broadcast_logical_and"  : _mx_broadcast_logical(_op.logical_and),
-    "broadcast_logical_xor"  : _mx_broadcast_logical(_op.logical_xor),
-    "broadcast_to"           : _mx_broadcast_to,
-    "logical_not"            : _mx_logical_not,
-    "_equal"                 : _mx_compare(_op.equal, _rename),
-    "_not_equal"             : _mx_compare(_op.not_equal, _rename),
-    "_greater"               : _mx_compare(_op.greater, _rename),
-    "_greater_equal"         : _mx_compare(_op.greater_equal, _rename),
-    "_lesser"                : _mx_compare(_op.less, _rename),
-    "_lesser_equal"          : _mx_compare(_op.less_equal, _rename),
-    "elemwise_add"           : _rename(_op.add),
-    "elemwise_sub"           : _rename(_op.subtract),
-    "elemwise_mul"           : _rename(_op.multiply),
-    "elemwise_div"           : _rename(_op.divide),
-    "_maximum"               : _rename(_op.maximum),
-    "_minimum"               : _rename(_op.minimum),
-    "flatten"                : _rename(_op.nn.batch_flatten),
-    "Flatten"                : _rename(_op.nn.batch_flatten),
+    "broadcast_lesser": _mx_compare(_op.less, _rename),
+    "broadcast_lesser_equal": _mx_compare(_op.less_equal, _rename),
+    "broadcast_logical_or": _mx_broadcast_logical(_op.logical_or),
+    "broadcast_logical_and": _mx_broadcast_logical(_op.logical_and),
+    "broadcast_logical_xor": _mx_broadcast_logical(_op.logical_xor),
+    "broadcast_to": _mx_broadcast_to,
+    "logical_not": _mx_logical_not,
+    "_equal": _mx_compare(_op.equal, _rename),
+    "_not_equal": _mx_compare(_op.not_equal, _rename),
+    "_greater": _mx_compare(_op.greater, _rename),
+    "_greater_equal": _mx_compare(_op.greater_equal, _rename),
+    "_lesser": _mx_compare(_op.less, _rename),
+    "_lesser_equal": _mx_compare(_op.less_equal, _rename),
+    "elemwise_add": _rename(_op.add),
+    "elemwise_sub": _rename(_op.subtract),
+    "elemwise_mul": _rename(_op.multiply),
+    "elemwise_div": _rename(_op.divide),
+    "_maximum": _rename(_op.maximum),
+    "_minimum": _rename(_op.minimum),
+    "flatten": _rename(_op.nn.batch_flatten),
+    "Flatten": _rename(_op.nn.batch_flatten),
     # scalar power
-    "square"                 : _mx_make_power(2),
-    "rsqrt"                  : _mx_make_power(-1/2),
-    "cbrt"                   : _mx_make_power(1/3),
-    "rcbrt"                  : _mx_make_power(-1/3),
-    "__pow_scalar__"         : _binop_scalar(_op.power),
-    "_power_scalar"          : _binop_scalar(_op.power),
-    "__rsub_scalar__"        : _rbinop_scalar(_op.subtract),
-    "_rminus_scalar"         : _rbinop_scalar(_op.subtract),
-    "__rdiv_scalar__"        : _rbinop_scalar(_op.divide),
-    "_rdiv_scalar"           : _rbinop_scalar(_op.divide),
-    "__rpow_scalar__"        : _rbinop_scalar(_op.power),
+    "square": _mx_make_power(2),
+    "rsqrt": _mx_make_power(-1 / 2),
+    "cbrt": _mx_make_power(1 / 3),
+    "rcbrt": _mx_make_power(-1 / 3),
+    "__pow_scalar__": _binop_scalar(_op.power),
+    "_power_scalar": _binop_scalar(_op.power),
+    "__rsub_scalar__": _rbinop_scalar(_op.subtract),
+    "_rminus_scalar": _rbinop_scalar(_op.subtract),
+    "__rdiv_scalar__": _rbinop_scalar(_op.divide),
+    "_rdiv_scalar": _rbinop_scalar(_op.divide),
+    "__rpow_scalar__": _rbinop_scalar(_op.power),
     # scalar op
-    "__add_scalar__"         : _binop_scalar(_op.add),
-    "_plus_scalar"           : _binop_scalar(_op.add),
-    "__sub_scalar__"         : _binop_scalar(_op.subtract),
-    "_minus_scalar"          : _binop_scalar(_op.subtract),
-    "__mul_scalar__"         : _binop_scalar(_op.multiply),
-    "_mul_scalar"            : _binop_scalar(_op.multiply),
-    "__div_scalar__"         : _binop_scalar(_op.divide),
-    "_div_scalar"            : _binop_scalar(_op.divide),
-    "log2"                   : _mx_make_logarithm(2),
-    "log10"                  : _mx_make_logarithm(10),
-    "log1p"                  : _mx_log1p,
-    "expm1"                  : _mx_expm1,
-    "_equal_scalar"          : _mx_compare(_op.equal, _binop_scalar),
-    "_not_equal_scalar"      : _mx_compare(_op.not_equal, _binop_scalar),
-    "_greater_scalar"        : _mx_compare(_op.greater, _binop_scalar),
-    "_greater_equal_scalar"  : _mx_compare(_op.greater_equal, _binop_scalar),
-    "_lesser_scalar"         : _mx_compare(_op.less, _binop_scalar),
-    "_lesser_equal_scalar"   : _mx_compare(_op.less_equal, _binop_scalar),
-    "_maximum_scalar"        : _binop_scalar(_op.maximum),
-    "_minimum_scalar"        : _binop_scalar(_op.minimum),
+    "__add_scalar__": _binop_scalar(_op.add),
+    "_plus_scalar": _binop_scalar(_op.add),
+    "__sub_scalar__": _binop_scalar(_op.subtract),
+    "_minus_scalar": _binop_scalar(_op.subtract),
+    "__mul_scalar__": _binop_scalar(_op.multiply),
+    "_mul_scalar": _binop_scalar(_op.multiply),
+    "__div_scalar__": _binop_scalar(_op.divide),
+    "_div_scalar": _binop_scalar(_op.divide),
+    "log2": _mx_make_logarithm(2),
+    "log10": _mx_make_logarithm(10),
+    "log1p": _mx_log1p,
+    "expm1": _mx_expm1,
+    "_equal_scalar": _mx_compare(_op.equal, _binop_scalar),
+    "_not_equal_scalar": _mx_compare(_op.not_equal, _binop_scalar),
+    "_greater_scalar": _mx_compare(_op.greater, _binop_scalar),
+    "_greater_equal_scalar": _mx_compare(_op.greater_equal, _binop_scalar),
+    "_lesser_scalar": _mx_compare(_op.less, _binop_scalar),
+    "_lesser_equal_scalar": _mx_compare(_op.less_equal, _binop_scalar),
+    "_maximum_scalar": _binop_scalar(_op.maximum),
+    "_minimum_scalar": _binop_scalar(_op.minimum),
     # reduction ops
-    "mean"          : _reduce(_op.mean),
-    "max"           : _reduce(_op.max),
-    "min"           : _reduce(_op.min),
-    "sum"           : _reduce(_op.sum),
-    "max_axis"      : _reduce(_op.max),
-    "min_axis"      : _reduce(_op.min),
-    "sum_axis"      : _reduce(_op.sum),
-    "argmax"        : _arg_reduce(_op.argmax),
-    "argmin"        : _arg_reduce(_op.argmin),
+    "mean": _reduce(_op.mean),
+    "max": _reduce(_op.max),
+    "min": _reduce(_op.min),
+    "sum": _reduce(_op.sum),
+    "max_axis": _reduce(_op.max),
+    "min_axis": _reduce(_op.min),
+    "sum_axis": _reduce(_op.sum),
+    "argmax": _arg_reduce(_op.argmax),
+    "argmin": _arg_reduce(_op.argmin),
     # init ops
-    "_ones"         : _init_op(_op.ones),
+    "_ones": _init_op(_op.ones),
     # softmax
-    "softmax"       : _softmax_op(_op.nn.softmax),
-    "log_softmax"   : _softmax_op(_op.nn.log_softmax),
-    "Softmax"       : _softmax_op(_op.nn.softmax),
-    "softsign"      : _mx_softsign,
-    "softmin"       : _mx_softmin,
-    "hard_sigmoid"  : _mx_hard_sigmoid,
-    "reciprocal"    : _mx_reciprocal,
+    "softmax": _softmax_op(_op.nn.softmax),
+    "log_softmax": _softmax_op(_op.nn.log_softmax),
+    "Softmax": _softmax_op(_op.nn.softmax),
+    "softsign": _mx_softsign,
+    "softmin": _mx_softmin,
+    "hard_sigmoid": _mx_hard_sigmoid,
+    "reciprocal": _mx_reciprocal,
     # per op specialization
-    "Reshape"       : _reshape,
-    "reshape"       : _reshape,
-    "Cast"          : _cast,
-    "amp_cast"      : _cast,
-    "amp_multicast" : _mx_amp_multicast,
-    "clip"          : _clip,
-    "transpose"     : _transpose,
-    "UpSampling"    : _upsampling,
-    "add_n"         : _elemwise_sum,
+    "Reshape": _reshape,
+    "reshape": _reshape,
+    "Cast": _cast,
+    "amp_cast": _cast,
+    "amp_multicast": _mx_amp_multicast,
+    "clip": _clip,
+    "transpose": _transpose,
+    "UpSampling": _upsampling,
+    "add_n": _elemwise_sum,
     # MXNet specific implementations
-    "_zeros"        : _mx_zeros,
+    "_zeros": _mx_zeros,
     "FullyConnected": _mx_fully_connected,
-    "Activation"    : _mx_activations,
-    "Convolution"   : _mx_conv,
+    "Activation": _mx_activations,
+    "Convolution": _mx_conv,
     "Convolution_v1": _mx_conv2d,
-    "Deconvolution" : _mx_conv_transpose,
-    "Pooling"       : _mx_pooling,
-    "Pooling_v1"    : _mx_pooling,
-    "Dropout"       : _mx_dropout,
-    "BatchNorm"     : _mx_batch_norm,
-    "BatchNorm_v1"  : _mx_batch_norm,
-    "_contrib_SyncBatchNorm" : _mx_batch_norm,
-    "InstanceNorm"  : _mx_instance_norm,
-    "LayerNorm"     : _mx_layer_norm,
-    "LRN"           : _mx_lrn,
-    "L2Normalization"  : _mx_l2_normalize,
-    "slice"         : _mx_slice,
-    "slice_like"    : _mx_slice_like,
-    "slice_axis"    : _mx_slice_axis,
-    "SliceChannel"  : _mx_split,
-    "split"         : _mx_split,
-    "_split_v2"     : _mx_split_v2,
-    "SwapAxis"      : _mx_swap_axis,
-    "expand_dims"   : _mx_expand_dims,
-    "Concat"        : _mx_concat,
-    "concat"        : _mx_concat,
-    "stack"         : _mx_stack,
-    "batch_dot"     : _mx_batch_dot,
-    "LeakyReLU"     : _mx_leaky_relu,
-    "_arange"       : _mx_arange,
-    "_full"         : _mx_full,
-    "repeat"        : _mx_repeat,
-    "tile"          : _mx_tile,
-    "pad"           : _mx_pad,
-    "Pad"           : _mx_pad,
-    "take"          : _mx_take,
-    "gather_nd"     : _mx_gather_nd,
-    "reverse"       : _mx_reverse,
-    "SequenceReverse"  : _mx_sequence_reverse,
-    "squeeze"       : _mx_squeeze,
+    "Deconvolution": _mx_conv_transpose,
+    "Pooling": _mx_pooling,
+    "Pooling_v1": _mx_pooling,
+    "Dropout": _mx_dropout,
+    "BatchNorm": _mx_batch_norm,
+    "BatchNorm_v1": _mx_batch_norm,
+    "_contrib_SyncBatchNorm": _mx_batch_norm,
+    "InstanceNorm": _mx_instance_norm,
+    "LayerNorm": _mx_layer_norm,
+    "LRN": _mx_lrn,
+    "L2Normalization": _mx_l2_normalize,
+    "slice": _mx_slice,
+    "slice_like": _mx_slice_like,
+    "slice_axis": _mx_slice_axis,
+    "SliceChannel": _mx_split,
+    "split": _mx_split,
+    "_split_v2": _mx_split_v2,
+    "SwapAxis": _mx_swap_axis,
+    "expand_dims": _mx_expand_dims,
+    "Concat": _mx_concat,
+    "concat": _mx_concat,
+    "stack": _mx_stack,
+    "batch_dot": _mx_batch_dot,
+    "LeakyReLU": _mx_leaky_relu,
+    "_arange": _mx_arange,
+    "_full": _mx_full,
+    "repeat": _mx_repeat,
+    "tile": _mx_tile,
+    "pad": _mx_pad,
+    "Pad": _mx_pad,
+    "take": _mx_take,
+    "gather_nd": _mx_gather_nd,
+    "reverse": _mx_reverse,
+    "SequenceReverse": _mx_sequence_reverse,
+    "squeeze": _mx_squeeze,
     "broadcast_axis": _mx_broadcast_axis,
     "broadcast_axes": _mx_broadcast_axis,
-    "BlockGrad"     : _mx_BlockGrad,
-    "shape_array"   : _mx_shape_array,
-    "Embedding"     : _mx_embedding,
-    "argsort"       : _mx_argsort,
-    "topk"          : _mx_topk,
+    "BlockGrad": _mx_BlockGrad,
+    "shape_array": _mx_shape_array,
+    "Embedding": _mx_embedding,
+    "argsort": _mx_argsort,
+    "topk": _mx_topk,
     "_unravel_index": _mx_unravel_index,
-    "SequenceMask"  : _mx_sequence_mask,
-    "SoftmaxOutput" : _mx_softmax_output,
-    "SoftmaxActivation" : _mx_softmax_activation,
-    "LinearRegressionOutput" : _mx_linear_regression_output,
-    "smooth_l1"     : _mx_smooth_l1,
-    "make_loss"     : _mx_make_loss,
+    "SequenceMask": _mx_sequence_mask,
+    "SoftmaxOutput": _mx_softmax_output,
+    "SoftmaxActivation": _mx_softmax_activation,
+    "LinearRegressionOutput": _mx_linear_regression_output,
+    "smooth_l1": _mx_smooth_l1,
+    "make_loss": _mx_make_loss,
     "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
     "_contrib_arange_like": _mx_contrib_arange_like,
-    "one_hot"           : _mx_one_hot,
-    "depth_to_space"    : _mx_depth_to_space,
-    "space_to_depth"    : _mx_space_to_depth,
-    "Correlation"       : _mx_correlation,
+    "one_hot": _mx_one_hot,
+    "depth_to_space": _mx_depth_to_space,
+    "space_to_depth": _mx_space_to_depth,
+    "Correlation": _mx_correlation,
     # vision
-    "_contrib_BilinearResize2D" : _mx_resize,
-    "_contrib_MultiBoxPrior" : _mx_multibox_prior,
-    "_contrib_MultiBoxDetection" : _mx_multibox_detection,
-    "_contrib_ROIAlign" : _mx_roi_align,
-    "ROIPooling"        : _mx_roi_pooling,
-    "_contrib_Proposal" : _mx_proposal,
-    "_contrib_MultiProposal" : _mx_proposal,
-    "_contrib_box_nms" : _mx_box_nms,
-    "_contrib_box_decode" : _mx_box_decode,
-    "_contrib_DeformableConvolution" : _mx_deformable_convolution,
-    "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
-    "GridGenerator"                 : _mx_grid_generator,
-    "BilinearSampler"               : _mx_bilinear_sampler,
+    "_contrib_BilinearResize2D": _mx_resize,
+    "_contrib_MultiBoxPrior": _mx_multibox_prior,
+    "_contrib_MultiBoxDetection": _mx_multibox_detection,
+    "_contrib_ROIAlign": _mx_roi_align,
+    "ROIPooling": _mx_roi_pooling,
+    "_contrib_Proposal": _mx_proposal,
+    "_contrib_MultiProposal": _mx_proposal,
+    "_contrib_box_nms": _mx_box_nms,
+    "_contrib_box_decode": _mx_box_decode,
+    "_contrib_DeformableConvolution": _mx_deformable_convolution,
+    "_contrib_AdaptiveAvgPooling2D": _mx_adaptive_avg_pooling,
+    "GridGenerator": _mx_grid_generator,
+    "BilinearSampler": _mx_bilinear_sampler,
     # NLP
-    "RNN"               : _mx_rnn_layer,
-    "_rnn_param_concat" : _mx_rnn_param_concat,
-    "_contrib_interleaved_matmul_selfatt_qk" : _mx_contrib_interleaved_matmul_selfatt_qk,
-    "_contrib_interleaved_matmul_selfatt_valatt" : _mx_contrib_interleaved_matmul_selfatt_valatt,
+    "RNN": _mx_rnn_layer,
+    "_rnn_param_concat": _mx_rnn_param_concat,
+    "_contrib_interleaved_matmul_selfatt_qk": _mx_contrib_interleaved_matmul_selfatt_qk,
+    "_contrib_interleaved_matmul_selfatt_valatt": _mx_contrib_interleaved_matmul_selfatt_valatt,
     # control flow
-    "_cond"             : _mx_cond,
+    "_cond": _mx_cond,
     # Depricated:
-    "Crop"              : _mx_crop_like,
+    "Crop": _mx_crop_like,
     # List of missing operators that are present in NNVMv1
     # TODO(tvm-tvm): support all operators.
     #
@@ -2476,7 +2558,7 @@ _convert_map = {
     "ring_buffer": _mx_contrib_fifo_buffer,
     # Qnn ops
     "_contrib_quantize_v2": _qnn_quantize,
-    "_contrib_quantized_concat" : _qnn_contrib_concat,
+    "_contrib_quantized_concat": _qnn_contrib_concat,
     # "_contrib_quantized_fifo_buffer": _qnn_contrib_quantized_fifo_buffer,
     "_contrib_quantized_ring_buffer": _qnn_contrib_quantized_fifo_buffer,
     "_sg_mkldnn_conv": _qnn_conv,
@@ -2484,40 +2566,40 @@ _convert_map = {
     "_contrib_dequantize": _qnn_dequantize,
     "_contrib_quantized_act": _qnn_activation,
     "_contrib_quantized_pooling": _qnn_pooling,
-    "_contrib_quantized_batch_norm" : _qnn_batch_norm,
+    "_contrib_quantized_batch_norm": _qnn_batch_norm,
     "_sg_mkldnn_fully_connected": _qnn_fully_connected,
     # numpy
-    "_np_transpose"     : _mx_npi_transpose,
-    "_npi_transpose"    : _mx_npi_transpose,
-    "_npi_pad"          : _mx_npi_pad,
-    "_npi_concatenate"  : _mx_npi_concatenate,
-    "_npx_reshape"      : _mx_npx_reshape,
-    "_np_copy"          : _rename(_op.copy),
-    "_npi_power"              : _rename(_op.power),
-    "_npi_power_scalar"       : _binop_scalar(_op.power),
-    "_npi_multiply"           : _rename(_op.multiply),
-    "_npi_multiply_scalar"    : _binop_scalar(_op.multiply),
-    "_npi_add"                : _rename(_op.add),
-    "_npi_add_scalar"         : _binop_scalar(_op.add),
-    "_npi_where_rscalar"      : _mx_npi_where_rscalar,
-    "_npi_less"               : _rename(_op.less),
-    "_npi_tanh"               : _rename(_op.tanh),
-    "_npi_true_divide_scalar" : _binop_scalar(_op.divide),
+    "_np_transpose": _mx_npi_transpose,
+    "_npi_transpose": _mx_npi_transpose,
+    "_npi_pad": _mx_npi_pad,
+    "_npi_concatenate": _mx_npi_concatenate,
+    "_npx_reshape": _mx_npx_reshape,
+    "_np_copy": _rename(_op.copy),
+    "_npi_power": _rename(_op.power),
+    "_npi_power_scalar": _binop_scalar(_op.power),
+    "_npi_multiply": _rename(_op.multiply),
+    "_npi_multiply_scalar": _binop_scalar(_op.multiply),
+    "_npi_add": _rename(_op.add),
+    "_npi_add_scalar": _binop_scalar(_op.add),
+    "_npi_where_rscalar": _mx_npi_where_rscalar,
+    "_npi_less": _rename(_op.less),
+    "_npi_tanh": _rename(_op.tanh),
+    "_npi_true_divide_scalar": _binop_scalar(_op.divide),
 }
 
 # set identity list
 _convert_map.update({k: _rename(k) for k in _identity_list})
 
-_control_flow_ops = ['_cond', '_foreach', '_while_loop']
-_qnn_subgraph_ops = ['_sg_mkldnn_conv', '_sg_mkldnn_fully_connected']
+_control_flow_ops = ["_cond", "_foreach", "_while_loop"]
+_qnn_subgraph_ops = ["_sg_mkldnn_conv", "_sg_mkldnn_fully_connected"]
 _subgraph_ops = _control_flow_ops + _qnn_subgraph_ops
-_params_ops = ['_contrib_quantized_ring_buffer']
+_params_ops = ["_contrib_quantized_ring_buffer"]
 
 
 def _get_op_params(children, attrs, op_name, node, params):
     op_params = [children, attrs]
     if op_name in _subgraph_ops:
-        subgraphs = node['subgraphs']
+        subgraphs = node["subgraphs"]
         op_params.append(subgraphs)
         if op_name in _qnn_subgraph_ops:
             op_params.append(params)
@@ -2527,7 +2609,7 @@ def _get_op_params(children, attrs, op_name, node, params):
 
 
 def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
-    #pylint: disable=unused-argument
+    # pylint: disable=unused-argument
     """Convert mxnet symbol to compatible relay Function.
 
     Reconstruct a relay Function by traversing the mxnet symbol.
@@ -2572,9 +2654,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
             unsupported[op_name] += 1
 
     if unsupported:
-        msg = '\n'.join(['{}: {}'.format(op_name, cnt) for op_name, cnt in unsupported.items()])
+        msg = "\n".join(["{}: {}".format(op_name, cnt) for op_name, cnt in unsupported.items()])
         raise tvm.error.OpNotImplemented(
-            'One or more operators are not supported in frontend MXNet:\n{}'.format(msg))
+            "One or more operators are not supported in frontend MXNet:\n{}".format(msg)
+        )
 
     for nid, node in enumerate(jnodes):
         children = [node_map[e[0]][e[1]] for e in node["inputs"]]
@@ -2599,8 +2682,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
             node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
         else:
             assert op_name in _convert_map
-            op_params = _get_op_params(children, attrs, op_name,
-                                       node, params)
+            op_params = _get_op_params(children, attrs, op_name, node, params)
             res = _convert_map[op_name](*op_params)
             if res is None:
                 # defer conversion, used in RNN state initialization
@@ -2625,23 +2707,18 @@ def _update_shape_dtype(shape, dtype, params):
     if not params:
         return shape, dtype
     shape = shape.copy()
-    shape.update({k : v.shape for k, v in params.items()})
+    shape.update({k: v.shape for k, v in params.items()})
     if isinstance(dtype, str):
         for k, v in params.items():
             if v.dtype != dtype:
-                raise ValueError(
-                    "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
+                raise ValueError("%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
     else:
         dtype = dtype.copy()
-        dtype.update({k : str(v.dtype) for k, v in params.items()})
+        dtype.update({k: str(v.dtype) for k, v in params.items()})
     return shape, dtype
 
 
-def from_mxnet(symbol,
-               shape=None,
-               dtype="float32",
-               arg_params=None,
-               aux_params=None):
+def from_mxnet(symbol, shape=None, dtype="float32", arg_params=None, aux_params=None):
     """Convert from MXNet"s model into compatible relay Function.
 
     Parameters
@@ -2670,7 +2747,7 @@ def from_mxnet(symbol,
         The parameter dict to be used by nnvm
     """
     try:
-        import mxnet as mx #pylint: disable=import-outside-toplevel
+        import mxnet as mx  # pylint: disable=import-outside-toplevel
     except ImportError as e:
         raise ImportError("{}. MXNet is required to parse symbols.".format(e))
 
index a8836ff..fd0d4c1 100644 (file)
@@ -29,9 +29,7 @@ zero_centered_uint8_quantized_range = np.float32(255.5)
 zero_centered_int8_quantized_range = np.float32(127.5)
 
 
-def _get_mkldnn_scale(data_min,
-                      data_max,
-                      quantized_range):
+def _get_mkldnn_scale(data_min, data_max, quantized_range):
     """Computes the scale as per MKLDNN specification mentioned here -
     https://intel.github.io/mkl-dnn/ex_int8_simplenet.html
 
@@ -48,29 +46,20 @@ def _get_mkldnn_scale(data_min,
     -------
     scale : A floating point number which acts as the scale for quantization.
     """
-    real_range = np.max([np.abs(np.float32(data_min)),
-                         np.abs(np.float32(data_max))])
+    real_range = np.max([np.abs(np.float32(data_min)), np.abs(np.float32(data_max))])
     scale = np.divide(quantized_range, real_range)
     scale_inverse = np.divide(1.0, scale)
     return scale_inverse
 
 
-def _quantize_scale_with_zero_centered(data,
-                                       scale,
-                                       zero_point,
-                                       out_dtype):
-    quantized_output = quantize(data,
-                                relay.const(scale, 'float32'),
-                                relay.const(zero_point, 'int32'),
-                                out_dtype=out_dtype)
+def _quantize_scale_with_zero_centered(data, scale, zero_point, out_dtype):
+    quantized_output = quantize(
+        data, relay.const(scale, "float32"), relay.const(zero_point, "int32"), out_dtype=out_dtype
+    )
     return quantized_output, scale, zero_point
 
 
-def _quantize_with_zero_centered(data,
-                                 data_min,
-                                 data_max,
-                                 quantized_range,
-                                 out_dtype):
+def _quantize_with_zero_centered(data, data_min, data_max, quantized_range, out_dtype):
     """Quantizes the given data tensor by calculating the scale
     using the MKLDNN formula `quantized_range / max(abs(data_min, data_max))`.
     Where quantized_range is 255 for uint8 and 127 for int8. The `data_min`
@@ -95,19 +84,12 @@ def _quantize_with_zero_centered(data,
         The computed result.
     """
 
-    scale = _get_mkldnn_scale(data_min,
-                              data_max,
-                              quantized_range)
+    scale = _get_mkldnn_scale(data_min, data_max, quantized_range)
     zero_point = 0
-    return _quantize_scale_with_zero_centered(data,
-                                              scale,
-                                              zero_point,
-                                              out_dtype)
+    return _quantize_scale_with_zero_centered(data, scale, zero_point, out_dtype)
 
 
-def _quantize_mkldnn_min_max_uint8(data,
-                                   data_min,
-                                   data_max):
+def _quantize_mkldnn_min_max_uint8(data, data_min, data_max):
     """Quantizes the given `data` in float32 and the given
     min and max ranges and the output data type is `uint8`.
     The method of quantizing is described here - https://tinyurl.com/y5k6fz5w.
@@ -130,16 +112,12 @@ def _quantize_mkldnn_min_max_uint8(data,
     result : tvm.relay.Expr
         The computed result.
     """
-    return _quantize_with_zero_centered(data,
-                                        data_min,
-                                        data_max,
-                                        zero_centered_uint8_quantized_range,
-                                        'uint8')
+    return _quantize_with_zero_centered(
+        data, data_min, data_max, zero_centered_uint8_quantized_range, "uint8"
+    )
 
 
-def _quantize_mkldnn_min_max_int8(data,
-                                  data_min,
-                                  data_max):
+def _quantize_mkldnn_min_max_int8(data, data_min, data_max):
     """Quantizes the given `data` in float32 and the given
     min and max ranges and the output data type is `int8`.
     The method of quantizing is described here - https://tinyurl.com/y5k6fz5w.
@@ -163,15 +141,12 @@ def _quantize_mkldnn_min_max_int8(data,
         The computed result.
     """
 
-    return _quantize_with_zero_centered(data,
-                                        data_min,
-                                        data_max,
-                                        zero_centered_int8_quantized_range,
-                                        'int8')
+    return _quantize_with_zero_centered(
+        data, data_min, data_max, zero_centered_int8_quantized_range, "int8"
+    )
 
 
-def get_mkldnn_int8_scale(range_min,
-                          range_max):
+def get_mkldnn_int8_scale(range_min, range_max):
     """Computes the quantization scale using MKLDNN specifications
     with the given range. The output datatype of tensor to be quantized should be
     int8.
@@ -188,14 +163,11 @@ def get_mkldnn_int8_scale(range_min,
     scale : A float32 number which acts as the scale for quantization.
     """
 
-    scale = _get_mkldnn_scale(range_min,
-                              range_max,
-                              zero_centered_int8_quantized_range)
+    scale = _get_mkldnn_scale(range_min, range_max, zero_centered_int8_quantized_range)
     return np.float32(scale)
 
 
-def get_mkldnn_uint8_scale(range_min,
-                           range_max):
+def get_mkldnn_uint8_scale(range_min, range_max):
     """Computes the quantization scale using MKLDNN specifications
     with the given range. The output datatype of tensor to be quantized should be
     uint8.
@@ -212,17 +184,13 @@ def get_mkldnn_uint8_scale(range_min,
     scale : A float32 number which acts as the scale for quantization.
     """
 
-    scale = _get_mkldnn_scale(range_min,
-                              range_max,
-                              zero_centered_uint8_quantized_range)
+    scale = _get_mkldnn_scale(range_min, range_max, zero_centered_uint8_quantized_range)
     return np.float32(scale)
 
 
-def quantize_conv_weights_bias_channel_mkldnn_from_var(weights_var,
-                                                       bias,
-                                                       min_vector_range,
-                                                       max_vector_range,
-                                                       data_scale):
+def quantize_conv_weights_bias_channel_mkldnn_from_var(
+    weights_var, bias, min_vector_range, max_vector_range, data_scale
+):
     """Helper method to quantize the convolution kernel in prequantized model
     in MXNet with MKLDNN. The kernel is always quantized to int8 output datatype.
     The inputs are the raw weights which are floating point numbers. The min and
@@ -249,67 +217,70 @@ def quantize_conv_weights_bias_channel_mkldnn_from_var(weights_var,
     """
 
     quantized_range = zero_centered_int8_quantized_range
-    real_vector_range = np.maximum(np.absolute(min_vector_range),
-                                   np.absolute(max_vector_range))
+    real_vector_range = np.maximum(np.absolute(min_vector_range), np.absolute(max_vector_range))
     # If real_vector_range is 0, then to avoid division by 0 in scaling,
     # make real_vector INT32_max
-    vector_scale = np.where(real_vector_range == 0,
-                            1./float(np.iinfo(np.int32).max),
-                            np.divide(real_vector_range, quantized_range))
+    vector_scale = np.where(
+        real_vector_range == 0,
+        1.0 / float(np.iinfo(np.int32).max),
+        np.divide(real_vector_range, quantized_range),
+    )
 
     # Handle bias impact on scales as done by MxNet-MKLDNN.
     if bias is not None:
-        common = 2.0 * bias.astype('float32') * (1/data_scale)
-        vector_scale_min = np.where(bias > 0,
-                                    common/float(np.iinfo(np.int32).max),
-                                    common/float(np.iinfo(np.int32).min))
+        common = 2.0 * bias.astype("float32") * (1 / data_scale)
+        vector_scale_min = np.where(
+            bias > 0, common / float(np.iinfo(np.int32).max), common / float(np.iinfo(np.int32).min)
+        )
         vector_scale = np.maximum(vector_scale, vector_scale_min)
 
     zero_point = 0
-    quantized_output = quantize(weights_var,
-                                relay.const(vector_scale),
-                                relay.const(zero_point, 'int32'),
-                                axis=0,
-                                out_dtype='int8')
+    quantized_output = quantize(
+        weights_var,
+        relay.const(vector_scale),
+        relay.const(zero_point, "int32"),
+        axis=0,
+        out_dtype="int8",
+    )
     return quantized_output, vector_scale, zero_point
 
 
-def get_mkldnn_requantize_scale_outDtype(min_output_range,
-                                         max_output_range,
-                                         out_dtype):
-    quantized_out_range = zero_centered_int8_quantized_range if out_dtype == 'int8' \
+def get_mkldnn_requantize_scale_outDtype(min_output_range, max_output_range, out_dtype):
+    """Get the MKLDNN requantized scale."""
+    quantized_out_range = (
+        zero_centered_int8_quantized_range
+        if out_dtype == "int8"
         else zero_centered_uint8_quantized_range
-    out_range = np.max([np.abs(np.float32(min_output_range)),
-                        np.abs(np.float32(max_output_range))])
+    )
+    out_range = np.max([np.abs(np.float32(min_output_range)), np.abs(np.float32(max_output_range))])
     output_scale = quantized_out_range / out_range
-    requantize_scale = np.float32(1/output_scale)
+    requantize_scale = np.float32(1 / output_scale)
     return requantize_scale
 
 
 def get_conv_mkldnn_requantized_scale_outDtype(min_output_range, max_output_range):
-    out_dtype = 'uint8' if min_output_range >= 0.0 else 'int8'
-    requantize_scale = get_mkldnn_requantize_scale_outDtype(min_output_range,
-                                                            max_output_range,
-                                                            out_dtype)
+    out_dtype = "uint8" if min_output_range >= 0.0 else "int8"
+    requantize_scale = get_mkldnn_requantize_scale_outDtype(
+        min_output_range, max_output_range, out_dtype
+    )
     return requantize_scale, out_dtype
 
 
-def quantize_conv_bias_mkldnn_from_var(bias_var,
-                                       bias_scale):
+def quantize_conv_bias_mkldnn_from_var(bias_var, bias_scale):
+    """Quantized conv2d bias"""
     zero_point = 0
-    quantized_bias = quantize(data=bias_var,
-                              output_scale=relay.const(bias_scale),
-                              output_zero_point=relay.const(zero_point, 'int32'),
-                              axis=0,
-                              out_dtype='int32')
+    quantized_bias = quantize(
+        data=bias_var,
+        output_scale=relay.const(bias_scale),
+        output_zero_point=relay.const(zero_point, "int32"),
+        axis=0,
+        out_dtype="int32",
+    )
 
     return quantized_bias
 
 
-def quantize_mxnet_min_max(data,
-                           min_range,
-                           max_range,
-                           out_dtype='int8'):
+def quantize_mxnet_min_max(data, min_range, max_range, out_dtype="int8"):
     """Quantizes the given `data` in float32 and the given
     min and max ranges and the output data type.
     Only `int8` and `uint8` is supported as output data types.
@@ -337,23 +308,15 @@ def quantize_mxnet_min_max(data,
         The computed result.
     """
 
-    if out_dtype == 'uint8':
-        return _quantize_mkldnn_min_max_uint8(data,
-                                              min_range,
-                                              max_range)
-    elif out_dtype == 'int8':
-        return _quantize_mkldnn_min_max_int8(data,
-                                             min_range,
-                                             max_range)
+    if out_dtype == "uint8":
+        return _quantize_mkldnn_min_max_uint8(data, min_range, max_range)
+    elif out_dtype == "int8":
+        return _quantize_mkldnn_min_max_int8(data, min_range, max_range)
     else:
-        raise ValueError(
-            "Expected out_dtype to be int8 or uint8 but was  %s" % out_dtype)
+        raise ValueError("Expected out_dtype to be int8 or uint8 but was  %s" % out_dtype)
 
 
-def _dequantize_zero_centered(data,
-                              data_min,
-                              data_max,
-                              quantized_range):
+def _dequantize_zero_centered(data, data_min, data_max, quantized_range):
     """Dequantizes the given data tensor by calculating the scale
     using the MKLDNN formula `max(abs(data_min, data_max))/quantized_range`.
     Where quantized_range is 255 for uint8 and 127 for int8. The `data_min`
@@ -376,16 +339,13 @@ def _dequantize_zero_centered(data,
         The computed result.
     """
 
-    real_range = np.max([np.abs(np.float32(data_min)),
-                         np.abs(np.float32(data_max))])
-    scale = relay.const(np.divide(real_range, quantized_range), 'float32')
-    zero_point = relay.const(0, 'int32')
+    real_range = np.max([np.abs(np.float32(data_min)), np.abs(np.float32(data_max))])
+    scale = relay.const(np.divide(real_range, quantized_range), "float32")
+    zero_point = relay.const(0, "int32")
     return dequantize(data, scale, zero_point)
 
 
-def _dequantize_mkldnn_min_max_int8(data,
-                                    imin_range,
-                                    imax_range):
+def _dequantize_mkldnn_min_max_int8(data, imin_range, imax_range):
     """Dequantizes the given `data` in {int8 or uint8} and the given
     min and max ranges and the output data type is `float32`.
     The method of dequantizing is described here - https://tinyurl.com/y5k6fz5w.
@@ -409,15 +369,15 @@ def _dequantize_mkldnn_min_max_int8(data,
         The computed result.
     """
 
-    return _dequantize_zero_centered(data,
-                                     data_min=imin_range,
-                                     data_max=imax_range,
-                                     quantized_range=zero_centered_int8_quantized_range)
+    return _dequantize_zero_centered(
+        data,
+        data_min=imin_range,
+        data_max=imax_range,
+        quantized_range=zero_centered_int8_quantized_range,
+    )
 
 
-def _dequantize_mkldnn_min_max_uint8(data,
-                                     imin_range,
-                                     imax_range):
+def _dequantize_mkldnn_min_max_uint8(data, imin_range, imax_range):
     """Dequantizes the given `data` in {int8 or uint8} and the given
     min and max ranges and the output data type is `float32`.
     The method of dequantize is described here - https://tinyurl.com/y5k6fz5w.
@@ -441,16 +401,15 @@ def _dequantize_mkldnn_min_max_uint8(data,
         The computed result.
     """
 
-    return _dequantize_zero_centered(data,
-                                     data_min=imin_range,
-                                     data_max=imax_range,
-                                     quantized_range=zero_centered_uint8_quantized_range)
+    return _dequantize_zero_centered(
+        data,
+        data_min=imin_range,
+        data_max=imax_range,
+        quantized_range=zero_centered_uint8_quantized_range,
+    )
 
 
-def dequantize_mxnet_min_max(data,
-                             min_range,
-                             max_range,
-                             in_dtype='int8'):
+def dequantize_mxnet_min_max(data, min_range, max_range, in_dtype="int8"):
     """Dequantizes the given `data` in {int8 or uint8} and the given
     min and max ranges. The output data type is float32.
     Only `float32` is supported as output data types.
@@ -478,14 +437,9 @@ def dequantize_mxnet_min_max(data,
         The computed result.
     """
 
-    if in_dtype == 'uint8':
-        return _dequantize_mkldnn_min_max_uint8(data,
-                                                min_range,
-                                                max_range)
-    elif in_dtype == 'int8':
-        return _dequantize_mkldnn_min_max_int8(data,
-                                               min_range,
-                                               max_range)
+    if in_dtype == "uint8":
+        return _dequantize_mkldnn_min_max_uint8(data, min_range, max_range)
+    elif in_dtype == "int8":
+        return _dequantize_mkldnn_min_max_int8(data, min_range, max_range)
     else:
-        raise ValueError(
-            "Expected out_dtype to be int8 or uint8 but was  %s" % in_dtype)
+        raise ValueError("Expected out_dtype to be int8 or uint8 but was  %s" % in_dtype)
index 7dd9c02..b2537af 100644 (file)
@@ -25,7 +25,8 @@ from .common import get_relay_op
 from .common import infer_type as _infer_type
 from .common import infer_shape as _infer_shape
 
-def _warn_not_used(attr, op='nnvm'):
+
+def _warn_not_used(attr, op="nnvm"):
     err = "{} is ignored in {}.".format(attr, op)
     warnings.warn(err)
 
@@ -34,8 +35,9 @@ def _rename(new_op):
     if isinstance(new_op, str):
         new_op = get_relay_op(new_op)
     # attrs are ignored.
-    def impl(inputs, _, _dtype='float32'):
+    def impl(inputs, _, _dtype="float32"):
         return new_op(*inputs)
+
     return impl
 
 
@@ -49,17 +51,20 @@ def _reshape(inputs, attrs):
 
 def _init_op(new_op):
     """Init ops like zeros/ones"""
+
     def _impl(inputs, attrs):
         assert len(inputs) == 0
         shape = attrs.get_int_tuple("shape")
         dtype = attrs.get_str("dtype", "float32")
         return new_op(shape=shape, dtype=dtype)
+
     return _impl
 
 
 def _softmax_op(new_op):
     """softmax/log_softmax"""
-    def _impl(inputs, attrs, _dtype='float32'):
+
+    def _impl(inputs, attrs, _dtype="float32"):
         axis = attrs.get_int("axis", -1)
         use_length = attrs.get_bool("use_length", False)
         if use_length:
@@ -93,13 +98,16 @@ def _softmax_op(new_op):
                 # Input data is now 2D, we can set the axis = 1
                 axis = 1
             elif data_ndims > 2:
-                raise error.OpNotImplemented(\
-                        "Operator softmax with use_length=True is supported only for axis -1")
+                raise error.OpNotImplemented(
+                    "Operator softmax with use_length=True is supported only for axis -1"
+                )
 
-            res = _op.sequence_mask(data=data,
-                                    valid_length=length,
-                                    mask_value=float(min_value(data_dtype).value),
-                                    axis=axis)
+            res = _op.sequence_mask(
+                data=data,
+                valid_length=length,
+                mask_value=float(min_value(data_dtype).value),
+                axis=axis,
+            )
 
             # Apply softmax
             res = new_op(res, axis=axis)
@@ -109,12 +117,14 @@ def _softmax_op(new_op):
                 return _op.reshape(res, newshape=data_shape)
             return res
         return new_op(inputs[0], axis=axis)
+
     return _impl
 
 
 def _reduce(new_op):
     """Reduction ops like sum/min/max"""
-    def _impl(inputs, attrs, _dtype='float32'):
+
+    def _impl(inputs, attrs, _dtype="float32"):
         assert len(inputs) == 1
         axis = attrs.get_int_tuple("axis", [])
         keepdims = attrs.get_bool("keepdims", False)
@@ -122,11 +132,13 @@ def _reduce(new_op):
         # use None for reduce over all axis.
         axis = None if len(axis) == 0 else axis
         return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude)
+
     return _impl
 
 
 def _arg_reduce(new_op):
     """Arg Reduction ops like argmin/argmax"""
+
     def _impl(inputs, attrs):
         assert len(inputs) == 1
         axis = attrs.get_int("axis", None)
@@ -135,6 +147,7 @@ def _arg_reduce(new_op):
         # cast to dtype.
         res = res.astype("float32")
         return res
+
     return _impl
 
 
@@ -162,7 +175,7 @@ def _upsampling(inputs, attrs):
     return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale)
 
 
-def _elemwise_sum(inputs, _, _dtype='float32'):
+def _elemwise_sum(inputs, _, _dtype="float32"):
     assert len(inputs) > 0
     res = inputs[0]
     for x in inputs[1:]:
@@ -178,6 +191,7 @@ def _binop_scalar(new_op):
             odtype = _infer_type(inputs[0]).checked_type.dtype
         scalar = _expr.const(scalar, dtype=odtype)
         return new_op(inputs[0], scalar)
+
     return _impl
 
 
@@ -189,12 +203,15 @@ def _rbinop_scalar(new_op):
             odtype = _infer_type(inputs[0]).checked_type.dtype
         scalar = _expr.const(scalar, dtype=odtype)
         return new_op(scalar, inputs[0])
+
     return _impl
 
 
 def _compare(new_op):
     """Compare ops like greater/less"""
-    def _impl(inputs, _, odtype='float32'):
+
+    def _impl(inputs, _, odtype="float32"):
         assert len(inputs) == 2
         return new_op(inputs[0], inputs[1]).astype(odtype)
+
     return _impl
index ea39010..8269ebf 100644 (file)
@@ -42,17 +42,20 @@ from .common import infer_type, get_name
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
 
-__all__ = ['from_onnx']
+__all__ = ["from_onnx"]
 
 g = None
 
+
 def infer_value(input_val, params, mod=None):
     return g.infer_value(input_val, params, mod)
 
+
 def infer_value_simulated(input_val, params):
     return g.infer_value_simulated(input_val, params)
 
-class onnx_input():
+
+class onnx_input:
     """ Dual purpose list or dictionary access object."""
 
     def __init__(self):
@@ -107,23 +110,23 @@ def get_numpy(tensor_proto):
     try:
         from onnx.numpy_helper import to_array
     except ImportError as e:
-        raise ImportError(
-            "Unable to import onnx which is required {}".format(e))
+        raise ImportError("Unable to import onnx which is required {}".format(e))
     return to_array(tensor_proto)
 
 
-def dimension_picker(prefix, suffix=''):
+def dimension_picker(prefix, suffix=""):
     """Check that dimensions are supported."""
+
     def _impl(attr):
-        kernel = attr['kernel_shape']
+        kernel = attr["kernel_shape"]
         if len(kernel) == 1:
-            return prefix + '1d' + suffix
+            return prefix + "1d" + suffix
         if len(kernel) == 2:
-            return prefix + '2d' + suffix
+            return prefix + "2d" + suffix
         if len(kernel) == 3:
-            return prefix + '3d' + suffix
-        msg = 'Only 1D, 2D, and 3D kernels are supported for operator {}.'
-        op_name = prefix + '1d/2d/3d'
+            return prefix + "3d" + suffix
+        msg = "Only 1D, 2D, and 3D kernels are supported for operator {}."
+        op_name = prefix + "1d/2d/3d"
         raise tvm.error.OpAttributeInvalid(msg.format(op_name))
 
     return _impl
@@ -136,8 +139,7 @@ def revert_caffe2_pad(pads):
     elif len(pads) == 2:
         pass
     else:
-        raise tvm.error.OpAttributeInvalid(
-            'Number of pads must be either 2 or 4.')
+        raise tvm.error.OpAttributeInvalid("Number of pads must be either 2 or 4.")
     return pads
 
 
@@ -154,11 +156,11 @@ def get_pad_pair(input1d, kernel1d, stride1d):
 
 def onnx_default_layout(dims):
     if dims == 1:
-        return 'NCW'
+        return "NCW"
     if dims == 2:
-        return 'NCHW'
+        return "NCHW"
     if dims == 3:
-        return 'NCDHW'
+        return "NCDHW"
 
     msg = "Only 1D, 2D and 3D layouts are currently supported"
     raise tvm.error.OpAttributeInvalid(msg.format(op_name))
@@ -167,14 +169,14 @@ def onnx_default_layout(dims):
 def onnx_storage_order2layout(storage_order, dims=2):
     """converter of onnx storage order parameter to tvm storage order format"""
     if storage_order not in (0, 1):
-        raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')
+        raise tvm.error.OpAttributeInvalid("Mode of storage_order must be either 0 or 1")
 
     if dims == 1:
-        return 'NCW' if storage_order == 0 else 'NWC'
+        return "NCW" if storage_order == 0 else "NWC"
     if dims == 2:
-        return 'NCHW' if storage_order == 0 else 'NHWC'
+        return "NCHW" if storage_order == 0 else "NHWC"
     if dims == 3:
-        return 'NCDHW' if storage_order == 0 else 'NDHWC'
+        return "NCDHW" if storage_order == 0 else "NDHWC"
 
     msg = "Only 1D, 2D and 3D layouts are currently supported"
     raise tvm.error.OpAttributeInvalid(msg.format(op_name))
@@ -182,7 +184,7 @@ def onnx_storage_order2layout(storage_order, dims=2):
 
 def dimension_constraint():
     def _dim_check(attrs):
-        if len(attrs['kernel_shape']) in [1, 2, 3]:
+        if len(attrs["kernel_shape"]) in [1, 2, 3]:
             return True
         return False
 
@@ -190,12 +192,11 @@ def dimension_constraint():
 
 
 class OnnxOpConverter(object):
-    """ A helper class for holding onnx op converters.
-    """
+    """A helper class for holding onnx op converters."""
 
     @classmethod
     def get_converter(cls, opset):
-        """ Get converter matches given opset.
+        """Get converter matches given opset.
 
         Parameters
         ----------
@@ -207,187 +208,180 @@ class OnnxOpConverter(object):
         converter, which should be `_impl_vx`. Number x is the biggest
             number smaller than or equal to opset belongs to all support versions.
         """
-        versions = [
-            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
-        ]
+        versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]
         versions = sorted(versions + [opset])
-        version = versions[
-            max([i for i, v in enumerate(versions) if v == opset]) - 1]
-        if hasattr(cls, '_impl_v{}'.format(version)):
-            return getattr(cls, '_impl_v{}'.format(version))
+        version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1]
+        if hasattr(cls, "_impl_v{}".format(version)):
+            return getattr(cls, "_impl_v{}".format(version))
         raise NotImplementedError(
-            'opset version {} of {} not implemented'.format(
-                version, cls.__name__))
+            "opset version {} of {} not implemented".format(version, cls.__name__)
+        )
 
 
 class Unary(OnnxOpConverter):
-    """ A helper class for unary op converters.
-    """
-    name = ''
+    """A helper class for unary op converters."""
+
+    name = ""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
-            cls.name, len(inputs))
+            cls.name, len(inputs)
+        )
         op_name = cls.name
         return get_relay_op(op_name)(*inputs)
 
 
 class Elemwise(OnnxOpConverter):
-    """ A helper class for elemwise op converters.
-    """
-    name = ''
+    """A helper class for elemwise op converters."""
+
+    name = ""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(
-            cls.name, len(inputs))
+        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs))
         op_name = cls.name
         conv_ops = ["conv2d", "conv2d_transpose"]
-        if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
+        if attr.get("broadcast", 0) and any(x in str(inputs[0]) for x in conv_ops):
             # TODO(zhreshold): remove hard coded infershape
-            axis = int(attr.get('axis', 0))
+            axis = int(attr.get("axis", 0))
             inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
         return get_relay_op(op_name)(*inputs)
 
 
 class Pool(OnnxOpConverter):
-    """ A helper class for pool op converters.
-    """
-    name = ''
+    """A helper class for pool op converters."""
+
+    name = ""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         input_shape = infer_shape(inputs[0])
-        if 'auto_pad' in attr:
-            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
-            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
+        if "auto_pad" in attr:
+            attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
+            if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
                 pad_tuple = []
                 for axis in range(len(input_shape) - 2):
                     axis_shape = input_shape[2 + axis]
-                    stride = attr['strides'][axis]
-                    kernel = attr['kernel_shape'][axis]
+                    stride = attr["strides"][axis]
+                    kernel = attr["kernel_shape"][axis]
                     pad = get_pad_pair(axis_shape, kernel, stride)
                     pad_tuple.append(pad)
                 pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
-                attr['pads'] = pad_tuple
-            elif attr['auto_pad'] == 'VALID':
-                attr['pads'] = 0
-            elif attr['auto_pad'] == 'NOTSET':
+                attr["pads"] = pad_tuple
+            elif attr["auto_pad"] == "VALID":
+                attr["pads"] = 0
+            elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
                 msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
-                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name))
+                raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"], cls.name))
             attr.pop("auto_pad")
 
-        if 'storage_order' in attr:
-            attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
-                                                       dims=(len(input_shape) - 2))
+        if "storage_order" in attr:
+            attr["layout"] = onnx_storage_order2layout(
+                attr["storage_order"], dims=(len(input_shape) - 2)
+            )
         else:
-            attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))
+            attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2))
 
         return AttrCvt(
             op_name=dimension_picker(cls.name),
-            transforms={
-                'kernel_shape': 'pool_size',
-                'pads': ('padding', 0)
-            },
-            ignores=['dilations', 'storage_order'],
-            custom_check=dimension_constraint())(inputs, attr, params)
+            transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)},
+            ignores=["dilations", "storage_order"],
+            custom_check=dimension_constraint(),
+        )(inputs, attr, params)
 
 
 class Absolute(Unary):
-    """ Operator converter for Absolute.
-    """
-    name = 'abs'
+    """Operator converter for Absolute."""
+
+    name = "abs"
 
 
 class Add(Elemwise):
-    """ Operator converter for Add.
-    """
-    name = 'add'
+    """Operator converter for Add."""
+
+    name = "add"
 
 
 class AveragePool(Pool):
-    """ Operator converter for AveragePool.
-    """
-    name = 'avg_pool'
+    """Operator converter for AveragePool."""
+
+    name = "avg_pool"
 
 
 class BatchNorm(OnnxOpConverter):
-    """ Operator converter for BatchNorm.
-    """
+    """Operator converter for BatchNorm."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # TODO(zhreshold): 'spatial' is not properly handled here.
         out = AttrCvt(
-            op_name='batch_norm',
-            ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
-                                                                           params)
+            op_name="batch_norm", ignores=["spatial", "is_test", "consumed_inputs", "momentum"]
+        )(inputs, attr, params)
         return out[0]
 
 
 class InstanceNorm(OnnxOpConverter):
-    """ Operator converter for BatchNorm.
-    """
+    """Operator converter for BatchNorm."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return AttrCvt(op_name='instance_norm')(inputs, attr, params)
+        return AttrCvt(op_name="instance_norm")(inputs, attr, params)
 
 
 class Conv(OnnxOpConverter):
-    """ Operator converter for Conv.
-    """
+    """Operator converter for Conv."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # Use shape of input to determine convolution type.
         input_shape = infer_shape(inputs[0])
-        if 'auto_pad' in attr:
-            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
-            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
+        if "auto_pad" in attr:
+            attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
+            if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
                 pad_tuple = []
                 for axis in range(len(input_shape) - 2):
                     axis_shape = input_shape[2 + axis]
-                    stride = attr['strides'][axis]
-                    kernel = attr['kernel_shape'][axis]
-                    dilation = attr['dilations'][axis]
+                    stride = attr["strides"][axis]
+                    kernel = attr["kernel_shape"][axis]
+                    dilation = attr["dilations"][axis]
                     dilated_kernel = (kernel - 1) * dilation + 1
                     pad = get_pad_pair(axis_shape, dilated_kernel, stride)
                     pad_tuple.append(pad)
                 pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
-                attr['pads'] = pad_tuple
-            elif attr['auto_pad'] == 'VALID':
-                attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
-            elif attr['auto_pad'] == 'NOTSET':
+                attr["pads"] = pad_tuple
+            elif attr["auto_pad"] == "VALID":
+                attr["pads"] = tuple([0 for i in range(len(input_shape) - 2)])
+            elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
                 msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
-                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
-            attr.pop('auto_pad')
-        elif len(attr['kernel_shape']) == 2:
+                raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
+            attr.pop("auto_pad")
+        elif len(attr["kernel_shape"]) == 2:
             sym_pad = True
-            if 'pads' in attr:
-                padding = attr['pads']
+            if "pads" in attr:
+                padding = attr["pads"]
             else:
                 padding = [0, 0, 0, 0]
             for i in range(0, len(padding), 2):
                 sym_pad = sym_pad and padding[i] == padding[i + 1]
 
             if sym_pad:
-                attr['pads'] = padding[0::2]
+                attr["pads"] = padding[0::2]
 
         out = AttrCvt(
-            op_name=dimension_picker('conv'),
+            op_name=dimension_picker("conv"),
             transforms={
-                'kernel_shape': 'kernel_size',
-                'dilations': ('dilation', 1),
-                'pads': ('padding', 0),
-                'group': ('groups', 1)
+                "kernel_shape": "kernel_size",
+                "dilations": ("dilation", 1),
+                "pads": ("padding", 0),
+                "group": ("groups", 1),
             },
-            custom_check=dimension_constraint())(inputs[:2], attr, params)
+            custom_check=dimension_constraint(),
+        )(inputs[:2], attr, params)
 
         use_bias = len(inputs) == 3
         if use_bias:
@@ -396,47 +390,48 @@ class Conv(OnnxOpConverter):
 
 
 class ConvTranspose(OnnxOpConverter):
-    """ Operator converter for ConvTranspose.
-    """
+    """Operator converter for ConvTranspose."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
         channels = infer_channels(inputs[1], True)
-        attr['channels'] = channels
-        groups = attr.pop('group')
-        attr['groups'] = groups
+        attr["channels"] = channels
+        groups = attr.pop("group")
+        attr["groups"] = groups
         # infer pads for auto_pad
-        if 'auto_pad' in attr:
-            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
-            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
+        if "auto_pad" in attr:
+            attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
+            if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
                 input_shape = infer_shape(inputs[0])
                 in_h, in_w = input_shape[2], input_shape[3]
-                stride_h, stride_w = attr['strides']
-                kernel_h, kernel_w = attr['kernel_shape']
-                dilation_h, dilation_w = attr['dilations']
+                stride_h, stride_w = attr["strides"]
+                kernel_h, kernel_w = attr["kernel_shape"]
+                dilation_h, dilation_w = attr["dilations"]
                 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
                 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
                 pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
                 pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
-                attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
-            elif attr['auto_pad'] == 'VALID':
-                attr['pads'] = (0, 0)
-            elif attr['auto_pad'] == 'NOTSET':
+                attr["pads"] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
+            elif attr["auto_pad"] == "VALID":
+                attr["pads"] = (0, 0)
+            elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
                 msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
-                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
-            attr.pop('auto_pad')
+                raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
+            attr.pop("auto_pad")
 
         out = AttrCvt(
-            op_name=dimension_picker('conv', '_transpose'),
+            op_name=dimension_picker("conv", "_transpose"),
             transforms={
-                'kernel_shape': 'kernel_size',
-                'dilations': ('dilation', (0, 0)),
-                'pads': ('padding', (0, 0), revert_caffe2_pad)
+                "kernel_shape": "kernel_size",
+                "dilations": ("dilation", (0, 0)),
+                "pads": ("padding", (0, 0), revert_caffe2_pad),
             },
-            disables=['output_shape'],
-            custom_check=dimension_constraint())(inputs[:2], attr, params)
+            disables=["output_shape"],
+            custom_check=dimension_constraint(),
+        )(inputs[:2], attr, params)
         use_bias = len(inputs) == 3
         if use_bias:
             out = _op.nn.bias_add(out, inputs[2])
@@ -444,35 +439,33 @@ class ConvTranspose(OnnxOpConverter):
 
 
 class Div(Elemwise):
-    """ Operator converter for Divide.
-    """
-    name = 'divide'
+    """Operator converter for Divide."""
+
+    name = "divide"
 
 
 class Elu(OnnxOpConverter):
-    """ Operator converter for Elu.
-    """
+    """Operator converter for Elu."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get('alpha', 1.0))
-        return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
-                                     _op.nn.relu(inputs[0])
+        alpha = float(attr.get("alpha", 1.0))
+        return _expr.const(-alpha) * _op.nn.relu(
+            _expr.const(1.0) - _op.exp(inputs[0])
+        ) + _op.nn.relu(inputs[0])
 
 
 class Gemm(OnnxOpConverter):
-    """ Operator converter for Gemm.
-    """
+    """Operator converter for Gemm."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
-            len(inputs))
+        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
         # Y = alpha * A * B + beta * C
-        alpha = float(attr.get('alpha', 1.0))
-        beta = float(attr.get('beta', 1.0))
-        transA = int(attr.get('transA', 0))
-        transB = int(attr.get('transB', 0))
+        alpha = float(attr.get("alpha", 1.0))
+        beta = float(attr.get("beta", 1.0))
+        transA = int(attr.get("transA", 0))
+        transB = int(attr.get("transB", 0))
         # get number of channels
         channels = infer_channels(inputs[1], not transB)
         if transA:
@@ -493,8 +486,7 @@ class Gemm(OnnxOpConverter):
 
 
 class MatMul(OnnxOpConverter):
-    """ Operator converter for MatMul.
-    """
+    """Operator converter for MatMul."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
@@ -525,8 +517,7 @@ class MatMul(OnnxOpConverter):
 
 
 class Mod(OnnxOpConverter):
-    """ Operator converter for Mod.
-    """
+    """Operator converter for Mod."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
@@ -535,7 +526,7 @@ class Mod(OnnxOpConverter):
         # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
         # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
         # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
-        if attr['fmod'] == 0:
+        if attr["fmod"] == 0:
             op_name = "floor_mod"
         else:
             op_name = "mod"
@@ -544,117 +535,119 @@ class Mod(OnnxOpConverter):
 
 
 class MaxPool(Pool):
-    """ Operator converter for MaxPool
-    """
-    name = 'max_pool'
+    """Operator converter for MaxPool"""
+
+    name = "max_pool"
+
 
 class LpPool(OnnxOpConverter):
-    """ A helper class for lppool op converters.
-    """
+    """A helper class for lppool op converters."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         input_shape = infer_shape(inputs[0])
         dtype = infer_type(inputs[0]).checked_type.dtype
 
-        if 'auto_pad' in attr:
-            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
-            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
+        if "auto_pad" in attr:
+            attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
+            if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
                 pad_tuple = []
                 for axis in range(len(input_shape) - 2):
                     axis_shape = input_shape[2 + axis]
-                    stride = attr['strides'][axis]
-                    kernel = attr['kernel_shape'][axis]
+                    stride = attr["strides"][axis]
+                    kernel = attr["kernel_shape"][axis]
                     pad = get_pad_pair(axis_shape, kernel, stride)
                     pad_tuple.append(pad)
                 pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
-                attr['pads'] = pad_tuple
-            elif attr['auto_pad'] == 'VALID':
-                attr['pads'] = 0
-            elif attr['auto_pad'] == 'NOTSET':
+                attr["pads"] = pad_tuple
+            elif attr["auto_pad"] == "VALID":
+                attr["pads"] = 0
+            elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
                 msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
-                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], "LpPool"))
+                raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"], "LpPool"))
             attr.pop("auto_pad")
 
-        if 'storage_order' in attr:
-            attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
-                                                       dims=(len(input_shape) - 2))
+        if "storage_order" in attr:
+            attr["layout"] = onnx_storage_order2layout(
+                attr["storage_order"], dims=(len(input_shape) - 2)
+            )
         else:
-            attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))
+            attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2))
 
-        p = _expr.const(attr['p'], dtype)
-        reci_p = _expr.const(1.0 / attr['p'], dtype)
+        p = _expr.const(attr["p"], dtype)
+        reci_p = _expr.const(1.0 / attr["p"], dtype)
         inputs[0] = _op.power(inputs[0], p)
 
-        out = AttrCvt(op_name=dimension_picker("avg_pool"),
-                      transforms={
-                          'kernel_shape': 'pool_size',
-                          'pads': ('padding', 0)
-                      },
-                      extras={'count_include_pad': True},
-                      ignores=['p'],
-                      custom_check=dimension_constraint())(inputs, attr, params)
-        kernels = attr['kernel_shape']
+        out = AttrCvt(
+            op_name=dimension_picker("avg_pool"),
+            transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)},
+            extras={"count_include_pad": True},
+            ignores=["p"],
+            custom_check=dimension_constraint(),
+        )(inputs, attr, params)
+        kernels = attr["kernel_shape"]
         out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype))
         return _op.power(out, reci_p)
 
 
 class Mul(Elemwise):
-    """ Operator converter for Multiply.
-    """
-    name = 'multiply'
+    """Operator converter for Multiply."""
+
+    name = "multiply"
 
 
 class Pad(OnnxOpConverter):
-    """ Operator converter for Pad.
-    """
+    """Operator converter for Pad."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         pad_width = []
-        pads = attr.pop('paddings')
+        pads = attr.pop("paddings")
         dims = int(len(pads) / 2)
         for i in range(dims):
-            pad_width.append((pads[i], pads[i+dims]))
-        attr['pad_width'] = pad_width
-        pad_mode = attr.get('mode', b'constant').decode('utf-8')
-        if pad_mode in ['constant', 'edge', 'reflect']:
-            attr['pad_mode'] = pad_mode
-            attr.pop('mode', None)
+            pad_width.append((pads[i], pads[i + dims]))
+        attr["pad_width"] = pad_width
+        pad_mode = attr.get("mode", b"constant").decode("utf-8")
+        if pad_mode in ["constant", "edge", "reflect"]:
+            attr["pad_mode"] = pad_mode
+            attr.pop("mode", None)
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
+                "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.'
+            )
 
         return AttrCvt(
             _op.nn.pad,
             transforms={
-                'value': 'pad_value',
+                "value": "pad_value",
             },
-            )(inputs, attr, params)
+        )(inputs, attr, params)
 
     @classmethod
     def _impl_v2(cls, inputs, attr, params):
         pad_width = []
-        pads = attr.pop('pads')
+        pads = attr.pop("pads")
         dims = int(len(pads) / 2)
         for i in range(dims):
-            pad_width.append((pads[i], pads[i+dims]))
-        attr['pad_width'] = pad_width
-        pad_mode = attr.get('mode', b'constant').decode('utf-8')
-        if pad_mode in ['constant', 'edge', 'reflect']:
-            attr['pad_mode'] = pad_mode
-            attr.pop('mode', None)
+            pad_width.append((pads[i], pads[i + dims]))
+        attr["pad_width"] = pad_width
+        pad_mode = attr.get("mode", b"constant").decode("utf-8")
+        if pad_mode in ["constant", "edge", "reflect"]:
+            attr["pad_mode"] = pad_mode
+            attr.pop("mode", None)
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
+                "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.'
+            )
 
         return AttrCvt(
-            'pad',
+            "pad",
             transforms={
-                'value': 'pad_value',
+                "value": "pad_value",
             },
-            )(inputs, attr, params)
+        )(inputs, attr, params)
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
@@ -667,35 +660,32 @@ class Pad(OnnxOpConverter):
         attr["pad_value"] = value
         dims = int(len(pads) / 2)
         for i in range(dims):
-            pad_width.append((pads[i], pads[i+dims]))
-        attr['pad_width'] = pad_width
-        pad_mode = attr.get('mode', b'constant').decode('utf-8')
-        if pad_mode in ['constant', 'edge', 'reflect']:
-            attr['pad_mode'] = pad_mode
-            attr.pop('mode', None)
+            pad_width.append((pads[i], pads[i + dims]))
+        attr["pad_width"] = pad_width
+        pad_mode = attr.get("mode", b"constant").decode("utf-8")
+        if pad_mode in ["constant", "edge", "reflect"]:
+            attr["pad_mode"] = pad_mode
+            attr.pop("mode", None)
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
-
-        return AttrCvt('pad')(inputs[:1], attr, params)
-
+                "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.'
+            )
 
+        return AttrCvt("pad")(inputs[:1], attr, params)
 
 
 class ParametricSoftPlus(OnnxOpConverter):
-    """ Operator converter for ParametricSoftPlus.
-    """
+    """Operator converter for ParametricSoftPlus."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = _expr.const(float(attr.get('alpha', 1.0)))
-        beta = _expr.const(float(attr.get('beta', 1.0)))
-        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha
+        alpha = _expr.const(float(attr.get("alpha", 1.0)))
+        beta = _expr.const(float(attr.get("beta", 1.0)))
+        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha
 
 
 class Prelu(OnnxOpConverter):
-    """ Operator converter for Prelu.
-    """
+    """Operator converter for Prelu."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
@@ -709,8 +699,7 @@ class Prelu(OnnxOpConverter):
 
 
 class Reciprocal(OnnxOpConverter):
-    """ Operator converter for Reciprocal.
-    """
+    """Operator converter for Reciprocal."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
@@ -718,12 +707,11 @@ class Reciprocal(OnnxOpConverter):
 
 
 class Flatten(OnnxOpConverter):
-    """ Operator converter for Flatten.
-    """
+    """Operator converter for Flatten."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get('axis', 1)
+        axis = attr.get("axis", 1)
         if axis == 1:
             out = _op.nn.batch_flatten(inputs[0])
         else:
@@ -734,12 +722,11 @@ class Flatten(OnnxOpConverter):
 
 
 class Reshape(OnnxOpConverter):
-    """ Operator converter for Reshape.
-    """
+    """Operator converter for Reshape."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return _op.reshape(inputs[0], attr['shape'])
+        return _op.reshape(inputs[0], attr["shape"])
 
     @classmethod
     def _impl_v5(cls, inputs, attr, params):
@@ -750,103 +737,95 @@ class Reshape(OnnxOpConverter):
         else:
             data, shape = inputs
             static_shape = infer_value_simulated(shape, params)
-            out = _op.reshape(data, newshape=tuple(
-                static_shape.asnumpy().astype('int32')))
+            out = _op.reshape(data, newshape=tuple(static_shape.asnumpy().astype("int32")))
         return out
 
 
 class DepthToSpace(OnnxOpConverter):
-    """ Operator converter for DepthToSpace.
-    """
+    """Operator converter for DepthToSpace."""
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
 
-        block_size = int(attr['blocksize'])
-        mode = attr.get('mode', b'DCR').decode('utf-8')
+        block_size = int(attr["blocksize"])
+        mode = attr.get("mode", b"DCR").decode("utf-8")
         return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
 
 
 class SpaceToDepth(OnnxOpConverter):
-    """ Operator converter for SpaceToDepth.
-    """
+    """Operator converter for SpaceToDepth."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
 
-        block_size = int(attr['blocksize'])
+        block_size = int(attr["blocksize"])
         return _op.nn.space_to_depth(inputs[0], block_size)
 
 
 class Concat(OnnxOpConverter):
-    """ Operator converter for Concat.
-    """
+    """Operator converter for Concat."""
 
     @classmethod
     def _impl_v1(cls, inputs, args, params):
-        return AttrCvt(op_name='concatenate')((inputs,), args)
+        return AttrCvt(op_name="concatenate")((inputs,), args)
+
 
 class Scale(OnnxOpConverter):
-    """ Operator converter for Scale.
-    """
+    """Operator converter for Scale."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        scale = float(attr.get('scale', 1.0))
+        scale = float(attr.get("scale", 1.0))
         return inputs[0] * _expr.const(scale)
 
 
 class Selu(OnnxOpConverter):
-    """ Operator converter for Selu.
-    """
+    """Operator converter for Selu."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get('alpha', 1.6732))
-        gamma = float(attr.get('gamma', 1.0507))
-        return _expr.const(gamma) * (_expr.const(-alpha) *
-                                     _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
-                                     _op.nn.relu(inputs[0]))
+        alpha = float(attr.get("alpha", 1.6732))
+        gamma = float(attr.get("gamma", 1.0507))
+        return _expr.const(gamma) * (
+            _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0]))
+            + _op.nn.relu(inputs[0])
+        )
 
 
 class ScaledTanh(OnnxOpConverter):
-    """ Operator converter for ScaledTanh.
-    """
+    """Operator converter for ScaledTanh."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get('alpha', 1.0))
-        beta = float(attr.get('beta', 1.0))
+        alpha = float(attr.get("alpha", 1.0))
+        beta = float(attr.get("beta", 1.0))
         return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)
 
 
 class SoftPlus(OnnxOpConverter):
-    """ Operator converter for SoftPlus.
-    """
+    """Operator converter for SoftPlus."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return _op.log(_op.exp(inputs[0]) + _expr.const(1.))
+        return _op.log(_op.exp(inputs[0]) + _expr.const(1.0))
 
 
 class Softsign(OnnxOpConverter):
-    """ Operator converter for Softsign.
-    """
+    """Operator converter for Softsign."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))
+        return inputs[0] / (_expr.const(1.0) + Absolute.get_converter(1)(inputs, attr, params))
 
 
 class Sub(Elemwise):
-    """ Operator converter for Subtract.
-    """
-    name = 'subtract'
+    """Operator converter for Subtract."""
+
+    name = "subtract"
 
 
 class Sum(OnnxOpConverter):
-    """ Operator converter for Sum.
-    """
+    """Operator converter for Sum."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
@@ -858,32 +837,29 @@ class Sum(OnnxOpConverter):
 
 
 class Affine(OnnxOpConverter):
-    """ Operator converter for Affine transformation.
-    """
+    """Operator converter for Affine transformation."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = _expr.const(attr.get('alpha', 1.0))
-        beta = _expr.const(attr.get('beta', 0.0))
+        alpha = _expr.const(attr.get("alpha", 1.0))
+        beta = _expr.const(attr.get("beta", 0.0))
         return (alpha * inputs[0]) + beta
 
 
 class ThresholdedRelu(OnnxOpConverter):
-    """ Operator converter for ThresholdedRelu.
-    """
+    """Operator converter for ThresholdedRelu."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get('alpha', 1.0))
+        alpha = float(attr.get("alpha", 1.0))
         alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
         mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
         return inputs[0] * mask
 
 
 def _broadcast_constraint():
-
     def _broadcast_check(attrs):
-        if attrs.get('axis', None):
+        if attrs.get("axis", None):
             return False
         return True
 
@@ -891,25 +867,23 @@ def _broadcast_constraint():
 
 
 def _fully_connected(opset):
-
     def _impl(inputs, attr, params):
         # get number of channels
         channels = infer_channels(inputs[1], params)
-        attr['units'] = channels
-        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
+        attr["units"] = channels
+        return AttrCvt("dense", ignores=["axis", "axis_w"])(inputs, attr)
 
     return _impl
 
 
 class Upsample(OnnxOpConverter):
-    """ Operator converter for Upsample (nearest mode).
-    """
+    """Operator converter for Upsample (nearest mode)."""
 
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
-        scales = attr.get('scales')
+        scales = attr.get("scales")
         if not scales:
-            #Here we are going to higher OPSET version.
+            # Here we are going to higher OPSET version.
             assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
             if get_name(inputs[1]) in params:
                 scales = params[inputs[1].name_hint].asnumpy()
@@ -919,95 +893,89 @@ class Upsample(OnnxOpConverter):
         assert scales[0] == 1.0 and scales[1] == 1.0
         input_shape = infer_shape(inputs[0])
         dims = len(input_shape)
-        mode = attr.get('mode')
-        if mode == b'nearest':
+        mode = attr.get("mode")
+        if mode == b"nearest":
             method = "nearest_neighbor"
-        elif mode == b'linear':
+        elif mode == b"linear":
             method = "trilinear" if dims == 5 else "bilinear"
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
-        attr = {'scale_h': scales[-2],
-                'scale_w': scales[-1],
-                'method': method}
+                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)
+            )
+        attr = {"scale_h": scales[-2], "scale_w": scales[-1], "method": method}
         if dims == 5:
             assert len(scales) == 5
-            attr['scale_d'] = scales[-3]
-            attr['layout'] = 'NCDHW'
-            op_name = 'upsampling3d'
+            attr["scale_d"] = scales[-3]
+            attr["layout"] = "NCDHW"
+            op_name = "upsampling3d"
         else:
             assert len(scales) == 4
-            attr['layout'] = 'NCHW'
-            if method == 'nearest_neighbor':
-                attr['align_corners'] = False
+            attr["layout"] = "NCHW"
+            if method == "nearest_neighbor":
+                attr["align_corners"] = False
             else:
-                attr['align_corners'] = True
-            op_name = 'upsampling'
+                attr["align_corners"] = True
+            op_name = "upsampling"
         return AttrCvt(op_name)(inputs, attr)
 
+
 class Shape(OnnxOpConverter):
-    """ Operator converter for Shape.
-    """
+    """Operator converter for Shape."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.shape_of(inputs[0], "int64")
 
+
 class Cast(OnnxOpConverter):
-    """ Operator converter for Cast.
-    """
+    """Operator converter for Cast."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+        return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs, attr)
 
     @classmethod
     def _impl_v5(cls, inputs, attr, params):
         try:
             from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
-            attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
+
+            attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]])
         except ImportError as e:
-            raise ImportError(
-                "Unable to import onnx.mapping which is required {}".format(e))
-        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+            raise ImportError("Unable to import onnx.mapping which is required {}".format(e))
+        return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs, attr)
 
 
 class Unsqueeze(OnnxOpConverter):
-    """ Operator converter for Unsqueeze.
-    """
+    """Operator converter for Unsqueeze."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        for axes in attr['axes']:
+        for axes in attr["axes"]:
             inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
         return inputs[0]
 
 
 class Split(OnnxOpConverter):
-    """ Operator converter for Split.
-    """
+    """Operator converter for Split."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        splits = attr.get('split', False)
+        splits = attr.get("split", False)
         if splits:
-            attr['indices_or_sections'] = []
+            attr["indices_or_sections"] = []
             index = 0
             for i in splits[:-1]:
                 index += i
-                attr['indices_or_sections'].append(index)
+                attr["indices_or_sections"].append(index)
         # When splits isnt specified divide evenly over axis.
         else:
             in_shape = infer_shape(inputs[0])
-            attr['indices_or_sections'] = in_shape[attr['axis']]
-        return AttrCvt(
-            'split',
-            ignores=['split'])(inputs, attr, params)
+            attr["indices_or_sections"] = in_shape[attr["axis"]]
+        return AttrCvt("split", ignores=["split"])(inputs, attr, params)
 
 
 class Slice(OnnxOpConverter):
-    """ Operator converter for Slice.
-    """
+    """Operator converter for Slice."""
 
     @classmethod
     def _common(cls, starts, ends, axes):
@@ -1029,138 +997,138 @@ class Slice(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if isinstance(attr['starts'], int):
-            attr['starts'] = (attr['starts'],)
-            attr['ends'] = (attr['ends'],)
+        if isinstance(attr["starts"], int):
+            attr["starts"] = (attr["starts"],)
+            attr["ends"] = (attr["ends"],)
 
         try:
             # Update the starts and ends according to axes if required.
-            if isinstance(attr['axes'], int):
-                attr['axes'] = (attr['axes'],)
-            if (max(attr['axes']) + 1) != len(attr['axes']):
+            if isinstance(attr["axes"], int):
+                attr["axes"] = (attr["axes"],)
+            if (max(attr["axes"]) + 1) != len(attr["axes"]):
                 new_starts, new_ends, new_axes = cls._common(
-                    attr['starts'], attr['ends'], attr['axes'])
-                attr['axes'] = new_axes
-                attr['starts'] = new_starts
-                attr['ends'] = new_ends
+                    attr["starts"], attr["ends"], attr["axes"]
+                )
+                attr["axes"] = new_axes
+                attr["starts"] = new_starts
+                attr["ends"] = new_ends
         except KeyError:
             pass
-        begin = list(attr['starts'])
-        end = list(attr['ends'])
+        begin = list(attr["starts"])
+        end = list(attr["ends"])
 
-        return _op.strided_slice(inputs[0],
-                                 begin=begin,
-                                 end=end)
+        return _op.strided_slice(inputs[0], begin=begin, end=end)
 
     @classmethod
     def _impl_v10(cls, inputs, attr, params):
-        attrs = {'starts' : inputs[1], 'ends' : inputs[2]}
+        attrs = {"starts": inputs[1], "ends": inputs[2]}
         if len(inputs) >= 4:
-            attrs['axes'] = inputs[3]
-        attrs = {k : (v, get_name(v)) for (k, v) in attrs.items()}
-        attrs = {k : params[v[1]].asnumpy() if v[1] in params else
-                     infer_value_simulated(v[0], params).asnumpy()
-                 for (k, v) in attrs.items()}
+            attrs["axes"] = inputs[3]
+        attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()}
+        attrs = {
+            k: params[v[1]].asnumpy()
+            if v[1] in params
+            else infer_value_simulated(v[0], params).asnumpy()
+            for (k, v) in attrs.items()
+        }
 
         # Update the starts and ends according to axes if required.
-        if 'axes' in attrs:
-            if max(attrs['axes'] + 1) != len(attrs['axes']):
-                new_starts, new_ends, _ = cls._common(
-                    attrs['starts'], attrs['ends'], attrs['axes'])
-                attrs['starts'] = new_starts
-                attrs['ends'] = new_ends
-        return _op.strided_slice(inputs[0],
-                                 begin=list(attrs['starts']),
-                                 end=list(attrs['ends']))
+        if "axes" in attrs:
+            if max(attrs["axes"] + 1) != len(attrs["axes"]):
+                new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"])
+                attrs["starts"] = new_starts
+                attrs["ends"] = new_ends
+        return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"]))
 
 
 class Gather(OnnxOpConverter):
-    """ Operator converter for Gather.
-    """
+    """Operator converter for Gather."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get('axis', 0)
-        return AttrCvt('take',
-                       extras={'axis': axis})(inputs, {})
+        axis = attr.get("axis", 0)
+        return AttrCvt("take", extras={"axis": axis})(inputs, {})
 
 
 class GatherND(OnnxOpConverter):
-    """ Operator converter for GatherND.
-    """
+    """Operator converter for GatherND."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.gather_nd(inputs[0], inputs[1])
 
 
 class Scatter(OnnxOpConverter):
-    """ Operator converter for Scatter.
-    """
+    """Operator converter for Scatter."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get('axis', 0)
+        axis = attr.get("axis", 0)
         return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
 
 
 class Greater(OnnxOpConverter):
-    """ Operator logical greater.
-    """
+    """Operator logical greater."""
+
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
         return _op.greater(inputs[0], inputs[1])
 
 
 class Less(OnnxOpConverter):
-    """ Operator logical less than.
-    """
+    """Operator logical less than."""
+
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
         return _op.less(inputs[0], inputs[1])
 
 
 class LRN(OnnxOpConverter):
-    """ Operator converter for Local Response Normalization.
-    """
+    """Operator converter for Local Response Normalization."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         """LRN support only NCHW format
         https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
         """
         axis = 1
-        alpha = attr.get('alpha', 0.0001)
-        beta = attr.get('beta', 0.75)
-        bias = attr.get('bias', 1.0)
-        nsize = attr.get('size')
-        attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
-        return AttrCvt('lrn')(inputs, attr)
+        alpha = attr.get("alpha", 0.0001)
+        beta = attr.get("beta", 0.75)
+        bias = attr.get("bias", 1.0)
+        nsize = attr.get("size")
+        attr = {"size": nsize, "axis": axis, "alpha": alpha, "beta": beta, "bias": bias}
+        return AttrCvt("lrn")(inputs, attr)
+
 
 class Maximum(OnnxOpConverter):
-    """ Operator converter for Maximum.
-    """
+    """Operator converter for Maximum."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
             raise ValueError("Expect minimum 2 inputs")
         _max = inputs[0]
         for i in range(1, len(inputs)):
-            _max = AttrCvt('maximum')([_max, inputs[i]], {})
+            _max = AttrCvt("maximum")([_max, inputs[i]], {})
         return _max
 
+
 class Minimum(OnnxOpConverter):
-    """ Operator converter for Minimum.
-    """
+    """Operator converter for Minimum."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
             raise ValueError("Expect minimum 2 inputs")
         _min = inputs[0]
         for i in range(1, len(inputs)):
-            _min = AttrCvt('minimum')([_min, inputs[i]], {})
+            _min = AttrCvt("minimum")([_min, inputs[i]], {})
         return _min
 
+
 class Mean(OnnxOpConverter):
-    """ Operator converter for Mean.
-    """
+    """Operator converter for Mean."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
@@ -1169,105 +1137,114 @@ class Mean(OnnxOpConverter):
         concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
         return _op.mean(concat, axis=0, keepdims=False)
 
+
 class HardSigmoid(OnnxOpConverter):
-    """ Operator converter for HardSigmoid.
-    """
+    """Operator converter for HardSigmoid."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = attr.get('alpha', 0.2)
-        beta = attr.get('beta', 0.5)
+        alpha = attr.get("alpha", 0.2)
+        beta = attr.get("beta", 0.5)
         transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
-        attr = {'a_min': 0, 'a_max': 1}
-        return AttrCvt('clip')([transformX], attr)
+        attr = {"a_min": 0, "a_max": 1}
+        return AttrCvt("clip")([transformX], attr)
+
 
 class Reduce(OnnxOpConverter):
-    """ Operator converter for reduce ops.
-    """
-    name = ''
+    """Operator converter for reduce ops."""
+
+    name = ""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'axes' in attr:
-            axis = attr.get('axes', 0)
+        if "axes" in attr:
+            axis = attr.get("axes", 0)
         else:
             axis_len = len(infer_shape(inputs[0]))
             axis = list(range(axis_len))
-        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
+        attr = {"axis": axis, "keepdims": attr.get("keepdims", True)}
         return AttrCvt(cls.name)(inputs, attr)
 
+
 class ReduceMax(Reduce):
-    """ Operator converter for ReduceMax.
-    """
-    name = 'max'
+    """Operator converter for ReduceMax."""
+
+    name = "max"
+
 
 class ReduceMin(Reduce):
-    """ Operator converter for ReduceMin.
-    """
-    name = 'min'
+    """Operator converter for ReduceMin."""
+
+    name = "min"
+
 
 class ReduceSum(Reduce):
-    """ Operator converter for ReduceSum.
-    """
-    name = 'sum'
+    """Operator converter for ReduceSum."""
+
+    name = "sum"
+
 
 class ReduceMean(Reduce):
-    """ Operator converter for ReduceMean.
-    """
-    name = 'mean'
+    """Operator converter for ReduceMean."""
+
+    name = "mean"
+
 
 class ReduceProd(Reduce):
-    """ Operator converter for ReduceProd.
-    """
-    name = 'prod'
+    """Operator converter for ReduceProd."""
+
+    name = "prod"
+
 
 class ReduceLogSumExp(Reduce):
-    """ Operator converter for ReduceLogSumExp.
-    """
-    name = 'logsumexp'
+    """Operator converter for ReduceLogSumExp."""
+
+    name = "logsumexp"
 
 
 class ReduceSumSquare(OnnxOpConverter):
-    """ Operator converter for ReduceSumSquare.
-    """
+    """Operator converter for ReduceSumSquare."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'axes' in attr:
-            axis = attr.get('axes', 0)
+        if "axes" in attr:
+            axis = attr.get("axes", 0)
         else:
             axis_len = len(infer_shape(inputs[0]))
             axis = list(range(axis_len))
-        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
+        attr = {"axis": axis, "keepdims": attr.get("keepdims", True)}
         inputs[0] = inputs[0] * inputs[0]
 
         return AttrCvt("sum")(inputs, attr)
 
 
 class ReduceL1(OnnxOpConverter):
-    """ Operator converter for ReduceL1.
-    """
+    """Operator converter for ReduceL1."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'axes' in attr:
-            axis = attr.get('axes', 0)
+        if "axes" in attr:
+            axis = attr.get("axes", 0)
         else:
             axis_len = len(infer_shape(inputs[0]))
             axis = list(range(axis_len))
-        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
+        attr = {"axis": axis, "keepdims": attr.get("keepdims", True)}
         inputs[0] = _op.abs(inputs[0])
 
         return AttrCvt("sum")(inputs, attr)
 
 
 class ReduceL2(OnnxOpConverter):
-    """ Operator converter for ReduceL2.
-    """
+    """Operator converter for ReduceL2."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'axes' in attr:
-            axis = attr.get('axes', 0)
+        if "axes" in attr:
+            axis = attr.get("axes", 0)
         else:
             axis_len = len(infer_shape(inputs[0]))
             axis = list(range(axis_len))
-        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
+        attr = {"axis": axis, "keepdims": attr.get("keepdims", True)}
         inputs[0] = inputs[0] * inputs[0]
         out = AttrCvt("sum")(inputs, attr)
 
@@ -1275,151 +1252,149 @@ class ReduceL2(OnnxOpConverter):
 
 
 class ReduceLogSum(OnnxOpConverter):
-    """ Operator converter for ReduceLogSum.
-    """
+    """Operator converter for ReduceLogSum."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'axes' in attr:
-            axis = attr.get('axes', 0)
+        if "axes" in attr:
+            axis = attr.get("axes", 0)
         else:
             axis_len = len(infer_shape(inputs[0]))
             axis = list(range(axis_len))
-        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
+        attr = {"axis": axis, "keepdims": attr.get("keepdims", True)}
         out = AttrCvt("sum")(inputs, attr)
 
         return _op.log(out)
 
 
 class ArgMax(OnnxOpConverter):
-    """ Operator converter for ArgMax.
-    """
+    """Operator converter for ArgMax."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get('axis', 0)
-        keepdims = attr.get('keepdims', True)
-        attr = {'axis': axis, 'keepdims': keepdims}
-        return AttrCvt('argmax')(inputs, attr)
+        axis = attr.get("axis", 0)
+        keepdims = attr.get("keepdims", True)
+        attr = {"axis": axis, "keepdims": keepdims}
+        return AttrCvt("argmax")(inputs, attr)
+
 
 class ArgMin(OnnxOpConverter):
-    """ Operator converter for ArgMin.
-    """
+    """Operator converter for ArgMin."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get('axis', 0)
-        keepdims = attr.get('keepdims', True)
-        attr = {'axis': axis, 'keepdims': keepdims}
-        return AttrCvt('argmin')(inputs, attr)
+        axis = attr.get("axis", 0)
+        keepdims = attr.get("keepdims", True)
+        attr = {"axis": axis, "keepdims": keepdims}
+        return AttrCvt("argmin")(inputs, attr)
+
 
 class Softmax(OnnxOpConverter):
-    """ Operator converter for Softmax.
-    """
+    """Operator converter for Softmax."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # set default value when axis is not set in the model
-        if 'axis' not in attr:
-            attr['axis'] = 1
-        return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)
+        if "axis" not in attr:
+            attr["axis"] = 1
+        return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params)
 
 
 class OneHot(OnnxOpConverter):
-    """ Operator converter for OneHot.
-    """
+    """Operator converter for OneHot."""
+
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
         # Extract relay one_hot inputs.
         indices, depth, values = inputs
         # Split onnx on off values into two separate expressions.
-        off_value, on_value = _op.take(
-            values, _op.const(0)), _op.take(values, _op.const(1))
+        off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
         # Extract the datatype of the output from on_value.
         dtype = infer_type(on_value).checked_type.dtype
         # Convert depth into an integer.
         depth = int(infer_value(depth, params).asnumpy()[0])
         # set default value when axis is not set in the model
-        if 'axis' not in attr:
-            attr['axis'] = -1
-        return _op.one_hot(indices,
-                           on_value,
-                           off_value,
-                           depth,
-                           int(attr['axis']),
-                           dtype=dtype)
+        if "axis" not in attr:
+            attr["axis"] = -1
+        return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype)
 
 
 class ConstantOfShape(OnnxOpConverter):
-    """ Operator converter for ConstantOfShape.
-    """
+    """Operator converter for ConstantOfShape."""
+
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
-        if 'value' in attr:
-            np_value = get_numpy(attr.pop('value'))[0]
+        if "value" in attr:
+            np_value = get_numpy(attr.pop("value"))[0]
             value = _expr.const(np_value)
             dtype = np_value.dtype.name
         else:
             value = _expr.const(0)
-            dtype = 'float32'
+            dtype = "float32"
         static_shape = infer_value_simulated(inputs[0], params)
-        output = _op.full(
-            value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
+        output = _op.full(value, shape=tuple(static_shape.asnumpy().astype("int32")), dtype=dtype)
         return output
 
 
 class Sign(OnnxOpConverter):
-    """ Operator converter for Sign.
-    """
+    """Operator converter for Sign."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.sign(inputs[0])
 
+
 class Equal(Elemwise):
-    """ Operator converter for Equal.
-    """
-    name = 'equal'
+    """Operator converter for Equal."""
+
+    name = "equal"
 
 
 class Not(Elemwise):
-    """ Operator converter for Not.
-    """
+    """Operator converter for Not."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.logical_not(inputs[0])
 
 
 class And(Elemwise):
-    """ Operator converter for And.
-    """
+    """Operator converter for And."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.logical_and(inputs[0], inputs[1])
 
 
 class Tile(Elemwise):
-    """Operator converter for Tile
-    """
+    """Operator converter for Tile"""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'repeats' not in attr:
-            raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
-                                               'for operator Tile.')
-        reps = attr.pop('repeats')  # The number of times repeating the tensor data.
+        if "repeats" not in attr:
+            raise tvm.error.OpAttributeInvalid(
+                'Attribute "repeats" should be set ' "for operator Tile."
+            )
+        reps = attr.pop("repeats")  # The number of times repeating the tensor data.
         return _op.tile(inputs[0], reps)
 
     @classmethod
     def _impl_v6(cls, inputs, attr, params):
-        reps = tuple(infer_value_simulated(
-            inputs[1], params).asnumpy().astype('int32'))
+        reps = tuple(infer_value_simulated(inputs[1], params).asnumpy().astype("int32"))
         return _op.tile(inputs[0], reps)
 
+
 class Erf(OnnxOpConverter):
-    """Operator converter for Erf
-    """
+    """Operator converter for Erf"""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         return _op.erf(inputs[0])
 
+
 class Where(OnnxOpConverter):
-    """Operator converter for Where
-    """
+    """Operator converter for Where"""
+
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
         condition_shape = infer_shape(inputs[0])
@@ -1451,24 +1426,25 @@ class Where(OnnxOpConverter):
             inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape)
         return _op.where(inputs[0], inputs[1], inputs[2])
 
+
 class Or(Elemwise):
-    """ Operator converter for Or.
-    """
+    """Operator converter for Or."""
+
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
         return _op.logical_or(inputs[0], inputs[1])
 
 
 class Expand(OnnxOpConverter):
-    """ Operator converter for Expand.
-    """
+    """Operator converter for Expand."""
+
     @classmethod
     def _impl_v8(cls, inputs, attr, params):
-        in_shape = np.array(infer_shape(inputs[0])).astype('int32')
+        in_shape = np.array(infer_shape(inputs[0])).astype("int32")
         if get_name(inputs[1]) in params:
-            shape = params[inputs[1].name_hint].asnumpy().astype('int32')
+            shape = params[inputs[1].name_hint].asnumpy().astype("int32")
         else:
-            shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')
+            shape = infer_value_simulated(inputs[1], params).asnumpy().astype("int32")
 
         # Currently 'op.broadcast_to' expect the rank of the given 'shape'
         # (the 2nd input) is always higher than that of the given 'input' (the 1st input)
@@ -1479,7 +1455,7 @@ class Expand(OnnxOpConverter):
         # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
         # so, here we solved this problem by expanding the given 'shape' itself.
         def expand_shape(in_shape, shape):
-            """ A function expands the shape when the rank is lower than that of the given
+            """A function expands the shape when the rank is lower than that of the given
             intput. Also it replaces the extent of the shape with the corresponding extent
             of the intput when it is 1.
             """
@@ -1508,17 +1484,16 @@ class Expand(OnnxOpConverter):
 
 
 class RNN(OnnxOpConverter):
-    """ Operator converter for RNNs such as LSTM and GRU.
-    """
+    """Operator converter for RNNs such as LSTM and GRU."""
 
     @classmethod
     def _activation_helper(cls, activation, alpha, beta):
         convert_map = _get_convert_map(1)
         attrs = {}
         if alpha is not None:
-            attrs['alpha'] = alpha
+            attrs["alpha"] = alpha
         if beta is not None:
-            attrs['beta'] = beta
+            attrs["beta"] = beta
         return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})
 
     @classmethod
@@ -1544,8 +1519,7 @@ class RNN(OnnxOpConverter):
 
 
 class LSTM(RNN):
-    """Operator converter for LSTM
-    """
+    """Operator converter for LSTM"""
 
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
@@ -1555,7 +1529,7 @@ class LSTM(RNN):
         R = inputs[2]
         B = inputs[3]
         # Sequence length currently unused as it can be inferred from shapes.
-        #sequence_lens = inputs['sequence_lens']
+        # sequence_lens = inputs['sequence_lens']
         h_0 = inputs[5]
         c_0 = inputs[6]
         P = inputs[7]
@@ -1593,17 +1567,16 @@ class LSTM(RNN):
         C_t = c_0
         h_list = []
 
-        if 'activations' in attr:
-            activations = attr['activations']
+        if "activations" in attr:
+            activations = attr["activations"]
             if len(activations) != 3:
-                raise NotImplementedError(
-                    "LSTM assumes 3 activation functions are provided")
+                raise NotImplementedError("LSTM assumes 3 activation functions are provided")
             alpha_loc = 0
-            alphas = attr.get('activation_alpha', [])
+            alphas = attr.get("activation_alpha", [])
             if isinstance(alphas, float):
                 alphas = [alphas]
             beta_loc = 0
-            betas = attr.get('activation_beta', [])
+            betas = attr.get("activation_beta", [])
             if isinstance(betas, float):
                 betas = [betas]
             acts = []
@@ -1611,12 +1584,10 @@ class LSTM(RNN):
                 alpha = None
                 beta = None
                 activation = activations[i]
-                if cls._activation_needs_alpha(
-                        activation) and len(alphas) > alpha_loc:
+                if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
                     alpha = alphas[alpha_loc]
                     alpha_loc += 1
-                if cls._activation_needs_beta(
-                        activation) and len(betas) > beta_loc:
+                if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
                     beta = betas[beta_loc]
                     beta_loc += 1
                 acts.append(cls._activation_helper(activation, alpha, beta))
@@ -1661,8 +1632,7 @@ class LSTM(RNN):
 
 
 class GRU(RNN):
-    """Operator convert for GRU
-    """
+    """Operator convert for GRU"""
 
     @classmethod
     def _impl_v7(cls, inputs, attr, params):
@@ -1672,9 +1642,9 @@ class GRU(RNN):
         R = inputs[2]
         B = inputs[3]
         # Sequence length currently unused as it can be inferred from shapes.
-        #sequence_lens = inputs['sequence_lens']
+        # sequence_lens = inputs['sequence_lens']
         h_0 = inputs[5]
-        linear_before_reset = attr.get('linear_before_reset', 0)
+        linear_before_reset = attr.get("linear_before_reset", 0)
 
         num_directions = infer_shape(W)[0]
         W_dtype = infer_type(W).type_annotation.dtype
@@ -1701,17 +1671,16 @@ class GRU(RNN):
         H_t = h_0
         h_list = []
 
-        if 'activations' in attr:
-            activations = attr['activations']
+        if "activations" in attr:
+            activations = attr["activations"]
             if len(activations) != 2:
-                raise NotImplementedError(
-                    "GRU assumes 2 activation functions are provided")
+                raise NotImplementedError("GRU assumes 2 activation functions are provided")
             alpha_loc = 0
-            alphas = attr.get('activation_alpha', [])
+            alphas = attr.get("activation_alpha", [])
             if isinstance(alphas, float):
                 alphas = [alphas]
             beta_loc = 0
-            betas = attr.get('activation_beta', [])
+            betas = attr.get("activation_beta", [])
             if isinstance(betas, float):
                 betas = [betas]
             acts = []
@@ -1719,12 +1688,10 @@ class GRU(RNN):
                 alpha = None
                 beta = None
                 activation = activations[i]
-                if cls._activation_needs_alpha(
-                        activation) and len(alphas) > alpha_loc:
+                if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
                     alpha = alphas[alpha_loc]
                     alpha_loc += 1
-                if cls._activation_needs_beta(
-                        activation) and len(betas) > beta_loc:
+                if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
                     beta = betas[beta_loc]
                     beta_loc += 1
                 acts.append(cls._activation_helper(activation, alpha, beta))
@@ -1772,18 +1739,19 @@ class GRU(RNN):
 
 
 class Resize(OnnxOpConverter):
-    """Operator converter for Resize
-    """
+    """Operator converter for Resize"""
+
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-        mode = attr.get('mode')
-        if mode == b'nearest':
+        mode = attr.get("mode")
+        if mode == b"nearest":
             method = "nearest_neighbor"
-        elif mode == b'linear':
+        elif mode == b"linear":
             method = "bilinear"
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode))
+                'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)
+            )
 
         in_size = np.array(infer_shape(inputs[0]))
         scale = infer_value_simulated(inputs[2], params).asnumpy()
@@ -1794,37 +1762,39 @@ class Resize(OnnxOpConverter):
             assert len(scale) != 0, "One of scale or size should be passed."
             size = (in_size * scale).astype(np.int32)
 
-        coord_trans = attr.get('coordinate_transformation_mode')
-        if coord_trans in [b'pytorch_half_pixel', b'half_pixel']:
+        coord_trans = attr.get("coordinate_transformation_mode")
+        if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
             coord_trans = "half_pixel"
-        elif coord_trans == b'align_corners':
+        elif coord_trans == b"align_corners":
             coord_trans = "align_corners"
-        elif coord_trans == b'asymmetric' or method == "nearest_neighbor":
+        elif coord_trans == b"asymmetric" or method == "nearest_neighbor":
             coord_trans = "asymmetric"
         else:
             raise tvm.error.OpAttributeInvalid(
-                'Unsupported coordinate_transformation_mode: {}'.format(coord_trans))
+                "Unsupported coordinate_transformation_mode: {}".format(coord_trans)
+            )
         layout = "NCHW"  # ONNX assumes NCHW layout
         out_size = (size[2], size[3])
         return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)
 
 
 class NonZero(OnnxOpConverter):
-    """Operator converter for NonZero
-    """
+    """Operator converter for NonZero"""
+
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
         if len(inputs) > 1:
             raise ValueError("Expect 1 input only")
 
-        output = AttrCvt(op_name='argwhere')(inputs, attr, params)
+        output = AttrCvt(op_name="argwhere")(inputs, attr, params)
         # ONNX NonZero always outputs int64
         output = _op.cast(output, "int64")
         return _op.transpose(output, axes=(1, 0))
 
+
 class TopK(OnnxOpConverter):
-    """Operator converter for TopK
-    """
+    """Operator converter for TopK"""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if len(inputs) != 2:
@@ -1841,8 +1811,8 @@ class TopK(OnnxOpConverter):
 
 
 class MaxRoiPool(OnnxOpConverter):
-    """Operator converter for MaxRoiPool.
-    """
+    """Operator converter for MaxRoiPool."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MMaxRoiPool op take 2 inputs, {} given".format(len(inputs))
@@ -1856,8 +1826,8 @@ class MaxRoiPool(OnnxOpConverter):
 
 
 class RoiAlign(OnnxOpConverter):
-    """Operator converter for RoiAlign.
-    """
+    """Operator converter for RoiAlign."""
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if len(inputs) != 3:
@@ -1866,7 +1836,7 @@ class RoiAlign(OnnxOpConverter):
         rois = inputs[1]
         batch_indices = inputs[2]
         mode = attr.get("mode", "avg")
-        if mode != b'avg':
+        if mode != b"avg":
             raise ValueError("RoiAlign in Relay only uses avg mode")
         output_height = attr.get("output_height", 1)
         output_width = attr.get("output_width", 1)
@@ -1875,19 +1845,20 @@ class RoiAlign(OnnxOpConverter):
         spatial_scale = attr.get("spatial_scale", 1.0)
 
         batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1)
-        batch_indices = _op.cast(
-            batch_indices, infer_type(rois).type_annotation.dtype)
+        batch_indices = _op.cast(batch_indices, infer_type(rois).type_annotation.dtype)
         rois = _op.concatenate([batch_indices, rois], 1)
 
-        return _vision.roi_align(x, rois, [output_height, output_width],
-                                 spatial_scale, sampling_ratio)
+        return _vision.roi_align(
+            x, rois, [output_height, output_width], spatial_scale, sampling_ratio
+        )
+
 
 class Clip(OnnxOpConverter):
-    """Operator converter for Clip.
-    """
+    """Operator converter for Clip."""
+
     @staticmethod
     def convert_attributes(inputs, attr, params):
-        convert = AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'})
+        convert = AttrCvt("clip", transforms={"min": "a_min", "max": "a_max"})
         return convert(inputs, attr, params)
 
     @classmethod
@@ -1896,16 +1867,17 @@ class Clip(OnnxOpConverter):
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-        if 'min' in attr and 'max' in attr:
+        if "min" in attr and "max" in attr:
             return Clip.convert_attributes(inputs, attr, params)
 
         assert len(inputs) <= 3, "Clip-11 takes up to 3 inputs, input, min, max"
         result = inputs[0]
         for i, op in enumerate([_maximum, _minimum]):
             if i < len(inputs) - 1:
-                result = op(result, inputs[i+1])
+                result = op(result, inputs[i + 1])
         return result
 
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1918,157 +1890,150 @@ _identity_list = []
 def _get_convert_map(opset):
     return {
         # defs/experimental
-        'Identity': Renamer('copy'),
-        'Affine': Affine.get_converter(opset),
-        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
-        'ScaledTanh': ScaledTanh.get_converter(opset),
-        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
-        'ConstantOfShape': ConstantOfShape.get_converter(opset),
+        "Identity": Renamer("copy"),
+        "Affine": Affine.get_converter(opset),
+        "ThresholdedRelu": ThresholdedRelu.get_converter(opset),
+        "ScaledTanh": ScaledTanh.get_converter(opset),
+        "ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
+        "ConstantOfShape": ConstantOfShape.get_converter(opset),
         # 'GivenTensorFill'
-        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
-        'Scale': Scale.get_converter(opset),
+        "FC": AttrCvt("dense", ignores=["axis", "axis_w"]),
+        "Scale": Scale.get_converter(opset),
         # 'GRUUnit'
         # 'ATen'
         # 'ImageScaler'
         # 'MeanVarianceNormalization'
         # 'Crop'
         # 'Embedding'
-        'Upsample': Upsample.get_converter(opset),
-        'SpatialBN': BatchNorm.get_converter(opset),
-
+        "Upsample": Upsample.get_converter(opset),
+        "SpatialBN": BatchNorm.get_converter(opset),
         # defs/generator
         # 'Constant' # Implemented
         # 'RandomUniform'
         # 'RandomNormal'
         # 'RandomUniformLike'
         # 'RandomNormalLike'
-
         # defs/logical
-
         # defs/math
-        'Add': Add.get_converter(opset),
-        'Sub': Sub.get_converter(opset),
-        'Mul': Mul.get_converter(opset),
-        'Div': Div.get_converter(opset),
-        'Neg': Renamer('negative'),
-        'Abs': Absolute.get_converter(opset),
-        'Reciprocal': Reciprocal.get_converter(opset),
-        'Floor': Renamer('floor'),
-        'Ceil': Renamer('ceil'),
-        'Round': Renamer('round'),
-        'IsInf': Renamer('isinf'),
-        'IsNaN': Renamer('isnan'),
-        'Sqrt': Renamer('sqrt'),
-        'Relu': Renamer('relu'),
-        'LeakyRelu': Renamer('leaky_relu'),
-        'Selu': Selu.get_converter(opset),
-        'Elu': Elu.get_converter(opset),
-        'Exp': Renamer('exp'),
-        'Greater': Greater.get_converter(opset),
-        'Less': Less.get_converter(opset),
-        'Log': Renamer('log'),
-        'ACos': Renamer('acos'),
-        'ACosh': Renamer('acosh'),
-        'ASin': Renamer('asin'),
-        'ASinh': Renamer('asinh'),
-        'ATan': Renamer('atan'),
-        'ATanh': Renamer('atanh'),
-        'Cos': Renamer('cos'),
-        'Cosh': Renamer('cosh'),
-        'Sin': Renamer('sin'),
-        'Sinh': Renamer('sinh'),
-        'Tan': Renamer('tan'),
-        'Tanh': Renamer('tanh'),
-        'Pow': Renamer('power'),
-        'PRelu': Prelu.get_converter(opset),
-        'Sigmoid': Renamer('sigmoid'),
-        'HardSigmoid': HardSigmoid.get_converter(opset),
-        'Max': Maximum.get_converter(opset),
-        'Min': Minimum.get_converter(opset),
-        'Sum': Sum.get_converter(opset),
-        'Mean': Mean.get_converter(opset),
-        'Clip': Clip.get_converter(opset),
+        "Add": Add.get_converter(opset),
+        "Sub": Sub.get_converter(opset),
+        "Mul": Mul.get_converter(opset),
+        "Div": Div.get_converter(opset),
+        "Neg": Renamer("negative"),
+        "Abs": Absolute.get_converter(opset),
+        "Reciprocal": Reciprocal.get_converter(opset),
+        "Floor": Renamer("floor"),
+        "Ceil": Renamer("ceil"),
+        "Round": Renamer("round"),
+        "IsInf": Renamer("isinf"),
+        "IsNaN": Renamer("isnan"),
+        "Sqrt": Renamer("sqrt"),
+        "Relu": Renamer("relu"),
+        "LeakyRelu": Renamer("leaky_relu"),
+        "Selu": Selu.get_converter(opset),
+        "Elu": Elu.get_converter(opset),
+        "Exp": Renamer("exp"),
+        "Greater": Greater.get_converter(opset),
+        "Less": Less.get_converter(opset),
+        "Log": Renamer("log"),
+        "ACos": Renamer("acos"),
+        "ACosh": Renamer("acosh"),
+        "ASin": Renamer("asin"),
+        "ASinh": Renamer("asinh"),
+        "ATan": Renamer("atan"),
+        "ATanh": Renamer("atanh"),
+        "Cos": Renamer("cos"),
+        "Cosh": Renamer("cosh"),
+        "Sin": Renamer("sin"),
+        "Sinh": Renamer("sinh"),
+        "Tan": Renamer("tan"),
+        "Tanh": Renamer("tanh"),
+        "Pow": Renamer("power"),
+        "PRelu": Prelu.get_converter(opset),
+        "Sigmoid": Renamer("sigmoid"),
+        "HardSigmoid": HardSigmoid.get_converter(opset),
+        "Max": Maximum.get_converter(opset),
+        "Min": Minimum.get_converter(opset),
+        "Sum": Sum.get_converter(opset),
+        "Mean": Mean.get_converter(opset),
+        "Clip": Clip.get_converter(opset),
         # softmax default axis is different in onnx
-        'Softmax': Softmax.get_converter(opset),
-        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
-        'OneHot': OneHot.get_converter(opset),
+        "Softmax": Softmax.get_converter(opset),
+        "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}),
+        "OneHot": OneHot.get_converter(opset),
         # 'Hardmax'
-        'Softsign': Softsign.get_converter(opset),
-        'SoftPlus': SoftPlus.get_converter(opset),
-        'Gemm': Gemm.get_converter(opset),
-        'MatMul': MatMul.get_converter(opset),
-        'Mod': Mod.get_converter(opset),
-        'Xor': Renamer('logical_xor'),
-
+        "Softsign": Softsign.get_converter(opset),
+        "SoftPlus": SoftPlus.get_converter(opset),
+        "Gemm": Gemm.get_converter(opset),
+        "MatMul": MatMul.get_converter(opset),
+        "Mod": Mod.get_converter(opset),
+        "Xor": Renamer("logical_xor"),
         # defs/nn
-        'AveragePool': AveragePool.get_converter(opset),
-        'LpPool': LpPool.get_converter(opset),
-        'MaxPool': MaxPool.get_converter(opset),
-        'Conv': Conv.get_converter(opset),
-        'ConvTranspose': ConvTranspose.get_converter(opset),
-        'GlobalAveragePool': Renamer('global_avg_pool2d'),
-        'GlobalMaxPool': Renamer('global_max_pool2d'),
-        'BatchNormalization': BatchNorm.get_converter(opset),
-        'InstanceNormalization': InstanceNorm.get_converter(opset),
+        "AveragePool": AveragePool.get_converter(opset),
+        "LpPool": LpPool.get_converter(opset),
+        "MaxPool": MaxPool.get_converter(opset),
+        "Conv": Conv.get_converter(opset),
+        "ConvTranspose": ConvTranspose.get_converter(opset),
+        "GlobalAveragePool": Renamer("global_avg_pool2d"),
+        "GlobalMaxPool": Renamer("global_max_pool2d"),
+        "BatchNormalization": BatchNorm.get_converter(opset),
+        "InstanceNormalization": InstanceNorm.get_converter(opset),
         # 'LpNormalization'
-        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
-        'Flatten': Flatten.get_converter(opset),
-        'LRN': LRN.get_converter(opset),
+        "Dropout": AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]),
+        "Flatten": Flatten.get_converter(opset),
+        "LRN": LRN.get_converter(opset),
         # Recurrent Layers
-        'LSTM': LSTM.get_converter(opset),
-        'GRU': GRU.get_converter(opset),
-
+        "LSTM": LSTM.get_converter(opset),
+        "GRU": GRU.get_converter(opset),
         # defs/vision
-        'MaxRoiPool': MaxRoiPool.get_converter(opset),
-        'RoiAlign': RoiAlign.get_converter(opset),
-
+        "MaxRoiPool": MaxRoiPool.get_converter(opset),
+        "RoiAlign": RoiAlign.get_converter(opset),
         # defs/reduction
-        'ReduceMax': ReduceMax.get_converter(opset),
-        'ReduceMin': ReduceMin.get_converter(opset),
-        'ReduceSum': ReduceSum.get_converter(opset),
-        'ReduceMean': ReduceMean.get_converter(opset),
-        'ReduceProd': ReduceProd.get_converter(opset),
-        'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset),
-        'ReduceLogSum': ReduceLogSum.get_converter(opset),
-        'ReduceSumSquare': ReduceSumSquare.get_converter(opset),
-        'ReduceL1': ReduceL1.get_converter(opset),
-        'ReduceL2': ReduceL2.get_converter(opset),
-
-        #defs/sorting
-        'ArgMax': ArgMax.get_converter(opset),
-        'ArgMin': ArgMin.get_converter(opset),
-        'TopK': TopK.get_converter(opset),
-
+        "ReduceMax": ReduceMax.get_converter(opset),
+        "ReduceMin": ReduceMin.get_converter(opset),
+        "ReduceSum": ReduceSum.get_converter(opset),
+        "ReduceMean": ReduceMean.get_converter(opset),
+        "ReduceProd": ReduceProd.get_converter(opset),
+        "ReduceLogSumExp": ReduceLogSumExp.get_converter(opset),
+        "ReduceLogSum": ReduceLogSum.get_converter(opset),
+        "ReduceSumSquare": ReduceSumSquare.get_converter(opset),
+        "ReduceL1": ReduceL1.get_converter(opset),
+        "ReduceL2": ReduceL2.get_converter(opset),
+        # defs/sorting
+        "ArgMax": ArgMax.get_converter(opset),
+        "ArgMin": ArgMin.get_converter(opset),
+        "TopK": TopK.get_converter(opset),
         # defs/tensor
-        'Cast': Cast.get_converter(opset),
-        'Reshape': Reshape.get_converter(opset),
-        'Expand': Expand.get_converter(opset),
-        'Concat': Concat.get_converter(opset),
-        'Split': Split.get_converter(opset),
-        'Slice': Slice.get_converter(opset),
-        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
-        'DepthToSpace': DepthToSpace.get_converter(opset),
-        'SpaceToDepth': SpaceToDepth.get_converter(opset),
-        'Gather': Gather.get_converter(opset),
-        'GatherND': GatherND.get_converter(opset),
-        'Scatter': Scatter.get_converter(opset),
-        'ScatterElements': Scatter.get_converter(opset),
-        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
-        'Unsqueeze': Unsqueeze.get_converter(opset),
-        'Pad': Pad.get_converter(opset),
-        'Shape': Shape.get_converter(opset),
-        'Sign': Sign.get_converter(opset),
-        'Equal': Equal.get_converter(opset),
-        'Not': Not.get_converter(opset),
-        'And': And.get_converter(opset),
-        'Tile': Tile.get_converter(opset),
-        'Erf': Erf.get_converter(opset),
-        'Where': Where.get_converter(opset),
-        'Or': Or.get_converter(opset),
-        'Resize': Resize.get_converter(opset),
-        'NonZero': NonZero.get_converter(opset),
+        "Cast": Cast.get_converter(opset),
+        "Reshape": Reshape.get_converter(opset),
+        "Expand": Expand.get_converter(opset),
+        "Concat": Concat.get_converter(opset),
+        "Split": Split.get_converter(opset),
+        "Slice": Slice.get_converter(opset),
+        "Transpose": AttrCvt("transpose", {"perm": "axes"}),
+        "DepthToSpace": DepthToSpace.get_converter(opset),
+        "SpaceToDepth": SpaceToDepth.get_converter(opset),
+        "Gather": Gather.get_converter(opset),
+        "GatherND": GatherND.get_converter(opset),
+        "Scatter": Scatter.get_converter(opset),
+        "ScatterElements": Scatter.get_converter(opset),
+        "Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
+        "Unsqueeze": Unsqueeze.get_converter(opset),
+        "Pad": Pad.get_converter(opset),
+        "Shape": Shape.get_converter(opset),
+        "Sign": Sign.get_converter(opset),
+        "Equal": Equal.get_converter(opset),
+        "Not": Not.get_converter(opset),
+        "And": And.get_converter(opset),
+        "Tile": Tile.get_converter(opset),
+        "Erf": Erf.get_converter(opset),
+        "Where": Where.get_converter(opset),
+        "Or": Or.get_converter(opset),
+        "Resize": Resize.get_converter(opset),
+        "NonZero": NonZero.get_converter(opset),
     }
 
+
 class GraphProto(ExprFunctor):
     """A helper class for handling Relay expression copying from pb2.GraphProto.
     Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
@@ -2091,7 +2056,7 @@ class GraphProto(ExprFunctor):
         self._shape = shape if shape else {}
         self._dtype = dtype
 
-        #For infering Values
+        # For infering Values
         self._tmp_params = {}
         self._infer_simulated = True
         self._mod = None
@@ -2118,12 +2083,9 @@ class GraphProto(ExprFunctor):
     def visit_function(self, fn):
         new_params = [self.visit(x) for x in fn.params]
         new_body = self.visit(fn.body)
-        return self.infer(Function(
-            list(new_params),
-            new_body,
-            fn.ret_type,
-            fn.type_params,
-            fn.attrs))
+        return self.infer(
+            Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+        )
 
     def visit_let(self, let):
         newvar = self.visit(let.var)
@@ -2146,10 +2108,9 @@ class GraphProto(ExprFunctor):
         return self.infer(global_var)
 
     def visit_if(self, ite):
-        return self.infer(If(
-            self.visit(ite.cond),
-            self.visit(ite.true_branch),
-            self.visit(ite.false_branch)))
+        return self.infer(
+            If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch))
+        )
 
     def visit_tuple(self, tup):
         return Tuple([self.visit(field) for field in tup.fields])
@@ -2173,10 +2134,13 @@ class GraphProto(ExprFunctor):
         return con
 
     def visit_match(self, m):
-        return self.infer(Match(
-            self.visit(m.data),
-            [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
-            complete=m.complete))
+        return self.infer(
+            Match(
+                self.visit(m.data),
+                [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
+                complete=m.complete,
+            )
+        )
 
     def visit_ref_create(self, r):
         return RefCreate(self.visit(r.value))
@@ -2217,21 +2181,23 @@ class GraphProto(ExprFunctor):
             if not init_tensor.name.strip():
                 raise ValueError("Tensor's name is required.")
             self._params[init_tensor.name] = self._parse_array(init_tensor)
-            self._nodes[init_tensor.name] = new_var(init_tensor.name,
-                                                    shape=self._params[init_tensor.name].shape,
-                                                    dtype=self._params[init_tensor.name].dtype)
+            self._nodes[init_tensor.name] = new_var(
+                init_tensor.name,
+                shape=self._params[init_tensor.name].shape,
+                dtype=self._params[init_tensor.name].dtype,
+            )
         for i in graph.input:
             # from onnx v0.2, GraphProto.input has type ValueInfoProto,
             #  and the name is 'i.name'
             i_name = self._parse_value_proto(i)
-            d_type = self._parse_dtype(i, 'float32')
+            d_type = self._parse_dtype(i, "float32")
             if i_name in self._params:
                 # i is a param instead of input
                 self._num_param += 1
                 self._params[i_name] = self._params.pop(i_name)
-                self._nodes[i_name] = new_var(i_name,
-                                              shape=self._params[i_name].shape,
-                                              dtype=self._params[i_name].dtype)
+                self._nodes[i_name] = new_var(
+                    i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
+                )
             else:
                 self._num_input += 1
                 if i_name in self._shape:
@@ -2248,13 +2214,15 @@ class GraphProto(ExprFunctor):
         unsupported_ops = set()
         for node in graph.node:
             op_name = node.op_type
-            if op_name not in convert_map and \
-               op_name != 'Constant' and \
-               op_name not in _identity_list:
+            if (
+                op_name not in convert_map
+                and op_name != "Constant"
+                and op_name not in _identity_list
+            ):
                 unsupported_ops.add(op_name)
         if unsupported_ops:
-            msg = 'The following operators are not supported for frontend ONNX: '
-            msg += ', '.join(unsupported_ops)
+            msg = "The following operators are not supported for frontend ONNX: "
+            msg += ", ".join(unsupported_ops)
             raise tvm.error.OpNotImplemented(msg)
         # construct nodes, nodes are stored as directed acyclic graph
         for node in graph.node:
@@ -2263,7 +2231,7 @@ class GraphProto(ExprFunctor):
             # Create and populate onnx input object.
             inputs = onnx_input()
             for i in node.input:
-                if i != '':
+                if i != "":
                     inputs[i] = self._nodes[self._renames.get(i, i)]
             if op_name == "Constant":
                 t_proto = self._parse_attr(node.attribute)["value"]
@@ -2272,13 +2240,12 @@ class GraphProto(ExprFunctor):
                 array = self._parse_array(t_proto)
                 self._params[node.output[0]] = array
                 self._nodes[node.output[0]] = new_var(
-                    node.output[0],
-                    shape=list(t_proto.dims),
-                    dtype=array.dtype)
+                    node.output[0], shape=list(t_proto.dims), dtype=array.dtype
+                )
             else:
                 i_name = self._parse_value_proto(node)
-                attr['tvm_custom'] = {}
-                attr['tvm_custom']['name'] = i_name
+                attr["tvm_custom"] = {}
+                attr["tvm_custom"]["name"] = i_name
 
                 op = self._convert_operator(op_name, inputs, attr, opset)
                 node_output = self._fix_outputs(op_name, node.output)
@@ -2286,9 +2253,11 @@ class GraphProto(ExprFunctor):
                     outputs_num = 1
                 else:
                     outputs_num = len(op)
-                assert len(node_output) == outputs_num, (
-                    "Number of output mismatch {} vs {} in {}.".format(
-                        len(node_output), outputs_num, op_name))
+                assert (
+                    len(node_output) == outputs_num
+                ), "Number of output mismatch {} vs {} in {}.".format(
+                    len(node_output), outputs_num, op_name
+                )
                 if outputs_num == 1:
                     self._nodes[node_output[0]] = op
                 else:
@@ -2313,6 +2282,7 @@ class GraphProto(ExprFunctor):
         """Parse dtype."""
         try:
             from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+
             return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
         except AttributeError:
             return dtype
@@ -2325,37 +2295,31 @@ class GraphProto(ExprFunctor):
         """Convert a list of AttributeProto to a dict, with names as keys."""
         attrs = {}
         for a in attr_proto:
-            for f in ['f', 'i', 's']:
+            for f in ["f", "i", "s"]:
                 if a.HasField(f):
                     attrs[a.name] = getattr(a, f)
-            for f in ['floats', 'ints', 'strings']:
+            for f in ["floats", "ints", "strings"]:
                 if list(getattr(a, f)):
                     assert a.name not in attrs, "Only one type of attr is allowed"
                     attrs[a.name] = tuple(getattr(a, f))
-            for f in ['t']:
+            for f in ["t"]:
                 if a.HasField(f):
                     attrs[a.name] = getattr(a, f)
-            for f in ['tensors']:
+            for f in ["tensors"]:
                 if list(getattr(a, f)):
                     assert a.name not in attrs, "Only one type of attr is allowed"
                     attrs[a.name] = tuple(getattr(a, f))
-            for f in ['g']:
+            for f in ["g"]:
                 if a.HasField(f):
-                    raise NotImplementedError(
-                        "Filed {} is not supported in relay.".format(f))
-            for f in ['graphs']:
+                    raise NotImplementedError("Filed {} is not supported in relay.".format(f))
+            for f in ["graphs"]:
                 if list(getattr(a, f)):
-                    raise NotImplementedError(
-                        "Filed {} is not supported in relay.".format(f))
+                    raise NotImplementedError("Filed {} is not supported in relay.".format(f))
             if a.name not in attrs:
                 raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
         return attrs
 
-    def _convert_operator(self,
-                          op_name,
-                          inputs,
-                          attrs,
-                          opset):
+    def _convert_operator(self, op_name, inputs, attrs, opset):
         """Convert ONNX operator into a Relay operator.
         The converter must specify conversions explicitly for incompatible name, and
         apply handlers to operator attributes.
@@ -2382,25 +2346,22 @@ class GraphProto(ExprFunctor):
         elif op_name in convert_map:
             sym = convert_map[op_name](inputs, attrs, self._params)
         else:
-            raise NotImplementedError(
-                "Operator {} not implemented.".format(op_name))
+            raise NotImplementedError("Operator {} not implemented.".format(op_name))
         return sym
 
     def _fix_outputs(self, op_name, outputs):
         """A hack to handle dropout or similar operator that have more than one out
         in ONNX.
         """
-        if op_name == 'Dropout':
+        if op_name == "Dropout":
             if len(outputs) == 1:
                 return outputs
             # TODO(zhreshold): support dropout mask?
             outputs = outputs[:-1]
         return outputs
 
-def from_onnx(model,
-              shape=None,
-              dtype="float32",
-              opset=None):
+
+def from_onnx(model, shape=None, dtype="float32", opset=None):
     """Convert a ONNX model into an equivalent Relay Function.
 
     ONNX graphs are represented as Python Protobuf objects.
@@ -2435,12 +2396,14 @@ def from_onnx(model,
     """
     try:
         import onnx
-        if hasattr(onnx.checker, 'check_model'):
+
+        if hasattr(onnx.checker, "check_model"):
             # try use onnx's own model checker before converting any model
             try:
                 onnx.checker.check_model(model)
             except onnx.onnx_cpp2py_export.checker.ValidationError as e:
                 import warnings
+
                 # the checker is a bit violent about errors, so simply print warnings here
                 warnings.warn(str(e))
     except ImportError:
index 19cbf75..33ce58f 100644 (file)
@@ -64,7 +64,7 @@ def _convert_to_list_adt(py_lst, prelude):
 def _map_tensor_array_constructor(adt_lst, prelude, shape):
     static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
     static_tensor_array_ops.register()
-    tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape)
+    tensor_create = prelude.get_var_static("tensor_constructor", "float32", shape)
     return prelude.map(tensor_create, adt_lst)
 
 
@@ -129,6 +129,7 @@ def _elemwise(name):
     def _impl(inputs, input_types):
         data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
         return get_relay_op(name)(data0, data1)
+
     return _impl
 
 
@@ -146,6 +147,7 @@ def _min_max_common(name_elemwise, name_reduce):
         else:
             data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
             return get_relay_op(name_elemwise)(data0, data1)
+
     return _impl
 
 
@@ -161,17 +163,19 @@ def _unary(name):
     def _impl(inputs, input_types):
         input_type = input_types[0]
         # this is just to ensure tensor input
-        data, = _pytorch_promote_types(inputs[:1], input_types[:1])
+        (data,) = _pytorch_promote_types(inputs[:1], input_types[:1])
         return get_relay_op(name)(data)
+
     return _impl
 
 
 def _log1p():
     def _impl(inputs, input_types):
         # 1_plus_log x = log(x + 1)
-        dtype, = input_types
+        (dtype,) = input_types
         one = _expr.const(1, dtype=dtype)
         return _op.log(inputs[0] + one)
+
     return _impl
 
 
@@ -219,12 +223,11 @@ def _arange():
             msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
             raise AssertionError(msg)
 
-        return _op.transform.arange(start=start,
-                                    stop=stop,
-                                    step=step,
-                                    dtype=dtype)
+        return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
+
     return _impl
 
+
 def _squeeze():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -235,14 +238,17 @@ def _squeeze():
             axis = [int(inputs[1])]
 
         return _op.transform.squeeze(data, axis)
+
     return _impl
 
+
 def _unsqueeze():
     def _impl(inputs, input_types):
         data = inputs[0]
         axis = inputs[1]
 
         return _op.transform.expand_dims(data, int(axis), 1)
+
     return _impl
 
 
@@ -251,12 +257,12 @@ def _concatenate(prelude):
         assert axis == 0, "Tensor array concat supported only for axis 0"
         tensor_array, shape = _convert_to_tensor_array(lst, prelude)
         concat_shape = (Any(),) + shape[1:]
-        concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
+        concat = prelude.get_var_static("tensor_array_concat", "float32", shape)
         concatenated = concat(tensor_array)
 
         static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape)
         static_tensor_array_ops.register()
-        get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape)
+        get_tensor = prelude.get_var_static("tensor_get_data", "float32", concat_shape)
         return get_tensor(concatenated)
 
     def _impl(inputs, input_types):
@@ -270,8 +276,10 @@ def _concatenate(prelude):
             data = [data]
 
         return _op.tensor.concatenate(data, int(axis))
+
     return _impl
 
+
 def _slice():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -308,13 +316,13 @@ def _slice():
         strides = [1] * len(end)
         strides[dim] = int(inputs[4])
 
-        return _op.transform.strided_slice(data,
-                                           begin=begin,
-                                           end=end,
-                                           strides=strides,
-                                           slice_mode="end")
+        return _op.transform.strided_slice(
+            data, begin=begin, end=end, strides=strides, slice_mode="end"
+        )
+
     return _impl
 
+
 def _split():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -328,8 +336,10 @@ def _split():
             split_index += split_size
 
         return _op.split(data, indices, dim)
+
     return _impl
 
+
 def _split_with_sizes():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -343,24 +353,30 @@ def _split_with_sizes():
             indices.append(split_index)
 
         return _op.split(data, indices, dim)
+
     return _impl
 
+
 def _select():
     def _impl(inputs, input_types):
         data = inputs[0]
         dim = int(inputs[1])
         index = _wrap_const(inputs[2])
         return _op.transform.take(data, index, axis=dim)
+
     return _impl
 
+
 def _take():
     def _impl(inputs, input_types):
         data = inputs[0]
         indices = _op.cast(inputs[1], "int32")
 
         return _op.transform.take(data, indices=indices)
+
     return _impl
 
+
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -376,21 +392,27 @@ def _topk():
         outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both")
 
         return outs[0], outs[1]
+
     return _impl
 
+
 def _reciprocal():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _expr.const(1.0, dtype=input_types[0]) / data
+
     return _impl
 
+
 def _repeat():
     def _impl(inputs, input_types):
         data = inputs[0]
         reps = _get_dims(inputs[1])
         return _op.transform.tile(data, reps=reps)
+
     return _impl
 
+
 def _repeat_interleave():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -400,10 +422,11 @@ def _repeat_interleave():
         else:
             msg = "Only repeat with one value as repeat is currently supported."
             raise AssertionError(msg)
-        if axis is None: # Flatten the data if no axis is given from torch
+        if axis is None:  # Flatten the data if no axis is given from torch
             data = _op.transform.reshape(data, [-1])
             axis = 0
         return _op.transform.repeat(data, repeats=repeats, axis=axis)
+
     return _impl
 
 
@@ -411,6 +434,7 @@ def _addcdiv():
     def _impl(inputs, input_types):
         data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
         return data + (c * (t1 / t2))
+
     return _impl
 
 
@@ -418,6 +442,7 @@ def _addcmul():
     def _impl(inputs, input_types):
         data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
         return data + (c * (t1 * t2))
+
     return _impl
 
 
@@ -435,6 +460,7 @@ def _ones():
         data = inputs[0]
 
         import torch
+
         if isinstance(data, _expr.Expr):
             shape = _infer_shape(data)
         elif isinstance(data, list):
@@ -448,8 +474,10 @@ def _ones():
         dtype = _convert_dtype_value(inputs[1])
 
         return _op.full(_expr.const(1), shape, dtype=dtype)
+
     return _impl
 
+
 def _ones_like():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -461,6 +489,7 @@ def _ones_like():
             out = _op.cast(out, dtype)
 
         return out
+
     return _impl
 
 
@@ -469,6 +498,7 @@ def _zeros():
         data = inputs[0]
 
         import torch
+
         if isinstance(data, _expr.Expr):
             shape = _infer_shape(data)
         elif isinstance(data, list):
@@ -482,6 +512,7 @@ def _zeros():
         dtype = _convert_dtype_value(inputs[1])
 
         return _op.full(_expr.const(0), shape, dtype=dtype)
+
     return _impl
 
 
@@ -496,6 +527,7 @@ def _zeros_like():
             out = _op.cast(out, dtype)
 
         return out
+
     return _impl
 
 
@@ -505,6 +537,7 @@ def _full(default_dtype):
 
         fill_value = inputs[1]
         import torch
+
         if isinstance(data, _expr.Expr):
             shape = _infer_shape(data)
         elif isinstance(data, list):
@@ -515,15 +548,17 @@ def _full(default_dtype):
             msg = "Data type %s could not be parsed in zeros op" % (type(data))
             raise AssertionError(msg)
 
-        if inputs[2] is not None: # dtype given
+        if inputs[2] is not None:  # dtype given
             dtype = _convert_dtype_value(inputs[2])
         else:
             # if dtype is None, torch uses a global default set by torch.set_default_tensor_type()
             dtype = default_dtype
 
         return _op.full(_expr.const(fill_value), shape, dtype=dtype)
+
     return _impl
 
+
 def _full_like():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -537,6 +572,7 @@ def _full_like():
             out = _op.cast(out, dtype)
 
         return out
+
     return _impl
 
 
@@ -553,16 +589,13 @@ def _linspace():
         else:
             stop = start + step
 
-        dtype = ("float32" if inputs[3] is not None
-                 else _convert_dtype_value(inputs[3]))
+        dtype = "float32" if inputs[3] is not None else _convert_dtype_value(inputs[3])
         start = _create_typed_const(start, dtype)
         stop = _create_typed_const(stop, dtype)
         step = _create_typed_const(step, dtype)
 
-        return _op.transform.arange(start=start,
-                                    stop=stop,
-                                    step=step,
-                                    dtype=dtype)
+        return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
+
     return _impl
 
 
@@ -574,39 +607,50 @@ def _relu(prelude):
             input_zero_point = _expr.const(inputs[2], dtype="int32")
             return qnn_torch.quantized_relu(data, input_zero_point)
         return _op.nn.relu(data)
+
     return _impl
 
+
 def _prelu():
     def _impl(inputs, input_types):
         data = inputs[0]
         alpha = inputs[1]
         return _op.nn.prelu(data, alpha)
+
     return _impl
 
+
 def _leaky_relu():
     def _impl(inputs, input_types):
         data = inputs[0]
         alpha = float(inputs[1])
         return _op.nn.leaky_relu(data, alpha)
+
     return _impl
 
+
 def _elu():
     def _impl(inputs, input_types):
         data = inputs[0]
         dtype = input_types[0]
         alpha = _expr.const(float(inputs[1]), dtype=dtype)
         return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
+
     return _impl
 
+
 def _celu():
     def _impl(inputs, input_types):
         data = inputs[0]
         dtype = input_types[0]
         alpha = _expr.const(float(inputs[1]), dtype=dtype)
-        return alpha * _op.nn.relu(_expr.const(1, dtype=dtype)
-                                   - _op.exp(data / alpha)) + _op.nn.relu(data)
+        return alpha * _op.nn.relu(
+            _expr.const(1, dtype=dtype) - _op.exp(data / alpha)
+        ) + _op.nn.relu(data)
+
     return _impl
 
+
 def _gelu():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -615,11 +659,14 @@ def _gelu():
         # normcdf expressed as erf because we don't currently have that intrinsic
         # note that there is also a fastgelu variant approximating normcdf
         # with tanh and third order polynomials, but this is "true" gelu
-        return data * (_expr.const(0.5, dtype=dtype) +
-                       _op.erf(data * _expr.const(0.5**0.5, dtype=dtype))
-                       * _expr.const(0.5, dtype=dtype))
+        return data * (
+            _expr.const(0.5, dtype=dtype)
+            + _op.erf(data * _expr.const(0.5 ** 0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype)
+        )
+
     return _impl
 
+
 def _selu():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -627,16 +674,21 @@ def _selu():
         dtype = input_types[0]
         alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype)
         gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype)
-        return gamma * (alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype)
-                                            - _op.exp(data)) + _op.nn.relu(data))
+        return gamma * (
+            alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
+        )
+
     return _impl
 
+
 def _log_sigmoid():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.log(_op.tensor.sigmoid(data))
+
     return _impl
 
+
 def _adaptive_avg_pool_2d(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -652,17 +704,18 @@ def _adaptive_avg_pool_2d(prelude):
 
     return _impl
 
+
 def _adaptive_max_pool_2d():
     def _impl(inputs, input_types):
         data = inputs[0]
         output_size = inputs[1]
 
         # returns dummy indices too
-        return _op.nn.adaptive_max_pool2d(
-            data,
-            output_size=output_size), None
+        return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None
+
     return _impl
 
+
 def _adaptive_max_pool_3d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -672,6 +725,7 @@ def _adaptive_max_pool_3d():
 
     return _impl
 
+
 def _adaptive_avg_pool_3d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -680,6 +734,7 @@ def _adaptive_avg_pool_3d():
 
     return _impl
 
+
 def _maxpool_2d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -695,14 +750,18 @@ def _maxpool_2d():
             raise NotImplementedError(msg)
 
         return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
+
     return _impl
 
+
 def _maxpool_2d_with_indices():
     def _impl(inputs, input_types):
         # returns dummy indices too
         return _maxpool_2d()(inputs, input_types), None
+
     return _impl
 
+
 def _maxpool_1d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -718,8 +777,10 @@ def _maxpool_1d():
             raise NotImplementedError(msg)
 
         return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
+
     return _impl
 
+
 def _maxpool_3d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -733,21 +794,23 @@ def _maxpool_3d():
             msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation))
             raise NotImplementedError(msg)
 
-        return _op.nn.max_pool3d(data,
-                                 pool_size=pool_size,
-                                 strides=strides,
-                                 padding=padding,
-                                 ceil_mode=ceil_mode)
+        return _op.nn.max_pool3d(
+            data, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode
+        )
+
     return _impl
 
+
 def _hardtanh():
     def _impl(inputs, input_types):
         a = inputs[0]
         tanh_min = float(inputs[1])
         tanh_max = float(inputs[2])
         return _op.tensor.clip(a, tanh_min, tanh_max)
+
     return _impl
 
+
 def _convolution():
     def _impl(inputs, input_types):
         # Use transpose or normal
@@ -790,9 +853,9 @@ def _convolution():
         use_bias = isinstance(bias, _expr.Expr)
 
         if len(kernel_size) == 1:
-            strides = (1, ) + strides
-            padding = (0, ) + padding
-            dilation = (1, ) + dilation
+            strides = (1,) + strides
+            padding = (0,) + padding
+            dilation = (1,) + dilation
 
         if use_transpose:
             if len(kernel_size) == 3:
@@ -816,20 +879,20 @@ def _convolution():
             data = _op.expand_dims(data, axis=2)
             weight = _op.expand_dims(weight, axis=2)
 
-        conv_out = conv_op(data,
-                           weight,
-                           strides=strides,
-                           padding=padding,
-                           dilation=dilation,
-                           groups=groups,
-                           channels=channels,
-                           kernel_size=[1] + kernel_size \
-                                        if len(kernel_size) == 1 \
-                                        else kernel_size,
-                           data_layout=data_layout,
-                           kernel_layout=kernel_layout,
-                           out_layout="",
-                           out_dtype="")
+        conv_out = conv_op(
+            data,
+            weight,
+            strides=strides,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            channels=channels,
+            kernel_size=[1] + kernel_size if len(kernel_size) == 1 else kernel_size,
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+            out_layout="",
+            out_dtype="",
+        )
         if use_bias:
             res = _op.nn.bias_add(conv_out, bias)
         else:
@@ -840,6 +903,7 @@ def _convolution():
 
     return _impl
 
+
 def _softmax():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -848,20 +912,26 @@ def _softmax():
             axis = int(axis)
 
         return _op.nn.softmax(data, axis=axis)
+
     return _impl
 
+
 def _threshold():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.nn.relu(data)
+
     return _impl
 
+
 def _contiguous():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.tensor.copy(data)
+
     return _impl
 
+
 def _batch_norm():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -887,17 +957,21 @@ def _batch_norm():
         moving_var = inputs[4]
         epsilon = float(inputs[7])
 
-        return _op.nn.batch_norm(data,
-                                 gamma,
-                                 beta,
-                                 moving_mean,
-                                 moving_var,
-                                 axis=1,
-                                 epsilon=epsilon,
-                                 center=center,
-                                 scale=scale)[0]
+        return _op.nn.batch_norm(
+            data,
+            gamma,
+            beta,
+            moving_mean,
+            moving_var,
+            axis=1,
+            epsilon=epsilon,
+            center=center,
+            scale=scale,
+        )[0]
+
     return _impl
 
+
 def _instance_norm():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -919,17 +993,16 @@ def _instance_norm():
             beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
 
         epsilon = float(inputs[7])
-        return _op.nn.instance_norm(data,
-                                    gamma,
-                                    beta,
-                                    axis=1,
-                                    epsilon=epsilon,
-                                    center=center,
-                                    scale=scale)
+        return _op.nn.instance_norm(
+            data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale
+        )
+
     return _impl
 
+
 def _get_dims(data):
     import torch
+
     if isinstance(data, _expr.Expr):
         dims = _infer_shape(data)
     elif isinstance(data, list):
@@ -941,19 +1014,23 @@ def _get_dims(data):
         raise AssertionError(msg)
     return dims
 
+
 def _layer_norm():
     def _impl(inputs, input_types):
         data = inputs[0]
         ndims = len(_get_dims(inputs[1]))
         assert ndims == 1, "Support only normalization over last one dimension."
 
-        return _op.nn.layer_norm(data,
-                                 gamma=inputs[2],
-                                 beta=inputs[3],
-                                 axis=-1,
-                                 epsilon=float(inputs[4]),
-                                 center=True,
-                                 scale=True)
+        return _op.nn.layer_norm(
+            data,
+            gamma=inputs[2],
+            beta=inputs[3],
+            axis=-1,
+            epsilon=float(inputs[4]),
+            center=True,
+            scale=True,
+        )
+
     return _impl
 
 
@@ -965,14 +1042,17 @@ def _group_norm():
         num_groups = inputs[1]
         epsilon = float(inputs[4])
 
-        return _op.nn.group_norm(data,
-                                 gamma=gamma,
-                                 beta=beta,
-                                 num_groups=num_groups,
-                                 axis=1,
-                                 epsilon=epsilon,
-                                 center=True,
-                                 scale=True)
+        return _op.nn.group_norm(
+            data,
+            gamma=gamma,
+            beta=beta,
+            num_groups=num_groups,
+            axis=1,
+            epsilon=epsilon,
+            center=True,
+            scale=True,
+        )
+
     return _impl
 
 
@@ -981,6 +1061,7 @@ def _transpose(prelude):
         data = inputs[0]
 
         import torch
+
         if isinstance(data, _expr.Expr):
             ndims = len(_infer_shape(data, prelude.mod))
         elif isinstance(data, list):
@@ -1012,6 +1093,7 @@ def _transpose(prelude):
         else:
             axes = inputs[1]
         return _op.transform.transpose(data, axes)
+
     return _impl
 
 
@@ -1060,6 +1142,7 @@ def _dense():
             return _op.nn.bias_add(dense_out, bias)
         else:
             return dense_out
+
     return _impl
 
 
@@ -1083,6 +1166,7 @@ def _size(prelude):
         if axis is not None:
             return shape[axis]
         return shape
+
     return _impl
 
 
@@ -1100,12 +1184,14 @@ def _numtotensor():
 
         arr = val * np.ones([]).astype(dtype)
         return arr
+
     return _impl
 
 
 def _tensortonum():
     def _impl(inputs, input_types):
         return inputs[0]
+
     return _impl
 
 
@@ -1127,6 +1213,7 @@ def _view():
                 new_shape[i] = np.asscalar(val.asnumpy())
 
         return _op.transform.reshape(data, new_shape)
+
     return _impl
 
 
@@ -1138,38 +1225,47 @@ def _reshape():
         else:
             assert isinstance(inputs[1], list)
             infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]]
-            new_shape = [np.asscalar(res.asnumpy().astype(np.int))
-                         for res in infer_res]
+            new_shape = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
         return _op.transform.reshape(data, new_shape)
+
     return _impl
 
+
 def _clone():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.tensor.copy(data)
+
     return _impl
 
+
 def _log_softmax():
     def _impl(inputs, input_types):
         data = inputs[0]
         axis = int(inputs[1])
         return _op.nn.log_softmax(data, axis)
+
     return _impl
 
+
 def _sigmoid():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.tensor.sigmoid(data)
+
     return _impl
 
+
 def _softplus():
     def _impl(inputs, input_types):
         data = inputs[0]
         dtype = input_types[0]
         beta = _expr.const(float(inputs[1]), dtype=dtype)
-        return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1., dtype=dtype)) / beta
+        return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
+
     return _impl
 
+
 def _avg_pool2d(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1181,12 +1277,14 @@ def _avg_pool2d(prelude):
         count_include_pad = int(inputs[5])
 
         def func(x):
-            return _op.nn.avg_pool2d(x,
-                                     pool_size=pool_size,
-                                     strides=strides,
-                                     padding=padding,
-                                     ceil_mode=ceil_mode,
-                                     count_include_pad=count_include_pad)
+            return _op.nn.avg_pool2d(
+                x,
+                pool_size=pool_size,
+                strides=strides,
+                padding=padding,
+                ceil_mode=ceil_mode,
+                count_include_pad=count_include_pad,
+            )
 
         if _is_quantized_tensor(data, prelude):
             return qnn_torch.apply_with_upcast(data, func)
@@ -1195,6 +1293,7 @@ def _avg_pool2d(prelude):
 
     return _impl
 
+
 def _avg_pool3d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1205,29 +1304,35 @@ def _avg_pool3d():
         ceil_mode = int(inputs[4])
         count_include_pad = int(inputs[5])
 
-        return _op.nn.avg_pool3d(data,
-                                 pool_size=pool_size,
-                                 strides=strides,
-                                 padding=padding,
-                                 ceil_mode=ceil_mode,
-                                 count_include_pad=count_include_pad)
+        return _op.nn.avg_pool3d(
+            data,
+            pool_size=pool_size,
+            strides=strides,
+            padding=padding,
+            ceil_mode=ceil_mode,
+            count_include_pad=count_include_pad,
+        )
+
     return _impl
 
+
 def _dropout():
     def _impl(inputs, input_types):
         data = inputs[0]
         rate = float(inputs[1])
 
         return _op.nn.dropout(data, rate)
+
     return _impl
 
+
 def _reduce(name):
     def _impl(inputs, input_types):
         data = inputs[0]
         axis = None
         keepdims = False
 
-        if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False
+        if len(inputs) > 2:  # default, torch have only data, axis=None, keepdims=False
             if isinstance(inputs[1], int):
                 axis = int(inputs[1])
             elif _is_int_seq(inputs[1]):
@@ -1240,6 +1345,7 @@ def _reduce(name):
 
     return _impl
 
+
 def _norm():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1258,10 +1364,11 @@ def _norm():
         else:
             reci_order = _expr.const(1.0 / order, dtype=dtype)
             order = _expr.const(order)
-            return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order),
-                                            axis=axis,
-                                            keepdims=keepdims),
-                             reci_order)
+            return _op.power(
+                _op.reduce.sum(_op.power(_op.abs(data), order), axis=axis, keepdims=keepdims),
+                reci_order,
+            )
+
     return _impl
 
 
@@ -1295,6 +1402,7 @@ def _std():
 
     return _impl
 
+
 def _variance():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1311,6 +1419,7 @@ def _variance():
 
     return _impl
 
+
 def _mean(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1336,13 +1445,13 @@ def _mean(prelude):
             assert len(inputs) == 6, "Input quant param not found in op inputs"
             input_scale = _expr.const(inputs[4])
             input_zero_point = _expr.const(inputs[5])
-            return qnn_torch.quantized_mean(data, input_scale,
-                                            input_zero_point, func)
+            return qnn_torch.quantized_mean(data, input_scale, input_zero_point, func)
 
         return func(data)
 
     return _impl
 
+
 def _chunk(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1372,10 +1481,7 @@ def _chunk(prelude):
             end[axis] = i + unif_size
             stride = [1] * len(shape)
 
-            chunk_out = _op.transform.strided_slice(data,
-                                                    begin=begin,
-                                                    end=end,
-                                                    strides=stride)
+            chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
             chunks.append(chunk_out)
 
         if dim % num_chunks:
@@ -1385,15 +1491,14 @@ def _chunk(prelude):
             end[axis] = dim
             stride = [1] * len(shape)
 
-            chunk_out = _op.transform.strided_slice(data,
-                                                    begin=begin,
-                                                    end=end,
-                                                    strides=stride)
+            chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
             chunks.append(chunk_out)
 
         return chunks
+
     return _impl
 
+
 def _matmul(prelude):
     def _impl(inputs, input_types):
 
@@ -1464,6 +1569,7 @@ def _expand():
             out = _op.tensor.concatenate(data, i)
 
         return out
+
     return _impl
 
 
@@ -1472,18 +1578,24 @@ def _int():
         if isinstance(inputs[0], _expr.Expr):
             return inputs[0]
         return int(inputs[0])
+
     return _impl
 
+
 def _identity():
     def _impl(inputs, input_types):
         return inputs[0]
+
     return _impl
 
+
 def _none():
     def _impl(inputs, input_types):
         return None
+
     return _impl
 
+
 def _pad(mode):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1507,7 +1619,7 @@ def _pad(mode):
             paddings[-6] = pad_list[4]
 
         # group into tuple of 2 ints
-        paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]
+        paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)]
 
         if mode == "constant":
             return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode)
@@ -1523,6 +1635,7 @@ def _clamp():
         amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
         amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
         return _op.clip(data, amin, amax)
+
     return _impl
 
 
@@ -1534,12 +1647,7 @@ def _to():
         # special handling for aten::to(data, 6, _, _, _) case
         # 6 means dtype = float
         # this happens when converting upsampling with scale factor
-        cast_func = {
-            6: float,
-            7: float,
-            3: int,
-            4: int
-        }
+        cast_func = {6: float, 7: float, 3: int, 4: int}
         cast_func_expr = {
             6: lambda x: _op.cast(x, "float32"),
             7: lambda x: _op.cast(x, "float64"),
@@ -1554,6 +1662,7 @@ def _to():
 
     return _impl
 
+
 def _upsample(method, prelude):
     def _impl(inputs, input_types):
         if isinstance(inputs[1], _expr.Var):
@@ -1562,8 +1671,7 @@ def _upsample(method, prelude):
             out_size = inputs[1]
         elif isinstance(inputs[1], list):
             infer_res = [_infer_value(size, {}) for size in inputs[1]]
-            out_size = [np.asscalar(res.asnumpy().astype(np.int))
-                        for res in infer_res]
+            out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
 
         data = inputs[0]
 
@@ -1596,8 +1704,7 @@ def _upsample(method, prelude):
 
             input_scale = _expr.const(inputs[-2])
             input_zero_point = _expr.const(inputs[-1])
-            return qnn_torch.quantized_upsample(data, input_scale,
-                                                input_zero_point, func)
+            return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)
         return func(data)
 
     return _impl
@@ -1611,8 +1718,7 @@ def _upsample3d(method):
             out_size = inputs[1]
         elif isinstance(inputs[1], list):
             infer_res = [_infer_value(size, {}) for size in inputs[1]]
-            out_size = [np.asscalar(res.asnumpy().astype(np.int))
-                        for res in infer_res]
+            out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
 
         data = inputs[0]
 
@@ -1629,6 +1735,7 @@ def _upsample3d(method):
             coord_trans = "half_pixel"
 
         return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans)
+
     return _impl
 
 
@@ -1639,6 +1746,7 @@ def _expand_as():
         msg = "aten::expand_as(...) found, assume it is part of broadcast op"
         logging.warning(msg)
         return inputs[0]
+
     return _impl
 
 
@@ -1646,18 +1754,22 @@ def _Bool():
     def _impl(inputs, input_types):
         assert len(inputs) == 1
         return inputs[0]
+
     return _impl
 
+
 def _Float():
     def _impl(inputs, input_types):
         assert len(inputs) == 1
         return _op.cast(inputs[0], "float32")
+
     return _impl
 
 
 def _mm():
     def _impl(inputs, input_types):
         return _op.nn.dense(inputs[0], inputs[1])
+
     return _impl
 
 
@@ -1672,6 +1784,7 @@ def _bitwise_not():
             out = _op.bitwise_not(_op.cast(data, "int"))
 
         return out
+
     return _impl
 
 
@@ -1683,6 +1796,7 @@ def _bitwise_xor():
         rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
 
         return _op.bitwise_xor(lhs, rhs)
+
     return _impl
 
 
@@ -1691,6 +1805,7 @@ def _logical_not():
         data = inputs[0]
 
         return _op.logical_not(_op.cast(data, "bool"))
+
     return _impl
 
 
@@ -1700,18 +1815,21 @@ def _logical_xor():
         rhs = _op.cast(inputs[1], "bool")
 
         return _op.logical_xor(lhs, rhs)
+
     return _impl
 
 
 def _list_getitem(prelude):
     def _impl(inputs, input_types):
         return prelude.nth(inputs[0], _wrap_const(inputs[1]))
+
     return _impl
 
 
 def _list_len(prelude):
     def _impl(inputs, input_types):
         return prelude.length(inputs[0])
+
     return _impl
 
 
@@ -1720,6 +1838,7 @@ def _type_as():
         assert len(inputs) == 2
         assert len(input_types) == 2
         return _op.cast(inputs[0], input_types[1])
+
     return _impl
 
 
@@ -1730,6 +1849,7 @@ def _gather():
         indices = inputs[2]
 
         return _op.gather(data, axis, indices)
+
     return _impl
 
 
@@ -1739,6 +1859,7 @@ def _add(prelude):
         if input_types[0] == "ListType":
             return prelude.concat(inputs[0], inputs[1])
         return _elemwise("add")(inputs, input_types)
+
     return _impl
 
 
@@ -1749,13 +1870,14 @@ def _tensor_array_stack(prelude):
         tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
 
         stacked_shape = (Any(),) + shape
-        stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
+        stack = prelude.get_var_static("tensor_array_stack", "float32", shape)
         stacked = stack(tensor_array)
 
         static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
         static_tensor_array_ops.register()
-        get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
+        get_tensor = prelude.get_var_static("tensor_get_data", "float32", stacked_shape)
         return get_tensor(stacked)
+
     return _impl
 
 
@@ -1773,6 +1895,7 @@ def _stack(prelude):
             msg = "The input list is expected to be List ADT"
             assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg
             return _tensor_array_stack(prelude)(inputs, input_types)
+
     return _impl
 
 
@@ -1785,6 +1908,7 @@ def _rsub():
 
         # note: rsub means data0 and data1 swap places
         return get_relay_op("subtract")(data1, alpha * data0)
+
     return _impl
 
 
@@ -1793,23 +1917,25 @@ def _embedding():
         weight = inputs[0]
         indices = inputs[1]
 
-        return _op.take(weight, indices.astype('int32'), axis=0)
+        return _op.take(weight, indices.astype("int32"), axis=0)
+
     return _impl
 
 
 def _one_hot():
     def _impl(inputs, input_types):
-        indices = inputs[0].astype('int32')
+        indices = inputs[0].astype("int32")
         num_classes = inputs[1]
         if num_classes == -1:
             msg = "Inferring the number of classes is not yet supported."
             raise NotImplementedError(msg)
 
-        dtype = 'int32'
+        dtype = "int32"
         on_value = tvm.relay.const(1.0, dtype)
         off_value = tvm.relay.const(0.0, dtype)
 
         return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype)
+
     return _impl
 
 
@@ -1818,6 +1944,7 @@ def _index():
         data = inputs[0]
         indices = inputs[1]
         return _op.adv_index([data] + indices)
+
     return _impl
 
 
@@ -1825,6 +1952,7 @@ def _meshgrid():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.meshgrid(data, indexing="ij")
+
     return _impl
 
 
@@ -1835,42 +1963,44 @@ def _nms(prelude):
         iou_threshold = inputs[2]
 
         # Generate data with shape (1, num_anchors, 5)
-        scores = AttrCvt(op_name="expand_dims",
-                         extras={'axis': -1, 'num_newaxis': 1})([scores], {})
+        scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})
 
         # Prepare input data for get_valid_counts
         data = _op.concatenate([scores, boxes], -1)
         data = _op.expand_dims(data, 0, 1)
         # Leverage get_valid_counts to sort the data and clear invalid boxes
-        ct, data, indices = get_relay_op('get_valid_counts')(data,
-                                                             score_threshold=-1.0,
-                                                             id_index=-1,
-                                                             score_index=0)
+        ct, data, indices = get_relay_op("get_valid_counts")(
+            data, score_threshold=-1.0, id_index=-1, score_index=0
+        )
 
         # Perform Non-Maximum Suppression,
         # PyTorch NMS doesn't have parameter top_k and max_output_size
         score_index = 0
         top_k = max_out_size = -1
-        nms_ret = get_relay_op('non_max_suppression')(data=data,
-                                                      valid_count=ct,
-                                                      indices=indices,
-                                                      max_output_size=max_out_size,
-                                                      iou_threshold=iou_threshold,
-                                                      force_suppress=True,
-                                                      top_k=top_k,
-                                                      coord_start=1,
-                                                      score_index=score_index,
-                                                      id_index=-1,
-                                                      return_indices=True,
-                                                      invalid_to_bottom=False)
+        nms_ret = get_relay_op("non_max_suppression")(
+            data=data,
+            valid_count=ct,
+            indices=indices,
+            max_output_size=max_out_size,
+            iou_threshold=iou_threshold,
+            force_suppress=True,
+            top_k=top_k,
+            coord_start=1,
+            score_index=score_index,
+            id_index=-1,
+            return_indices=True,
+            invalid_to_bottom=False,
+        )
 
         # squeeze the two outputs of nms for strided_slice
         size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
         data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
 
         # strided slice to get the dynamic result
-        return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
-                                             end=size, slice_mode="size")
+        return get_relay_op("strided_slice")(
+            data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
+        )
+
     return _impl
 
 
@@ -1882,12 +2012,14 @@ def _logsumexp():
         # dim is output of prim::ListConstruct, even if it is int in python code
         assert isinstance(dim_list, list), "dim is expected to be a list"
         return _op.logsumexp(data[0], axis=dim_list, keepdims=keepdim)
+
     return _impl
 
 
 def _pytorch_result_type(dtypes, non_tensor_inputs):
     """This promotes TVM dtypes like PyTorch would"""
     import torch
+
     dtype_map = {
         "float64": torch.float64,
         "float32": torch.float32,
@@ -1898,24 +2030,30 @@ def _pytorch_result_type(dtypes, non_tensor_inputs):
         "int16": torch.int16,
         "int8": torch.int8,
         "uint8": torch.uint8,
-        "bool": torch.bool
-        }
+        "bool": torch.bool,
+    }
     if len(dtypes) > 0:
         result_type = dtypes[0]
         for dt in dtypes[1:]:
-            if dt != result_type: # we don't want to work with same types as we
-                                  # don't do quantized here (which cannot be promoted?)
-                result_type = _convert_data_type(str(torch.result_type(
-                    torch.zeros((), dtype=dtype_map[result_type]),
-                    torch.zeros((), dtype=dtype_map[dt]))))
+            if dt != result_type:  # we don't want to work with same types as we
+                # don't do quantized here (which cannot be promoted?)
+                result_type = _convert_data_type(
+                    str(
+                        torch.result_type(
+                            torch.zeros((), dtype=dtype_map[result_type]),
+                            torch.zeros((), dtype=dtype_map[dt]),
+                        )
+                    )
+                )
     else:
         result_type = "bool"  # this is the smallest type...
     for inp in non_tensor_inputs:
         result_type = _convert_data_type(
-            str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]),
-                                  inp)))
+            str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]), inp))
+        )
     return result_type
 
+
 def _pytorch_promote_types(inputs, dtypes):
     """This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
     tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
@@ -1931,28 +2069,32 @@ def _pytorch_promote_types(inputs, dtypes):
             results.append(_op.cast(inp, result_type))
     return results
 
+
 # Helper functions for operator implementation
 def _convert_dtype_value(val):
     """converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
-    convert_torch_dtype_map = {7:"torch.float64",
-                               6:"torch.float32",
-                               5:"torch.float16",
-                               4:"torch.int64",
-                               3:"torch.int32",
-                               2:"torch.int16",
-                               1:"torch.int8",
-                               0:"torch.unit8",
-                               None:"torch.int64"} # Default is torch.int64
+    convert_torch_dtype_map = {
+        7: "torch.float64",
+        6: "torch.float32",
+        5: "torch.float16",
+        4: "torch.int64",
+        3: "torch.int32",
+        2: "torch.int16",
+        1: "torch.int8",
+        0: "torch.unit8",
+        None: "torch.int64",
+    }  # Default is torch.int64
     if val in convert_torch_dtype_map:
         return _convert_data_type(convert_torch_dtype_map[val])
     else:
         msg = "Torch data type value %d is not handled yet." % (val)
         raise NotImplementedError(msg)
 
+
 def _convert_data_type(input_type, default_dtype=None):
     """converts the PyTorch scalar type input_type to a TVM dtype.
-       optionally, default_dtype can be a TVM dtype that is used
-       if input_type is None (but not when it is unknown)"""
+    optionally, default_dtype can be a TVM dtype that is used
+    if input_type is None (but not when it is unknown)"""
     if input_type is None and default_dtype is not None:
         return default_dtype
 
@@ -1985,9 +2127,10 @@ def _convert_data_type(input_type, default_dtype=None):
         raise NotImplementedError("input_type {} is not handled yet".format(input_type))
     return "float32"  # Never reached
 
+
 def _create_typed_const(data, dtype):
     """create a (scalar) constant of given value and dtype.
-       dtype should be a TVM dtype"""
+    dtype should be a TVM dtype"""
 
     if dtype == "float64":
         typed_data = _expr.const(np.float64(data), dtype=dtype)
@@ -2009,184 +2152,186 @@ def _create_typed_const(data, dtype):
         raise NotImplementedError("input_type {} is not handled yet".format(dtype))
     return typed_data
 
+
 def _wrap_const(c):
     if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
         return _expr.const(c)
     return c
 
+
 # Operator mappings
 def _get_convert_map(prelude, default_dtype):
     convert_map = {
-        "aten::device"                          : _none(),
-        "prim::device"                          : _none(),
-        "aten::sub"                             : _elemwise("subtract"),
-        "aten::sub_"                            : _elemwise("subtract"),
-        "aten::max"                             : _max(),
-        "aten::min"                             : _min(),
-        "aten::mul"                             : _elemwise("multiply"),
-        "aten::mul_"                            : _elemwise("multiply"),
-        "aten::pow"                             : _elemwise("power"),
-        "aten::arange"                          : _arange(),
-        "aten::meshgrid"                        : _meshgrid(),
-        "aten::div"                             : _elemwise("divide"),
-        "aten::div_"                            : _elemwise("divide"),
-        "aten::floor_divide"                    : _elemwise("floor_divide"),
-        "aten::addcdiv"                         : _addcdiv(),
-        "aten::addcmul"                         : _addcmul(),
-        "aten::ones"                            : _ones(),
-        "aten::ones_like"                       : _ones_like(),
-        "aten::zeros"                           : _zeros(),
-        "aten::zeros_like"                      : _zeros_like(),
-        "aten::full"                            : _full(default_dtype),
-        "aten::full_like"                       : _full_like(),
-        "aten::linspace"                        : _linspace(),
-        "aten::reciprocal"                      : _reciprocal(),
-        "aten::repeat"                          : _repeat(),
-        "aten::repeat_interleave"               : _repeat_interleave(),
-        "aten::to"                              : _to(),
-        "aten::squeeze"                         : _squeeze(),
-        "aten::unsqueeze"                       : _unsqueeze(),
-        "aten::cat"                             : _concatenate(prelude),
-        "aten::slice"                           : _slice(),
-        "aten::split"                           : _split(),
-        "aten::split_with_sizes"                : _split_with_sizes(),
-        "aten::select"                          : _select(),
-        "aten::take"                            : _take(),
-        "aten::where"                           : _where(),
-        "aten::topk"                            : _topk(),
-        "aten::relu"                            : _relu(prelude),
-        "aten::relu_"                           : _relu(prelude),
-        "aten::prelu"                           : _prelu(),
-        "aten::leaky_relu"                      : _leaky_relu(),
-        "aten::leaky_relu_"                     : _leaky_relu(),
-        "aten::elu"                             : _elu(),
-        "aten::elu_"                            : _elu(),
-        "aten::celu"                            : _celu(),
-        "aten::gelu"                            : _gelu(),
-        "aten::selu"                            : _selu(),
-        "aten::log_sigmoid"                     : _log_sigmoid(),
-        "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(prelude),
-        "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
-        "aten::max_pool2d"                      : _maxpool_2d(),
-        "aten::max_pool2d_with_indices"         : _maxpool_2d_with_indices(),
-        "aten::max_pool1d"                      : _maxpool_1d(),
-        "aten::max_pool3d"                      : _maxpool_3d(),
-        "aten::hardtanh"                        : _hardtanh(),
-        "aten::hardtanh_"                       : _hardtanh(),
-        "aten::_convolution"                    : _convolution(),
-        "aten::softmax"                         : _softmax(),
-        "aten::threshold"                       : _threshold(),
-        "aten::threshold_"                      : _threshold(),
-        "aten::contiguous"                      : _contiguous(),
-        "aten::batch_norm"                      : _batch_norm(),
-        "aten::instance_norm"                   : _instance_norm(),
-        "aten::layer_norm"                      : _layer_norm(),
-        "aten::group_norm"                      : _group_norm(),
-        "aten::transpose"                       : _transpose(prelude),
-        "aten::transpose_"                      : _transpose(prelude),
-        "aten::t"                               : _transpose(prelude),
-        "aten::flatten"                         : _flatten(),
-        "aten::addmm"                           : _dense(),
-        "aten::size"                            : _size(prelude),
-        "aten::view"                            : _view(),
-        "aten::reshape"                         : _reshape(),
-        "aten::clone"                           : _clone(),
-        "aten::log_softmax"                     : _log_softmax(),
-        "aten::sigmoid"                         : _sigmoid(),
-        "aten::softplus"                        : _softplus(),
-        "aten::avg_pool2d"                      : _avg_pool2d(prelude),
-        "aten::avg_pool3d"                      : _avg_pool3d(),
-        "aten::dropout"                         : _dropout(),
-        "aten::dropout_"                        : _dropout(),
-        "aten::feature_dropout"                 : _dropout(),
-        "aten::alpha_dropout"                   : _dropout(),
-        "aten::mean"                            : _mean(prelude),
-        "aten::chunk"                           : _chunk(prelude),
-        "aten::matmul"                          : _matmul(prelude),
-        "aten::bmm"                             : _matmul(prelude),
-        "aten::expand"                          : _expand(),
-        "aten::Int"                             : _int(),
-        "prim::NumToTensor"                     : _numtotensor(),
-        "prim::ImplicitTensorToNum"             : _tensortonum(),
-        "aten::ScalarImplicit"                  : _tensortonum(),
-        "aten::constant_pad_nd"                 : _pad("constant"),
-        "aten::reflection_pad1d"                : _pad("reflect"),
-        "aten::reflection_pad2d"                : _pad("reflect"),
-        "aten::replication_pad1d"               : _pad("edge"),
-        "aten::replication_pad2d"               : _pad("edge"),
-        "aten::replication_pad3d"               : _pad("edge"),
-        "aten::permute"                         : _transpose(prelude),
-        "aten::sum"                             : _reduce("sum"),
-        "aten::prod"                            : _reduce("prod"),
-        "aten::argmin"                          : _reduce("argmin"),
-        "aten::argmax"                          : _reduce("argmax"),
-        "aten::norm"                            : _norm(),
-        "aten::frobenius_norm"                  : _frobenius_norm(),
-        "aten::std"                             : _std(),
-        "aten::var"                             : _variance(),
-        "aten::abs"                             : _unary("abs"),
-        "aten::neg"                             : _unary("negative"),
-        "aten::cos"                             : _unary("cos"),
-        "aten::cosh"                            : _unary("cosh"),
-        "aten::sin"                             : _unary("sin"),
-        "aten::sinh"                            : _unary("sinh"),
-        "aten::tan"                             : _unary("tan"),
-        "aten::tanh"                            : _unary("tanh"),
-        "aten::acos"                            : _unary("acos"),
-        "aten::asin"                            : _unary("asin"),
-        "aten::atan"                            : _unary("atan"),
-        "aten::log"                             : _unary("log"),
-        "aten::log2"                            : _unary("log2"),
-        "aten::log10"                           : _unary("log10"),
-        "aten::log1p"                           : _log1p(),
-        "aten::exp"                             : _unary("exp"),
-        "aten::erf"                             : _unary("erf"),
-        "aten::trunc"                           : _unary("trunc"),
-        "aten::sign"                            : _unary("sign"),
-        "aten::sqrt"                            : _unary("sqrt"),
-        "aten::rsqrt"                           : _unary("rsqrt"),
-        "aten::ceil"                            : _unary("ceil"),
-        "aten::floor"                           : _unary("floor"),
-        "aten::round"                           : _unary("round"),
-        "aten::isfinite"                        : _unary("isfinite"),
-        "aten::isinf"                           : _unary("isinf"),
-        "aten::isnan"                           : _unary("isnan"),
-        "aten::clamp"                           : _clamp(),
-        "aten::detach"                          : _identity(),
-        "aten::upsample_bilinear2d"             : _upsample("bilinear", prelude),
-        "aten::upsample_nearest2d"              : _upsample("nearest_neighbor", prelude),
-        "aten::upsample_trilinear3d"            : _upsample3d("trilinear"),
-        "aten::upsample_nearest3d"              : _upsample3d("nearest_neighbor"),
-        "aten::expand_as"                       : _expand_as(),
-        "aten::lt"                              : _elemwise("less"),
-        "aten::gt"                              : _elemwise("greater"),
-        "aten::le"                              : _elemwise("less_equal"),
-        "aten::ge"                              : _elemwise("greater_equal"),
-        "aten::ne"                              : _elemwise("not_equal"),
-        "aten::eq"                              : _elemwise("equal"),
-        "aten::logical_not"                     : _logical_not(),
-        "aten::logical_xor"                     : _logical_xor(),
-        "aten::bitwise_not"                     : _bitwise_not(),
-        "aten::bitwise_xor"                     : _bitwise_xor(),
-        "aten::Bool"                            : _Bool(),
-        "aten::Float"                           : _Float(),
-        "aten::adaptive_avg_pool3d"             : _adaptive_avg_pool_3d(),
-        "aten::adaptive_max_pool3d"             : _adaptive_max_pool_3d(),
-        "aten::rsub"                            : _rsub(),
-        "aten::embedding"                       : _embedding(),
-        "aten::one_hot"                         : _one_hot(),
-        "aten::mm"                              : _matmul(prelude),
-        "aten::add"                             : _add(prelude),
-        "aten::add_"                            : _add(prelude),
-        "aten::stack"                           : _stack(prelude),
-        "aten::__getitem__"                     : _list_getitem(prelude),
-        "aten::len"                             : _list_len(prelude),
-        "aten::type_as"                         : _type_as(),
-        "aten::gather"                          : _gather(),
-        "aten::index_select"                    : _select(),
-        "aten::index"                           : _index(),
-        "torchvision::nms"                      : _nms(prelude),
-        "aten::logsumexp"                       : _logsumexp()
+        "aten::device": _none(),
+        "prim::device": _none(),
+        "aten::sub": _elemwise("subtract"),
+        "aten::sub_": _elemwise("subtract"),
+        "aten::max": _max(),
+        "aten::min": _min(),
+        "aten::mul": _elemwise("multiply"),
+        "aten::mul_": _elemwise("multiply"),
+        "aten::pow": _elemwise("power"),
+        "aten::arange": _arange(),
+        "aten::meshgrid": _meshgrid(),
+        "aten::div": _elemwise("divide"),
+        "aten::div_": _elemwise("divide"),
+        "aten::floor_divide": _elemwise("floor_divide"),
+        "aten::addcdiv": _addcdiv(),
+        "aten::addcmul": _addcmul(),
+        "aten::ones": _ones(),
+        "aten::ones_like": _ones_like(),
+        "aten::zeros": _zeros(),
+        "aten::zeros_like": _zeros_like(),
+        "aten::full": _full(default_dtype),
+        "aten::full_like": _full_like(),
+        "aten::linspace": _linspace(),
+        "aten::reciprocal": _reciprocal(),
+        "aten::repeat": _repeat(),
+        "aten::repeat_interleave": _repeat_interleave(),
+        "aten::to": _to(),
+        "aten::squeeze": _squeeze(),
+        "aten::unsqueeze": _unsqueeze(),
+        "aten::cat": _concatenate(prelude),
+        "aten::slice": _slice(),
+        "aten::split": _split(),
+        "aten::split_with_sizes": _split_with_sizes(),
+        "aten::select": _select(),
+        "aten::take": _take(),
+        "aten::where": _where(),
+        "aten::topk": _topk(),
+        "aten::relu": _relu(prelude),
+        "aten::relu_": _relu(prelude),
+        "aten::prelu": _prelu(),
+        "aten::leaky_relu": _leaky_relu(),
+        "aten::leaky_relu_": _leaky_relu(),
+        "aten::elu": _elu(),
+        "aten::elu_": _elu(),
+        "aten::celu": _celu(),
+        "aten::gelu": _gelu(),
+        "aten::selu": _selu(),
+        "aten::log_sigmoid": _log_sigmoid(),
+        "aten::adaptive_avg_pool2d": _adaptive_avg_pool_2d(prelude),
+        "aten::adaptive_max_pool2d": _adaptive_max_pool_2d(),
+        "aten::max_pool2d": _maxpool_2d(),
+        "aten::max_pool2d_with_indices": _maxpool_2d_with_indices(),
+        "aten::max_pool1d": _maxpool_1d(),
+        "aten::max_pool3d": _maxpool_3d(),
+        "aten::hardtanh": _hardtanh(),
+        "aten::hardtanh_": _hardtanh(),
+        "aten::_convolution": _convolution(),
+        "aten::softmax": _softmax(),
+        "aten::threshold": _threshold(),
+        "aten::threshold_": _threshold(),
+        "aten::contiguous": _contiguous(),
+        "aten::batch_norm": _batch_norm(),
+        "aten::instance_norm": _instance_norm(),
+        "aten::layer_norm": _layer_norm(),
+        "aten::group_norm": _group_norm(),
+        "aten::transpose": _transpose(prelude),
+        "aten::transpose_": _transpose(prelude),
+        "aten::t": _transpose(prelude),
+        "aten::flatten": _flatten(),
+        "aten::addmm": _dense(),
+        "aten::size": _size(prelude),
+        "aten::view": _view(),
+        "aten::reshape": _reshape(),
+        "aten::clone": _clone(),
+        "aten::log_softmax": _log_softmax(),
+        "aten::sigmoid": _sigmoid(),
+        "aten::softplus": _softplus(),
+        "aten::avg_pool2d": _avg_pool2d(prelude),
+        "aten::avg_pool3d": _avg_pool3d(),
+        "aten::dropout": _dropout(),
+        "aten::dropout_": _dropout(),
+        "aten::feature_dropout": _dropout(),
+        "aten::alpha_dropout": _dropout(),
+        "aten::mean": _mean(prelude),
+        "aten::chunk": _chunk(prelude),
+        "aten::matmul": _matmul(prelude),
+        "aten::bmm": _matmul(prelude),
+        "aten::expand": _expand(),
+        "aten::Int": _int(),
+        "prim::NumToTensor": _numtotensor(),
+        "prim::ImplicitTensorToNum": _tensortonum(),
+        "aten::ScalarImplicit": _tensortonum(),
+        "aten::constant_pad_nd": _pad("constant"),
+        "aten::reflection_pad1d": _pad("reflect"),
+        "aten::reflection_pad2d": _pad("reflect"),
+        "aten::replication_pad1d": _pad("edge"),
+        "aten::replication_pad2d": _pad("edge"),
+        "aten::replication_pad3d": _pad("edge"),
+        "aten::permute": _transpose(prelude),
+        "aten::sum": _reduce("sum"),
+        "aten::prod": _reduce("prod"),
+        "aten::argmin": _reduce("argmin"),
+        "aten::argmax": _reduce("argmax"),
+        "aten::norm": _norm(),
+        "aten::frobenius_norm": _frobenius_norm(),
+        "aten::std": _std(),
+        "aten::var": _variance(),
+        "aten::abs": _unary("abs"),
+        "aten::neg": _unary("negative"),
+        "aten::cos": _unary("cos"),
+        "aten::cosh": _unary("cosh"),
+        "aten::sin": _unary("sin"),
+        "aten::sinh": _unary("sinh"),
+        "aten::tan": _unary("tan"),
+        "aten::tanh": _unary("tanh"),
+        "aten::acos": _unary("acos"),
+        "aten::asin": _unary("asin"),
+        "aten::atan": _unary("atan"),
+        "aten::log": _unary("log"),
+        "aten::log2": _unary("log2"),
+        "aten::log10": _unary("log10"),
+        "aten::log1p": _log1p(),
+        "aten::exp": _unary("exp"),
+        "aten::erf": _unary("erf"),
+        "aten::trunc": _unary("trunc"),
+        "aten::sign": _unary("sign"),
+        "aten::sqrt": _unary("sqrt"),
+        "aten::rsqrt": _unary("rsqrt"),
+        "aten::ceil": _unary("ceil"),
+        "aten::floor": _unary("floor"),
+        "aten::round": _unary("round"),
+        "aten::isfinite": _unary("isfinite"),
+        "aten::isinf": _unary("isinf"),
+        "aten::isnan": _unary("isnan"),
+        "aten::clamp": _clamp(),
+        "aten::detach": _identity(),
+        "aten::upsample_bilinear2d": _upsample("bilinear", prelude),
+        "aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude),
+        "aten::upsample_trilinear3d": _upsample3d("trilinear"),
+        "aten::upsample_nearest3d": _upsample3d("nearest_neighbor"),
+        "aten::expand_as": _expand_as(),
+        "aten::lt": _elemwise("less"),
+        "aten::gt": _elemwise("greater"),
+        "aten::le": _elemwise("less_equal"),
+        "aten::ge": _elemwise("greater_equal"),
+        "aten::ne": _elemwise("not_equal"),
+        "aten::eq": _elemwise("equal"),
+        "aten::logical_not": _logical_not(),
+        "aten::logical_xor": _logical_xor(),
+        "aten::bitwise_not": _bitwise_not(),
+        "aten::bitwise_xor": _bitwise_xor(),
+        "aten::Bool": _Bool(),
+        "aten::Float": _Float(),
+        "aten::adaptive_avg_pool3d": _adaptive_avg_pool_3d(),
+        "aten::adaptive_max_pool3d": _adaptive_max_pool_3d(),
+        "aten::rsub": _rsub(),
+        "aten::embedding": _embedding(),
+        "aten::one_hot": _one_hot(),
+        "aten::mm": _matmul(prelude),
+        "aten::add": _add(prelude),
+        "aten::add_": _add(prelude),
+        "aten::stack": _stack(prelude),
+        "aten::__getitem__": _list_getitem(prelude),
+        "aten::len": _list_len(prelude),
+        "aten::type_as": _type_as(),
+        "aten::gather": _gather(),
+        "aten::index_select": _select(),
+        "aten::index": _index(),
+        "torchvision::nms": _nms(prelude),
+        "aten::logsumexp": _logsumexp(),
     }
     return convert_map
 
@@ -2194,6 +2339,7 @@ def _get_convert_map(prelude, default_dtype):
 def _run_jit_passes(graph):
     """ The inline pass is necessary to unwrap prim::CallMethod """
     import torch
+
     torch._C._jit_pass_inline(graph)
 
 
@@ -2238,15 +2384,20 @@ def _get_users(node):
 
 def _report_missing_conversion(op_names, convert_map):
     """ Check if all ops in an input graph are supported by TVM """
-    known_ops = ["prim::Constant", "prim::GetAttr",
-                 "prim::ListConstruct", "prim::ListUnpack",
-                 "prim::TupleConstruct", "prim::TupleUnpack",
-                 "prim::If", "prim::Loop"]
+    known_ops = [
+        "prim::Constant",
+        "prim::GetAttr",
+        "prim::ListConstruct",
+        "prim::ListUnpack",
+        "prim::TupleConstruct",
+        "prim::TupleUnpack",
+        "prim::If",
+        "prim::Loop",
+    ]
     known_ops += list(convert_map.keys())
     known_ops += list(qnn_torch.convert_map.keys())
 
-    missing = [op_name for op_name in op_names
-               if op_name not in known_ops]
+    missing = [op_name for op_name in op_names if op_name not in known_ops]
 
     if missing:
         msg = "The following operators are not implemented: {}".format(missing)
@@ -2266,7 +2417,7 @@ def _getattr_full_name(getattrs):
 
 def _get_pytorch_value_type(typ, default_dtype="float32"):
     kind = typ.kind()
-    if kind == 'TensorType':
+    if kind == "TensorType":
         if typ.scalarType() is None:
             # Tensor's type can be unknown if we use torch.jit.script(...)
             # Defaults can be passed in, if not it is float32
@@ -2275,15 +2426,14 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):
         else:
             return _convert_data_type(typ.scalarType())
 
-    elif kind == 'ListType':
+    elif kind == "ListType":
         return "ListType"
-    elif kind in ['IntType', 'FloatType', 'BoolType',
-                  'StringType', 'OptionalType']:
+    elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]:
         pt_dtype = str(typ).lower()
-        dtype = pt_dtype if pt_dtype == 'OptionalType' else _convert_data_type(pt_dtype)
+        dtype = pt_dtype if pt_dtype == "OptionalType" else _convert_data_type(pt_dtype)
         return dtype
     else:
-        return 'UnsupportedType'
+        return "UnsupportedType"
 
 
 def _get_input_types(op_node, outputs, default_dtype="float32"):
@@ -2374,29 +2524,32 @@ def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_
 
     if len(graph_inputs) != len(input_shapes):
         msg = "PyTorch has {} inputs and input_shapes lists {}.".format(
-            len(graph_inputs), len(input_shapes))
+            len(graph_inputs), len(input_shapes)
+        )
         raise RuntimeError(msg)
 
     def get_relay_ty(ishape, pt_type):
-        if pt_type.kind() == 'TensorType':
+        if pt_type.kind() == "TensorType":
             if not (_is_int_seq(ishape) or len(ishape) == 0):
                 msg = "Shape for Tensors must be lists of ints"
                 raise RuntimeError(msg)
-            if ((pt_type.dim() is not None and pt_type.dim() != len(ishape)) or
-                    (pt_type.sizes() is not None
-                     and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)]))):
+            if (pt_type.dim() is not None and pt_type.dim() != len(ishape)) or (
+                pt_type.sizes() is not None
+                and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)])
+            ):
                 msg = "Shapes of input list and information in the graph do not match"
                 raise RuntimeError(msg)
             pt_dtype = pt_type.scalarType()
             dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
             return TensorType(ishape, dtype)
-        elif pt_type.kind() == 'TupleType':
+        elif pt_type.kind() == "TupleType":
             if not isinstance(ishape, tuple):
                 msg = "Shapes for tuples must be tuples"
                 raise RuntimeError(msg)
-            return TupleType([get_relay_ty(elem, pt_t)
-                              for elem, pt_t in zip(ishape, pt_type.elements())])
-        elif pt_type.kind() == 'ListType':
+            return TupleType(
+                [get_relay_ty(elem, pt_t) for elem, pt_t in zip(ishape, pt_type.elements())]
+            )
+        elif pt_type.kind() == "ListType":
             if not isinstance(ishape, list):
                 msg = "Shapes for lists must be lists"
                 raise RuntimeError(msg)
@@ -2406,7 +2559,7 @@ def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_
                 msg = "List elements need have identical types"
                 raise RuntimeError(msg)
             return prelude.l(elem_tys[0])
-        elif pt_type.kind() == 'OptionalType':
+        elif pt_type.kind() == "OptionalType":
             # we do not support None yet, so we fill in the type
             return get_relay_ty(ishape, pt_type.getElementType())
         # TODO: scalar inputs
@@ -2418,12 +2571,14 @@ def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_
         if not isinstance(inp, tuple):
             msg = "Graph input {} is not a tuple".format(num)
             raise RuntimeError(msg)
-        if (len(inp) != 2 or not isinstance(inp[0], str)):
+        if len(inp) != 2 or not isinstance(inp[0], str):
             msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
             raise RuntimeError(msg)
 
-    input_types = [(name, get_relay_ty(shape, gi.type()))
-                   for (name, shape), gi in zip(input_shapes, graph_inputs)]
+    input_types = [
+        (name, get_relay_ty(shape, gi.type()))
+        for (name, shape), gi in zip(input_shapes, graph_inputs)
+    ]
 
     ir_inputs = [i.debugName() for i in graph_inputs]
     for ir_input, (name, itype) in zip(ir_inputs, input_types):
@@ -2465,6 +2620,7 @@ def get_use_chains(root_node, terminate=lambda _: False):
     Track a chain of users of this node forward, returning a list of chains
     See get_attr_chains below for its usage
     """
+
     def concat_lists(lists):
         return itertools.chain.from_iterable(lists)
 
@@ -2480,7 +2636,7 @@ def get_use_chains(root_node, terminate=lambda _: False):
 
 
 def get_attr_chains(root_getattr_node):
-    """ Returns chains of attribute access starting from root_getattr_node
+    """Returns chains of attribute access starting from root_getattr_node
 
     For example, given attribute "block", as in "self.block" when "self" points
     to the top level torch.nn.Module, it returns lists of attribute "chains",
@@ -2491,6 +2647,7 @@ def get_attr_chains(root_getattr_node):
     and "self.block.0._packed_params" will return the parameters of the first
     submodule.
     """
+
     def terminate(users):
         next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
         return len(next_attrs) == 0
@@ -2529,8 +2686,7 @@ def convert_params(graph, state_dict):
                     var = vars_by_name[full_attr]
                 else:
                     torch_tensor = state_dict[full_attr]
-                    tensor, var = _get_tensor_and_var(torch_tensor,
-                                                      full_attr)
+                    tensor, var = _get_tensor_and_var(torch_tensor, full_attr)
                     param_tensors[full_attr] = tensor
                     vars_by_name[full_attr] = var
                 params[full_attr_node_name] = var
@@ -2542,24 +2698,28 @@ def convert_block(block, outputs, convert_map, prelude, default_dtype="float32")
     """ Translate Torch "Block", used for prim::If and prim::Loop """
     ops = _get_operator_nodes(block.nodes())
     ret_names = _get_input_names(block.returnNode())
-    return convert_operators(ops, outputs, ret_names, convert_map, prelude,
-                             default_dtype=default_dtype)
+    return convert_operators(
+        ops, outputs, ret_names, convert_map, prelude, default_dtype=default_dtype
+    )
 
 
 def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"):
     """ Translate Torch prim::If to Relay If """
     cond = outputs[if_node.inputsAt(0).debugName()]
     blocks = list(if_node.blocks())
-    true_branch = convert_block(blocks[0], outputs, convert_map, prelude,
-                                default_dtype=default_dtype)
-    false_branch = convert_block(blocks[1], outputs, convert_map, prelude,
-                                 default_dtype=default_dtype)
+    true_branch = convert_block(
+        blocks[0], outputs, convert_map, prelude, default_dtype=default_dtype
+    )
+    false_branch = convert_block(
+        blocks[1], outputs, convert_map, prelude, default_dtype=default_dtype
+    )
     assert len(true_branch) == 1 and len(false_branch) == 1
     return _expr.If(cond, true_branch[0], false_branch[0])
 
 
 def convert_loop(loop_node, outputs, convert_map, prelude):
     """ Translate Torch prim::Loop to Relay while_loop """
+
     def get_input(index):
         ivalue = loop_node.inputsAt(index)
         inode = ivalue.node()
@@ -2581,8 +2741,10 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
 
     # while loop has always max_loop_count being int64 max
     # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
-    is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
-                     _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
+    is_while_loop = (
+        isinstance(max_loop_count, _expr.Constant)
+        and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize
+    )
 
     if is_while_loop:
         loop_iter_dtype = "bool"
@@ -2599,8 +2761,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
     body_block = list(loop_node.blocks())[0]
     block_input_names = _get_input_names(body_block)
     num_block_inputs = len(block_input_names)
-    name_val_pairs = list(zip(block_input_names,
-                              [init_loop_iter_val] + init_vals))
+    name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals))
     outputs.update(name_val_pairs)
 
     def get_var(name, val):
@@ -2609,8 +2770,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
             return _expr.var(name, type_annotation=checked_type)
         return _expr.var(name)
 
-    loop_iter_var = _expr.var(block_input_names[0], shape=(),
-                              dtype=loop_iter_dtype)
+    loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype)
     loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
 
     # Add non constant free variables to loop variables to prevent code blow up
@@ -2620,9 +2780,13 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
     # This issue was found when converting from Stacked LSTM test. Torch does not add the output
     # of the eariler loop into loop variables of the next loop.
     # So the variable corresponding to the first loop output appears free in the second loop body.
-    free_vars = [var for var in _get_free_vars_from_block(body_block)
-                 if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))
-                 and outputs[var]]
+    free_vars = [
+        var
+        for var in _get_free_vars_from_block(body_block)
+        if var in outputs
+        and not isinstance(outputs[var], (_expr.Constant, int, float))
+        and outputs[var]
+    ]
 
     prev_outputs = {}
     for name in free_vars:
@@ -2637,7 +2801,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
         i = current_vals[0]
 
         if is_while_loop:
-            return _op.equal(i, _expr.const(True, 'bool'))
+            return _op.equal(i, _expr.const(True, "bool"))
 
         return _op.less(i, max_loop_count)
 
@@ -2649,7 +2813,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
             if i < num_block_inputs:
                 outputs[block_input_names[i]] = val
             else:
-                outputs[free_vars[i-num_block_inputs]] = val
+                outputs[free_vars[i - num_block_inputs]] = val
 
         block_outputs = convert_block(body_block, outputs, convert_map, prelude)
         block_outputs += [outputs[name] for name in free_vars]
@@ -2670,7 +2834,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
     outputs.update(prev_outputs)
 
     # The first element is a loop counter or boolean condition, ignore it
-    return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
+    return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)]
 
 
 def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"):
@@ -2706,8 +2870,9 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
             outputs.update(zip(unpacked_names, loop_out))
         else:
             relay_op = convert_map[operator]
-            relay_out = relay_op(inputs, _get_input_types(op_node, outputs,
-                                                          default_dtype=default_dtype))
+            relay_out = relay_op(
+                inputs, _get_input_types(op_node, outputs, default_dtype=default_dtype)
+            )
 
             if isinstance(relay_out, tuple):
                 # This is for torch operators that return multiple outputs
@@ -2718,8 +2883,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
                 assert op_node.outputsSize() == 1
                 outputs[node_name] = relay_out
 
-    return [_wrap_const(outputs[ret_name])
-            for ret_name in ret_names]
+    return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
 
 
 def get_all_op_names(graph):
@@ -2735,7 +2899,7 @@ def get_all_op_names(graph):
 
 
 def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_dtype="float32"):
-    """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
+    """Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
     The companion parameters will be handled automatically.
 
     Parameters
@@ -2778,9 +2942,9 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d
 
     is_module = isinstance(script_module, torch.jit.ScriptModule)
     params = script_module.state_dict() if is_module else {}
-    outputs = _get_relay_input_vars(graph, input_shapes, prelude,
-                                    default_dtype=default_dtype,
-                                    is_module=is_module)
+    outputs = _get_relay_input_vars(
+        graph, input_shapes, prelude, default_dtype=default_dtype, is_module=is_module
+    )
     param_vars, tensors, packed_param_map = convert_params(graph, params)
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
@@ -2791,15 +2955,18 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d
     if "aten::quantize_per_tensor" in op_names:
         weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
         qnn_torch.add_input_quant_params_to_op_inputs(graph)
-        qnn_torch.add_quant_params_to_outputs(outputs,
-                                              packed_param_map,
-                                              weight_quant_params)
+        qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params)
         qnn_torch.add_quant_params(tvm_params, weight_quant_params)
         convert_map.update(qnn_torch.convert_map)
 
-    ret = convert_operators(_get_operator_nodes(graph.nodes()),
-                            outputs, ret_name, convert_map, prelude,
-                            default_dtype=default_dtype)
+    ret = convert_operators(
+        _get_operator_nodes(graph.nodes()),
+        outputs,
+        ret_name,
+        convert_map,
+        prelude,
+        default_dtype=default_dtype,
+    )
 
     mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
index b3f78d1..1213073 100644 (file)
@@ -31,14 +31,12 @@ class QNNParam:
     """ A placeholder for weight quantization parameters """
 
     def __init__(self, weight, bias, scale, zero_point, param_key):
-        param_prefix = param_key[:-len("._packed_params")]
-        self.weight_var = _expr.var(param_prefix + "_weight",
-                                    shape=weight.shape)
+        param_prefix = param_key[: -len("._packed_params")]
+        self.weight_var = _expr.var(param_prefix + "_weight", shape=weight.shape)
         self.weight = weight
 
         if bias is not None:
-            self.bias_var = _expr.var(param_prefix + "_bias",
-                                      shape=bias.shape)
+            self.bias_var = _expr.var(param_prefix + "_bias", shape=bias.shape)
             self.bias = bias.detach().numpy()
         else:
             self.bias_var = None
@@ -55,9 +53,11 @@ def _unpack_quant_params(param_name, packed_params, unpack_func):
     weight_np = qweight.dequantize().numpy()
 
     import torch
+
     if qweight.qscheme() == torch.per_tensor_affine:
-        param = QNNParam(weight_np, bias, qweight.q_scale(),
-                         int(qweight.q_zero_point()), param_name)
+        param = QNNParam(
+            weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name
+        )
     else:
         scales = qweight.q_per_channel_scales().numpy()
         zero_points = qweight.q_per_channel_zero_points().numpy()
@@ -75,6 +75,7 @@ def get_weight_quant_params(script_module):
     linear_packed_params = []
 
     import torch
+
     # conv and linear requires different unpacking function
     # extract all conv and linear parameters separately to distinguish them
     for name, m in script_module.named_modules():
@@ -84,8 +85,10 @@ def get_weight_quant_params(script_module):
             elif m.original_name == "LinearPackedParams":
                 linear_packed_params.append((name, m.state_dict()))
 
-    pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params),
-             (torch.ops.quantized.linear_unpack, linear_packed_params)]
+    pairs = [
+        (torch.ops.quantized.conv2d_unpack, conv_packed_params),
+        (torch.ops.quantized.linear_unpack, linear_packed_params),
+    ]
 
     quant_params = {}
     param_name = "_packed_params"
@@ -95,23 +98,21 @@ def get_weight_quant_params(script_module):
             assert param_name in state_dict
             key = name + "." + param_name
             packed_param = state_dict[param_name]
-            quant_params[key] = _unpack_quant_params(key, packed_param,
-                                                     unpack_func)
+            quant_params[key] = _unpack_quant_params(key, packed_param, unpack_func)
 
     return quant_params
 
 
-def add_quant_params_to_outputs(outputs, packed_param_map,
-                                quant_params):
+def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
     """
     Add quant params to outputs so that they can be referenced by other
     ops later. Weights are quantized here.
     """
     for node_name, packed_param_name in packed_param_map.items():
         qparam = quant_params[packed_param_name]
-        qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
-                                        qparam.zero_point, out_dtype="int8",
-                                        axis=0)
+        qweight = relay.qnn.op.quantize(
+            qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0
+        )
         param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
         outputs[node_name] = param_tup
 
@@ -140,7 +141,7 @@ def _get_quant_param_for_input(input_value):
         "quantized::mul": (2, 3),
         "quantized::cat": (2, 3),
         "quantized::mul_scalar": (2, 3),
-        "quantized::add_scalar": (2, 3)
+        "quantized::add_scalar": (2, 3),
     }
 
     def dfs(current_node):
@@ -163,8 +164,7 @@ def _get_quant_param_for_input(input_value):
     return dfs(input_value.node())
 
 
-def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
-                                       scalar):
+def _get_add_scalar_output_quant_param(input_scale, input_zero_point, scalar):
     """
     Determine the output scale and zp of quantized::add_scalar op
     This is used for mobilenet v3
@@ -191,8 +191,7 @@ def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
     return s_prime, z_prime
 
 
-def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
-                                       scalar):
+def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, scalar):
     """
     Determine the output scale and zp of quantized::mul_scalar op
     This is used for mobilenet v3
@@ -218,9 +217,7 @@ def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
     return s_prime, z_prime
 
 
-def _add_output_quant_params_to_scalar_op(node, graph,
-                                          input_scale, input_zero_point,
-                                          scalar):
+def _add_output_quant_params_to_scalar_op(node, graph, input_scale, input_zero_point, scalar):
     """
     The output scale and zp of {add,mul}_scalar op are not explicit in the IR
     They are required for _get_quant_param_for_input above to work correctly
@@ -240,16 +237,17 @@ def _add_output_quant_params_to_scalar_op(node, graph,
     %7 and %8 are newly created output scale and zp constant nodes
     """
     import torch
+
     operator = node.kind()
 
     if operator == "quantized::mul_scalar":
-        out_scale, out_zero_point = \
-          _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
-                                             scalar)
+        out_scale, out_zero_point = _get_mul_scalar_output_quant_param(
+            input_scale, input_zero_point, scalar
+        )
     elif operator == "quantized::add_scalar":
-        out_scale, out_zero_point = \
-          _get_add_scalar_output_quant_param(input_scale, input_zero_point,
-                                             scalar)
+        out_scale, out_zero_point = _get_add_scalar_output_quant_param(
+            input_scale, input_zero_point, scalar
+        )
     else:
         raise NotImplementedError("unsupported scalar op: %s" % operator)
 
@@ -293,22 +291,24 @@ def add_input_quant_params_to_op_inputs(graph):
     # How many quantized tensors each op takes as inputs?
     # A pair of (scale, zp) for each input quantized tensor will be added
     # to the input nodes
-    num_quantized_inputs = {"quantized::conv2d": 1,
-                            "quantized::conv2d_relu": 1,
-                            "quantized::linear": 1,
-                            "quantized::linear_relu": 1,
-                            "quantized::add_relu": 2,
-                            "quantized::add": 2,
-                            "quantized::mul_relu": 2,
-                            "quantized::mul": 2,
-                            "aten::dequantize": 1,
-                            "aten::mean": 1,
-                            "aten::upsample_bilinear2d": 1,
-                            "aten::relu_": 1,
-                            "aten::relu": 1,
-                            "quantized::add_scalar": 1,
-                            "quantized::mul_scalar": 1,
-                            'quantized::relu6': 1}
+    num_quantized_inputs = {
+        "quantized::conv2d": 1,
+        "quantized::conv2d_relu": 1,
+        "quantized::linear": 1,
+        "quantized::linear_relu": 1,
+        "quantized::add_relu": 2,
+        "quantized::add": 2,
+        "quantized::mul_relu": 2,
+        "quantized::mul": 2,
+        "aten::dequantize": 1,
+        "aten::mean": 1,
+        "aten::upsample_bilinear2d": 1,
+        "aten::relu_": 1,
+        "aten::relu": 1,
+        "quantized::add_scalar": 1,
+        "quantized::mul_scalar": 1,
+        "quantized::relu6": 1,
+    }
 
     need_input_quant_param = set(num_quantized_inputs.keys())
     need_input_quant_param.add("quantized::cat")
@@ -341,9 +341,7 @@ def add_input_quant_params_to_op_inputs(graph):
             inp_zero_point = input_zero_points[0].node().i("value")
 
             # see the comments in this function above
-            _add_output_quant_params_to_scalar_op(node, graph,
-                                                  inp_scale, inp_zero_point,
-                                                  scalar)
+            _add_output_quant_params_to_scalar_op(node, graph, inp_scale, inp_zero_point, scalar)
 
         for scale, zp in zip(input_scales, input_zero_points):
             node.addInput(scale)
@@ -368,16 +366,14 @@ def quantized_mean(data, input_scale, input_zero_point, func_fp32):
     # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
     dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
     out = func_fp32(dequantized)
-    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
-                                 out_dtype="uint8", axis=1)
+    return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1)
 
 
 def quantized_upsample(data, input_scale, input_zero_point, func_fp32):
     # currently piggy backs to fp32, it gets identical output as torch
     data = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
     out = func_fp32(data)
-    return relay.qnn.op.quantize(out, input_scale, input_zero_point,
-                                 out_dtype="uint8", axis=1)
+    return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1)
 
 
 def quantized_relu(data, input_zero_point):
@@ -388,9 +384,10 @@ def quantized_relu(data, input_zero_point):
 
 def _quantize_per_tensor():
     def _impl(inputs, _):
-        return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]),
-                                     _expr.const(inputs[2]), out_dtype="uint8",
-                                     axis=1)
+        return relay.qnn.op.quantize(
+            inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=1
+        )
+
     return _impl
 
 
@@ -400,6 +397,7 @@ def _dequantize():
         inp_scale = _expr.const(inputs[1])
         inp_zero_point = _expr.const(inputs[2])
         return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point)
+
     return _impl
 
 
@@ -411,13 +409,12 @@ def _get_scalar(relay_const_scalar):
     return np.asscalar(_get_numpy(relay_const_scalar))
 
 
-def _do_bias_and_requantize(output, bias, input_scale, weight_scale,
-                            output_scale, output_zero_point,
-                            with_relu):
+def _do_bias_and_requantize(
+    output, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
+):
     """ Output processing for conv and linear """
     # this is a vector for per channel case
-    requant_input_scale = _expr.const(_get_numpy(input_scale) *
-                                      _get_numpy(weight_scale))
+    requant_input_scale = _expr.const(_get_numpy(input_scale) * _get_numpy(weight_scale))
     # Torch does bias add and requanize scale in fp32
     # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h
     # Instead, we do bias add in int32 and use qnn requantize, which needs
@@ -427,23 +424,27 @@ def _do_bias_and_requantize(output, bias, input_scale, weight_scale,
     # Instead, the torch way requires rounding of activation at runtime
 
     if bias is not None:
-        qbias = relay.qnn.op.quantize(bias, requant_input_scale,
-                                      _expr.const(0, "int32"),
-                                      out_dtype="int32", axis=0)
+        qbias = relay.qnn.op.quantize(
+            bias, requant_input_scale, _expr.const(0, "int32"), out_dtype="int32", axis=0
+        )
         requantize_input = _op.nn.bias_add(output, qbias)
     else:
         requantize_input = output
 
-    requantized = relay.qnn.op.requantize(requantize_input,
-                                          requant_input_scale,
-                                          relay.const(0, 'int32'),
-                                          output_scale, output_zero_point,
-                                          out_dtype="int32", axis=1)
+    requantized = relay.qnn.op.requantize(
+        requantize_input,
+        requant_input_scale,
+        relay.const(0, "int32"),
+        output_scale,
+        output_zero_point,
+        out_dtype="int32",
+        axis=1,
+    )
     clip_min = 0
     if with_relu:
         clip_min = _get_scalar(output_zero_point)
 
-    clip = _op.tensor.clip(requantized, clip_min, 255.)
+    clip = _op.tensor.clip(requantized, clip_min, 255.0)
     return _op.cast(clip, dtype="uint8")
 
 
@@ -482,28 +483,41 @@ def _quantized_conv2d(with_relu=False):
 
         if padding[0] != 0 or padding[1] != 0:
             pad_val = _get_scalar(input_zero_point)
-            inp = _op.nn.pad(inputs[0], pad_width=((0, 0),
-                                                   (0, 0),
-                                                   (padding[0], padding[0]),
-                                                   (padding[1], padding[1])),
-                             pad_value=float(pad_val))
+            inp = _op.nn.pad(
+                inputs[0],
+                pad_width=((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])),
+                pad_value=float(pad_val),
+            )
         else:
             inp = inputs[0]
 
         # padding is (0, 0) because we did explicit pad op with
         # pad value being zero point above
-        conv_out = relay.qnn.op.conv2d(inp, weight,
-                                       input_zero_point, weight_zero_point,
-                                       input_scale, weight_scale,
-                                       kernel_size=kernel_size,
-                                       dilation=dilation, strides=strides,
-                                       padding=(0, 0), groups=groups,
-                                       channels=out_channels)
+        conv_out = relay.qnn.op.conv2d(
+            inp,
+            weight,
+            input_zero_point,
+            weight_zero_point,
+            input_scale,
+            weight_scale,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            strides=strides,
+            padding=(0, 0),
+            groups=groups,
+            channels=out_channels,
+        )
         bias_var = inputs[1][3]
 
-        return _do_bias_and_requantize(conv_out, bias_var, input_scale,
-                                       weight_scale, output_scale,
-                                       output_zero_point, with_relu)
+        return _do_bias_and_requantize(
+            conv_out,
+            bias_var,
+            input_scale,
+            weight_scale,
+            output_scale,
+            output_zero_point,
+            with_relu,
+        )
 
     return _impl
 
@@ -522,26 +536,45 @@ def _linear(with_relu=False):
         input_zero_point = _expr.const(inputs[5])
 
         weight_shape = infer_shape(weight)
-        dense = relay.qnn.op.dense(inputs[0], weight,
-                                   input_zero_point, weight_zero_point,
-                                   input_scale, weight_scale,
-                                   units=weight_shape[0])
+        dense = relay.qnn.op.dense(
+            inputs[0],
+            weight,
+            input_zero_point,
+            weight_zero_point,
+            input_scale,
+            weight_scale,
+            units=weight_shape[0],
+        )
         bias_var = inputs[1][3]
 
-        return _do_bias_and_requantize(dense, bias_var, input_scale,
-                                       weight_scale, output_scale,
-                                       output_zero_point, with_relu)
+        return _do_bias_and_requantize(
+            dense, bias_var, input_scale, weight_scale, output_scale, output_zero_point, with_relu
+        )
 
     return _impl
 
 
 def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
-    def qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
-                 input_scale_rhs, input_zero_point_rhs,
-                 output_scale, output_zero_point):
-        qnn_out = relay_op(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
-                           input_scale_rhs, input_zero_point_rhs,
-                           output_scale, output_zero_point)
+    def qnn_impl(
+        lhs,
+        rhs,
+        input_scale_lhs,
+        input_zero_point_lhs,
+        input_scale_rhs,
+        input_zero_point_rhs,
+        output_scale,
+        output_zero_point,
+    ):
+        qnn_out = relay_op(
+            lhs,
+            rhs,
+            input_scale_lhs,
+            input_zero_point_lhs,
+            input_scale_rhs,
+            input_zero_point_rhs,
+            output_scale,
+            output_zero_point,
+        )
         if with_relu:
             clip_min = _get_scalar(output_zero_point)
             return _op.tensor.clip(qnn_out, clip_min, 255)
@@ -549,32 +582,33 @@ def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
 
     # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
     # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
-    def torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
-                   input_scale_rhs, input_zero_point_rhs,
-                   output_scale, output_zero_point):
-        if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
+    def torch_impl(
+        lhs,
+        rhs,
+        input_scale_lhs,
+        input_zero_point_lhs,
+        input_scale_rhs,
+        input_zero_point_rhs,
+        output_scale,
+        output_zero_point,
+    ):
+        if isinstance(lhs, _expr.Call) and lhs.op.name == "qnn.quantize":
             lhs = lhs.args[0]
         else:
-            lhs = relay.qnn.op.dequantize(lhs,
-                                          input_scale_lhs,
-                                          input_zero_point_lhs)
+            lhs = relay.qnn.op.dequantize(lhs, input_scale_lhs, input_zero_point_lhs)
 
-        if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize':
+        if isinstance(rhs, _expr.Call) and rhs.op.name == "qnn.quantize":
             rhs = rhs.args[0]
         else:
-            rhs = relay.qnn.op.dequantize(rhs,
-                                          input_scale_rhs,
-                                          input_zero_point_rhs)
+            rhs = relay.qnn.op.dequantize(rhs, input_scale_rhs, input_zero_point_rhs)
         fp32_out = relay_op(lhs, rhs)
 
         if with_relu:
             fp32_out = _op.nn.relu(fp32_out)
 
-        return relay.qnn.op.quantize(fp32_out,
-                                     output_scale,
-                                     output_zero_point,
-                                     axis=-1,
-                                     out_dtype="uint8")
+        return relay.qnn.op.quantize(
+            fp32_out, output_scale, output_zero_point, axis=-1, out_dtype="uint8"
+        )
 
     def _impl(inputs, _):
         lhs = inputs[0]
@@ -590,13 +624,27 @@ def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
 
         if fp32_piggy_back:
             logging.info("Piggy backing to FP32 op (PyTorch way)")
-            return torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
-                              input_scale_rhs, input_zero_point_rhs,
-                              output_scale, output_zero_point)
-
-        return qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
-                        input_scale_rhs, input_zero_point_rhs,
-                        output_scale, output_zero_point)
+            return torch_impl(
+                lhs,
+                rhs,
+                input_scale_lhs,
+                input_zero_point_lhs,
+                input_scale_rhs,
+                input_zero_point_rhs,
+                output_scale,
+                output_zero_point,
+            )
+
+        return qnn_impl(
+            lhs,
+            rhs,
+            input_scale_lhs,
+            input_zero_point_lhs,
+            input_scale_rhs,
+            input_zero_point_rhs,
+            output_scale,
+            output_zero_point,
+        )
 
     return _impl
 
@@ -605,16 +653,15 @@ def _cat(fp32_piggy_back=False):
     # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
     # for concat they also piggy backs to fp32(!)
     # dequantize -> fp32 math -> quantize
-    def torch_impl(inputs, input_scales, input_zero_points,
-                   output_scale, output_zero_point, axis):
+    def torch_impl(inputs, input_scales, input_zero_points, output_scale, output_zero_point, axis):
         dequantized = []
-        for inp, inp_scale, inp_zp in zip(inputs, input_scales,
-                                          input_zero_points):
+        for inp, inp_scale, inp_zp in zip(inputs, input_scales, input_zero_points):
             dequantized.append(relay.qnn.op.dequantize(inp, inp_scale, inp_zp))
 
         concat = _op.tensor.concatenate(dequantized, axis=axis)
-        return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
-                                     axis=axis, out_dtype="uint8")
+        return relay.qnn.op.quantize(
+            concat, output_scale, output_zero_point, axis=axis, out_dtype="uint8"
+        )
 
     def _impl(inputs, _):
         axis = inputs[1]
@@ -626,17 +673,17 @@ def _cat(fp32_piggy_back=False):
         input_zero_points = []
 
         for i in range(0, num_inputs):
-            input_scales.append(_expr.const(inputs[4+i*2]))
-            input_zero_points.append(_expr.const(inputs[4+i*2+1]))
+            input_scales.append(_expr.const(inputs[4 + i * 2]))
+            input_zero_points.append(_expr.const(inputs[4 + i * 2 + 1]))
 
         if fp32_piggy_back:
-            return torch_impl(inputs[0], input_scales, input_zero_points,
-                              output_scale, output_zero_point, axis)
+            return torch_impl(
+                inputs[0], input_scales, input_zero_points, output_scale, output_zero_point, axis
+            )
 
-        return relay.qnn.op.concatenate(inputs[0],
-                                        input_scales, input_zero_points,
-                                        output_scale, output_zero_point,
-                                        axis)
+        return relay.qnn.op.concatenate(
+            inputs[0], input_scales, input_zero_points, output_scale, output_zero_point, axis
+        )
 
     return _impl
 
@@ -659,11 +706,11 @@ def _add_scalar():
         out_zp = _expr.const(inputs[3])
 
         if q_min > z - c_q or q_max < z - c_q:
-            dequant = relay.qnn.op.dequantize(inputs[0],
-                                              _expr.const(s), _expr.const(z))
+            dequant = relay.qnn.op.dequantize(inputs[0], _expr.const(s), _expr.const(z))
             dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
-            return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp,
-                                         axis=1, out_dtype="uint8")
+            return relay.qnn.op.quantize(
+                dequantized_add, out_scale, out_zp, axis=1, out_dtype="uint8"
+            )
         # only scale change
         return inputs[0]
 
@@ -682,8 +729,9 @@ def _relu6():
         assert len(inputs) == 4, "Input quant params not found in op inputs"
         input_scale = inputs[2]
         input_zero_point = inputs[3]
-        six = quantize_scalar(6., input_scale, input_zero_point)
+        six = quantize_scalar(6.0, input_scale, input_zero_point)
         return _op.tensor.clip(inputs[0], input_zero_point, six)
+
     return _impl
 
 
@@ -714,18 +762,18 @@ def _mul_scalar():
 
 
 convert_map = {
-    'aten::quantize_per_tensor': _quantize_per_tensor(),
-    'quantized::conv2d_relu': _quantized_conv2d(with_relu=True),
-    'aten::dequantize': _dequantize(),
-    'quantized::conv2d': _quantized_conv2d(),
-    'quantized::add_relu': _binop(relay.qnn.op.add, with_relu=True),
-    'quantized::add': _binop(relay.qnn.op.add),
-    'quantized::mul_relu': _binop(relay.qnn.op.mul, with_relu=True),
-    'quantized::mul': _binop(relay.qnn.op.mul),
-    'quantized::linear': _linear(),
-    'quantized::linear_relu': _linear(with_relu=True),
-    'quantized::cat': _cat(),
-    'quantized::add_scalar': _add_scalar(),
-    'quantized::mul_scalar': _mul_scalar(),
-    'quantized::relu6': _relu6()
+    "aten::quantize_per_tensor": _quantize_per_tensor(),
+    "quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
+    "aten::dequantize": _dequantize(),
+    "quantized::conv2d": _quantized_conv2d(),
+    "quantized::add_relu": _binop(relay.qnn.op.add, with_relu=True),
+    "quantized::add": _binop(relay.qnn.op.add),
+    "quantized::mul_relu": _binop(relay.qnn.op.mul, with_relu=True),
+    "quantized::mul": _binop(relay.qnn.op.mul),
+    "quantized::linear": _linear(),
+    "quantized::linear_relu": _linear(with_relu=True),
+    "quantized::cat": _cat(),
+    "quantized::add_scalar": _add_scalar(),
+    "quantized::mul_scalar": _mul_scalar(),
+    "quantized::relu6": _relu6(),
 }
index 02c8204..1cd14c3 100644 (file)
@@ -1,4 +1,3 @@
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -41,7 +40,7 @@ from .common import infer_shape as _infer_shape
 from .common import infer_channels as _infer_channels
 from .common import infer_value as _infer_value
 
-__all__ = ['from_tensorflow']
+__all__ = ["from_tensorflow"]
 
 
 def _get_pad_pair(input1d, kernel1d, stride1d):
@@ -55,46 +54,59 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
 
     return [pad_before, pad_after]
 
+
 def _math_name_picker(surfix):
     def _impl(attr):
-        return 'broadcast_' + surfix
+        return "broadcast_" + surfix
+
     return _impl
 
-def _dimension_picker(prefix, surfix=''):
+
+def _dimension_picker(prefix, surfix=""):
     def _impl(attr):
-        kernel = attr['kernel_shape']
+        kernel = attr["kernel_shape"]
         if len(kernel) == 2:
-            return prefix + '2d' + surfix
+            return prefix + "2d" + surfix
         if len(kernel) == 3:
-            return prefix + '3d' + surfix
+            return prefix + "3d" + surfix
         raise tvm.error.OpAttributeInvalid(
-            'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d'))
+            "Only 2D or 3D kernels are supported for operator {}".format(prefix + "2d or 3d")
+        )
+
     return _impl
 
+
 def _dimension_constraint():
     def _dim_check(attrs):
-        if len(attrs['kernel_shape']) in (2, 3):
+        if len(attrs["kernel_shape"]) in (2, 3):
             return True
         return False
+
     return _dim_check, "Only 2d or 3d kernel supported."
 
+
 def _get_param(params, input_node):
     if isinstance(input_node, _expr.Constant):
         return np.atleast_1d(input_node.data.asnumpy())
     return params[input_node.name_hint].asnumpy()
 
+
 def _get_num_param(params, input_node):
     return _get_param(params, input_node).item()
 
+
 def _get_list_param(params, input_node):
     return _get_param(params, input_node).tolist()
 
+
 def _get_tuple_param(params, input_node):
     return tuple(_get_param(params, input_node))
 
+
 def _need_prelude_for_shape_inference(op):
     return "TensorArray" in op
 
+
 def _get_more_static_shape(shape0, shape1):
     """Compare two shapes with the same rank,
     and return the one with fewer symbolic dimension.
@@ -112,14 +124,18 @@ def _get_more_static_shape(shape0, shape1):
         return shape0
     return shape1
 
+
 def _rsqrt():
     def _impl(inputs, attr, params, mod):
-        inputs.append(tvm.relay.const(-0.5, attr['T'].name))
+        inputs.append(tvm.relay.const(-0.5, attr["T"].name))
         return AttrCvt(op_name="power")(inputs, attr)
+
     return _impl
 
+
 def _argx(func, func_name):
     """ A common wrapper for argmin and argmax operations """
+
     def _impl(inputs, attr, params, mod):
         try:
             # In Tensorflow, `axis` argument is a Tensor, not attribute. We
@@ -127,47 +143,51 @@ def _argx(func, func_name):
             axis_input_value = [_get_num_param(params, inputs[1])]
         except (IndexError, KeyError):
             raise TypeError(
-                "Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
+                "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)
+            )
         return func(inputs[0], axis=axis_input_value, keepdims=False)
+
     return _impl
 
+
 def _elemwise(name):
     def _impl(inputs, attr, params, mod):
         assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
         return get_relay_op(name)(*inputs)
+
     return _impl
 
+
 def _pool3d(name):
     def _impl(inputs, attr, params, mod):
-        attr['data_format'] = attr['data_format'].decode("utf-8")
+        attr["data_format"] = attr["data_format"].decode("utf-8")
         flip_layout = False
 
         input_shape = _infer_shape(inputs[0], mod)
 
-        if attr['data_format'] == 'NDHWC':
-            attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3])
-            attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
-        elif attr['data_format'] == 'NCDHW':
-            attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4])
-            attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
+        if attr["data_format"] == "NDHWC":
+            attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2], attr["ksize"][3])
+            attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3])
+        elif attr["data_format"] == "NCDHW":
+            attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3], attr["ksize"][4])
+            attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4])
         else:
-            msg = 'Value {} of attribute "data_format" of operator Pooling ' \
-                  'is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
-        if attr['data_format'] == "NDHWC":
+            msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
+        if attr["data_format"] == "NDHWC":
             input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)]
             inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
-            attr['data_format'] = "NCDHW"
+            attr["data_format"] = "NCDHW"
             flip_layout = True
 
-        attr['padding'] = attr['padding'].decode("utf-8")
+        attr["padding"] = attr["padding"].decode("utf-8")
 
-        if attr['padding'] == 'VALID':
-            attr['padding'] = [0, 0, 0, 0, 0, 0]
-        elif attr['padding'] == 'SAME':
-            stride_d, stride_h, stride_w = attr['strides']
-            kernel_d, kernel_h, kernel_w = attr['kernel_shape']
-            if attr['data_format'] == 'NDHWC':
+        if attr["padding"] == "VALID":
+            attr["padding"] = [0, 0, 0, 0, 0, 0]
+        elif attr["padding"] == "SAME":
+            stride_d, stride_h, stride_w = attr["strides"]
+            kernel_d, kernel_h, kernel_w = attr["kernel_shape"]
+            if attr["data_format"] == "NDHWC":
                 in_d = input_shape[1]
                 in_h = input_shape[2]
                 in_w = input_shape[3]
@@ -179,62 +199,60 @@ def _pool3d(name):
             pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
             pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
 
-            attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
+            attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
         else:
-            msg = 'Value {} in attribute "padding" of operator Pooling is ' \
-                  'not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+            msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
         if name == "avg_pool":
-            attr['count_include_pad'] = False
-        attr['ceil_mode'] = False
+            attr["count_include_pad"] = False
+        attr["ceil_mode"] = False
         out = AttrCvt(
             op_name=name,
-            transforms={
-                'kernel_shape': 'pool_size',
-                'data_format': 'layout'},
-            ignores=['ksize'])(inputs, attr)
+            transforms={"kernel_shape": "pool_size", "data_format": "layout"},
+            ignores=["ksize"],
+        )(inputs, attr)
         if flip_layout:
             out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
         return out
 
     return _impl
 
+
 def _pooling(name):
     def _impl(inputs, attr, params, mod):
 
-        attr['data_format'] = attr['data_format'].decode("utf-8")
+        attr["data_format"] = attr["data_format"].decode("utf-8")
         flip_layout = False
 
         input_shape = _infer_shape(inputs[0], mod)
 
-        if attr['data_format'] == 'NHWC':
-            attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
-            attr['strides'] = (attr['strides'][1], attr['strides'][2])
-        elif attr['data_format'] == 'NCHW':
-            attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
-            attr['strides'] = (attr['strides'][2], attr['strides'][3])
+        if attr["data_format"] == "NHWC":
+            attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2])
+            attr["strides"] = (attr["strides"][1], attr["strides"][2])
+        elif attr["data_format"] == "NCHW":
+            attr["kernel_shape"] = (attr["ksize"][2], attr["ksize"][3])
+            attr["strides"] = (attr["strides"][2], attr["strides"][3])
         else:
-            msg = 'Value {} of attribute "data_format" of operator Pooling ' \
-                  'is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
+            msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
 
-        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
+        if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC":
             tmp_shape = _infer_shape(inputs[0], mod)
             input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
             inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
-            attr['data_format'] = "NCHW"
+            attr["data_format"] = "NCHW"
             flip_layout = True
 
         # Fix padding
-        attr['padding'] = attr['padding'].decode("utf-8")
-
-        if attr['padding'] == 'VALID':
-            attr['padding'] = [0, 0]
-        elif attr['padding'] == 'SAME':
-            stride_h, stride_w = attr['strides']
-            kernel_h, kernel_w = attr['kernel_shape']
-            if attr['data_format'] == 'NHWC':
+        attr["padding"] = attr["padding"].decode("utf-8")
+
+        if attr["padding"] == "VALID":
+            attr["padding"] = [0, 0]
+        elif attr["padding"] == "SAME":
+            stride_h, stride_w = attr["strides"]
+            kernel_h, kernel_w = attr["kernel_shape"]
+            if attr["data_format"] == "NHWC":
                 in_h = input_shape[1]
                 in_w = input_shape[2]
             else:
@@ -244,57 +262,60 @@ def _pooling(name):
             pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
             pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
 
-            attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
+            attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
         else:
-            msg = 'Value {} in attribute "padding" of operator Pooling is ' \
-                  'not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+            msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
         if name == "avg_pool":
-            attr['count_include_pad'] = False
+            attr["count_include_pad"] = False
 
         out = AttrCvt(
             op_name=_dimension_picker(name),
-            transforms={
-                'kernel_shape':'pool_size',
-                'data_format':'layout'},
-            ignores=['ksize'],
-            extras={'ceil_mode': False},
-            custom_check=_dimension_constraint())(inputs, attr)
+            transforms={"kernel_shape": "pool_size", "data_format": "layout"},
+            ignores=["ksize"],
+            extras={"ceil_mode": False},
+            custom_check=_dimension_constraint(),
+        )(inputs, attr)
 
         if flip_layout:
             out = _op.transpose(out, axes=(0, 2, 3, 1))
 
         return out
+
     return _impl
 
+
 def _conv(opname):
     def _impl(inputs, attr, params, mod):
-        attr['data_format'] = attr['data_format'].decode("utf-8")
+        attr["data_format"] = attr["data_format"].decode("utf-8")
         flip_layout = False
 
-        if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
+        if opname == "conv_transpose" and attr["data_format"] == "NHWC":
             # transform to NCHW for TVM backend compatible and set 'flip_layout'
             # to have output flip back to NHWC
             inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
-            attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
-                attr['strides'][3], attr['strides'][1], attr['strides'][2]
-            attr['data_format'] = 'NCHW'
-
-            if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
-                tmp_shape = attr['_output_shapes'][0]
+            attr["strides"][1], attr["strides"][2], attr["strides"][3] = (
+                attr["strides"][3],
+                attr["strides"][1],
+                attr["strides"][2],
+            )
+            attr["data_format"] = "NCHW"
+
+            if opname == "conv_transpose" and len(attr["_output_shapes"]) > 0:
+                tmp_shape = attr["_output_shapes"][0]
                 tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
-                attr['_output_shapes'][0] = tmp_shape
+                attr["_output_shapes"][0] = tmp_shape
 
             flip_layout = True
 
-        inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
+        inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2]
 
         # NCHW Layout require weights transpose
         weights_shape = _infer_shape(inputs[1], mod)
-        if attr['data_format'] == 'NCHW':
+        if attr["data_format"] == "NCHW":
             tmp_shape = weights_shape
-            if opname in ['conv', 'conv_transpose']:
+            if opname in ["conv", "conv_transpose"]:
                 tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
                 inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
             else:
@@ -302,196 +323,194 @@ def _conv(opname):
                 inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
             weights_shape = tmp_shape
 
-
         input_shape = _infer_shape(inputs_data, mod)
-        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
+        if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC":
             input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
             inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
-            if opname in ['conv', 'conv_transpose']:
+            if opname in ["conv", "conv_transpose"]:
                 weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
                 inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
             else:
                 weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
                 inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
 
-            attr['data_format'] = "NCHW"
-            attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
+            attr["data_format"] = "NCHW"
+            attr["strides"] = [attr["strides"][ii] for ii in (0, 3, 1, 2)]
             flip_layout = True
 
-        if attr['data_format'] == 'NHWC':
+        if attr["data_format"] == "NHWC":
             in_channels = input_shape[3]
             kernel_h, kernel_w, _, depth_mult = weights_shape
-            attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
-            if opname == 'conv':
-                attr['channels'] = weights_shape[3]
-            elif opname == 'conv_transpose':
-                attr['channels'] = weights_shape[2]
+            attr["kernel_shape"] = (weights_shape[0], weights_shape[1])
+            if opname == "conv":
+                attr["channels"] = weights_shape[3]
+            elif opname == "conv_transpose":
+                attr["channels"] = weights_shape[2]
             else:
-                attr['channels'] = input_shape[3] * depth_mult
+                attr["channels"] = input_shape[3] * depth_mult
 
-            if 'dilations' in attr:
-                attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
-            attr['strides'] = (attr['strides'][1], attr['strides'][2])
-        elif attr['data_format'] == 'NCHW':
+            if "dilations" in attr:
+                attr["dilations"] = (attr["dilations"][1], attr["dilations"][2])
+            attr["strides"] = (attr["strides"][1], attr["strides"][2])
+        elif attr["data_format"] == "NCHW":
             in_channels = input_shape[1]
             _, depth_mult, kernel_h, kernel_w = weights_shape
-            attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
-            if opname == 'conv':
-                attr['channels'] = weights_shape[0]
-            elif opname == 'conv_transpose':
-                attr['channels'] = weights_shape[1]
+            attr["kernel_shape"] = (weights_shape[2], weights_shape[3])
+            if opname == "conv":
+                attr["channels"] = weights_shape[0]
+            elif opname == "conv_transpose":
+                attr["channels"] = weights_shape[1]
             else:
-                attr['channels'] = input_shape[1] * depth_mult
-                if attr['channels'] < 0:
-                    attr['channels'] *= -1
+                attr["channels"] = input_shape[1] * depth_mult
+                if attr["channels"] < 0:
+                    attr["channels"] *= -1
 
-            if 'dilations' in attr:
-                attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
-            attr['strides'] = (attr['strides'][2], attr['strides'][3])
+            if "dilations" in attr:
+                attr["dilations"] = (attr["dilations"][2], attr["dilations"][3])
+            attr["strides"] = (attr["strides"][2], attr["strides"][3])
         else:
-            msg = 'Value {} in attribute "data_format" of operator Conv is ' \
-                  'not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
+            msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
 
-        if opname == 'depthwise':
-            attr['groups'] = in_channels
+        if opname == "depthwise":
+            attr["groups"] = in_channels
 
         # Fix padding
-        attr['padding'] = attr['padding'].decode("utf-8")
+        attr["padding"] = attr["padding"].decode("utf-8")
 
-        if attr['padding'] == 'VALID':
-            attr['padding'] = [0, 0]
-        elif attr['padding'] == 'SAME':
-            stride_h, stride_w = attr['strides']
-            kernel_h, kernel_w = attr['kernel_shape']
+        if attr["padding"] == "VALID":
+            attr["padding"] = [0, 0]
+        elif attr["padding"] == "SAME":
+            stride_h, stride_w = attr["strides"]
+            kernel_h, kernel_w = attr["kernel_shape"]
 
             pdata_shape = input_shape
-            if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
-                pdata_shape = attr['_output_shapes'][0]
+            if opname == "conv_transpose" and len(attr["_output_shapes"]) > 0:
+                pdata_shape = attr["_output_shapes"][0]
 
-            if attr['data_format'] == 'NHWC':
+            if attr["data_format"] == "NHWC":
                 in_h = pdata_shape[1]
                 in_w = pdata_shape[2]
             else:
                 in_h = pdata_shape[2]
                 in_w = pdata_shape[3]
 
-            dilation_h = attr['dilations'][0]
-            dilation_w = attr['dilations'][1]
+            dilation_h = attr["dilations"][0]
+            dilation_w = attr["dilations"][1]
             dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
             dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
             pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
             pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
-            attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
+            attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
         else:
-            msg = 'Value {} in attribute "padding" of operator Conv is not ' \
-                  'valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+            msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
-        if 'kernel_layout' not in attr:
-            if opname in ['conv', 'conv_transpose']:
-                attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
+        if "kernel_layout" not in attr:
+            if opname in ["conv", "conv_transpose"]:
+                attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
             else:
-                attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
+                attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
 
         # Ignore the new attributes from TF2.0, for now.
         out = AttrCvt(
-            op_name=_dimension_picker('conv',
-                                      surfix="_transpose" if opname == 'conv_transpose' else ""),
-            ignores=['explicit_paddings'],
+            op_name=_dimension_picker(
+                "conv", surfix="_transpose" if opname == "conv_transpose" else ""
+            ),
+            ignores=["explicit_paddings"],
             transforms={
-                'kernel_shape': 'kernel_size',
-                'data_format': 'data_layout',
-                'dilations': ('dilation', (0, 0)),
-                'group': ('groups', 1)},
-            custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
+                "kernel_shape": "kernel_size",
+                "data_format": "data_layout",
+                "dilations": ("dilation", (0, 0)),
+                "group": ("groups", 1),
+            },
+            custom_check=_dimension_constraint(),
+        )([inputs_data, inputs[1]], attr)
 
         if flip_layout:
             out = _op.transpose(out, axes=(0, 2, 3, 1))
 
         return out
+
     return _impl
 
 
 # Dilation2d
 def _dilation2d():
     def _impl(inputs, attr, params, mod):
-        if 'data_format' not in attr:
-            attr['data_format'] = 'NHWC'
+        if "data_format" not in attr:
+            attr["data_format"] = "NHWC"
 
         input_shape = _infer_shape(inputs[0], mod)
         weights_shape = _infer_shape(inputs[1], mod)
 
-        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
+        if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC":
             input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
             inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
             weights_shape = [weights_shape[ii] for ii in (2, 0, 1)]
             inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1))
-            attr['data_format'] = "NCHW"
-
-        if attr['data_format'] in ['NHWC', 'NCHW']:
-            if 'rates' in attr:
-                attr['dilations'] = attr['rates']
-            if 'dilations' in attr:
-                attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
-            attr['strides'] = (attr['strides'][1], attr['strides'][2])
+            attr["data_format"] = "NCHW"
+
+        if attr["data_format"] in ["NHWC", "NCHW"]:
+            if "rates" in attr:
+                attr["dilations"] = attr["rates"]
+            if "dilations" in attr:
+                attr["dilations"] = (attr["dilations"][1], attr["dilations"][2])
+            attr["strides"] = (attr["strides"][1], attr["strides"][2])
         else:
-            msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' \
-                  'not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
-
-        attr['padding'] = attr['padding'].decode("utf-8")
-        if attr['padding'] == 'VALID':
-            attr['padding'] = [0, 0]
-        elif attr['padding'] == 'SAME':
-            stride_h, stride_w = attr['strides']
-            if attr['data_format'] == 'NHWC':
+            msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' "not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
+
+        attr["padding"] = attr["padding"].decode("utf-8")
+        if attr["padding"] == "VALID":
+            attr["padding"] = [0, 0]
+        elif attr["padding"] == "SAME":
+            stride_h, stride_w = attr["strides"]
+            if attr["data_format"] == "NHWC":
                 kernel_h, kernel_w = weights_shape[0], weights_shape[1]
             else:
                 kernel_h, kernel_w = weights_shape[1], weights_shape[2]
-            if attr['data_format'] == 'NHWC':
+            if attr["data_format"] == "NHWC":
                 in_h = input_shape[1]
                 in_w = input_shape[2]
             else:
                 in_h = input_shape[2]
                 in_w = input_shape[3]
 
-            dilation_h = attr['dilations'][0]
-            dilation_w = attr['dilations'][1]
+            dilation_h = attr["dilations"][0]
+            dilation_w = attr["dilations"][1]
             dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
             dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
             pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
             pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
-            if attr['data_format'] == 'NHWC':
-                inputs[0] = _op.nn.pad(data=inputs[0],
-                                       pad_width=((0, 0),
-                                                  (pad_v[0], pad_v[1]),
-                                                  (pad_h[0], pad_h[1]),
-                                                  (0, 0)))
+            if attr["data_format"] == "NHWC":
+                inputs[0] = _op.nn.pad(
+                    data=inputs[0],
+                    pad_width=((0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1]), (0, 0)),
+                )
             else:
-                inputs[0] = _op.nn.pad(data=inputs[0],
-                                       pad_width=((0, 0),
-                                                  (0, 0),
-                                                  (pad_v[0], pad_v[1]),
-                                                  (pad_h[0], pad_h[1])))
+                inputs[0] = _op.nn.pad(
+                    data=inputs[0],
+                    pad_width=((0, 0), (0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1])),
+                )
 
-            attr['padding'] = [0, 0]
+            attr["padding"] = [0, 0]
 
         else:
-            msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' \
-                  'valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+            msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' "valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
-        attr['kernel_layout'] = 'HWI' if attr['data_format'] == 'NHWC' else 'IHW'
+        attr["kernel_layout"] = "HWI" if attr["data_format"] == "NHWC" else "IHW"
         out = AttrCvt(
-            op_name='dilation2d',
-            ignores=['explicit_paddings', 'rates'],
+            op_name="dilation2d",
+            ignores=["explicit_paddings", "rates"],
             transforms={
-                'data_format': 'data_layout',
-            })([inputs[0], inputs[1]], attr)
-        if attr['_target_layout'] == "NCHW":
+                "data_format": "data_layout",
+            },
+        )([inputs[0], inputs[1]], attr)
+        if attr["_target_layout"] == "NCHW":
             out = _op.transpose(out, axes=(0, 2, 3, 1))
         return out
 
@@ -500,14 +519,14 @@ def _dilation2d():
 
 def _conv3d(opname):
     def _impl(inputs, attr, params, mod):
-        attr['data_format'] = attr['data_format'].decode("utf-8")
+        attr["data_format"] = attr["data_format"].decode("utf-8")
         flip_layout = False
 
-        inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
+        inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2]
 
         # NCDHW Layout require weights transpose
         weights_shape = _infer_shape(inputs[1], mod)
-        if attr['data_format'] == 'NCDHW':
+        if attr["data_format"] == "NCDHW":
             tmp_shape = weights_shape
             tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
             inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
@@ -515,59 +534,64 @@ def _conv3d(opname):
 
         input_shape = _infer_shape(inputs_data, mod)
 
-        if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
+        if attr["_target_layout"] == "NCDHW" and attr["data_format"] == "NDHWC":
             input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
             inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3))
             weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)]
             inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
 
-            attr['data_format'] = "NCDHW"
-            attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)]
+            attr["data_format"] = "NCDHW"
+            attr["strides"] = [attr["strides"][ii] for ii in (0, 4, 1, 2, 3)]
             flip_layout = True
 
-        if attr['data_format'] == 'NDHWC':
+        if attr["data_format"] == "NDHWC":
             kernel_d, kernel_h, kernel_w, _, _ = weights_shape
-            attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
-            if opname == 'conv':
-                attr['channels'] = weights_shape[4]
-            elif opname == 'conv_transpose':
-                attr['channels'] = weights_shape[3]
-
-            if 'dilations' in attr:
-                attr['dilations'] = \
-                    (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
-            attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
-        elif attr['data_format'] == 'NCDHW':
+            attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w)
+            if opname == "conv":
+                attr["channels"] = weights_shape[4]
+            elif opname == "conv_transpose":
+                attr["channels"] = weights_shape[3]
+
+            if "dilations" in attr:
+                attr["dilations"] = (
+                    attr["dilations"][1],
+                    attr["dilations"][2],
+                    attr["dilations"][3],
+                )
+            attr["strides"] = (attr["strides"][1], attr["strides"][2], attr["strides"][3])
+        elif attr["data_format"] == "NCDHW":
             _, _, kernel_d, kernel_h, kernel_w = weights_shape
-            attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
-            if opname == 'conv':
-                attr['channels'] = weights_shape[0]
-            elif opname == 'conv_transpose':
-                attr['channels'] = weights_shape[1]
-
-            if 'dilations' in attr:
-                attr['dilations'] = \
-                    (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
-            attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
+            attr["kernel_shape"] = (kernel_d, kernel_h, kernel_w)
+            if opname == "conv":
+                attr["channels"] = weights_shape[0]
+            elif opname == "conv_transpose":
+                attr["channels"] = weights_shape[1]
+
+            if "dilations" in attr:
+                attr["dilations"] = (
+                    attr["dilations"][2],
+                    attr["dilations"][3],
+                    attr["dilations"][4],
+                )
+            attr["strides"] = (attr["strides"][2], attr["strides"][3], attr["strides"][4])
         else:
-            msg = 'Value {} in attribute "data_format" of operator Conv is ' \
-                  'not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
+            msg = 'Value {} in attribute "data_format" of operator Conv is ' "not valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"]))
 
         # Fix padding
-        attr['padding'] = attr['padding'].decode("utf-8")
+        attr["padding"] = attr["padding"].decode("utf-8")
 
-        if attr['padding'] == 'VALID':
-            attr['padding'] = [0, 0, 0]
-        elif attr['padding'] == 'SAME':
-            stride_d, stride_h, stride_w = attr['strides']
-            kernel_d, kernel_h, kernel_w = attr['kernel_shape']
+        if attr["padding"] == "VALID":
+            attr["padding"] = [0, 0, 0]
+        elif attr["padding"] == "SAME":
+            stride_d, stride_h, stride_w = attr["strides"]
+            kernel_d, kernel_h, kernel_w = attr["kernel_shape"]
 
             pdata_shape = input_shape
-            if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
-                pdata_shape = attr['_output_shapes'][0]
+            if opname == "conv_transpose" and len(attr["_output_shapes"]) > 0:
+                pdata_shape = attr["_output_shapes"][0]
 
-            if attr['data_format'] == 'NDHWC':
+            if attr["data_format"] == "NDHWC":
                 in_d = pdata_shape[1]
                 in_h = pdata_shape[2]
                 in_w = pdata_shape[3]
@@ -576,9 +600,9 @@ def _conv3d(opname):
                 in_h = pdata_shape[3]
                 in_w = pdata_shape[4]
 
-            dilation_d = attr['dilations'][0]
-            dilation_h = attr['dilations'][1]
-            dilation_w = attr['dilations'][2]
+            dilation_d = attr["dilations"][0]
+            dilation_h = attr["dilations"][1]
+            dilation_w = attr["dilations"][2]
             dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
             dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
             dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
@@ -586,87 +610,94 @@ def _conv3d(opname):
             pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
             pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
-            attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
+            attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
 
         else:
-            msg = 'Value {} in attribute "padding" of operator Conv is not ' \
-                  'valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+            msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
+            raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
-        if 'kernel_layout' not in attr:
-            attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'
+        if "kernel_layout" not in attr:
+            attr["kernel_layout"] = "DHWIO" if attr["data_format"] == "NDHWC" else "OIDHW"
 
-        use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
-        channel_axis = 1 if attr['data_format'] == "NCDHW" else 4
+        use_bias = len(inputs) == (3 if opname != "conv_transpose" else 4)
+        channel_axis = 1 if attr["data_format"] == "NCDHW" else 4
 
         # Ignore the new attributes from TF2.0, for now.
         out = AttrCvt(
-            op_name=_dimension_picker('conv',
-                                      surfix="_transpose" if opname == 'conv_transpose' else ""),
-            ignores=['explicit_paddings', 'Tshape'],
+            op_name=_dimension_picker(
+                "conv", surfix="_transpose" if opname == "conv_transpose" else ""
+            ),
+            ignores=["explicit_paddings", "Tshape"],
             transforms={
-                'kernel_shape': 'kernel_size',
-                'data_format': 'data_layout',
-                'dilations': ('dilation', (0, 0)),
-                'group': ('groups', 1)},
-            custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
+                "kernel_shape": "kernel_size",
+                "data_format": "data_layout",
+                "dilations": ("dilation", (0, 0)),
+                "group": ("groups", 1),
+            },
+            custom_check=_dimension_constraint(),
+        )([inputs_data, inputs[1]], attr)
 
         if use_bias:
-            out = _op.nn.bias_add(out,
-                                  inputs[2] if opname != 'conv_transpose' else inputs[3],
-                                  axis=channel_axis)
+            out = _op.nn.bias_add(
+                out, inputs[2] if opname != "conv_transpose" else inputs[3], axis=channel_axis
+            )
 
         if flip_layout:
             out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
 
         return out
+
     return _impl
 
+
 def _nms():
     def _impl(inputs, attr, params, mod):
         # Get parameter values
         try:
-            max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy()
-                                                .astype("int64"))[0])
+            max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0])
         except Exception:
             try:
-                max_output_size = _infer_value(inputs[2], params,
-                                               mod).asnumpy().astype("int64").tolist()[0]
+                max_output_size = (
+                    _infer_value(inputs[2], params, mod).asnumpy().astype("int64").tolist()[0]
+                )
             except Exception:
                 max_output_size = inputs[2]
         iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
         # score_threshold was introduced from V3
         score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0
-        pad_output = 'pad_to_max_output_size'
+        pad_output = "pad_to_max_output_size"
 
         # Generate data with shape (1, num_anchors, 5)
-        scores = AttrCvt(op_name="expand_dims",
-                         ignores=['T_threshold', pad_output],
-                         extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr)
-        data = get_relay_op('concatenate')([scores, inputs[0]], -1)
-        data = get_relay_op('expand_dims')(data, 0, 1)
+        scores = AttrCvt(
+            op_name="expand_dims",
+            ignores=["T_threshold", pad_output],
+            extras={"axis": -1, "num_newaxis": 1},
+        )([inputs[1]], attr)
+        data = get_relay_op("concatenate")([scores, inputs[0]], -1)
+        data = get_relay_op("expand_dims")(data, 0, 1)
 
         # reason why using get_valid_counts is for inference performance
-        ct, data, indices = get_relay_op('get_valid_counts')(data,
-                                                             score_threshold=score_threshold,
-                                                             id_index=-1,
-                                                             score_index=0)
+        ct, data, indices = get_relay_op("get_valid_counts")(
+            data, score_threshold=score_threshold, id_index=-1, score_index=0
+        )
         # TensorFlow NMS doesn't have parameter top_k
         top_k = -1
         # TF doesn't have class id for nms input
         score_index = 0
-        nms_ret = get_relay_op('non_max_suppression')(data=data,
-                                                      valid_count=ct,
-                                                      indices=indices,
-                                                      max_output_size=max_output_size,
-                                                      iou_threshold=iou_threshold,
-                                                      force_suppress=True,
-                                                      top_k=top_k,
-                                                      coord_start=1,
-                                                      score_index=score_index,
-                                                      id_index=-1,
-                                                      return_indices=True,
-                                                      invalid_to_bottom=False)
+        nms_ret = get_relay_op("non_max_suppression")(
+            data=data,
+            valid_count=ct,
+            indices=indices,
+            max_output_size=max_output_size,
+            iou_threshold=iou_threshold,
+            force_suppress=True,
+            top_k=top_k,
+            coord_start=1,
+            score_index=score_index,
+            id_index=-1,
+            return_indices=True,
+            invalid_to_bottom=False,
+        )
 
         if pad_output in attr and attr[pad_output]:
             return nms_ret
@@ -675,23 +706,30 @@ def _nms():
         data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
 
         # slice to get the dynamic result
-        ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
-                                            end=size, slice_mode="size")
+        ret = get_relay_op("strided_slice")(
+            data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
+        )
         return ret
+
     return _impl
 
+
 def _decode_image():
     def _impl(inputs, attr, params, mod):
         # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
         warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
         return inputs[0]
+
     return _impl
 
+
 def _unravel_index():
     def _impl(inputs, attr, params, mod):
         return _op.unravel_index(inputs[0], inputs[1])
+
     return _impl
 
+
 def _crop_and_resize():
     def _impl(inputs, attr, params, mod):
         # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
@@ -701,35 +739,44 @@ def _crop_and_resize():
         except (IndexError, KeyError):
             crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist()
 
-        method = attr['method'].decode()
-        method = 'nearest_neighbor' if method == 'nearest' else method
-        if method not in ['bilinear', 'nearest_neighbor']:
-            raise tvm.error.OpAttributeUnImplemented(
-                'Method {} is not supported'.format(method))
-        layout = attr['layout'] if 'layout' in attr else 'NHWC'
-        extrapolation_value = attr['extrapolation_value']
+        method = attr["method"].decode()
+        method = "nearest_neighbor" if method == "nearest" else method
+        if method not in ["bilinear", "nearest_neighbor"]:
+            raise tvm.error.OpAttributeUnImplemented("Method {} is not supported".format(method))
+        layout = attr["layout"] if "layout" in attr else "NHWC"
+        extrapolation_value = attr["extrapolation_value"]
+
+        return get_relay_op("crop_and_resize")(
+            inputs[0], inputs[1], inputs[2], crop_size, layout, method, extrapolation_value
+        )
 
-        return get_relay_op("crop_and_resize")(inputs[0], inputs[1], inputs[2], crop_size,
-                                               layout, method, extrapolation_value)
     return _impl
 
+
 def _cast():
     def _impl(inputs, attr, params, mod):
-        return inputs[0].astype(attr['DstT'].name)
+        return inputs[0].astype(attr["DstT"].name)
+
     return _impl
 
+
 def _expand_dims():
     def _impl(inputs, attr, params, mod):
         dim_input = inputs.pop(1)
         axis = _get_num_param(params, dim_input)
-        return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
-                       extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
+        return AttrCvt(
+            op_name="expand_dims",
+            ignores=["Tdim", "N"],
+            extras={"axis": int(axis), "num_newaxis": 1},
+        )(inputs, attr)
+
     return _impl
 
+
 def _resize(method):
     def _impl(inputs, attr, params, mod):
-        if attr['_output_shapes'][0] is not None:
-            size = attr['_output_shapes'][0][1:3]
+        if attr["_output_shapes"][0] is not None:
+            size = attr["_output_shapes"][0][1:3]
             # Important that the size is defined. If an axis is not, we need to infer what
             # the shape should be.
             if -1 in size:
@@ -737,27 +784,31 @@ def _resize(method):
         else:
             size = _infer_value(inputs[1], params, mod).asnumpy().reshape([-1]).tolist()
 
-        attr['size'] = size
+        attr["size"] = size
         inputs.pop(1)
         # NHWC
-        attr['layout'] = 'NHWC'
-        if attr.pop('align_corners') is True:
-            attr['coordinate_transformation_mode'] = 'align_corners'
+        attr["layout"] = "NHWC"
+        if attr.pop("align_corners") is True:
+            attr["coordinate_transformation_mode"] = "align_corners"
         else:
-            attr['coordinate_transformation_mode'] = 'asymmetric'
+            attr["coordinate_transformation_mode"] = "asymmetric"
 
         # Ignore the new attributes from TF2.0, for now.
-        return AttrCvt(op_name='resize',
-                       ignores=['Tdim', 'half_pixel_centers'],
-                       extras={'method': method})(inputs, attr)
+        return AttrCvt(
+            op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method}
+        )(inputs, attr)
+
     return _impl
 
+
 def _check_numerics():
     def _impl(inputs, attr, params, mod):
         # Making a copy node assuming no need to verify
-        return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
+        return AttrCvt(op_name="copy", ignores=["message"])(inputs, attr)
+
     return _impl
 
+
 def _assert():
     # ToDo: In general people want asserts to be gone from TensorFlow graphs
     # when they are optimizing them, so converting it to a no-op is
@@ -765,6 +816,7 @@ def _assert():
     # once Relay gets a Halt or Assert op.
     return _no_op()
 
+
 def _no_op():
     def _impl(inputs, attr, params, mod):
         # ToDo: This should really be an op that returns nothing, which could
@@ -775,21 +827,24 @@ def _no_op():
         # improved. In the mean time, it is hard to imagine a case where it
         # matters in any real way that a no-op is converted to a constant 0.
         return tvm.relay.const(0)
+
     return _impl
 
+
 def _matmul():
     def _impl(inputs, attr, params, mod):
-        channels = _infer_channels(inputs[1], not attr['transpose_b'])
-        if attr['transpose_a']:
+        channels = _infer_channels(inputs[1], not attr["transpose_b"])
+        if attr["transpose_a"]:
             inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
-        if not attr['transpose_b']:
+        if not attr["transpose_b"]:
             inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
-        return AttrCvt(op_name="dense",
-                       extras={'units': channels},
-                       ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr)
+        return AttrCvt(
+            op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"]
+        )(inputs, attr)
 
     return _impl
 
+
 def _batch_matmul():
     def _impl(inputs, attr, params, mod):
         input_x = inputs[0]
@@ -806,11 +861,11 @@ def _batch_matmul():
             input_x = _op.reshape(input_x, newshape=new_shape_x)
             input_y = _op.reshape(input_y, newshape=new_shape_y)
 
-        adj_x = attr['adj_x']
-        adj_y = attr['adj_y']
+        adj_x = attr["adj_x"]
+        adj_y = attr["adj_y"]
         input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
         input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
-        ret = get_relay_op('batch_matmul')(input_x, input_y)
+        ret = get_relay_op("batch_matmul")(input_x, input_y)
 
         # reshape result back to n-dimensional
         if len(orig_shape_x) > 3:
@@ -820,63 +875,68 @@ def _batch_matmul():
             ret = _op.reshape(ret, newshape=final_shape)
 
         return ret
+
     return _impl
 
+
 def _identity():
     def _impl(inputs, attr, params, mod):
         return inputs[0]
+
     return _impl
 
+
 def _concatV2():
     def _impl(inputs, attr, params, mod):
-        pop_node = inputs.pop(len(inputs)-1)
+        pop_node = inputs.pop(len(inputs) - 1)
         axis = int(_get_num_param(params, pop_node))
-        return AttrCvt(
-            op_name="concatenate", ignores=['T', 'N', 'Tidx'],
-            extras={'axis': axis})([inputs], attr)
+        return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})(
+            [inputs], attr
+        )
+
     return _impl
 
+
 def _concat():
     def _impl(inputs, attr, params, mod):
         pop_node = inputs.pop(0)
         axis = int(_get_num_param(params, pop_node))
-        return AttrCvt(
-            op_name="concatenate", ignores=['N'],
-            extras={'axis': axis})([inputs], attr)
+        return AttrCvt(op_name="concatenate", ignores=["N"], extras={"axis": axis})([inputs], attr)
+
     return _impl
 
+
 def _pack():
     def _impl(inputs, attr, params, mod):
         axis = int(attr["axis"])
         inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
         return _op.concatenate(inputs_reshaped, axis)
+
     return _impl
 
+
 def _tensor_array():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr.get('dtype').name
-        assert not attr["dynamic_size"], "Dynamic size tensor array is " \
-                                         "not supported in TVM yet."
+        dtype_str = attr.get("dtype").name
+        assert not attr["dynamic_size"], "Dynamic size tensor array is " "not supported in TVM yet."
 
         if "shape" in attr:
             shape = attr["shape"]
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape)
             static_tensor_array_ops.register()
-            tensor_array_constructor = prelude.get_var_static('tensor_array',
-                                                              dtype_str,
-                                                              shape)
+            tensor_array_constructor = prelude.get_var_static("tensor_array", dtype_str, shape)
             tensor_array = tensor_array_constructor(inputs[0])
         else:
-            tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
+            tensor_array_constructor = prelude.get_var("tensor_array", dtype_str)
             tensor_array = tensor_array_constructor(inputs[0])
         return tensor_array
+
     return _impl
 
+
 def _tensor_array_scatter():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr.get('T').name
+        dtype_str = attr.get("T").name
         input_ta = inputs[0]
         input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
         values_shape = _infer_shape(inputs[2], prelude.mod)
@@ -888,72 +948,60 @@ def _tensor_array_scatter():
             unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
             unstack_function = prelude.get_var(unstack_name, dtype_str)
             values = unstack_function(inputs[2])
-            tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
+            tensor_array_scatter_func = prelude.get_var("tensor_array_scatter", dtype_str)
         else:
             input_t_shape = _get_more_static_shape(input_t_shape, input_shape)
             values_shape = (values_shape[0],) + input_t_shape
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_t_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_t_shape)
             static_tensor_array_ops.register()
             # Register static indices shape
             if isinstance(indices_shape[0], int):
                 static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
-            tensor_array_scatter_func = prelude.get_var_static('tensor_array_scatter',
-                                                               dtype_str,
-                                                               input_t_shape)
+            tensor_array_scatter_func = prelude.get_var_static(
+                "tensor_array_scatter", dtype_str, input_t_shape
+            )
 
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           values_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, values_shape)
             static_tensor_array_ops.register()
-            unstack_function = prelude.get_var_static('tensor_array_unstack',
-                                                      dtype_str,
-                                                      values_shape)
+            unstack_function = prelude.get_var_static(
+                "tensor_array_unstack", dtype_str, values_shape
+            )
             values = unstack_function(inputs[2])
         ret = tensor_array_scatter_func(input_ta, inputs[1], values)
         return ret
+
     return _impl
 
+
 def _tensor_array_gather():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr.get('dtype').name
+        dtype_str = attr.get("dtype").name
         input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
         indices_shape = _infer_shape(inputs[1], prelude.mod)
 
         if input_shape is None:
-            gather_func = prelude.get_var('tensor_array_gather', dtype_str)
+            gather_func = prelude.get_var("tensor_array_gather", dtype_str)
             out = gather_func(inputs[2], inputs[1])
         else:
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape)
             static_tensor_array_ops.register()
 
             if not isinstance(indices_shape[0], int):
-                gather_function = prelude.get_var_static('tensor_array_gather',
-                                                         dtype_str,
-                                                         input_shape)
+                gather_function = prelude.get_var_static(
+                    "tensor_array_gather", dtype_str, input_shape
+                )
                 out_tensor_t = gather_function(inputs[2], inputs[1])
                 out_shape = (indices_shape[0],) + input_shape
-                static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                               dtype_str,
-                                                               out_shape)
+                static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape)
                 static_tensor_array_ops.register()
 
                 # Output shape is (indices_shape[0],) + input_shape
-                get_data_func = prelude.get_var_static('tensor_get_data',
-                                                       dtype_str,
-                                                       out_shape)
+                get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, out_shape)
                 out = get_data_func(out_tensor_t)
             else:
                 # For fixed length indices, directly generate static shape output
-                read_func = prelude.get_var_static('tensor_array_read',
-                                                   dtype_str,
-                                                   input_shape)
-                get_data_func = prelude.get_var_static('tensor_get_data',
-                                                       dtype_str,
-                                                       input_shape)
+                read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape)
+                get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, input_shape)
                 tensor_list = []
                 for i in range(indices_shape[0]):
                     index = _op.take(inputs[1], tvm.relay.const(i))
@@ -966,38 +1014,39 @@ def _tensor_array_gather():
                     out = tensor_list[0]
 
         return out
+
     return _impl
 
+
 def _tensor_array_size():
     def _impl(inputs, attr, params, prelude):
         return prelude.length(inputs[0])
+
     return _impl
 
+
 def _tensor_array_write():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr.get('T').name
+        dtype_str = attr.get("T").name
         input_ta = inputs[3]
         input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
         input_t_shape = _infer_shape(inputs[2], prelude.mod)
         input_rank = len(input_t_shape)
 
         if input_ta_shape is None:
-            tensor_name = 'tensor{}'.format(input_rank)
+            tensor_name = "tensor{}".format(input_rank)
             tensor_func = prelude.get_var(tensor_name, dtype_str)
             v = tensor_func(inputs[2])
-            write_func = prelude.get_var('tensor_array_write', dtype_str)
+            write_func = prelude.get_var("tensor_array_write", dtype_str)
         else:
             input_ta_rank = len(input_ta_shape)
-            assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
-                format(input_ta_rank, input_rank)
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_ta_shape)
+            assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format(
+                input_ta_rank, input_rank
+            )
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape)
             static_tensor_array_ops.register()
 
-            tensor_func = prelude.get_var_static("tensor_constructor",
-                                                 dtype_str,
-                                                 input_ta_shape)
+            tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, input_ta_shape)
             v = tensor_func(inputs[2])
             # Write tensor with more static shape
             actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape)
@@ -1011,111 +1060,96 @@ def _tensor_array_write():
                 if num_any_dim <= 1:
                     v = tensor_func(_op.reshape(inputs[2], new_shape))
 
-            write_func = prelude.get_var_static('tensor_array_write',
-                                                dtype_str,
-                                                input_ta_shape)
+            write_func = prelude.get_var_static("tensor_array_write", dtype_str, input_ta_shape)
 
         return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v)
+
     return _impl
 
+
 def _tensor_array_read():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr['dtype'].name
+        dtype_str = attr["dtype"].name
         input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
 
         if input_shape is None:
-            read_func = prelude.get_var('tensor_array_read', dtype_str)
+            read_func = prelude.get_var("tensor_array_read", dtype_str)
             out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
         else:
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape)
             static_tensor_array_ops.register()
             read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape)
             out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
-            get_data_func = prelude.get_var_static('tensor_get_data',
-                                                   dtype_str,
-                                                   input_shape)
+            get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, input_shape)
             out = get_data_func(out_tensor)
 
         return out
+
     return _impl
 
+
 def _tensor_array_split():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr.get('T').name
+        dtype_str = attr.get("T").name
         input_ta = inputs[0]
         input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
-        lengths = _op.cast(inputs[2], 'int32')
+        lengths = _op.cast(inputs[2], "int32")
         lengths_shape = _infer_shape(lengths, prelude.mod)
         value_shape = _infer_shape(inputs[1], prelude.mod)
         input_rank = len(value_shape)
 
         if input_ta_shape is None:
             v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
-            split_func = prelude.get_var('tensor_array_split', dtype_str)
+            split_func = prelude.get_var("tensor_array_split", dtype_str)
         else:
             input_ta_rank = len(input_ta_shape)
-            assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
-                format(input_ta_rank, input_rank)
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_ta_shape)
+            assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}".format(
+                input_ta_rank, input_rank
+            )
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape)
             static_tensor_array_ops.register()
 
             # Check static value/indices shape
             if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int):
-                static_tensor_array_ops.define_tensor_array_split(value_shape,
-                                                                  lengths_shape,
-                                                                  True)
+                static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True)
 
-            tensor_func_name = prelude.get_name_static("tensor_constructor",
-                                                       dtype_str,
-                                                       value_shape)
+            tensor_func_name = prelude.get_name_static("tensor_constructor", dtype_str, value_shape)
             if not hasattr(prelude, tensor_func_name):
-                static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                               dtype_str,
-                                                               value_shape)
+                static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, value_shape)
                 static_tensor_array_ops.register()
-            tensor_func = prelude.get_var_static("tensor_constructor",
-                                                 dtype_str,
-                                                 value_shape)
+            tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, value_shape)
             v = tensor_func(inputs[1])
-            split_func = prelude.get_var_static('tensor_array_split',
-                                                dtype_str,
-                                                input_ta_shape)
+            split_func = prelude.get_var_static("tensor_array_split", dtype_str, input_ta_shape)
 
         return split_func(input_ta, v, lengths)
+
     return _impl
 
+
 def _tensor_array_concat():
     def _impl(inputs, attr, params, prelude):
-        dtype_str = attr['dtype'].name
+        dtype_str = attr["dtype"].name
         input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude)
 
         if input_shape is None:
-            concat_func = prelude.get_var('tensor_array_concat', dtype_str)
+            concat_func = prelude.get_var("tensor_array_concat", dtype_str)
             out = concat_func(inputs[1])
         else:
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           input_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape)
             static_tensor_array_ops.register()
             concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape)
             out_tensor = concat_func(inputs[1])
             out_shape = (Any(),) + input_shape[1:]
-            static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                           dtype_str,
-                                                           out_shape)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape)
             static_tensor_array_ops.register()
-            get_data_func = prelude.get_var_static('tensor_get_data',
-                                                   dtype_str,
-                                                   out_shape)
+            get_data_func = prelude.get_var_static("tensor_get_data", dtype_str, out_shape)
             out = get_data_func(out_tensor)
 
         return out
+
     return _impl
 
+
 def _tile():
     def _impl(inputs, attr, params, mod):
         reps_input = inputs.pop()
@@ -1126,12 +1160,13 @@ def _tile():
             reps = _get_list_param(params, reps_input)
         new_input = [inputs.pop(0)]
 
-        return AttrCvt(
-            op_name='tile',
-            extras={'reps': tuple(reps)},
-            ignores=['Tmultiples'])(new_input, attr)
+        return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])(
+            new_input, attr
+        )
+
     return _impl
 
+
 def _slice():
     def _impl(inputs, attr, params, mod):
         try:
@@ -1160,8 +1195,10 @@ def _slice():
         elif not isinstance(size, (_expr.Call, _expr.Var)):
             for _ in range(len(size), data_dim):
                 size.append(-1)
-        return _op.strided_slice(inputs[0], begin=begin, end=size,
-                                 strides=strides, slice_mode="size")
+        return _op.strided_slice(
+            inputs[0], begin=begin, end=size, strides=strides, slice_mode="size"
+        )
+
     return _impl
 
 
@@ -1176,27 +1213,25 @@ def _reshape():
             # try to infer shape by precompute prune if possible.
             try:
                 params_new = _infer_value(pop_node, params, mod)
-                shape_arg = tuple(params_new.asnumpy().astype('int32').flatten())
+                shape_arg = tuple(params_new.asnumpy().astype("int32").flatten())
             except Exception:
                 # Deal with symbolic shape case.
-                if isinstance(pop_node, _expr.Call) and \
-                        "shape_of" in str(pop_node.op):
+                if isinstance(pop_node, _expr.Call) and "shape_of" in str(pop_node.op):
                     # shape_of is the direct ancestor.
                     return _op.reshape_like(inputs[0], pop_node.args[0])
                 shape_arg = pop_node
 
-        return AttrCvt(
-            op_name="reshape",
-            extras={'newshape': shape_arg},
-            ignores=['Tshape'])(inputs, attr)
-    return _impl
+        return AttrCvt(op_name="reshape", extras={"newshape": shape_arg}, ignores=["Tshape"])(
+            inputs, attr
+        )
 
+    return _impl
 
 
 def _depth_to_space():
     def _impl(inputs, attr, params, mod):
-        block_size = int(attr['block_size'])
-        layout = attr['data_format'].decode("utf-8")
+        block_size = int(attr["block_size"])
+        layout = attr["data_format"].decode("utf-8")
         return _op.nn.depth_to_space(inputs[0], block_size, layout)
 
     return _impl
@@ -1204,8 +1239,8 @@ def _depth_to_space():
 
 def _space_to_depth():
     def _impl(inputs, attr, params, mod):
-        block_size = int(attr['block_size'])
-        layout = attr['data_format'].decode("utf-8")
+        block_size = int(attr["block_size"])
+        layout = attr["data_format"].decode("utf-8")
         return _op.nn.space_to_depth(inputs[0], block_size, layout)
 
     return _impl
@@ -1214,14 +1249,15 @@ def _space_to_depth():
 def _bias_add():
     def _impl(inputs, attr, params, mod):
         # Must expand for proper broadcasting in NCHW.
-        if 'data_format' in attr and \
-                attr['data_format'].decode("utf-8") == 'NCHW':
+        if "data_format" in attr and attr["data_format"].decode("utf-8") == "NCHW":
             bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1))
         else:
             bias = inputs[1]
         return _op.add(inputs[0], bias)
+
     return _impl
 
+
 def _broadcast_to():
     def _impl(inputs, attr, params, mod):
         if isinstance(inputs[1], _expr.Var):
@@ -1230,18 +1266,21 @@ def _broadcast_to():
             shape = _infer_value(inputs[1], params, mod)
         shape = list(shape.asnumpy().reshape([-1]))
         return _op.broadcast_to(inputs[0], shape)
+
     return _impl
 
+
 def _squeeze():
     def _impl(inputs, attr, params, mod):
-        if len(attr['squeeze_dims']) == 0:
-            attr['squeeze_dims'] = None
-        return AttrCvt(
-            op_name="squeeze",
-            transforms={'squeeze_dims':'axis'},
-            ignores=['T'])(inputs, attr)
+        if len(attr["squeeze_dims"]) == 0:
+            attr["squeeze_dims"] = None
+        return AttrCvt(op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T"])(
+            inputs, attr
+        )
+
     return _impl
 
+
 def _fused_batch_norm():
     def _impl(inputs, attr, params, mod):
         # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
@@ -1250,13 +1289,13 @@ def _fused_batch_norm():
         axis = 3
         need_cast = False
 
-        if 'data_format' in attr:
-            attr['data_format'] = attr['data_format'].decode("utf-8")
-            if attr['data_format'] == 'NCHW':
+        if "data_format" in attr:
+            attr["data_format"] = attr["data_format"].decode("utf-8")
+            if attr["data_format"] == "NCHW":
                 axis = 1
-        if 'U' in attr and attr['U'].name != attr['T'].name:
+        if "U" in attr and attr["U"].name != attr["T"].name:
             need_cast = True
-            inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name)
+            inputs[0] = _op.cast(inputs[0], dtype=attr["U"].name)
         # Check if mean and variance are empty
         # If so, replace them with Mean and Variance Ops
         # For run-time calculation
@@ -1265,19 +1304,22 @@ def _fused_batch_norm():
         if moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0:
             inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True)
             inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True)
-        out = AttrCvt(op_name='batch_norm',
-                      transforms={'scale_after_normalization':'scale',
-                                  'variance_epsilon':'epsilon'},
-                      extras={'axis': axis},
-                      ignores=['data_format', 'U'],
-                      disables=['momentum'])(inputs, attr)
+        out = AttrCvt(
+            op_name="batch_norm",
+            transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"},
+            extras={"axis": axis},
+            ignores=["data_format", "U"],
+            disables=["momentum"],
+        )(inputs, attr)
 
         if need_cast:
             out = _expr.TupleGetItem(out.astuple(), 0)
-            out = _op.cast(out, dtype=attr['T'].name)
+            out = _op.cast(out, dtype=attr["T"].name)
         return out
+
     return _impl
 
+
 def _batch_norm():
     def _impl(inputs, attr, params, mod):
         # Rearrange inputs from
@@ -1287,24 +1329,29 @@ def _batch_norm():
         new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]
 
         axis = 3
-        if 'data_format' in attr:
-            attr['data_format'] = attr['data_format'].decode("utf-8")
-            if attr['data_format'] == 'NCHW':
+        if "data_format" in attr:
+            attr["data_format"] = attr["data_format"].decode("utf-8")
+            if attr["data_format"] == "NCHW":
                 axis = 1
 
         return AttrCvt(
-            op_name='batch_norm',
-            transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
-            extras={'axis': axis},
-            ignores=['data_format'],
-            disables=['momentum'])(new_inputs, attr)
+            op_name="batch_norm",
+            transforms={"scale_after_normalization": "scale", "variance_epsilon": "epsilon"},
+            extras={"axis": axis},
+            ignores=["data_format"],
+            disables=["momentum"],
+        )(new_inputs, attr)
+
     return _impl
 
+
 def _relu6():
     def _impl(inputs, attr, params, mod):
         return _op.clip(inputs[0], a_min=0, a_max=6)
+
     return _impl
 
+
 def _shape():
     def _impl(inputs, attr, params, mod):
         is_symbolic_shape = False
@@ -1315,13 +1362,14 @@ def _shape():
                 break
 
         if is_symbolic_shape:
-            ret = _op.shape_of(inputs[0], dtype='int32')
+            ret = _op.shape_of(inputs[0], dtype="int32")
         else:
-            ret = np.array(input_shape, dtype='int32')
+            ret = np.array(input_shape, dtype="int32")
         return ret
 
     return _impl
 
+
 def _fill():
     def _impl(inputs, attr, params, mod):
         try:
@@ -1329,32 +1377,39 @@ def _fill():
         except Exception:
             output_shape = inputs[0]
 
-        return _op.full(inputs[1], output_shape, attr['T'].name)
+        return _op.full(inputs[1], output_shape, attr["T"].name)
+
     return _impl
 
+
 def _lrn():
     def _impl(inputs, attr, params, mod):
         attr_new = {}
-        depth_radius = attr.get('depth_radius', 5)
+        depth_radius = attr.get("depth_radius", 5)
         size = (depth_radius * 2) + 1
-        attr_new['axis'] = 3 # Fix axis, NHWC format
-        attr_new['size'] = size
-        attr_new['bias'] = attr.get('bias', 1)
-        attr_new['alpha'] = attr.get('alpha', 1) * size
-        attr_new['beta'] = attr.get('beta', 0.5)
-        return AttrCvt(op_name='lrn')(inputs, attr_new)
+        attr_new["axis"] = 3  # Fix axis, NHWC format
+        attr_new["size"] = size
+        attr_new["bias"] = attr.get("bias", 1)
+        attr_new["alpha"] = attr.get("alpha", 1) * size
+        attr_new["beta"] = attr.get("beta", 0.5)
+        return AttrCvt(op_name="lrn")(inputs, attr_new)
+
     return _impl
 
+
 def _sum():
     def _impl(inputs, attr, params, mod):
         axis = _get_tuple_param(params, inputs[1])
         return AttrCvt(
-            op_name='sum',
-            extras={'axis': axis},
-            transforms={'keep_dims':'keepdims'},
-            ignores=['name', 'Tidx'])([inputs[0]], attr)
+            op_name="sum",
+            extras={"axis": axis},
+            transforms={"keep_dims": "keepdims"},
+            ignores=["name", "Tidx"],
+        )([inputs[0]], attr)
+
     return _impl
 
+
 def _reduce(op):
     def _impl(inputs, attr, params, mod):
         axis = _get_list_param(params, inputs[1])
@@ -1363,51 +1418,65 @@ def _reduce(op):
             axis = None
         return AttrCvt(
             op_name=op,
-            extras={'axis': axis},
-            transforms={'keep_dims':'keepdims'},
-            ignores=['name', 'Tidx'])([inputs[0]], attr)
+            extras={"axis": axis},
+            transforms={"keep_dims": "keepdims"},
+            ignores=["name", "Tidx"],
+        )([inputs[0]], attr)
+
     return _impl
 
+
 def _euclidean_norm():
     def _impl(inputs, attr, params, mod):
         axis = tuple(_get_list_param(params, inputs[1]))
-        keep_dims = bool(attr.get('keep_dims', False))
-        return _op.sqrt(_op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]),
-                                                axis, keep_dims), "float32"))
+        keep_dims = bool(attr.get("keep_dims", False))
+        return _op.sqrt(
+            _op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32")
+        )
+
     return _impl
 
+
 def _square():
     def _impl(inputs, attr, params, mod):
         return _op.multiply(inputs[0], inputs[0])
+
     return _impl
 
+
 def _gather():
     "GatherV2, Gather"
+
     def _impl(inputs, attr, params, mod):
         if len(inputs) > 2:
             axis = _get_num_param(params, inputs.pop(2))
         else:
             axis = 0
-        if int(attr.get('batch_dims', 0)) != 0:
-            raise tvm.error.OpAttributeUnImplemented(
-                'Attribute batch_dims is not supported')
+        if int(attr.get("batch_dims", 0)) != 0:
+            raise tvm.error.OpAttributeUnImplemented("Attribute batch_dims is not supported")
         new_input = inputs[0:2]
-        return AttrCvt(op_name="take",
-                       extras={'axis': tvm.tir.const(axis, 'int32')},
-                       ignores=['Tindices', 'Tparams', 'validate_indices',
-                                'Taxis', '_class', 'batch_dims'])(new_input, attr)
+        return AttrCvt(
+            op_name="take",
+            extras={"axis": tvm.tir.const(axis, "int32")},
+            ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class", "batch_dims"],
+        )(new_input, attr)
+
     return _impl
 
+
 def _gather_nd():
     """GatherNd"""
+
     def _impl(inputs, attr, params, mod):
         indices_dims = len(_infer_shape(inputs[1], mod))
-        indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1)))
-        return AttrCvt(op_name="gather_nd",
-                       ignores=['Tindices', 'Tparams',\
-                                'Taxis', '_class'])([inputs[0], indices], attr)
+        indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1)))
+        return AttrCvt(op_name="gather_nd", ignores=["Tindices", "Tparams", "Taxis", "_class"])(
+            [inputs[0], indices], attr
+        )
+
     return _impl
 
+
 def _stridedSlice():
     def _impl(inputs, attr, params, mod):
         """Strided Slice.
@@ -1419,11 +1488,11 @@ def _stridedSlice():
         end = _get_list_param(params, inputs[2])
         stride = _get_list_param(params, inputs[3])
 
-        begin_mask = int(attr.get('begin_mask', 0))
-        end_mask = int(attr.get('end_mask', 0))
-        ellipsis_mask = int(attr.get('ellipsis_mask', 0))
-        new_axis_mask = int(attr.get('new_axis_mask', 0))
-        shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
+        begin_mask = int(attr.get("begin_mask", 0))
+        end_mask = int(attr.get("end_mask", 0))
+        ellipsis_mask = int(attr.get("ellipsis_mask", 0))
+        new_axis_mask = int(attr.get("new_axis_mask", 0))
+        shrink_axis_mask = int(attr.get("shrink_axis_mask", 0))
         in_type = _infer_type(inputs[0], mod)
         data_shape = get_const_tuple(in_type.checked_type.shape)
         data_dim = len(data_shape)
@@ -1467,7 +1536,7 @@ def _stridedSlice():
             m_end = [0] * data_dim
             m_stride = [0] * data_dim
             fshape_indices = []
-            #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
+            # Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
             ellipsis_seen = False
             new_axes_after_ellipsis = 0
             for i in range(stride_dim):
@@ -1477,42 +1546,44 @@ def _stridedSlice():
                 if (mask & ellipsis_mask) != 0:
                     ellipsis_seen = True
             if not ellipsis_seen:
-                #Used later for extending the stride attributes in the below loop.
-                ellipsis_mask |= (1 << stride_dim)
+                # Used later for extending the stride attributes in the below loop.
+                ellipsis_mask |= 1 << stride_dim
                 stride_dim += 1
             final_index = 0
             for index in range(stride_dim):
                 mask = 1 << index
                 if mask & ellipsis_mask:
-                    #Identify the end index for applying ellipsis_mask
-                    to_index = min(((data_dim - (stride_dim-index)) + 1
-                                    + new_axes_after_ellipsis), data_dim)
+                    # Identify the end index for applying ellipsis_mask
+                    to_index = min(
+                        ((data_dim - (stride_dim - index)) + 1 + new_axes_after_ellipsis), data_dim
+                    )
                     for i in range(final_index, to_index):
                         m_begin[final_index] = 0
                         m_end[final_index] = data_shape[final_index]
                         m_stride[final_index] = 1
                         fshape_indices.append(final_index)
                         final_index += 1
-                elif mask &new_axis_mask:
+                elif mask & new_axis_mask:
                     fshape_indices.append(-1)
                 elif not mask & new_axis_mask:
                     if final_index == len(m_begin):
                         break
                     if mask & begin_mask:
-                        m_begin[final_index] = data_shape[final_index] \
-                            if stride[index] < 0 else 0
+                        m_begin[final_index] = data_shape[final_index] if stride[index] < 0 else 0
                     elif begin[index]:
                         m_begin[final_index] = begin[index]
                     if mask & end_mask:
-                        m_end[final_index] = 0 if stride[index] < 0 \
-                            else data_shape[final_index]
+                        m_end[final_index] = 0 if stride[index] < 0 else data_shape[final_index]
                     elif end[index]:
                         m_end[final_index] = end[index]
                     m_stride[final_index] = stride[index]
                     if mask & shrink_axis_mask:
-                        #Tensorflow make axis with shrink_axis_mask as dimension 1
-                        m_begin[final_index] = data_shape[final_index] + begin[index] \
-                            if begin[index] < 0 else begin[index]
+                        # Tensorflow make axis with shrink_axis_mask as dimension 1
+                        m_begin[final_index] = (
+                            data_shape[final_index] + begin[index]
+                            if begin[index] < 0
+                            else begin[index]
+                        )
                         m_end[final_index] = begin[index] + 1
                         m_stride[final_index] = 1
                         fshape_indices.append(-2)
@@ -1525,15 +1596,12 @@ def _stridedSlice():
         fshape_indices = None
         if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
             begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
-        out = _op.strided_slice(inputs[0],
-                                begin=begin,
-                                end=end,
-                                strides=stride)
+        out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
         out_shape = _infer_shape(out, mod=mod)
         if not fshape_indices:
             fshape_indices = range(len(out_shape))
 
-        #Create final output shape.
+        # Create final output shape.
         final_output = []
         for gather_index in fshape_indices:
             if gather_index == -1:
@@ -1559,36 +1627,44 @@ def _stridedSlice():
         else:
             ret = _op.reshape(out, newshape=tuple(final_output))
         return ret
+
     return _impl
 
+
 def _pad(name):
     def _impl(inputs, attr, params, mod):
         padlist = _get_param(params, inputs[1])
         paddings = tuple(tuple(l) for l in padlist)
-        attr['pad_width'] = paddings
-        attr['pad_value'] = 0
+        attr["pad_width"] = paddings
+        attr["pad_value"] = 0
         new_inputs = [inputs[0]]
-        if name == 'PadV2':
+        if name == "PadV2":
             constant_values = _get_num_param(params, inputs[2])
-            attr['pad_value'] = constant_values
+            attr["pad_value"] = constant_values
         return AttrCvt(
-            op_name='pad',
-            ignores=['Tpaddings'],)(new_inputs, attr)
+            op_name="pad",
+            ignores=["Tpaddings"],
+        )(new_inputs, attr)
+
     return _impl
 
+
 def _mirror_pad():
     def _impl(inputs, attr, params, mod):
         padlist = _get_param(params, inputs[1])
         paddings = tuple(tuple(l) for l in padlist)
-        attr['pad_width'] = paddings
-        mode = attr['mode'].decode('utf-8')
-        attr['mode'] = mode
+        attr["pad_width"] = paddings
+        mode = attr["mode"].decode("utf-8")
+        attr["mode"] = mode
         new_inputs = [inputs[0]]
         return AttrCvt(
-            op_name='mirror_pad',
-            ignores=['Tpaddings'],)(new_inputs, attr)
+            op_name="mirror_pad",
+            ignores=["Tpaddings"],
+        )(new_inputs, attr)
+
     return _impl
 
+
 def _transpose():
     def _impl(inputs, attr, params, mod):
         # If perm is not specified, axes is left empty,
@@ -1598,44 +1674,49 @@ def _transpose():
         except (IndexError, KeyError, AttributeError):
             axes = _infer_value(inputs[1], params, mod).asnumpy().tolist()
         return _op.transpose(inputs[0], axes=axes)
+
     return _impl
 
+
 def _where():
     def _impl(inputs, attr, params, mod):
         if len(inputs) == 1:
             return AttrCvt(op_name="argwhere")(inputs, attr)
         return AttrCvt(op_name="where")(inputs, attr)
+
     return _impl
 
+
 def _clip_by_value():
     def _impl(inputs, attr, params, mod):
         a_min = _get_num_param(params, inputs[1])
         a_max = _get_num_param(params, inputs[2])
         return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
+
     return _impl
 
+
 def _reverse_v2():
     def _impl(inputs, attr, params, mod):
         axis = _get_num_param(params, inputs[1])
-        return AttrCvt(
-            op_name="reverse",
-            ignores=['Tidx'],
-            extras={'axis': int(axis)})([inputs[0]], attr)
+        return AttrCvt(op_name="reverse", ignores=["Tidx"], extras={"axis": int(axis)})(
+            [inputs[0]], attr
+        )
+
     return _impl
 
+
 def _rank():
     def _impl(inputs, attr, params, mod):
         input_shape = _infer_shape(inputs[0], mod)
 
         name = attr["_node_name"]
-        params[name] = tvm.nd.array(np.array([len(input_shape)])
-                                    .astype("int32"))
-        return [_expr.var(name,
-                          shape=params[name].shape,
-                          dtype='int32')]
+        params[name] = tvm.nd.array(np.array([len(input_shape)]).astype("int32"))
+        return [_expr.var(name, shape=params[name].shape, dtype="int32")]
 
     return _impl
 
+
 def _range():
     def _impl(inputs, attr, params, mod):
         try:
@@ -1649,9 +1730,11 @@ def _range():
                 start = inputs[0]
 
         try:
-            limit = _get_param(params, inputs[1])[0] \
-                if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
-                else params.pop('Rank').asnumpy()[0]
+            limit = (
+                _get_param(params, inputs[1])[0]
+                if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant)
+                else params.pop("Rank").asnumpy()[0]
+            )
         except (IndexError, KeyError, AttributeError):
             try:
                 limit = _infer_value(inputs[1], params, mod).asnumpy().tolist()
@@ -1669,8 +1752,7 @@ def _range():
                 # Symbolic delta
                 delta = inputs[2]
 
-
-        dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
+        dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
         if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)):
             start = _expr.const(start)
         if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)):
@@ -1680,46 +1762,59 @@ def _range():
 
         return AttrCvt(
             op_name="arange",
-            ignores=['Tidx', '_class'],
-            extras={'start': start,
-                    'stop': limit,
-                    'step': delta,
-                    'dtype': dtype})([], attr)
+            ignores=["Tidx", "_class"],
+            extras={"start": start, "stop": limit, "step": delta, "dtype": dtype},
+        )([], attr)
+
     return _impl
 
+
 def _elu():
     def _impl(inputs, attr, params, mod):
-        dtype = attr['T'].name
+        dtype = attr["T"].name
         alpha = tvm.relay.const(-1.0, dtype)
-        return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
-                                   - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
+        return alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) + _op.nn.relu(
+            inputs[0]
+        )
+
     return _impl
 
+
 def _selu():
     def _impl(inputs, attr, params, mod):
-        dtype = attr['T'].name
+        dtype = attr["T"].name
         alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
         gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
-        return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype)
-                                            - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
+        return gamma * (
+            alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0]))
+            + _op.nn.relu(inputs[0])
+        )
+
     return _impl
 
+
 def _mean():
     def _impl(inputs, attr, params, mod):
         axis = _get_tuple_param(params, inputs[1])
-        return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
-                       transforms={'keep_dims': 'keepdims'},
-                       extras={'axis': axis})([inputs[0]], attr)
+        return AttrCvt(
+            op_name="mean",
+            ignores=["Tdim", "Tidx"],
+            transforms={"keep_dims": "keepdims"},
+            extras={"axis": axis},
+        )([inputs[0]], attr)
+
     return _impl
 
+
 def _broadcast(name):
     def _impl(inputs, attr, params, mod):
-        return AttrCvt(
-            op_name=name,
-            ignores=['name', 'incompatible_shape_error', 'Tidx']
-        )(inputs, attr)
+        return AttrCvt(op_name=name, ignores=["name", "incompatible_shape_error", "Tidx"])(
+            inputs, attr
+        )
+
     return _impl
 
+
 def _split(has_size_vector):
     # TF documentation https://www.tensorflow.org/api_docs/python/tf/split
     def _impl(inputs, attr, params, mod):
@@ -1741,51 +1836,58 @@ def _split(has_size_vector):
             else:
                 input_node_index = 1
                 input_axis_index = 0
-                indices_or_sections = attr['num_split']
+                indices_or_sections = attr["num_split"]
             input_node = inputs[input_node_index]
             axis_input_value = _get_num_param(params, inputs[input_axis_index])
         except (IndexError, KeyError, AttributeError):
             raise TypeError(
                 "Unsupported argument for split: `axis` and `num_or_size_splits` "
-                "should be constants")
-        return _op.split(input_node,
-                         indices_or_sections=indices_or_sections,
-                         axis=int(axis_input_value))
+                "should be constants"
+            )
+        return _op.split(
+            input_node, indices_or_sections=indices_or_sections, axis=int(axis_input_value)
+        )
+
     return _impl
 
+
 def _unpack():
     def _impl(inputs, attr, params, mod):
         input_node = inputs[0]
-        axis = attr['axis']
+        axis = attr["axis"]
         input_shape = _infer_shape(input_node, mod)
         axis_length = input_shape[axis]
         if axis_length < 0:
             raise TypeError("Unstack with unknown axis length")
-        splitted = _op.split(input_node,
-                             indices_or_sections=axis_length,
-                             axis=axis)
+        splitted = _op.split(input_node, indices_or_sections=axis_length, axis=axis)
         axis = [axis]
         return _expr.TupleWrapper(
-            _expr.Tuple([_op.squeeze(split_item, axis=axis) \
-            for split_item in splitted]), len(splitted))
+            _expr.Tuple([_op.squeeze(split_item, axis=axis) for split_item in splitted]),
+            len(splitted),
+        )
+
     return _impl
 
+
 def _softmax():
     def _impl(inputs, attr, params, mod):
-        return AttrCvt(op_name='softmax',
-                       transforms={'axis': ('axis', 1)})([inputs[0]], attr)
+        return AttrCvt(op_name="softmax", transforms={"axis": ("axis", 1)})([inputs[0]], attr)
+
     return _impl
 
+
 def _softplus():
     # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
     def _impl(inputs, attr, params, mod):
-        exp_out = AttrCvt('exp')(inputs, attr)
-        inputs.append(tvm.relay.const(1, attr['T'].name))
-        rh = tvm.relay.const(1, attr['T'].name)
-        add_out = get_relay_op('add')(exp_out, rh)
-        return get_relay_op('log')(add_out)
+        exp_out = AttrCvt("exp")(inputs, attr)
+        inputs.append(tvm.relay.const(1, attr["T"].name))
+        rh = tvm.relay.const(1, attr["T"].name)
+        add_out = get_relay_op("add")(exp_out, rh)
+        return get_relay_op("log")(add_out)
+
     return _impl
 
+
 def _topk():
     def _impl(inputs, attr, params, mod):
         k_input = inputs.pop(1)
@@ -1799,33 +1901,45 @@ def _topk():
         if isinstance(k, int):
             if k < 1:
                 raise tvm.error.OpAttributeInvalid(
-                    'Attribute k must be positive in operator TopKV2')
+                    "Attribute k must be positive in operator TopKV2"
+                )
             k = _expr.const(k)
-        if attr['sorted'] is False:
+        if attr["sorted"] is False:
             raise tvm.error.OpAttributeUnImplemented(
-                'Attribute sorted=False is not supported in operator TopKV2')
-        return AttrCvt(op_name='topk',
-                       ignores=['sorted'],
-                       extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})([inputs[0]], attr)
+                "Attribute sorted=False is not supported in operator TopKV2"
+            )
+        return AttrCvt(
+            op_name="topk",
+            ignores=["sorted"],
+            extras={"k": k, "is_ascend": False, "dtype": "int32"},
+        )([inputs[0]], attr)
+
     return _impl
 
+
 def _floordiv():
     def _impl(inputs, attr, params, mod):
         assert len(inputs) == 2
-        return AttrCvt('floor_divide')(inputs, attr)
+        return AttrCvt("floor_divide")(inputs, attr)
+
     return _impl
 
+
 def _floormod():
     def _impl(inputs, attr, params, mod):
         assert len(inputs) == 2
-        return AttrCvt('floor_mod')(inputs, attr)
+        return AttrCvt("floor_mod")(inputs, attr)
+
     return _impl
 
+
 def _logical(name):
     def _impl(inputs, attr, params, mod):
         return AttrCvt(op_name=name)(inputs, attr)
+
     return _impl
 
+
 def _space_to_batch_nd():
     def _impl(inputs, attr, params, mod):
         input_node = inputs[0]
@@ -1860,17 +1974,22 @@ def _space_to_batch_nd():
         # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
         # block_shape + [batch] + [padded_shape[1] / block_shape[0], ...,
         # padded_shape[M] / block_shape[M-1]] + remaining_shape
-        axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
-               list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
+        axes = (
+            [2 * i + 2 for i in range(M)]
+            + [0]
+            + [2 * i + 1 for i in range(M)]
+            + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
+        )
         permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
         permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod)
         # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
         # producing an output tensor of shape:
         # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
         # padded_shape[M] / block_shape[M-1]] + remaining_shape
-        shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
-        reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded,
-                                                              newshape=shape2)
+        shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :]
+        reshaped_permuted_reshaped_padded = tvm.relay.reshape(
+            permuted_reshaped_padded, newshape=shape2
+        )
         return reshaped_permuted_reshaped_padded
 
     return _impl
@@ -1904,8 +2023,11 @@ def _batch_to_space_nd():
         # Permute dimensions of reshaped to produce permuted of shape
         # [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
         # input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
-        axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \
-            list(range(2 * M + 1, len(shape1)))
+        axes = (
+            [M]
+            + [axis for i in range(M) for axis in [M + i + 1, i]]
+            + list(range(2 * M + 1, len(shape1)))
+        )
         permuted = tvm.relay.transpose(reshaped, axes=axes)
         # Reshape permuted to produce reshaped_permuted of shape
         # [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
@@ -1919,13 +2041,13 @@ def _batch_to_space_nd():
         #  input_shape[M+1], ..., input_shape[N-1]]
         reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod)
         cropped = reshaped_permuted
-        for axis in range(1, M+1):
+        for axis in range(1, M + 1):
             crop = crops[axis - 1]
             if crop != [0, 0]:
                 indices = tvm.relay.arange(
                     _expr.const(crop[0]),
                     _expr.const(reshaped_permuted_shape[axis] - crop[1]),
-                    dtype='int32'
+                    dtype="int32",
                 )
                 cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
 
@@ -1933,55 +2055,70 @@ def _batch_to_space_nd():
 
     return _impl
 
+
 def _atan2():
     def _impl(inputs, attr, params, mod):
         divide = _elemwise("divide")(inputs, attr, params, mod)
         return get_relay_op("atan")(divide)
+
     return _impl
 
+
 def _prod():
     def _impl(inputs, attr, params, mod):
         axis = _get_num_param(params, inputs[1])
-        keepdims = attr['keep_dims']
+        keepdims = attr["keep_dims"]
         return _op.prod(inputs[0], int(axis), keepdims=keepdims)
+
     return _impl
 
+
 def _log1p():
     # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
     def _impl(inputs, attr, params, mod):
-        one = tvm.relay.const(1, attr['T'].name)
-        add_out = get_relay_op('add')(inputs[0], one)
-        return get_relay_op('log')(add_out)
+        one = tvm.relay.const(1, attr["T"].name)
+        add_out = get_relay_op("add")(inputs[0], one)
+        return get_relay_op("log")(add_out)
+
     return _impl
 
+
 def _one_hot():
     def _impl(inputs, attr, params, mod):
         depth = int(_get_num_param(params, inputs[1]))
-        dtype = attr['T'].name
+        dtype = attr["T"].name
 
         on_value = _get_num_param(params, inputs[2])
         off_value = _get_num_param(params, inputs[3])
-        new_inputs = [inputs[0],
-                      tvm.relay.const(on_value, dtype),
-                      tvm.relay.const(off_value, dtype)]
-        return AttrCvt('one_hot',
-                       ignores=['TI'],
-                       extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
+        new_inputs = [
+            inputs[0],
+            tvm.relay.const(on_value, dtype),
+            tvm.relay.const(off_value, dtype),
+        ]
+        return AttrCvt("one_hot", ignores=["TI"], extras={"depth": depth, "dtype": dtype})(
+            new_inputs, attr
+        )
+
     return _impl
 
+
 def _squared_difference():
     def _impl(inputs, attr, params, mod):
         difference = _op.subtract(inputs[0], inputs[1])
         return _op.multiply(difference, difference)
+
     return _impl
 
+
 def _size():
     def _impl(inputs, attr, params, mod):
         new_attr = attr
-        new_attr['out_type'] = attr['out_type'].name
-        return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
+        new_attr["out_type"] = attr["out_type"].name
+        return AttrCvt("ndarray_size", transforms={"out_type": "dtype"})(inputs, new_attr)
+
     return _impl
 
+
 def _add_n():
     def _impl(inputs, attr, params, mod):
         if not isinstance(inputs, tuple):
@@ -1990,9 +2127,11 @@ def _add_n():
         _res = inputs[0]
         for each in inputs[1:]:
             _res = _op.add(_res, each)
-        return  _res
+        return _res
+
     return _impl
 
+
 def _LSTMBlockCell():
     def _impl(inputs, attr, params, mod):
         """LSTM Block cell.
@@ -2023,42 +2162,48 @@ def _LSTMBlockCell():
         in_state_h = inputs[2]
         in_weight = inputs[3]
         in_bias = inputs[7]
-        forget_bias = attr.pop('forget_bias')
+        forget_bias = attr.pop("forget_bias")
         input_shape = _infer_shape(inputs[0], mod)
         weight_shape = _infer_shape(inputs[3], mod)
         batch_size, input_size = input_shape[0], input_shape[1]
         num_hidden_layers = weight_shape[1]
 
-        in_data = _op.reshape(in_data,
-                              newshape=(batch_size, input_size))
+        in_data = _op.reshape(in_data, newshape=(batch_size, input_size))
         ixh = _op.concatenate([in_data, in_state_h], axis=1)
         in_weight = _op.transpose(in_weight, axes=None)
-        gates = _op.nn.dense(ixh, in_weight,
-                             units=num_hidden_layers)
+        gates = _op.nn.dense(ixh, in_weight, units=num_hidden_layers)
         gates_bias = _op.add(gates, in_bias)
         gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
         in_gate = _op.sigmoid(gate_list[0])
         in_transform = _op.tanh(gate_list[1])
-        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name))
+        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr["T"].name))
         forget_gate = _op.sigmoid(forget_gate)
         out_gate = _op.sigmoid(gate_list[3])
-        next_c = _op.add(_op.multiply(forget_gate, in_state_c),
-                         _op.multiply(in_gate, in_transform))
+        next_c = _op.add(_op.multiply(forget_gate, in_state_c), _op.multiply(in_gate, in_transform))
         co = _op.tanh(next_c)
         next_h = out_gate * co
 
         return tvm.relay.TupleWrapper(
-            tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7)
+            tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7
+        )
 
     return _impl
 
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
 # Operators that get pruned away when the complete graph is frozen.
 # These operators are not needed for inference.
-_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
-                                 'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp']
+_freezed_graph_pruned_op_list = [
+    "ReadVariableOp",
+    "ResourceGather",
+    "Variable",
+    "VariableV2",
+    "VarHandleOp",
+    "Assign",
+    "AssignVariableOp",
+]
 
 
 # _convert_map defines maps of name to converter functor(callable)
@@ -2067,171 +2212,172 @@ _freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
 # for 1 to N mapping(composed), use custom callable functions
 # for N to 1 mapping, currently not supported(?)
 _convert_map = {
-    'Abs'                               : AttrCvt('abs'),
-    'Acos'                              : AttrCvt('acos'),
-    'Acosh'                             : AttrCvt('acosh'),
-    'Add'                               : _elemwise('add'),
-    'AddN'                              : _add_n(),
-    'AddV2'                             : _elemwise('add'),
-    'All'                               : _reduce('all'),
-    'Any'                               : _reduce('any'),
-    'ArgMax'                            : _argx(_op.argmax, 'argmax'),
-    'ArgMin'                            : _argx(_op.argmin, 'argmin'),
-    'Asin'                              : AttrCvt('asin'),
-    'Asinh'                             : AttrCvt('asinh'),
-    'Assert'                            : _assert(),
-    'Atan'                              : AttrCvt('atan'),
-    'Atanh'                             : AttrCvt('atanh'),
-    'Atan2'                             : _atan2(),
-    'AvgPool'                           : _pooling('avg_pool'),
-    'AvgPool3D'                         : _pool3d('avg_pool3d'),
-    'BatchMatMul'                       : _batch_matmul(),
-    'BatchMatMulV2'                     : _batch_matmul(),
-    'BatchNormWithGlobalNormalization'  : _batch_norm(),
-    'BatchToSpaceND'                    : _batch_to_space_nd(),
-    'BiasAdd'                           : _bias_add(),
-    'BroadcastTo'                       : _broadcast_to(),
-    'Cast'                              : _cast(),
-    'Ceil'                              : AttrCvt('ceil'),
-    'CheckNumerics'                     : _check_numerics(),
-    'ClipByValue'                       : _clip_by_value(),
-    'Concat'                            : _concat(),
-    'ConcatV2'                          : _concatV2(),
-    'Conv2D'                            : _conv('conv'),
-    'Conv2DBackpropInput'               : _conv('conv_transpose'),
-    'Conv3D'                            : _conv3d('conv'),
-    'Conv3DBackpropInputV2'             : _conv3d('conv_transpose'),
-    'Cos'                               : AttrCvt('cos'),
-    'Cosh'                              : AttrCvt('cosh'),
-    'CropAndResize'                     : _crop_and_resize(),
-    'DecodeJpeg'                        : _decode_image(),
-    'DepthToSpace'                      : _depth_to_space(),
-    'DepthwiseConv2dNative'             : _conv('depthwise'),
-    'Dilation2D'                        : _dilation2d(),
-    'Elu'                               : _elu(),
-    'Equal'                             : _broadcast('equal'),
-    'Erf'                               : AttrCvt('erf'),
-    'EuclideanNorm'                     : _euclidean_norm(),
-    'Exp'                               : AttrCvt('exp'),
-    'ExpandDims'                        : _expand_dims(),
-    'Fill'                              : _fill(),
-    'Floor'                             : AttrCvt('floor'),
-    'FloorDiv'                          : _floordiv(),
-    'FloorMod'                          : _floormod(),
-    'FusedBatchNorm'                    : _fused_batch_norm(),
-    'FusedBatchNormV2'                  : _fused_batch_norm(),
-    'FusedBatchNormV3'                  : _fused_batch_norm(),
-    'Gather'                            : _gather(),
-    'GatherNd'                          : _gather_nd(),
-    'GatherV2'                          : _gather(),
-    'Greater'                           : _broadcast('greater'),
-    'GreaterEqual'                      : _broadcast('greater_equal'),
-    'Identity'                          : _identity(),
-    'IsFinite'                          : AttrCvt('isfinite'),
-    'IsInf'                             : AttrCvt('isinf'),
-    'LeakyRelu'                         : AttrCvt('leaky_relu'),
-    'LeftShift'                         : AttrCvt('left_shift'),
-    'Less'                              : _broadcast('less'),
-    'LessEqual'                         : _broadcast('less_equal'),
-    'Log'                               : AttrCvt('log'),
-    'Log1p'                             : _log1p(),
-    'LogicalAnd'                        : _logical('logical_and'),
-    'LogicalNot'                        : _logical('logical_not'),
-    'LogicalOr'                         : _logical('logical_or'),
-    'LogSoftmax'                        : AttrCvt('log_softmax'),
-    'LRN'                               : _lrn(),
-    'LSTMBlockCell'                     : _LSTMBlockCell(),
-    'MatMul'                            : _matmul(),
-    'Max'                               : _reduce('max'),
-    'Maximum'                           : _elemwise('maximum'),
-    'MaxPool'                           : _pooling('max_pool'),
-    'MaxPool3D'                         : _pool3d('max_pool3d'),
-    'Mean'                              : _mean(),
-    'Min'                               : _reduce('min'),
-    'Minimum'                           : _elemwise('minimum'),
-    'MirrorPad'                         : _mirror_pad(),
-    'Mod'                               : _elemwise('mod'),
-    'Mul'                               : _elemwise('multiply'),
-    'Neg'                               : AttrCvt('negative'),
-    'NonMaxSuppressionV2'               : _nms(),
-    'NonMaxSuppressionV3'               : _nms(),
-    'NonMaxSuppressionV4'               : _nms(),
-    'NoOp'                              : _no_op(),
-    'NotEqual'                          : _broadcast('not_equal'),
-    'OneHot'                            : _one_hot(),
-    'Pack'                              : _pack(),
-    'Pad'                               : _pad('Pad'),
-    'PadV2'                             : _pad('PadV2'),
-    'Pow'                               : _elemwise('power'),
-    'Prod'                              : _prod(),
-    'Range'                             : _range(),
-    'Rank'                              : _rank(),
-    'RealDiv'                           : _elemwise('divide'),
-    'Relu'                              : AttrCvt('relu'),
-    'Relu6'                             : _relu6(),
-    'Reshape'                           : _reshape(),
-    'ResizeBicubic'                     : _resize('bilinear'),
-    'ResizeBilinear'                    : _resize('bilinear'),
-    'ResizeNearestNeighbor'             : _resize('nearest_neighbor'),
-    'ReverseV2'                         : _reverse_v2(),
-    'RightShift'                        : AttrCvt('right_shift'),
-    'Round'                             : AttrCvt('round'),
-    'Rsqrt'                             : _rsqrt(),
-    'Select'                            : _where(),
-    'Selu'                              : _selu(),
-    'Shape'                             : _shape(),
-    'Sigmoid'                           : AttrCvt('sigmoid'),
-    'Sign'                              : AttrCvt('sign'),
-    'Sin'                               : AttrCvt('sin'),
-    'Sinh'                              : AttrCvt('sinh'),
-    'Size'                              : _size(),
-    'Slice'                             : _slice(),
-    'Softmax'                           : _softmax(),
-    'Softplus'                          : _softplus(),
-    'SpaceToBatchND'                    : _space_to_batch_nd(),
-    'SpaceToDepth'                      : _space_to_depth(),
-    'Split'                             : _split(False),
-    'SplitV'                            : _split(True),
-    'Sqrt'                              : AttrCvt('sqrt'),
-    'Square'                            : _square(),
-    'SquaredDifference'                 : _squared_difference(),
-    'Squeeze'                           : _squeeze(),
-    'StopGradient'                      : _identity(),
-    'StridedSlice'                      : _stridedSlice(),
-    'Sub'                               : _elemwise('subtract'),
-    'Sum'                               : _sum(),
-    'Tan'                               : AttrCvt('tan'),
-    'Tanh'                              : AttrCvt('tanh'),
-    'TensorArrayConcatV3'               : _tensor_array_concat(),
-    'TensorArrayGatherV3'               : _tensor_array_gather(),
-    'TensorArrayReadV3'                 : _tensor_array_read(),
-    'TensorArrayScatterV3'              : _tensor_array_scatter(),
-    'TensorArraySizeV3'                 : _tensor_array_size(),
-    'TensorArraySplitV3'                : _tensor_array_split(),
-    'TensorArrayV3'                     : _tensor_array(),
-    'TensorArrayWriteV3'                : _tensor_array_write(),
-    'Tile'                              : _tile(),
-    'TopKV2'                            : _topk(),
-    'Transpose'                         : _transpose(),
-    'TruncateMod'                       : _elemwise('mod'),
-    'Unpack'                            : _unpack(),
-    'UnravelIndex'                      : _unravel_index(),
-    'Where'                             : _where(),
-    'ZerosLike'                         : AttrCvt('zeros_like'),
+    "Abs": AttrCvt("abs"),
+    "Acos": AttrCvt("acos"),
+    "Acosh": AttrCvt("acosh"),
+    "Add": _elemwise("add"),
+    "AddN": _add_n(),
+    "AddV2": _elemwise("add"),
+    "All": _reduce("all"),
+    "Any": _reduce("any"),
+    "ArgMax": _argx(_op.argmax, "argmax"),
+    "ArgMin": _argx(_op.argmin, "argmin"),
+    "Asin": AttrCvt("asin"),
+    "Asinh": AttrCvt("asinh"),
+    "Assert": _assert(),
+    "Atan": AttrCvt("atan"),
+    "Atanh": AttrCvt("atanh"),
+    "Atan2": _atan2(),
+    "AvgPool": _pooling("avg_pool"),
+    "AvgPool3D": _pool3d("avg_pool3d"),
+    "BatchMatMul": _batch_matmul(),
+    "BatchMatMulV2": _batch_matmul(),
+    "BatchNormWithGlobalNormalization": _batch_norm(),
+    "BatchToSpaceND": _batch_to_space_nd(),
+    "BiasAdd": _bias_add(),
+    "BroadcastTo": _broadcast_to(),
+    "Cast": _cast(),
+    "Ceil": AttrCvt("ceil"),
+    "CheckNumerics": _check_numerics(),
+    "ClipByValue": _clip_by_value(),
+    "Concat": _concat(),
+    "ConcatV2": _concatV2(),
+    "Conv2D": _conv("conv"),
+    "Conv2DBackpropInput": _conv("conv_transpose"),
+    "Conv3D": _conv3d("conv"),
+    "Conv3DBackpropInputV2": _conv3d("conv_transpose"),
+    "Cos": AttrCvt("cos"),
+    "Cosh": AttrCvt("cosh"),
+    "CropAndResize": _crop_and_resize(),
+    "DecodeJpeg": _decode_image(),
+    "DepthToSpace": _depth_to_space(),
+    "DepthwiseConv2dNative": _conv("depthwise"),
+    "Dilation2D": _dilation2d(),
+    "Elu": _elu(),
+    "Equal": _broadcast("equal"),
+    "Erf": AttrCvt("erf"),
+    "EuclideanNorm": _euclidean_norm(),
+    "Exp": AttrCvt("exp"),
+    "ExpandDims": _expand_dims(),
+    "Fill": _fill(),
+    "Floor": AttrCvt("floor"),
+    "FloorDiv": _floordiv(),
+    "FloorMod": _floormod(),
+    "FusedBatchNorm": _fused_batch_norm(),
+    "FusedBatchNormV2": _fused_batch_norm(),
+    "FusedBatchNormV3": _fused_batch_norm(),
+    "Gather": _gather(),
+    "GatherNd": _gather_nd(),
+    "GatherV2": _gather(),
+    "Greater": _broadcast("greater"),
+    "GreaterEqual": _broadcast("greater_equal"),
+    "Identity": _identity(),
+    "IsFinite": AttrCvt("isfinite"),
+    "IsInf": AttrCvt("isinf"),
+    "LeakyRelu": AttrCvt("leaky_relu"),
+    "LeftShift": AttrCvt("left_shift"),
+    "Less": _broadcast("less"),
+    "LessEqual": _broadcast("less_equal"),
+    "Log": AttrCvt("log"),
+    "Log1p": _log1p(),
+    "LogicalAnd": _logical("logical_and"),
+    "LogicalNot": _logical("logical_not"),
+    "LogicalOr": _logical("logical_or"),
+    "LogSoftmax": AttrCvt("log_softmax"),
+    "LRN": _lrn(),
+    "LSTMBlockCell": _LSTMBlockCell(),
+    "MatMul": _matmul(),
+    "Max": _reduce("max"),
+    "Maximum": _elemwise("maximum"),
+    "MaxPool": _pooling("max_pool"),
+    "MaxPool3D": _pool3d("max_pool3d"),
+    "Mean": _mean(),
+    "Min": _reduce("min"),
+    "Minimum": _elemwise("minimum"),
+    "MirrorPad": _mirror_pad(),
+    "Mod": _elemwise("mod"),
+    "Mul": _elemwise("multiply"),
+    "Neg": AttrCvt("negative"),
+    "NonMaxSuppressionV2": _nms(),
+    "NonMaxSuppressionV3": _nms(),
+    "NonMaxSuppressionV4": _nms(),
+    "NoOp": _no_op(),
+    "NotEqual": _broadcast("not_equal"),
+    "OneHot": _one_hot(),
+    "Pack": _pack(),
+    "Pad": _pad("Pad"),
+    "PadV2": _pad("PadV2"),
+    "Pow": _elemwise("power"),
+    "Prod": _prod(),
+    "Range": _range(),
+    "Rank": _rank(),
+    "RealDiv": _elemwise("divide"),
+    "Relu": AttrCvt("relu"),
+    "Relu6": _relu6(),
+    "Reshape": _reshape(),
+    "ResizeBicubic": _resize("bilinear"),
+    "ResizeBilinear": _resize("bilinear"),
+    "ResizeNearestNeighbor": _resize("nearest_neighbor"),
+    "ReverseV2": _reverse_v2(),
+    "RightShift": AttrCvt("right_shift"),
+    "Round": AttrCvt("round"),
+    "Rsqrt": _rsqrt(),
+    "Select": _where(),
+    "Selu": _selu(),
+    "Shape": _shape(),
+    "Sigmoid": AttrCvt("sigmoid"),
+    "Sign": AttrCvt("sign"),
+    "Sin": AttrCvt("sin"),
+    "Sinh": AttrCvt("sinh"),
+    "Size": _size(),
+    "Slice": _slice(),
+    "Softmax": _softmax(),
+    "Softplus": _softplus(),
+    "SpaceToBatchND": _space_to_batch_nd(),
+    "SpaceToDepth": _space_to_depth(),
+    "Split": _split(False),
+    "SplitV": _split(True),
+    "Sqrt": AttrCvt("sqrt"),
+    "Square": _square(),
+    "SquaredDifference": _squared_difference(),
+    "Squeeze": _squeeze(),
+    "StopGradient": _identity(),
+    "StridedSlice": _stridedSlice(),
+    "Sub": _elemwise("subtract"),
+    "Sum": _sum(),
+    "Tan": AttrCvt("tan"),
+    "Tanh": AttrCvt("tanh"),
+    "TensorArrayConcatV3": _tensor_array_concat(),
+    "TensorArrayGatherV3": _tensor_array_gather(),
+    "TensorArrayReadV3": _tensor_array_read(),
+    "TensorArrayScatterV3": _tensor_array_scatter(),
+    "TensorArraySizeV3": _tensor_array_size(),
+    "TensorArraySplitV3": _tensor_array_split(),
+    "TensorArrayV3": _tensor_array(),
+    "TensorArrayWriteV3": _tensor_array_write(),
+    "Tile": _tile(),
+    "TopKV2": _topk(),
+    "Transpose": _transpose(),
+    "TruncateMod": _elemwise("mod"),
+    "Unpack": _unpack(),
+    "UnravelIndex": _unravel_index(),
+    "Where": _where(),
+    "ZerosLike": AttrCvt("zeros_like"),
 }
 
 # An internal list to contain all the control flow primitives used in Tensorflow
 # 1.x.
-_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
+_control_flow_nodes = ["Merge", "Switch", "NextIteration", "Exit", "Enter", "LoopCond"]
 
 # A map to record tensor array write ops and input ta/tensor indices
 # Value is (index of tensor array, index of written node)
 _tensor_array_write_ops = {
-    "TensorArrayWrite"   : (3, 2),
-    "TensorArrayScatter" : (0, 2),
-    "TensorArraySplit"   : (0, 1),
+    "TensorArrayWrite": (3, 2),
+    "TensorArrayScatter": (0, 2),
+    "TensorArraySplit": (0, 1),
 }
 
+
 def is_tensor_array_constuctor(tf_node):
     """Check whether is tensor array constructor node."""
     is_ta = False
@@ -2240,10 +2386,11 @@ def is_tensor_array_constuctor(tf_node):
         is_ta = tf_node.op[len(ta_start)].isnumeric()
     return is_ta
 
+
 def find_parent_loop_name(node_name, while_loop_name_set):
     """Find name of direct parent while loop."""
     ploop_name = ""
-    name_prefix = node_name.rsplit('/', 1)[0]
+    name_prefix = node_name.rsplit("/", 1)[0]
     if name_prefix.startswith("^"):
         name_prefix = name_prefix[1:]
     for lname in while_loop_name_set:
@@ -2255,6 +2402,7 @@ def find_parent_loop_name(node_name, while_loop_name_set):
 
     return ploop_name
 
+
 def _in_while_loop(control_flow_node_map, op_name):
     """
     Check if a given control flow operator is part of a while loop execution
@@ -2277,8 +2425,8 @@ def _in_while_loop(control_flow_node_map, op_name):
         Return true if the operator is in a while loop execution frame,
     otherwise, return false.
     """
-    return op_name in control_flow_node_map and \
-            "LoopCond" in control_flow_node_map[op_name]
+    return op_name in control_flow_node_map and "LoopCond" in control_flow_node_map[op_name]
+
 
 class RewriteSubgraph(ExprMutator):
     """
@@ -2289,6 +2437,7 @@ class RewriteSubgraph(ExprMutator):
     rewrite_map : Dict[expr, expr]
         A dictionay contains a set of expr to var mapping.
     """
+
     def __init__(self, rewrite_map):
         ExprMutator.__init__(self)
         self.rewrite_map = rewrite_map
@@ -2298,10 +2447,12 @@ class RewriteSubgraph(ExprMutator):
             return self.rewrite_map[expr]
         return super().visit(expr)
 
+
 def rewrite_subgraph(expr, rewrites):
     """Rewrite loop body."""
     return RewriteSubgraph(rewrites).visit(expr)
 
+
 class Branch:
     """A class contains the components that are used to build up a Relay if
     node.
@@ -2360,6 +2511,7 @@ class Branch:
           }
         }
     """
+
     def __init__(self):
         self._if = None
         self.cond = None
@@ -2382,6 +2534,7 @@ class Branch:
             self._if = self._if_node()
         return self._if
 
+
 class VarChecker(ExprVisitor):
     """Check whether a Variable is used in loop body.
 
@@ -2390,6 +2543,7 @@ class VarChecker(ExprVisitor):
     var : relay.expr.Var
         Relay Variable to be checked.
     """
+
     def __init__(self, var):
         ExprVisitor.__init__(self)
         self._var = var
@@ -2400,6 +2554,7 @@ class VarChecker(ExprVisitor):
             self.used = True
         super().visit(expr)
 
+
 class Loop:
     """
     A class contains the components that are used to build up a Relay
@@ -2447,6 +2602,7 @@ class Loop:
           %6
         }
     """
+
     def __init__(self, mod, loop_name, lvar2expr):
         self.cond = None
         self.body = []
@@ -2463,7 +2619,7 @@ class Loop:
         `while_loop` construct.
         """
         bind_map = {}
-        wl = tvm.relay.var('while_loop')
+        wl = tvm.relay.var("while_loop")
         sb = tvm.relay.scope_builder.ScopeBuilder()
 
         lv_list = []
@@ -2529,10 +2685,11 @@ class Loop:
 
 
 class GraphProto(object):
-    """ A helper class for handling relay graph copying from Tensorflow GraphDef.
+    """A helper class for handling relay graph copying from Tensorflow GraphDef.
     Definition:
         https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
     """
+
     def __init__(self):
         self._nodes = {}
         self._tf_node_map = {}
@@ -2597,8 +2754,7 @@ class GraphProto(object):
         try:
             from tensorflow.python.framework import tensor_util
         except ImportError as e:
-            raise ImportError(
-                "Unable to import tensorflow which is required {}".format(e))
+            raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
         missing_operators = self._parse_import_prerequisites(graph)
         control_flow_nodes = []
@@ -2612,53 +2768,63 @@ class GraphProto(object):
         if missing_operators:
             freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
             if freezed_ops:
-                raise Exception("Graph is not frozen. Provide a frozen graph. "
-                                "Found operators {}".format(freezed_ops))
+                raise Exception(
+                    "Graph is not frozen. Provide a frozen graph. "
+                    "Found operators {}".format(freezed_ops)
+                )
 
             raise NotImplementedError(
-                "The following operators are not implemented: {}".format(missing_operators))
+                "The following operators are not implemented: {}".format(missing_operators)
+            )
 
         for node in graph.node:
-            node_name_prefix = node.name.rsplit('/', 1)[0]
+            node_name_prefix = node.name.rsplit("/", 1)[0]
             self._control_flow_node_map[node_name_prefix].add(node.op)
             self._tf_node_map[node.name] = node
 
             # Parse output_shapes attribute
             parsed_attr = self._parse_attr(node.attr)
-            if '_output_shapes' in parsed_attr:
-                self._output_shapes[node.name] = \
-                    [tensor_util.TensorShapeProtoToList(tshape) \
-                     for tshape in parsed_attr['_output_shapes']]
+            if "_output_shapes" in parsed_attr:
+                self._output_shapes[node.name] = [
+                    tensor_util.TensorShapeProtoToList(tshape)
+                    for tshape in parsed_attr["_output_shapes"]
+                ]
             else:
                 self._output_shapes[node.name] = [None]
 
             # Parse placeholder and const here since input shape info is required.
-            if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
+            if node.op == "Placeholder" or node.op == "PlaceholderWithDefault":
                 # Give priority to user argument.
                 if shape and node.name in shape:
                     self._input_shapes[node.name] = list(shape[node.name])
                 else:
-                    self._input_shapes[node.name] = \
-                        tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
+                    self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
+                        node.attr["shape"].shape
+                    )
                     for idx, dim in enumerate(self._input_shapes[node.name]):
                         if dim < 0:
                             self._input_shapes[node.name][idx] = Any()
 
                 self._output_shapes[node.name] = [self._input_shapes[node.name]]
                 attr = self._parse_attr(node.attr)
-                self._nodes[node.name] = [_expr.var(node.name,
-                                                    shape=self._input_shapes[node.name],
-                                                    dtype=attr['dtype'].name)]
+                self._nodes[node.name] = [
+                    _expr.var(
+                        node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name
+                    )
+                ]
 
                 # Ignore user's input shape for Non placeholder
-            elif node.op == 'Const':
-                tensor_value = node.attr['value'].tensor
-                self._input_shapes[node.name] = \
-                    tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
+            elif node.op == "Const":
+                tensor_value = node.attr["value"].tensor
+                self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
+                    tensor_value.tensor_shape
+                )
                 self._output_shapes[node.name] = [self._input_shapes[node.name]]
                 if shape and node.name in shape:
-                    warnings.warn("Ignore the passed shape. Shape in graphdef "
-                                  "will be used for operator %s." % node.name)
+                    warnings.warn(
+                        "Ignore the passed shape. Shape in graphdef "
+                        "will be used for operator %s." % node.name
+                    )
                 for key, value in node.attr.items():
                     self._parse_param(key, value, node.name, self._in_shape)
             elif node.op in _control_flow_nodes:
@@ -2778,7 +2944,7 @@ class GraphProto(object):
         return func
 
     def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
-        """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function
+        """Wrapper to _get_relay_func which converts Tensorflow graph to Relay function
         which is used as main function for the Relay module
         """
         func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
@@ -2786,26 +2952,29 @@ class GraphProto(object):
         return self._mod, self._params
 
     def _parse_import_prerequisites(self, graph):
-        """ Calculate the named preconditions from TensorFlow `graph`.
-            Return prerequisites for parsing:
-            a. Set of operator names which don't have their mapping in TVM, i.e.
-                which are not supported
+        """Calculate the named preconditions from TensorFlow `graph`.
+        Return prerequisites for parsing:
+        a. Set of operator names which don't have their mapping in TVM, i.e.
+            which are not supported
         """
         missing_operators = set()
         from tensorflow.python.framework import op_def_registry
+
         for node in graph.node:
-            getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\
-                        "_registered_ops") else op_def_registry.get
+            getOpDef = (
+                op_def_registry._registered_ops.get
+                if hasattr(op_def_registry, "_registered_ops")
+                else op_def_registry.get
+            )
             op_def = getOpDef(node.op)
-            if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
+            if node.op == "Placeholder" or node.op == "PlaceholderWithDefault":
                 pass
             elif node.op == "Const":
                 pass
             elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
                 pass
             else:
-                if any([node.op in t for t in [_identity_list, _convert_map,
-                                               _control_flow_nodes]]):
+                if any([node.op in t for t in [_identity_list, _convert_map, _control_flow_nodes]]):
                     pass
                 elif op_def is not None and op_def.is_stateful:
                     missing_operators.add(node.op)
@@ -2818,10 +2987,9 @@ class GraphProto(object):
         try:
             from tensorflow.python.framework import tensor_util
         except ImportError as e:
-            raise ImportError(
-                "Unable to import tensorflow which is required {}".format(e))
+            raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
-        if key == 'value':
+        if key == "value":
             np_array = tensor_util.MakeNdarray(value.tensor)
 
             if np_array.dtype == np.dtype(object):
@@ -2831,7 +2999,7 @@ class GraphProto(object):
                     var_shape = shape[name]
                 else:
                     var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
-                self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')]
+                self._nodes[name] = [_expr.var(name, shape=var_shape, dtype="uint8")]
                 return
 
             array_ndim = len(np_array.shape)
@@ -2839,13 +3007,14 @@ class GraphProto(object):
                 self._nodes[name] = [tvm.relay.const(np_array, np_array.dtype)]
             else:
                 self._params[name] = tvm.nd.array(np_array)
-                self._nodes[name] = [_expr.var(name,
-                                               shape=self._params[name].shape,
-                                               dtype=self._params[name].dtype)]
+                self._nodes[name] = [
+                    _expr.var(name, shape=self._params[name].shape, dtype=self._params[name].dtype)
+                ]
         else:
-            if key not in ('dtype', '_output_shapes', '_class'):
-                raise NotImplementedError \
-                    ("Other attributes for a Const(param) Node {} ? .".format(key))
+            if key not in ("dtype", "_output_shapes", "_class"):
+                raise NotImplementedError(
+                    "Other attributes for a Const(param) Node {} ? .".format(key)
+                )
 
     def _get_attr(self, buf):
         """Returns the value of the attr of this buf with the given `name`.
@@ -2868,8 +3037,7 @@ class GraphProto(object):
         try:
             from tensorflow.python.framework import dtypes
         except ImportError as e:
-            raise ImportError(
-                "Unable to import tensorflow which is required {}".format(e))
+            raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
         # Treat an empty oneof value as an empty list.
         if not x.WhichOneof("value"):
@@ -2924,15 +3092,13 @@ class GraphProto(object):
         op : tvm.relay.Expr
             Converted relay expression.
         """
-        node_name_prefix = node.name.rsplit('/', 1)[0]
+        node_name_prefix = node.name.rsplit("/", 1)[0]
         plname = find_parent_loop_name(node.name, self._while_loop_name_set)
         if node.op == "Merge":
             if _in_while_loop(self._control_flow_node_map, node_name_prefix):
                 op = self._licm_construct(plname, node.input[0])
                 if node_name_prefix not in self._loops:
-                    self._loops[node_name_prefix] = Loop(self._mod,
-                                                         plname,
-                                                         self._lvar2expr)
+                    self._loops[node_name_prefix] = Loop(self._mod, plname, self._lvar2expr)
             else:
                 if node_name_prefix not in self._branches:
                     switch_prefix = node_name_prefix + "/Switch"
@@ -2951,8 +3117,9 @@ class GraphProto(object):
                 op = branch.if_node()
                 if node_name_prefix not in self._while_loop_name_set:
                     try:
-                        cond_val = np.all(_infer_value(branch.cond, self._params,
-                                                       self._mod).asnumpy())
+                        cond_val = np.all(
+                            _infer_value(branch.cond, self._params, self._mod).asnumpy()
+                        )
                         if cond_val:
                             op = branch.true_branch
                         else:
@@ -2973,8 +3140,8 @@ class GraphProto(object):
                             loop_vars.append(loop.loop_vars[j])
                 loop.loop_vars = loop_vars
                 loop.aligned = True
-            exit_name = node.name.split('/')[-1]
-            if '_' in exit_name:
+            exit_name = node.name.split("/")[-1]
+            if "_" in exit_name:
                 exit_number = int(exit_name[5:])
             else:
                 exit_number = 0
@@ -2999,8 +3166,9 @@ class GraphProto(object):
                 if node.name.endswith("Switch"):
                     self._loop_var_order[node_name_prefix].append(0)
                 else:
-                    self._loop_var_order[node_name_prefix].\
-                        append(int(node.name.split("Switch_")[-1]))
+                    self._loop_var_order[node_name_prefix].append(
+                        int(node.name.split("Switch_")[-1])
+                    )
                 self._loops[node_name_prefix].loop_vars.append(op)
             else:
                 if node_name_prefix not in self._branches:
@@ -3012,13 +3180,13 @@ class GraphProto(object):
             if node.name.endswith("NextIteration"):
                 self._loop_body_order[node_name_prefix].append(0)
             else:
-                self._loop_body_order[node_name_prefix].\
-                    append(int(node.name.split("NextIteration_")[-1]))
+                self._loop_body_order[node_name_prefix].append(
+                    int(node.name.split("NextIteration_")[-1])
+                )
             op = self._licm_construct(plname, node.input[0])
             self._loops[node_name_prefix].body.append(op)
         else:
-            raise Exception("Cannot identify control flow operator: " +
-                            "{}".format(node.op))
+            raise Exception("Cannot identify control flow operator: " + "{}".format(node.op))
 
         return op
 
@@ -3048,25 +3216,27 @@ class GraphProto(object):
         try:
             from tensorflow.python.framework import function_def_to_graph
         except ImportError as e:
-            raise ImportError(
-                "Unable to import tensorflow which is required {}".format(e))
+            raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
         main_graph_proto = self._main_graph_proto
         outer_graph_def = main_graph_proto._graph
 
-        node_func_name = attr.get('f').name
-        func = next((f for f in outer_graph_def.library.function
-                     if f.signature.name == node_func_name), None)
+        node_func_name = attr.get("f").name
+        func = next(
+            (f for f in outer_graph_def.library.function if f.signature.name == node_func_name),
+            None,
+        )
         if func:
             devices = set(node.device for node in func.node_def)
             if len(devices) > 1:
-                raise Exception("Found inconsistent Device assignment in the "\
-                                "Stateful Partitioned SubGraph. Rejecting "\
-                                "the subgraph ")
+                raise Exception(
+                    "Found inconsistent Device assignment in the "
+                    "Stateful Partitioned SubGraph. Rejecting "
+                    "the subgraph "
+                )
             # Convert function definition to graph
             func_input_shapes = func.attr["_input_shapes"].list.shape
-            subgraph, _ = function_def_to_graph.\
-                function_def_to_graph_def(func, func_input_shapes)
+            subgraph, _ = function_def_to_graph.function_def_to_graph_def(func, func_input_shapes)
 
             # Computing subgraph's input shape dictionary
             subgraph_shape_dict, input_expr_dict = {}, {}
@@ -3074,7 +3244,7 @@ class GraphProto(object):
                 input_expr_dict[f_arg.name] = input
                 subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod)
 
-            func_name = 'func_{}'.format(func.signature.name)
+            func_name = "func_{}".format(func.signature.name)
             try:
                 global_func = main_graph_proto._mod[func_name]
                 sub_func = global_func
@@ -3107,8 +3277,9 @@ class GraphProto(object):
             raise Exception("Function not found - {}".format(node_func_name))
         return ret
 
-    def _convert_operator(self, op_name, inputs, attrs,
-                          graph, identity_list=None, convert_map=None):
+    def _convert_operator(
+        self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None
+    ):
         """Convert from Tensorflow operator to relay operator.
         The converter must specify conversions explicitly for incompatible name, and
         apply handlers to operator attributes.
@@ -3168,7 +3339,7 @@ class GraphProto(object):
             Converted relay expression or loop var.
         """
         actual_expr = self._backtrack_construct(node_name)
-        tn = node_name.split(':')
+        tn = node_name.split(":")
         node_name = tn[0].split("^")[-1]
         cloop_name = find_parent_loop_name(node_name, self._while_loop_name_set)
 
@@ -3219,10 +3390,9 @@ class GraphProto(object):
         try:
             from tensorflow.python.framework import tensor_util
         except ImportError as e:
-            raise ImportError(
-                "Unable to import tensorflow which is required {}".format(e))
+            raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
-        input_op_name = node_name.split(':')[0].split("^")[-1]
+        input_op_name = node_name.split(":")[0].split("^")[-1]
 
         if input_op_name not in self._nodes:
             node = self._tf_node_map[input_op_name]
@@ -3230,9 +3400,9 @@ class GraphProto(object):
 
             if node.op in _control_flow_nodes:
                 attr = self._parse_attr(node.attr)
-                op = self._convert_control_flow_operator(node, [],
-                                                         attr,
-                                                         self._control_flow_node_map)
+                op = self._convert_control_flow_operator(
+                    node, [], attr, self._control_flow_node_map
+                )
             else:
                 attr["_output_shapes"] = self._output_shapes[input_op_name]
                 attr["_node_name"] = node.name
@@ -3244,7 +3414,7 @@ class GraphProto(object):
 
                 # For TensorArrayV3 op, we need to infer shape first
                 if is_tensor_array_constuctor(node):
-                    raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape'])
+                    raw_elem_shape = tensor_util.TensorShapeProtoToList(attr["element_shape"])
                     elem_shape = []
                     for dim in raw_elem_shape:
                         if dim < 0:
@@ -3254,9 +3424,10 @@ class GraphProto(object):
 
                     if elem_shape:
                         attr["shape"] = elem_shape
-                    if attr['identical_element_shapes'] or elem_shape:
-                        shape_node, wnode_op, output_index = \
-                            self._tensor_array_shape_nodes[node.name]
+                    if attr["identical_element_shapes"] or elem_shape:
+                        shape_node, wnode_op, output_index = self._tensor_array_shape_nodes[
+                            node.name
+                        ]
                         name = shape_node.name
                         if output_index > 0:
                             name += ":" + str(output_index)
@@ -3286,9 +3457,13 @@ class GraphProto(object):
 
             if isinstance(op, np.ndarray):
                 self._params[node.name] = tvm.nd.array(op)
-                op = [_expr.var(node.name,
-                                shape=self._params[node.name].shape,
-                                dtype=self._params[node.name].dtype)]
+                op = [
+                    _expr.var(
+                        node.name,
+                        shape=self._params[node.name].shape,
+                        dtype=self._params[node.name].dtype,
+                    )
+                ]
 
             elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)):
                 op = [op]
@@ -3298,7 +3473,7 @@ class GraphProto(object):
         out = self._nodes[input_op_name]
 
         if isinstance(out, _expr.TupleWrapper):
-            tn = node_name.split(':')
+            tn = node_name.split(":")
             tensor_slot = int(tn[1]) if len(tn) > 1 else 0
             return out[tensor_slot]
 
@@ -3306,14 +3481,14 @@ class GraphProto(object):
 
 
 class SubGraphProto(GraphProto):
-    """ A helper class for handling relay subgraph copying from Tensorflow GraphDef.
-    """
+    """A helper class for handling relay subgraph copying from Tensorflow GraphDef."""
+
     def __init__(self, main_graph_proto):
         super().__init__()
         self._main_graph_proto = main_graph_proto  # holds main graph proto object
 
     def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
-        """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function.
+        """Wrapper to _get_relay_func which converts Tensorflow graph to Relay function.
         Return Relay function and params
         """
         func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
index 22c6f94..a176e12 100644 (file)
@@ -44,6 +44,7 @@ class TFParser(object):
 
     def __init__(self, model_dir, outputs=None):
         from tensorflow.core.framework import graph_pb2
+
         self._tmp_dir = util.tempdir()
         self._model_dir = model_dir
         self._graph = graph_pb2.GraphDef()
@@ -67,15 +68,17 @@ class TFParser(object):
     def _get_tag_set(self):
         """Return the tag set of saved model, multiple metagraphs are not supported"""
         try:
-            from tensorflow.contrib.saved_model.python.saved_model.reader \
-                import get_saved_model_tag_sets
+            from tensorflow.contrib.saved_model.python.saved_model.reader import (
+                get_saved_model_tag_sets,
+            )
         except ImportError:
             try:
                 from tensorflow.python.tools.saved_model_utils import get_saved_model_tag_sets
             except ImportError:
                 raise ImportError(
                     "InputConfiguration: Unable to import get_saved_model_tag_sets which is "
-                    "required to get tag set from saved model.")
+                    "required to get tag set from saved model."
+                )
         tag_sets = get_saved_model_tag_sets(self._model_dir)
         return tag_sets[0]
 
@@ -86,13 +89,12 @@ class TFParser(object):
         except ImportError:
             raise ImportError(
                 "InputConfiguration: Unable to import tensorflow which is "
-                "required to restore from saved model.")
+                "required to restore from saved model."
+            )
         tags = self._get_tag_set()
         output_names = set()
         with tf.Session() as sess:
-            meta_graph_def = tf.saved_model.loader.load(sess,
-                                                        tags,
-                                                        self._model_dir)
+            meta_graph_def = tf.saved_model.loader.load(sess, tags, self._model_dir)
             for sig_def in meta_graph_def.signature_def.values():
                 for output_tensor in sig_def.outputs.values():
                     output_names.add(output_tensor.name.replace(":0", ""))
@@ -109,7 +111,8 @@ class TFParser(object):
         except ImportError:
             raise ImportError(
                 "InputConfiguration: Unable to import tensorflow which is "
-                "required to restore from saved model.")
+                "required to restore from saved model."
+            )
 
         saved_model_dir = self._model_dir
         output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
@@ -126,25 +129,38 @@ class TFParser(object):
         input_graph_filename = None
         saved_model_tags = ",".join(self._get_tag_set())
 
-        freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
-                                  input_binary, checkpoint_path, output_node_names,
-                                  restore_op_name, filename_tensor_name,
-                                  output_graph_filename, clear_devices, "", "", "",
-                                  input_meta_graph, input_saved_model_dir,
-                                  saved_model_tags)
+        freeze_graph.freeze_graph(
+            input_graph_filename,
+            input_saver_def_path,
+            input_binary,
+            checkpoint_path,
+            output_node_names,
+            restore_op_name,
+            filename_tensor_name,
+            output_graph_filename,
+            clear_devices,
+            "",
+            "",
+            "",
+            input_meta_graph,
+            input_saved_model_dir,
+            saved_model_tags,
+        )
 
         with ops.Graph().as_default():
             output_graph_def = graph_pb2.GraphDef()
             with open(output_graph_filename, "rb") as f:
                 output_graph_def.ParseFromString(f.read())
-            output_graph_def = graph_util.remove_training_nodes(output_graph_def,
-                                                                protected_nodes=self._outputs)
+            output_graph_def = graph_util.remove_training_nodes(
+                output_graph_def, protected_nodes=self._outputs
+            )
             return output_graph_def
 
     def _load_ckpt(self):
         """TODO: Load checkpoint model."""
-        raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
-                           "not supported yet.")
+        raise RuntimeError(
+            "InputConfiguration: Loading tf checkpoint model is " "not supported yet."
+        )
 
     def parse(self):
         """
@@ -167,8 +183,7 @@ class TFParser(object):
                 graph = self._load_ckpt()
         elif os.path.isfile(self._model_dir):
             # Only .pb or .pbtxt is a valid suffix name.
-            if self._model_dir.endswith(".pb") or \
-               self._model_dir.endswith(".pbtxt"):
+            if self._model_dir.endswith(".pb") or self._model_dir.endswith(".pbtxt"):
                 cur_dir = os.path.dirname(self._model_dir)
             else:
                 raise RuntimeError("InputConfiguration: Invalid model format.")
@@ -181,8 +196,7 @@ class TFParser(object):
             else:
                 graph = self._load_pb_file()
         else:
-            raise RuntimeError("InputConfiguration: Unrecognized model "
-                               "file or path.")
+            raise RuntimeError("InputConfiguration: Unrecognized model " "file or path.")
 
         self._set_graph(graph)
         return graph
index 59ba9f4..1b09cf3 100644 (file)
@@ -34,18 +34,22 @@ from .common import infer_shape as _infer_shape
 from .tflite_flexbuffer import FlexBufferDecoder
 
 
-__all__ = ['from_tflite']
+__all__ = ["from_tflite"]
+
 
 class TensorWrapper(object):
     """Tensor wrapper for TFLite Tensor"""
+
     def __init__(self, tensor_idx, tensor, buffer, qnn_params=None):
         self.tensor_idx = tensor_idx
         self.tensor = tensor
         self.buffer = buffer
         self.qnn_params = qnn_params
 
+
 class OperatorConverter(object):
     """Operator Converted for converting TFLite ops to Relay ops"""
+
     def __init__(self, model, subgraph, exp_tab):
 
         try:
@@ -64,107 +68,107 @@ class OperatorConverter(object):
 
         # Add more operators
         self.convert_map = {
-            'ABS': self.convert_abs,
-            'ADD': self.convert_add,
-            'ADD_N': self.convert_add_n,
-            'ARG_MAX': self.convert_arg_max,
-            'ARG_MIN': self.convert_arg_min,
-            'AVERAGE_POOL_2D': self.convert_average_pool2d,
-            'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
-            'CAST': self.convert_cast,
-            'CEIL': self.convert_ceil,
-            'CONCATENATION': self.convert_concatenation,
-            'CONV_2D': self.convert_conv2d,
-            'COS': self.convert_cos,
-            'DEPTH_TO_SPACE': self.convert_depth_to_space,
-            'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
-            'DEQUANTIZE': self.convert_dequantize,
-            'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
-            'DIV': self.convert_div,
-            'ELU': self.convert_elu,
-            'EQUAL': self.convert_equal,
-            'EXP': self.convert_exp,
-            'EXPAND_DIMS': self.convert_expand_dims,
-            'FILL': self.convert_fill,
-            'FLOOR_DIV': self.convert_floor_div,
-            'FLOOR_MOD': self.convert_floor_mod,
-            'FLOOR': self.convert_floor,
-            'FULLY_CONNECTED': self.convert_fully_connected,
-            'GATHER': self.convert_gather,
-            'GATHER_ND' : self.convert_gather_nd,
-            'GREATER_EQUAL': self.convert_greater_equal,
-            'GREATER': self.convert_greater,
-            'HARD_SWISH': self.convert_hard_swish,
-            'L2_NORMALIZATION': self.convert_l2_normalization,
-            'L2_POOL_2D': self.convert_l2_pool2d,
-            'LEAKY_RELU': self.convert_leaky_relu,
-            'LESS_EQUAL': self.convert_less_equal,
-            'LESS': self.convert_less,
-            'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
-            'LOG': self.convert_log,
-            'LOG_SOFTMAX': self.convert_log_softmax,
-            'LOGICAL_AND': self.convert_logical_and,
-            'LOGICAL_NOT': self.convert_logical_not,
-            'LOGICAL_OR': self.convert_logical_or,
-            'LOGISTIC': self.convert_logistic,
-            'MATRIX_DIAG': self.convert_matrix_diag,
-            'MATRIX_SET_DIAG': self.convert_matrix_set_diag,
-            'MAX_POOL_2D': self.convert_max_pool2d,
-            'MAXIMUM': self.convert_maximum,
-            'MEAN': self.convert_reduce_mean,
-            'MINIMUM': self.convert_minimum,
-            'MIRROR_PAD': self.convert_mirror_pad,
-            'MUL': self.convert_mul,
-            'NEG': self.convert_neg,
-            'NOT_EQUAL': self.convert_not_equal,
-            'ONE_HOT': self.convert_one_hot,
-            'PACK': self.convert_pack,
-            'PAD': self.convert_pad,
-            'PADV2': self.convert_pad,
-            'POW': self.convert_pow,
-            'PRELU': self.convert_prelu,
-            'RANGE': self.convert_range,
-            'QUANTIZE': self.convert_quantize,
-            'REDUCE_ANY': self.convert_reduce_any,
-            'REDUCE_MAX': self.convert_reduce_max,
-            'REDUCE_MIN': self.convert_reduce_min,
-            'REDUCE_PROD': self.convert_reduce_prod,
-            'RELU':self.convert_relu,
-            'RELU6': self.convert_relu6,
-            'RELU_N1_TO_1': self.convert_relu_n1_to_1,
-            'RESHAPE': self.convert_reshape,
-            'RESIZE_BILINEAR': self.convert_resize_bilinear,
-            'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
-            'ROUND': self.convert_round,
-            'RSQRT': self.convert_rsqrt,
-            'REVERSE_SEQUENCE': self.convert_reverse_sequence,
-            'REVERSE_V2': self.convert_reverse_v2,
-            'SELECT': self.convert_select,
-            'SHAPE': self.convert_shape,
-            'SIN': self.convert_sin,
-            'SLICE': self.convert_slice,
-            'SOFTMAX': self.convert_softmax,
-            'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
-            'SPACE_TO_DEPTH': self.convert_space_to_depth,
-            'SPARSE_TO_DENSE': self.convert_sparse_to_dense,
-            'SPLIT': self.convert_split,
-            'SPLIT_V': self.convert_split_v,
-            'SQRT': self.convert_sqrt,
-            'SQUARE': self.convert_square,
-            'SQUARED_DIFFERENCE': self.convert_squared_difference,
-            'SQUEEZE': self.convert_squeeze,
-            'STRIDED_SLICE': self.convert_strided_slice,
-            'SUB': self.convert_sub,
-            'SUM': self.convert_reduce_sum,
-            'TAN': self.convert_tan,
-            'TANH':self.convert_tanh,
-            'TILE': self.convert_tile,
-            'TOPK_V2': self.convert_topk_v2,
-            'TRANSPOSE_CONV': self.convert_transpose_conv,
-            'TRANSPOSE': self.convert_transpose,
-            'UNPACK': self.convert_unpack,
-            'WHERE': self.convert_select,
-            'ZEROS_LIKE': self.convert_zeros_like,
+            "ABS": self.convert_abs,
+            "ADD": self.convert_add,
+            "ADD_N": self.convert_add_n,
+            "ARG_MAX": self.convert_arg_max,
+            "ARG_MIN": self.convert_arg_min,
+            "AVERAGE_POOL_2D": self.convert_average_pool2d,
+            "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
+            "CAST": self.convert_cast,
+            "CEIL": self.convert_ceil,
+            "CONCATENATION": self.convert_concatenation,
+            "CONV_2D": self.convert_conv2d,
+            "COS": self.convert_cos,
+            "DEPTH_TO_SPACE": self.convert_depth_to_space,
+            "DEPTHWISE_CONV_2D": self.convert_depthwise_conv2d,
+            "DEQUANTIZE": self.convert_dequantize,
+            "DETECTION_POSTPROCESS": self.convert_detection_postprocess,
+            "DIV": self.convert_div,
+            "ELU": self.convert_elu,
+            "EQUAL": self.convert_equal,
+            "EXP": self.convert_exp,
+            "EXPAND_DIMS": self.convert_expand_dims,
+            "FILL": self.convert_fill,
+            "FLOOR_DIV": self.convert_floor_div,
+            "FLOOR_MOD": self.convert_floor_mod,
+            "FLOOR": self.convert_floor,
+            "FULLY_CONNECTED": self.convert_fully_connected,
+            "GATHER": self.convert_gather,
+            "GATHER_ND": self.convert_gather_nd,
+            "GREATER_EQUAL": self.convert_greater_equal,
+            "GREATER": self.convert_greater,
+            "HARD_SWISH": self.convert_hard_swish,
+            "L2_NORMALIZATION": self.convert_l2_normalization,
+            "L2_POOL_2D": self.convert_l2_pool2d,
+            "LEAKY_RELU": self.convert_leaky_relu,
+            "LESS_EQUAL": self.convert_less_equal,
+            "LESS": self.convert_less,
+            "LOCAL_RESPONSE_NORMALIZATION": self.convert_lrn,
+            "LOG": self.convert_log,
+            "LOG_SOFTMAX": self.convert_log_softmax,
+            "LOGICAL_AND": self.convert_logical_and,
+            "LOGICAL_NOT": self.convert_logical_not,
+            "LOGICAL_OR": self.convert_logical_or,
+            "LOGISTIC": self.convert_logistic,
+            "MATRIX_DIAG": self.convert_matrix_diag,
+            "MATRIX_SET_DIAG": self.convert_matrix_set_diag,
+            "MAX_POOL_2D": self.convert_max_pool2d,
+            "MAXIMUM": self.convert_maximum,
+            "MEAN": self.convert_reduce_mean,
+            "MINIMUM": self.convert_minimum,
+            "MIRROR_PAD": self.convert_mirror_pad,
+            "MUL": self.convert_mul,
+            "NEG": self.convert_neg,
+            "NOT_EQUAL": self.convert_not_equal,
+            "ONE_HOT": self.convert_one_hot,
+            "PACK": self.convert_pack,
+            "PAD": self.convert_pad,
+            "PADV2": self.convert_pad,
+            "POW": self.convert_pow,
+            "PRELU": self.convert_prelu,
+            "RANGE": self.convert_range,
+            "QUANTIZE": self.convert_quantize,
+            "REDUCE_ANY": self.convert_reduce_any,
+            "REDUCE_MAX": self.convert_reduce_max,
+            "REDUCE_MIN": self.convert_reduce_min,
+            "REDUCE_PROD": self.convert_reduce_prod,
+            "RELU": self.convert_relu,
+            "RELU6": self.convert_relu6,
+            "RELU_N1_TO_1": self.convert_relu_n1_to_1,
+            "RESHAPE": self.convert_reshape,
+            "RESIZE_BILINEAR": self.convert_resize_bilinear,
+            "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor,
+            "ROUND": self.convert_round,
+            "RSQRT": self.convert_rsqrt,
+            "REVERSE_SEQUENCE": self.convert_reverse_sequence,
+            "REVERSE_V2": self.convert_reverse_v2,
+            "SELECT": self.convert_select,
+            "SHAPE": self.convert_shape,
+            "SIN": self.convert_sin,
+            "SLICE": self.convert_slice,
+            "SOFTMAX": self.convert_softmax,
+            "SPACE_TO_BATCH_ND": self.convert_space_to_batch_nd,
+            "SPACE_TO_DEPTH": self.convert_space_to_depth,
+            "SPARSE_TO_DENSE": self.convert_sparse_to_dense,
+            "SPLIT": self.convert_split,
+            "SPLIT_V": self.convert_split_v,
+            "SQRT": self.convert_sqrt,
+            "SQUARE": self.convert_square,
+            "SQUARED_DIFFERENCE": self.convert_squared_difference,
+            "SQUEEZE": self.convert_squeeze,
+            "STRIDED_SLICE": self.convert_strided_slice,
+            "SUB": self.convert_sub,
+            "SUM": self.convert_reduce_sum,
+            "TAN": self.convert_tan,
+            "TANH": self.convert_tanh,
+            "TILE": self.convert_tile,
+            "TOPK_V2": self.convert_topk_v2,
+            "TRANSPOSE_CONV": self.convert_transpose_conv,
+            "TRANSPOSE": self.convert_transpose,
+            "UNPACK": self.convert_unpack,
+            "WHERE": self.convert_select,
+            "ZEROS_LIKE": self.convert_zeros_like,
         }
 
     def check_unsupported_ops(self):
@@ -178,9 +182,8 @@ class OperatorConverter(object):
                 unsupported_ops_set.add(op_code_str)
 
         if unsupported_ops_set:
-            msg = 'The following operators are not supported in frontend ' \
-                  'TFLite: {}'
-            ops = str(list(unsupported_ops_set)).strip('[,]')
+            msg = "The following operators are not supported in frontend " "TFLite: {}"
+            ops = str(list(unsupported_ops_set)).strip("[,]")
             raise tvm.error.OpNotImplemented(msg.format(ops))
 
     def convert_op_to_relay(self):
@@ -202,8 +205,9 @@ class OperatorConverter(object):
                 self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
             else:
                 for idx, output_tensor in enumerate(output_tensors):
-                    self.exp_tab.set_expr(get_tensor_name(self.subgraph, output_tensor.tensor_idx),
-                                          ret[idx])
+                    self.exp_tab.set_expr(
+                        get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx]
+                    )
 
     def get_op_code_str(self, op):
         """Get TFLite ops string representation"""
@@ -217,12 +221,15 @@ class OperatorConverter(object):
         try:
             op_code_str = self.builtin_op_code[op_code_id]
         except KeyError:
-            raise NotImplementedError('TFLite operator with code ' + str(op_code_id) + \
-                                      ' is not supported by this version of the fbs schema.')
+            raise NotImplementedError(
+                "TFLite operator with code "
+                + str(op_code_id)
+                + " is not supported by this version of the fbs schema."
+            )
         if op_code_id == BuiltinOperator.CUSTOM:
             # Custom operator
             custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode()
-            if custom_op_code_str == b'TFLite_Detection_PostProcess':
+            if custom_op_code_str == b"TFLite_Detection_PostProcess":
                 return "DETECTION_POSTPROCESS"
 
             raise NotImplementedError("Custom operators are currently not supported")
@@ -274,9 +281,10 @@ class OperatorConverter(object):
                         # Ensure that all zero points are zeros
                         zero_point = tflite_zero_point
                         if not np.all(zero_point == 0):
-                            raise tvm.error.OpAttributeInvalid(\
-                                    "TFLite per-axis quantization restricts all zero points to be"
-                                    + " 0, but a non-zero value is observed")
+                            raise tvm.error.OpAttributeInvalid(
+                                "TFLite per-axis quantization restricts all zero points to be"
+                                + " 0, but a non-zero value is observed"
+                            )
                         zero_point = int(zero_point[0])
 
                     # Scalar - Per-tensor quantization
@@ -285,44 +293,49 @@ class OperatorConverter(object):
                         zero_point = int(tflite_zero_point[0])
 
                     else:
-                        raise NotImplementedError(\
-                                "Quantized type {} (scale) and  {} (zero point) not supported"
-                                .format(type(tflite_scale), type(tflite_zero_point)))
+                        raise NotImplementedError(
+                            "Quantized type {} (scale) and  {} (zero point) not supported".format(
+                                type(tflite_scale), type(tflite_zero_point)
+                            )
+                        )
                 elif tflite_scale == 0 and tflite_zero_point == 0:
                     # Handle corner case for ops like quantized reshape whose second operand (shape)
                     # has zero scale and zero zero point. This is not used.
                     is_qnn_params_valid = False
                 else:
-                    raise NotImplementedError("Quantized type {} not supported"
-                                              .format(type(tflite_scale)))
+                    raise NotImplementedError(
+                        "Quantized type {} not supported".format(type(tflite_scale))
+                    )
 
                 # Check that the scale and zero points are valid.
                 if is_qnn_params_valid:
                     qnn_params = dict()
-                    qnn_params['scale'] = relay.const(scale, 'float32')
-                    qnn_params['zero_point'] = relay.const(zero_point, 'int32')
+                    qnn_params["scale"] = relay.const(scale, "float32")
+                    qnn_params["zero_point"] = relay.const(zero_point, "int32")
             return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
         return return_list
 
-
     def get_tensor_type_as_numpy(self, tensor_wrapper):
         """Returns np.dtype out of TensorType"""
         assert isinstance(tensor_wrapper, TensorWrapper)
 
         try:
             from tflite.TensorType import TensorType
-            return {TensorType.UINT8: np.uint8,
-                    TensorType.INT8: np.int8,
-                    TensorType.FLOAT32: np.float32,
-                    TensorType.INT32: np.int32,
-                    TensorType.INT64: np.int64,
-                    TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
+
+            return {
+                TensorType.UINT8: np.uint8,
+                TensorType.INT8: np.int8,
+                TensorType.FLOAT32: np.float32,
+                TensorType.INT32: np.int32,
+                TensorType.INT64: np.int64,
+                TensorType.BOOL: np.bool_,
+            }[tensor_wrapper.tensor.Type()]
         except ImportError:
             raise ImportError("The tflite package must be installed")
         except KeyError:
-            raise NotImplementedError("Tensor type '{}' currently not supported"
-                                      .format(tensor_wrapper.tensor.Type()))
-
+            raise NotImplementedError(
+                "Tensor type '{}' currently not supported".format(tensor_wrapper.tensor.Type())
+            )
 
     def get_tensor_value(self, tensor_wrapper):
         """Get tensor buffer value from given tensor wrapper"""
@@ -338,7 +351,6 @@ class OperatorConverter(object):
 
         return np.frombuffer(data, dtype=dtype).reshape(shape)
 
-
     def get_tensor_type_str(self, tensor_type):
         """Get tensor type string representation when given TFLite tensor type"""
         try:
@@ -358,24 +370,21 @@ class OperatorConverter(object):
             return "int64"
         if tensor_type == TensorType.BOOL:
             return "bool"
-        raise NotImplementedError("Tensor type {} is currently not supported"
-                                  .format(str(tensor_type)))
+        raise NotImplementedError(
+            "Tensor type {} is currently not supported".format(str(tensor_type))
+        )
 
     def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
-        lhs_scale = lhs_tensor.qnn_params['scale']
-        rhs_scale = rhs_tensor.qnn_params['scale']
-        lhs_zero_point = lhs_tensor.qnn_params['zero_point']
-        rhs_zero_point = rhs_tensor.qnn_params['zero_point']
+        lhs_scale = lhs_tensor.qnn_params["scale"]
+        rhs_scale = rhs_tensor.qnn_params["scale"]
+        lhs_zero_point = lhs_tensor.qnn_params["zero_point"]
+        rhs_zero_point = rhs_tensor.qnn_params["zero_point"]
         # 0.1 + 0.2 != 0.3
-        return np.allclose(lhs_scale.data.asnumpy(),
-                           rhs_scale.data.asnumpy(),
-                           rtol=1e-5,
-                           atol=1e-5) \
-               and \
-               np.allclose(lhs_zero_point.data.asnumpy(),
-                           rhs_zero_point.data.asnumpy(),
-                           rtol=1e-5,
-                           atol=1e-5)
+        return np.allclose(
+            lhs_scale.data.asnumpy(), rhs_scale.data.asnumpy(), rtol=1e-5, atol=1e-5
+        ) and np.allclose(
+            lhs_zero_point.data.asnumpy(), rhs_zero_point.data.asnumpy(), rtol=1e-5, atol=1e-5
+        )
 
     def is_quantized(self, op):
         """Check if an input tensor is quantized."""
@@ -387,24 +396,28 @@ class OperatorConverter(object):
         """ Helper function to quantize a tensor with Relay """
         tensor_type = tensor_to_quantize.tensor.Type()
         tensor_type_str = self.get_tensor_type_str(tensor_type)
-        quantized = _qnn.op.quantize(data=expr,
-                                     output_scale=tensor_to_quantize.qnn_params['scale'],
-                                     output_zero_point=tensor_to_quantize.qnn_params['zero_point'],
-                                     out_dtype=tensor_type_str)
+        quantized = _qnn.op.quantize(
+            data=expr,
+            output_scale=tensor_to_quantize.qnn_params["scale"],
+            output_zero_point=tensor_to_quantize.qnn_params["zero_point"],
+            out_dtype=tensor_type_str,
+        )
         return quantized
 
     def dequantize(self, expr, tensor):
         """ Helper function to dequantize a tensor with Relay """
-        dequantized = _qnn.op.dequantize(data=expr,
-                                         input_scale=tensor.qnn_params['scale'],
-                                         input_zero_point=tensor.qnn_params['zero_point'])
+        dequantized = _qnn.op.dequantize(
+            data=expr,
+            input_scale=tensor.qnn_params["scale"],
+            input_zero_point=tensor.qnn_params["zero_point"],
+        )
         return dequantized
 
-
-    def convert_qnn_fused_activation_function(self, expr, fused_activation_fn,
-                                              scale, zero_point, dtype):
+    def convert_qnn_fused_activation_function(
+        self, expr, fused_activation_fn, scale, zero_point, dtype
+    ):
         """Convert TFLite fused activation function. The expr is an input quantized tensor with
-        scale and zero point """
+        scale and zero point"""
         try:
             from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
@@ -423,21 +436,16 @@ class OperatorConverter(object):
         if fused_activation_fn == ActivationFunctionType.NONE:
             return expr
         if fused_activation_fn == ActivationFunctionType.RELU6:
-            return _op.clip(expr,
-                            a_min=max(qmin, quantize(0)),
-                            a_max=min(qmax, quantize(6.0)))
+            return _op.clip(expr, a_min=max(qmin, quantize(0)), a_max=min(qmax, quantize(6.0)))
         if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
-            return _op.clip(expr,
-                            a_min=max(qmin, quantize(-1.0)),
-                            a_max=min(qmax, quantize(1.0)))
+            return _op.clip(expr, a_min=max(qmin, quantize(-1.0)), a_max=min(qmax, quantize(1.0)))
         if fused_activation_fn == ActivationFunctionType.RELU:
-            return _op.clip(expr,
-                            a_min=max(qmin, quantize(0.0)),
-                            a_max=qmax)
+            return _op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)
 
         fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
         raise tvm.error.OpNotImplemented(
-            'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str))
+            "Quantized activation {} is not supported yet.".format(fused_activation_fn_str)
+        )
 
     def convert_conv2d(self, op):
         """Convert TFLite conv2d"""
@@ -484,6 +492,7 @@ class OperatorConverter(object):
                 target_shape = self.get_tensor_value(shape_tensor)
                 # convert to flattened list
                 from itertools import chain
+
                 try:
                     target_shape = list(chain(*target_shape))
                 except TypeError:
@@ -501,8 +510,9 @@ class OperatorConverter(object):
         # If the tensors are quantized, ensure that input/output qnn params are same.
         if input_tensor.qnn_params:
             output_tensor = output_tensors[0]
-            assert self.has_same_qnn_params(input_tensor, output_tensor), \
-                    "TFLite reshape requires input and output scale and zero points to be equal"
+            assert self.has_same_qnn_params(
+                input_tensor, output_tensor
+            ), "TFLite reshape requires input and output scale and zero points to be equal"
 
         out = _op.reshape(in_expr, newshape=target_shape)
         return out
@@ -512,10 +522,12 @@ class OperatorConverter(object):
         try:
             from tflite.BuiltinOptions import BuiltinOptions
             from tflite.ResizeBilinearOptions import ResizeBilinearOptions
+
             # ResizeNearestNeighborOptions was added in tflite v1.13
             tflite_ver = 1120
-            if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions):
+            if "ResizeNearestNeighborOptions" in dir(BuiltinOptions):
                 from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions
+
                 tflite_ver = 1130
         except ImportError:
             raise ImportError("The tflite package must be installed")
@@ -547,8 +559,9 @@ class OperatorConverter(object):
 
         # Use layout NHWC
         coord_trans = "align_corners" if align_corners else "asymmetric"
-        out = _op.image.resize(in_expr, target_size, "NHWC", method,
-                               coordinate_transformation_mode=coord_trans)
+        out = _op.image.resize(
+            in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans
+        )
         return out
 
     def convert_resize_bilinear(self, op):
@@ -587,7 +600,8 @@ class OperatorConverter(object):
 
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
+                "TFLite quantized L2_NORMALIZATION operator is not supported yet."
+            )
 
         # TFL uses only the default epsilon value
         out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1])
@@ -595,7 +609,8 @@ class OperatorConverter(object):
         # if we have fused activation fn
         if output_tensor.qnn_params:
             raise tvm.error.OpNotImplemented(
-                'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
+                "TFLite quantized L2_NORMALIZATION operator is not supported yet."
+            )
         out = self.convert_fused_activation_function(out, fused_activation_fn)
 
         return out
@@ -609,8 +624,7 @@ class OperatorConverter(object):
             raise ImportError("The tflite package must be installed")
 
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized LRN operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized LRN operator is not supported yet.")
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
@@ -630,7 +644,7 @@ class OperatorConverter(object):
         beta = lrn_options.Beta()
         size = (radius * 2) + 1
         alpha = alpha * size
-        axis = 3 # NHWC format
+        axis = 3  # NHWC format
         out = _op.nn.lrn(in_expr, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
 
         return out
@@ -667,7 +681,7 @@ class OperatorConverter(object):
         assert len(output_tensors) == 1, "output tensors length should be 1"
         output_tensor = output_tensors[0]
 
-        params = {'axis': 1}  # 1 is channel
+        params = {"axis": 1}  # 1 is channel
         in_expr = self.get_expr(input_tensor_idx)
 
         # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can
@@ -746,27 +760,30 @@ class OperatorConverter(object):
 
         if input_tensor.qnn_params:
             # Quantize a float value to an quantized integer value
-            scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
-            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
+            scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"])
+            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"])
 
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=in_expr,
-                    fused_activation_fn=ActivationFunctionType.RELU,
-                    scale=scale_val,
-                    zero_point=zero_point_val,
-                    dtype=output_tensor_type_str)
+            out = self.convert_qnn_fused_activation_function(
+                expr=in_expr,
+                fused_activation_fn=ActivationFunctionType.RELU,
+                scale=scale_val,
+                zero_point=zero_point_val,
+                dtype=output_tensor_type_str,
+            )
         else:
             out = _op.nn.relu(in_expr)
 
         if output_tensor.qnn_params:
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = _qnn.op.requantize(out,
-                                     input_scale=input_tensor.qnn_params['scale'],
-                                     input_zero_point=input_tensor.qnn_params['zero_point'],
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=input_tensor.qnn_params["scale"],
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
 
         return out
 
@@ -818,27 +835,30 @@ class OperatorConverter(object):
 
         if input_tensor.qnn_params:
             # Quantize a float value to an quantized integer value
-            scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
-            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
+            scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"])
+            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"])
 
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=in_expr,
-                    fused_activation_fn=ActivationFunctionType.RELU6,
-                    scale=scale_val,
-                    zero_point=zero_point_val,
-                    dtype=output_tensor_type_str)
+            out = self.convert_qnn_fused_activation_function(
+                expr=in_expr,
+                fused_activation_fn=ActivationFunctionType.RELU6,
+                scale=scale_val,
+                zero_point=zero_point_val,
+                dtype=output_tensor_type_str,
+            )
         else:
             out = _op.clip(in_expr, a_min=0, a_max=6)
 
         if output_tensor.qnn_params:
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = _qnn.op.requantize(out,
-                                     input_scale=input_tensor.qnn_params['scale'],
-                                     input_zero_point=input_tensor.qnn_params['zero_point'],
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=input_tensor.qnn_params["scale"],
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
 
         return out
 
@@ -886,8 +906,8 @@ class OperatorConverter(object):
 
         if input_tensor.qnn_params:
             # Quantize a float value to an quantized integer value
-            scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
-            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
+            scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"])
+            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"])
             quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)
 
             # Get min/max of the input dtype. This will be used to ensure that
@@ -896,20 +916,20 @@ class OperatorConverter(object):
             qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
             qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)
 
-            out = _op.clip(in_expr,
-                           a_min=max(qmin, quantize(-1.0)),
-                           a_max=min(qmax, quantize(1.0)))
+            out = _op.clip(in_expr, a_min=max(qmin, quantize(-1.0)), a_max=min(qmax, quantize(1.0)))
         else:
             out = _op.clip(in_expr, a_min=-1, a_max=1)
 
         if output_tensor.qnn_params:
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = _qnn.op.requantize(out,
-                                     input_scale=input_tensor.qnn_params['scale'],
-                                     input_zero_point=input_tensor.qnn_params['zero_point'],
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=input_tensor.qnn_params["scale"],
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
 
         return out
 
@@ -958,27 +978,31 @@ class OperatorConverter(object):
         if not input_tensors[0].qnn_params:
             out = _op.concatenate(in_exprs, axis=concatenation_axis)
         else:
-            input_scales = [input_tensor.qnn_params['scale'] for input_tensor in input_tensors]
-            input_zero_points = \
-                    [input_tensor.qnn_params['zero_point'] for input_tensor in input_tensors]
-            out = _qnn.op.concatenate(in_exprs,
-                                      input_scales=input_scales,
-                                      input_zero_points=input_zero_points,
-                                      output_scale=output_tensor.qnn_params['scale'],
-                                      output_zero_point=output_tensor.qnn_params['zero_point'],
-                                      axis=concatenation_axis)
+            input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors]
+            input_zero_points = [
+                input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors
+            ]
+            out = _qnn.op.concatenate(
+                in_exprs,
+                input_scales=input_scales,
+                input_zero_points=input_zero_points,
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                axis=concatenation_axis,
+            )
 
         # Handle fused activations
         if output_tensor.qnn_params:
-            scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
-            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+            scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
+            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"])
             output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=out,
-                    fused_activation_fn=fused_activation_fn,
-                    scale=scale_val,
-                    zero_point=zero_point_val,
-                    dtype=output_tensor_type_str)
+            out = self.convert_qnn_fused_activation_function(
+                expr=out,
+                fused_activation_fn=fused_activation_fn,
+                scale=scale_val,
+                zero_point=zero_point_val,
+                dtype=output_tensor_type_str,
+            )
         else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
 
@@ -998,101 +1022,94 @@ class OperatorConverter(object):
     def convert_abs(self, op):
         """Convert TFLite ABS"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized ABS operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized ABS operator is not supported yet.")
         return self._convert_unary_elemwise(_op.abs, op)
 
     def convert_ceil(self, op):
         """Convert TFLite CEIL"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized CEIL operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized CEIL operator is not supported yet.")
         return self._convert_unary_elemwise(_op.ceil, op)
 
     def convert_floor(self, op):
         """Convert TFLite FLOOR"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized FLOOR operator is not supported yet.')
+                "TFlite quantized FLOOR operator is not supported yet."
+            )
         return self._convert_unary_elemwise(_op.floor, op)
 
     def convert_round(self, op):
         """Convert TFLite ROUND"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized ROUND operator is not supported yet.')
+                "TFlite quantized ROUND operator is not supported yet."
+            )
         return self._convert_unary_elemwise(_op.round, op)
 
     def convert_exp(self, op):
         """Convert TFLite EXP"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized EXP operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized EXP operator is not supported yet.")
         return self._convert_unary_elemwise(_op.exp, op)
 
     def convert_log(self, op):
         """Convert TFLite LOG"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized LOG operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized LOG operator is not supported yet.")
         return self._convert_unary_elemwise(_op.log, op)
 
     def convert_sin(self, op):
         """Convert TFLite SIN"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized SIN operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized SIN operator is not supported yet.")
         return self._convert_unary_elemwise(_op.sin, op)
 
     def convert_tan(self, op):
         """Convert TFLite TAN"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized TAN operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized TAN operator is not supported yet.")
         return self._convert_unary_elemwise(_op.tan, op)
 
     def convert_cos(self, op):
         """Convert TFLite COS"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized COS operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized COS operator is not supported yet.")
         return self._convert_unary_elemwise(_op.cos, op)
 
     def convert_sqrt(self, op):
         """Convert TFLite SQRT"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized SQRT operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized SQRT operator is not supported yet.")
         return self._convert_unary_elemwise(_op.sqrt, op)
 
     def convert_rsqrt(self, op):
         """Convert TFLite RSQRT"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized RSQRT operator is not supported yet.')
+                "TFlite quantized RSQRT operator is not supported yet."
+            )
         return self._convert_unary_elemwise(_op.rsqrt, op)
 
     def convert_neg(self, op):
         """Convert TFLite NEG"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized NEG operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized NEG operator is not supported yet.")
         return self._convert_unary_elemwise(_op.negative, op)
 
     def convert_elu(self, op):
         """Convert TFLite ELU"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized ELU operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized ELU operator is not supported yet.")
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
 
         input_tensor = input_tensors[0]
         in_expr = self.get_expr(input_tensor.tensor_idx)
         exp_type = self.get_tensor_type_str(input_tensor.tensor.Type())
-        out = relay.const(-1.0, exp_type) * \
-              _op.nn.relu(relay.const(1., exp_type) - _op.exp(in_expr)) + \
-              _op.nn.relu(in_expr)
+        out = relay.const(-1.0, exp_type) * _op.nn.relu(
+            relay.const(1.0, exp_type) - _op.exp(in_expr)
+        ) + _op.nn.relu(in_expr)
 
         return out
 
@@ -1109,7 +1126,8 @@ class OperatorConverter(object):
 
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized SQUARE operator is not supported yet.')
+                "TFlite quantized SQUARE operator is not supported yet."
+            )
 
         exp_type = self.get_tensor_type_str(output_tensor.tensor.Type())
         out = _op.power(in_expr, relay.const(2, exp_type))
@@ -1142,23 +1160,26 @@ class OperatorConverter(object):
         # TFLite format demands equal scale and zero_point tuple parameters for some operations
         # to allow us to use non-quantized operation instead of quantized if ignore_qnn_params=True
         if ignore_qnn_params:
-            assert  lhs_tensor.qnn_params \
-                and self.has_same_qnn_params(lhs_tensor, output_tensor) \
-                and self.has_same_qnn_params(rhs_tensor, output_tensor), \
-                "All tensors should be quantized with the same (scale,zero-point) tuple parameters"
+            assert (
+                lhs_tensor.qnn_params
+                and self.has_same_qnn_params(lhs_tensor, output_tensor)
+                and self.has_same_qnn_params(rhs_tensor, output_tensor)
+            ), "All tensors should be quantized with the same (scale,zero-point) tuple parameters"
 
         # If quantized, extracts qnn params and call QNN add operator.
         if not ignore_qnn_params and lhs_tensor.qnn_params:
             assert rhs_tensor.qnn_params, "Both tensors should be quantized."
             assert output_tensor.qnn_params, "Output tensor should be quantized."
-            out = relay_op(lhs=lhs_expr,
-                           rhs=rhs_expr,
-                           lhs_scale=lhs_tensor.qnn_params['scale'],
-                           lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
-                           rhs_scale=rhs_tensor.qnn_params['scale'],
-                           rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
-                           output_scale=output_tensor.qnn_params['scale'],
-                           output_zero_point=output_tensor.qnn_params['zero_point'])
+            out = relay_op(
+                lhs=lhs_expr,
+                rhs=rhs_expr,
+                lhs_scale=lhs_tensor.qnn_params["scale"],
+                lhs_zero_point=lhs_tensor.qnn_params["zero_point"],
+                rhs_scale=rhs_tensor.qnn_params["scale"],
+                rhs_zero_point=rhs_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+            )
         else:
             out = relay_op(lhs_expr, rhs_expr)
 
@@ -1180,15 +1201,16 @@ class OperatorConverter(object):
 
             # Handle fused activations
             if not ignore_qnn_params and output_tensor.qnn_params:
-                scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
-                zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+                scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
+                zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"])
                 output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
-                out = self.convert_qnn_fused_activation_function(\
-                        expr=out,
-                        fused_activation_fn=fused_activation_fn,
-                        scale=scale_val,
-                        zero_point=zero_point_val,
-                        dtype=output_tensor_type_str)
+                out = self.convert_qnn_fused_activation_function(
+                    expr=out,
+                    fused_activation_fn=fused_activation_fn,
+                    scale=scale_val,
+                    zero_point=zero_point_val,
+                    dtype=output_tensor_type_str,
+                )
             else:
                 out = self.convert_fused_activation_function(out, fused_activation_fn)
         return out
@@ -1232,16 +1254,14 @@ class OperatorConverter(object):
         """Convert TFLite DIV"""
         # Check if the input tensor is quantized, call QNN op
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized DIV operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized DIV operator is not supported yet.")
         return self._convert_elemwise(_op.divide, op)
 
     def convert_pow(self, op):
         """Convert TFLite POW"""
         # Check if the input tensor is quantized, call QNN op
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized POW operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized POW operator is not supported yet.")
         return self._convert_elemwise(_op.power, op)
 
     def convert_maximum(self, op):
@@ -1257,7 +1277,8 @@ class OperatorConverter(object):
         # Check if the input tensor is quantized, call QNN op
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized GREATER operator is not supported yet.')
+                "TFlite quantized GREATER operator is not supported yet."
+            )
         return self._convert_elemwise(_op.greater, op)
 
     def convert_squared_difference(self, op):
@@ -1265,7 +1286,8 @@ class OperatorConverter(object):
         # Check if the input tensor is quantized, call QNN op
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized squared difference operator is not supported yet.')
+                "TFlite quantized squared difference operator is not supported yet."
+            )
         difference = self._convert_elemwise(_op.subtract, op)
         # _convert_elemwise has guaranteed only have one output tensor
         exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
@@ -1276,35 +1298,38 @@ class OperatorConverter(object):
         """Convert TFLite GREATER_EQUAL"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized GREATER_EQUAL operator is not supported yet.')
+                "TFlite quantized GREATER_EQUAL operator is not supported yet."
+            )
         return self._convert_elemwise(_op.greater_equal, op)
 
     def convert_less(self, op):
         """Convert TFLite LESS"""
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized LESS operator is not supported yet.')
+            raise tvm.error.OpNotImplemented("TFlite quantized LESS operator is not supported yet.")
         return self._convert_elemwise(_op.less, op)
 
     def convert_less_equal(self, op):
         """Convert TFLite LESS_EQUAL"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized LESS_EQUAL operator is not supported yet.')
+                "TFlite quantized LESS_EQUAL operator is not supported yet."
+            )
         return self._convert_elemwise(_op.less_equal, op)
 
     def convert_equal(self, op):
         """Convert TFLite EQUAL"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized EQUAL operator is not supported yet.')
+                "TFlite quantized EQUAL operator is not supported yet."
+            )
         return self._convert_elemwise(_op.equal, op)
 
     def convert_not_equal(self, op):
         """Convert TFLite NOT_EQUAL"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized NOT_EQUAL operator is not supported yet.')
+                "TFlite quantized NOT_EQUAL operator is not supported yet."
+            )
         return self._convert_elemwise(_op.not_equal, op)
 
     def _convert_logical_binary(self, relay_op, op):
@@ -1373,18 +1398,21 @@ class OperatorConverter(object):
             indices_expr = self.get_expr(indices.tensor_idx)
         else:
             indices_val = self.get_tensor_value(indices)
-            indices_expr = self.exp_tab.new_const(indices_val,
-                                                  dtype=self.get_tensor_type_str(indices_type))
+            indices_expr = self.exp_tab.new_const(
+                indices_val, dtype=self.get_tensor_type_str(indices_type)
+            )
             indices_shape = list(indices_val.shape)
             indices_len = len(indices_shape)
 
-            out_shape = data_shape[:axis] + indices_shape[:] + data_shape[axis+1:]
+            out_shape = data_shape[:axis] + indices_shape[:] + data_shape[axis + 1 :]
 
             loopover = [range(s) for s in out_shape]
             for idx in list(itertools.product(*loopover)):
-                real_indices = list(idx[:axis]) \
-                    + [indices_val[idx[axis: axis + indices_len]]] \
-                    + list(idx[axis + indices_len:])
+                real_indices = (
+                    list(idx[:axis])
+                    + [indices_val[idx[axis : axis + indices_len]]]
+                    + list(idx[axis + indices_len :])
+                )
                 if np.any(np.subtract(data_shape, real_indices) < 0):
                     raise ValueError("TFLite out of bound indices are not supported.")
 
@@ -1412,45 +1440,45 @@ class OperatorConverter(object):
         assert indices_type in (TensorType.INT32, TensorType.INT64)
 
         indices_dims = len(_infer_shape(indices))
-        indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1)))
+        indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
 
         out = _op.gather_nd(data, indices_t)
         return out
 
     def convert_strided_slice(self, op):
         """Method to Convert TFLite STRIDED_SLICE operator.
-           NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
-           and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
-           But in future, they may open up the mask implementation, so kept the implementation
-           same as tensorflow.
+        NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
+        and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
+        But in future, they may open up the mask implementation, so kept the implementation
+        same as tensorflow.
 
-           This op extracts a slice of size (end - begin) / stride from the given input tensor.
-           Starting at the location specified by begin the slice continues by adding stride to the
-           index until all dimensions are not less than end. Note that a stride can be negative,
-           which causes a reverse slice.
+        This op extracts a slice of size (end - begin) / stride from the given input tensor.
+        Starting at the location specified by begin the slice continues by adding stride to the
+        index until all dimensions are not less than end. Note that a stride can be negative,
+        which causes a reverse slice.
 
-           For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.
+        For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.
 
-           In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
-           the ith bit will correspond to the ith val.
+        In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
+        the ith bit will correspond to the ith val.
 
-           If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
-           in that dimension is used instead.
+        If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
+        in that dimension is used instead.
 
-           If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
-           inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.
+        If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
+        inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.
 
-           If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
-           new length 1 dimension is added at this point in the output tensor.
+        If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
+        new length 1 dimension is added at this point in the output tensor.
 
-           If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
-           the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
-           are ignored in this case.
-           begin and end are zero-indexed. strides entries must be non-zero.
+        If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
+        the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
+        are ignored in this case.
+        begin and end are zero-indexed. strides entries must be non-zero.
 
-           TVM Relay implementation of doesn't support mask, so the mask values are processed in
-           this function and begin/end/strides are updated accordingly. If any mask is present, and
-           since tvm doesn't support mask computation directly, the output need a final reshape.
+        TVM Relay implementation of doesn't support mask, so the mask values are processed in
+        this function and begin/end/strides are updated accordingly. If any mask is present, and
+        since tvm doesn't support mask computation directly, the output need a final reshape.
         """
         try:
             from tflite.BuiltinOptions import BuiltinOptions
@@ -1480,13 +1508,14 @@ class OperatorConverter(object):
         data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
         data_dim = len(data_shape)
         stride_dim = len(stride)
+
         def _transform_mask(stride_dim, ellipsis_mask):
             """Handle mask inputs to create new begin, end, stride and output shape"""
             m_begin = [0] * data_dim
             m_end = [0] * data_dim
             m_stride = [0] * data_dim
             fshape_indices = []
-            #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
+            # Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
             ellipsis_seen = False
             new_axes_after_ellipsis = 0
             for i in range(stride_dim):
@@ -1496,42 +1525,44 @@ class OperatorConverter(object):
                 if (mask & ellipsis_mask) != 0:
                     ellipsis_seen = True
             if not ellipsis_seen:
-                #Used later for extending the stride attributes in the below loop.
-                ellipsis_mask |= (1 << stride_dim)
+                # Used later for extending the stride attributes in the below loop.
+                ellipsis_mask |= 1 << stride_dim
                 stride_dim += 1
             final_index = 0
             for index in range(stride_dim):
                 mask = 1 << index
                 if mask & ellipsis_mask:
-                    #Identify the end index for applying ellipsis_mask
-                    to_index = min(((data_dim - (stride_dim-index)) + 1 \
-                                     + new_axes_after_ellipsis), data_dim)
+                    # Identify the end index for applying ellipsis_mask
+                    to_index = min(
+                        ((data_dim - (stride_dim - index)) + 1 + new_axes_after_ellipsis), data_dim
+                    )
                     for i in range(final_index, to_index):
                         m_begin[final_index] = 0
                         m_end[final_index] = data_shape[final_index]
                         m_stride[final_index] = 1
                         fshape_indices.append(final_index)
                         final_index += 1
-                elif mask &new_axis_mask:
+                elif mask & new_axis_mask:
                     fshape_indices.append(-1)
                 elif not mask & new_axis_mask:
                     if final_index == len(m_begin):
                         break
                     if mask & begin_mask:
-                        m_begin[final_index] = data_shape[final_index] \
-                                                     if stride[index] < 0 else 0
+                        m_begin[final_index] = data_shape[final_index] if stride[index] < 0 else 0
                     elif begin[index]:
                         m_begin[final_index] = begin[index]
                     if mask & end_mask:
-                        m_end[final_index] = 0 if stride[index] < 0 \
-                                                 else data_shape[final_index]
+                        m_end[final_index] = 0 if stride[index] < 0 else data_shape[final_index]
                     elif end[index]:
                         m_end[final_index] = end[index]
                     m_stride[final_index] = stride[index]
                     if mask & shrink_axis_mask:
-                        #Tensorflow make axis with shrink_axis_mask as dimension 1
-                        m_begin[final_index] = data_shape[final_index] + begin[index] \
-                                                 if begin[index] < 0 else begin[index]
+                        # Tensorflow make axis with shrink_axis_mask as dimension 1
+                        m_begin[final_index] = (
+                            data_shape[final_index] + begin[index]
+                            if begin[index] < 0
+                            else begin[index]
+                        )
                         m_end[final_index] = begin[index] + 1
                         m_stride[final_index] = 1
                         fshape_indices.append(-2)
@@ -1550,7 +1581,7 @@ class OperatorConverter(object):
         if not fshape_indices:
             fshape_indices = range(len(out_shape))
 
-        #Create final output shape.
+        # Create final output shape.
         final_output = []
         for gather_index in fshape_indices:
             if gather_index == -1:
@@ -1581,8 +1612,9 @@ class OperatorConverter(object):
         assert len(input_tensors) == 2, "input tensors length should be 2"
 
         if self.has_expr(input_tensors[0].tensor_idx):
-            raise tvm.error.OpNotImplemented("For dims parameter of Fill operator,"
-                                             " only constant values are supported.")
+            raise tvm.error.OpNotImplemented(
+                "For dims parameter of Fill operator," " only constant values are supported."
+            )
 
         in_dims = list(self.get_tensor_value(input_tensors[0]))
         in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
@@ -1626,12 +1658,14 @@ class OperatorConverter(object):
         output_tensor = output_tensors[0]
         output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
         if output_tensor.qnn_params:
-            out = _qnn.op.requantize(out,
-                                     input_scale=input_tensor.qnn_params['scale'],
-                                     input_zero_point=input_tensor.qnn_params['zero_point'],
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=input_tensor.qnn_params["scale"],
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
 
         return out
 
@@ -1694,7 +1728,8 @@ class OperatorConverter(object):
         """Convert TFLite ARG_MIN"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized ARG_MIN operator is not supported yet.')
+                "TFlite quantized ARG_MIN operator is not supported yet."
+            )
         return self._convert_arg_min_max(_op.argmin, op)
 
     def convert_arg_max(self, op):
@@ -1738,7 +1773,7 @@ class OperatorConverter(object):
         in_expr = self.get_tensor_expr(input_tensor)
         in_expr = _op.reshape(in_expr, target_shape)
 
-        #TODO: Change the output shape calculation based on keep_dim option
+        # TODO: Change the output shape calculation based on keep_dim option
         assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions
         op_options = op.BuiltinOptions()
         fully_connected_options = FullyConnectedOptions()
@@ -1758,13 +1793,16 @@ class OperatorConverter(object):
         weight_shape = _infer_shape(weight_expr)
 
         if input_tensor.qnn_params:
-            out = _qnn.op.dense(in_expr, weight_expr,
-                                input_zero_point=input_tensor.qnn_params['zero_point'],
-                                kernel_zero_point=weight_tensor.qnn_params['zero_point'],
-                                input_scale=input_tensor.qnn_params['scale'],
-                                kernel_scale=weight_tensor.qnn_params['scale'],
-                                units=weight_shape[0],
-                                out_dtype='int32')
+            out = _qnn.op.dense(
+                in_expr,
+                weight_expr,
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                kernel_zero_point=weight_tensor.qnn_params["zero_point"],
+                input_scale=input_tensor.qnn_params["scale"],
+                kernel_scale=weight_tensor.qnn_params["scale"],
+                units=weight_shape[0],
+                out_dtype="int32",
+            )
         else:
             out = _op.nn.dense(in_expr, weight_expr)
 
@@ -1775,37 +1813,41 @@ class OperatorConverter(object):
             # bias tensor type should be INT32 (quantization) or FLOAT32
             assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
             bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
-            bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
-                                               dtype=bias_tensor_type_str)
+            bias_expr = self.exp_tab.new_const(
+                self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
+            )
             out = _op.nn.bias_add(out, bias_expr)
 
         # Finally if the dense is quantized. Add a requantize at the end.
         if output_tensor.qnn_params:
-            data_scale = input_tensor.qnn_params['scale']
-            weight_scale = weight_tensor.qnn_params['scale']
+            data_scale = input_tensor.qnn_params["scale"]
+            weight_scale = weight_tensor.qnn_params["scale"]
             data_scale_val = get_scalar_from_constant(data_scale)
             weight_scale_val = get_scalar_from_constant(weight_scale)
             new_input_scale_val = data_scale_val * weight_scale_val
-            new_input_scale = relay.const(new_input_scale_val, 'float32')
-            new_input_zero_point = relay.const(0, 'int32')
+            new_input_scale = relay.const(new_input_scale_val, "float32")
+            new_input_zero_point = relay.const(0, "int32")
 
             # Requantize
-            out = _qnn.op.requantize(out,
-                                     input_scale=new_input_scale,
-                                     input_zero_point=new_input_zero_point,
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=new_input_scale,
+                input_zero_point=new_input_zero_point,
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
 
             # Call activation function
-            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
-            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=out,
-                    fused_activation_fn=fused_activation_fn,
-                    scale=output_scale_val,
-                    zero_point=output_zero_point_val,
-                    dtype=output_tensor_type_str)
+            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
+            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"])
+            out = self.convert_qnn_fused_activation_function(
+                expr=out,
+                fused_activation_fn=fused_activation_fn,
+                scale=output_scale_val,
+                zero_point=output_zero_point_val,
+                dtype=output_tensor_type_str,
+            )
 
         else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
@@ -1857,7 +1899,8 @@ class OperatorConverter(object):
             return _op.tanh(in_expr)
         fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
         raise tvm.error.OpNotImplemented(
-            'Fused activation {} is not supported yet.'.format(fused_activation_fn_str))
+            "Fused activation {} is not supported yet.".format(fused_activation_fn_str)
+        )
 
     def convert_conv(self, op, conv_type):
         """convolution implementation."""
@@ -1884,12 +1927,12 @@ class OperatorConverter(object):
         output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)
 
         is_depthwise_conv = False
-        if conv_type == 'conv2d':
+        if conv_type == "conv2d":
             assert op.BuiltinOptionsType() == BuiltinOptions.Conv2DOptions
             op_options = op.BuiltinOptions()
             conv_options = Conv2DOptions()
             conv_options.Init(op_options.Bytes, op_options.Pos)
-        elif conv_type == 'depthwise':
+        elif conv_type == "depthwise":
             is_depthwise_conv = True
             assert op.BuiltinOptionsType() == BuiltinOptions.DepthwiseConv2DOptions
             op_options = op.BuiltinOptions()
@@ -1898,7 +1941,8 @@ class OperatorConverter(object):
             depth_multiplier = conv_options.DepthMultiplier()
         else:
             raise tvm.error.OpNotImplemented(
-                'Operator {} is not supported for frontend TFLite.'.format(conv_type))
+                "Operator {} is not supported for frontend TFLite.".format(conv_type)
+            )
 
         stride_h = conv_options.StrideH()
         stride_w = conv_options.StrideW()
@@ -1920,21 +1964,23 @@ class OperatorConverter(object):
         dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
         dilated_kernel_w = dilation_w * (kernel_w - 1) + 1
 
-        params = {'kernel_size': [kernel_h, kernel_w],
-                  'strides': [stride_h, stride_w],
-                  'dilation': [dilation_h, dilation_w],
-                  'padding': [0, 0],
-                  'data_layout': 'NHWC'}
+        params = {
+            "kernel_size": [kernel_h, kernel_w],
+            "strides": [stride_h, stride_w],
+            "dilation": [dilation_h, dilation_w],
+            "padding": [0, 0],
+            "data_layout": "NHWC",
+        }
 
         if is_depthwise_conv:
-            params['channels'] = int(in_channels)
-            params['groups'] = int(input_c)
+            params["channels"] = int(in_channels)
+            params["groups"] = int(input_c)
             # If number of input channels is 1, treat as normal
             # convolution.
-            params['kernel_layout'] = 'HWIO' if input_c == 1 else 'HWOI'
+            params["kernel_layout"] = "HWIO" if input_c == 1 else "HWOI"
         else:
-            params['channels'] = int(output_channels)
-            params['kernel_layout'] = 'HWIO'
+            params["channels"] = int(output_channels)
+            params["kernel_layout"] = "HWIO"
 
         # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
         weight_tensor_type = weight_tensor.tensor.Type()
@@ -1964,19 +2010,20 @@ class OperatorConverter(object):
             pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
             do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
             if do_pad:
-                params['padding'] = [pad_top, pad_left, pad_bottom, pad_right]
+                params["padding"] = [pad_top, pad_left, pad_bottom, pad_right]
 
         else:
             raise tvm.error.OpAttributeUnImplemented(
-                'Padding format {} is not supported for operator Conv.'.format(padding))
+                "Padding format {} is not supported for operator Conv.".format(padding)
+            )
 
         if input_tensor.qnn_params:
             qnn_conv2d_params = dict(params)
-            qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point']
-            qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point']
-            qnn_conv2d_params['out_dtype'] = 'int32'
-            qnn_conv2d_params['input_scale'] = input_tensor.qnn_params['scale']
-            qnn_conv2d_params['kernel_scale'] = weight_tensor.qnn_params['scale']
+            qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"]
+            qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"]
+            qnn_conv2d_params["out_dtype"] = "int32"
+            qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"]
+            qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"]
             out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
         else:
             out = _op.nn.conv2d(in_expr, weight_expr, **params)
@@ -1988,18 +2035,19 @@ class OperatorConverter(object):
             # bias tensor type should be INT32 (quantization) or FLOAT32
             assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
             bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
-            bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
-                                               dtype=bias_tensor_type_str)
+            bias_expr = self.exp_tab.new_const(
+                self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
+            )
             channel_axis = 3
             out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
 
         # Handle fused activation.
         if output_tensor.qnn_params:
             # Calculate the intermediate scale and zero point of the int32 output.
-            data_scale = input_tensor.qnn_params['scale']
+            data_scale = input_tensor.qnn_params["scale"]
             data_scale_val = get_scalar_from_constant(data_scale)
 
-            weight_scale = weight_tensor.qnn_params['scale']
+            weight_scale = weight_tensor.qnn_params["scale"]
             # If weight scale is scalar, it is per-tensor quantization
             if isinstance(weight_scale, float):
                 weight_scale_val = get_scalar_from_constant(weight_scale)
@@ -2007,27 +2055,30 @@ class OperatorConverter(object):
                 weight_scale_val = get_tensor_from_constant(weight_scale)
 
             new_input_scale_val = data_scale_val * weight_scale_val
-            new_input_scale = relay.const(new_input_scale_val, 'float32')
-            new_input_zero_point = relay.const(0, 'int32')
+            new_input_scale = relay.const(new_input_scale_val, "float32")
+            new_input_zero_point = relay.const(0, "int32")
 
             # Finally requantize
-            out = _qnn.op.requantize(out,
-                                     input_scale=new_input_scale,
-                                     input_zero_point=new_input_zero_point,
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str,
-                                     axis=3)
+            out = _qnn.op.requantize(
+                out,
+                input_scale=new_input_scale,
+                input_zero_point=new_input_zero_point,
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+                axis=3,
+            )
 
             # Call activation function
-            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
-            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=out,
-                    fused_activation_fn=fused_activation_fn,
-                    scale=output_scale_val,
-                    zero_point=output_zero_point_val,
-                    dtype=output_tensor_type_str)
+            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
+            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"])
+            out = self.convert_qnn_fused_activation_function(
+                expr=out,
+                fused_activation_fn=fused_activation_fn,
+                scale=output_scale_val,
+                zero_point=output_zero_point_val,
+                dtype=output_tensor_type_str,
+            )
         else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
         return out
@@ -2078,8 +2129,10 @@ class OperatorConverter(object):
         in_expr = self.get_expr(input_tensor_idx)
 
         if self.has_expr(input_tensors[1].tensor_idx):
-            raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, "
-                                             "only constant values are supported.")
+            raise tvm.error.OpNotImplemented(
+                "For size_splits parameter of SPLIT_V operator, "
+                "only constant values are supported."
+            )
         size_splits = list(self.get_tensor_value(input_tensors[1]))
         size_splits = tuple(np.cumsum(size_splits)[:-1])
 
@@ -2160,7 +2213,8 @@ class OperatorConverter(object):
 
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFLite does not support quantized REVERSE_SEQUENCE operator yet.')
+                "TFLite does not support quantized REVERSE_SEQUENCE operator yet."
+            )
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 2, "input tensors length should be 2"
@@ -2259,10 +2313,12 @@ class OperatorConverter(object):
         filter_w = pool2d_options.FilterWidth()
         fused_activation_fn = pool2d_options.FusedActivationFunction()
 
-        params = {'pool_size': (filter_h, filter_w),
-                  'strides': (stride_h, stride_w),
-                  'padding': [0, 0],
-                  'layout': 'NHWC'}
+        params = {
+            "pool_size": (filter_h, filter_w),
+            "strides": (stride_h, stride_w),
+            "padding": [0, 0],
+            "layout": "NHWC",
+        }
 
         in_expr = self.get_expr(input_tensor_idx)
 
@@ -2272,16 +2328,18 @@ class OperatorConverter(object):
         elif padding == Padding.SAME:
             pad_top, pad_bottom = get_pad_value(input_h, filter_h, stride_h)
             pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w)
-            params['padding'] = [pad_top, pad_left, pad_bottom, pad_right]
+            params["padding"] = [pad_top, pad_left, pad_bottom, pad_right]
         else:
             raise tvm.error.OpAttributeUnImplemented(
-                'Padding format {} for operator Pool2D is not supported.'.format(padding))
+                "Padding format {} for operator Pool2D is not supported.".format(padding)
+            )
 
         if pool_type == "average":
             if input_tensor.qnn_params:
-                assert self.has_same_qnn_params(input_tensor, output_tensor), \
-                        'TFLite avg_pool2dreshape requires input and output scale' \
-                        'and zero points to be equal'
+                assert self.has_same_qnn_params(input_tensor, output_tensor), (
+                    "TFLite avg_pool2dreshape requires input and output scale"
+                    "and zero points to be equal"
+                )
                 out = _op.cast(in_expr, dtype="int32")
                 out = _op.nn.avg_pool2d(out, **params)
                 out = _op.cast(out, dtype=output_tensor_type_str)
@@ -2289,14 +2347,16 @@ class OperatorConverter(object):
                 out = _op.nn.avg_pool2d(in_expr, **params)
         elif pool_type == "max":
             if input_tensor.qnn_params:
-                assert self.has_same_qnn_params(input_tensor, output_tensor), \
-                        "qnn.op.max_pool2d requires input and output qnn params to be same"
+                assert self.has_same_qnn_params(
+                    input_tensor, output_tensor
+                ), "qnn.op.max_pool2d requires input and output qnn params to be same"
             out = _op.nn.max_pool2d(in_expr, **params)
         elif pool_type == "l2":
             # L2_POOL_2D is equivalent to square_root(avg_pool(square(in_data)))
             # TFLite does not have support for quantised L2_POOL_2D op.
-            assert not input_tensor.qnn_params, \
-                "As TFLite does not have support for quantized L2_POOL_2D, \
+            assert (
+                not input_tensor.qnn_params
+            ), "As TFLite does not have support for quantized L2_POOL_2D, \
                 Quantized input is not expected."
             exp_type = self.get_tensor_type_str(output_tensor.tensor.Type())
             square_exp = _op.power(in_expr, relay.const(2, exp_type))
@@ -2304,18 +2364,20 @@ class OperatorConverter(object):
             out = _op.sqrt(avg_pool_exp)
         else:
             raise tvm.error.OpNotImplemented(
-                'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))
+                "Operator {} is not supported for frontend TFLite.".format(pool_type + " pool")
+            )
 
         # Handle fused activations
         if output_tensor.qnn_params:
-            scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
-            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
-            out = self.convert_qnn_fused_activation_function(\
-                    expr=out,
-                    fused_activation_fn=fused_activation_fn,
-                    scale=scale_val,
-                    zero_point=zero_point_val,
-                    dtype=output_tensor_type_str)
+            scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
+            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"])
+            out = self.convert_qnn_fused_activation_function(
+                expr=out,
+                fused_activation_fn=fused_activation_fn,
+                scale=scale_val,
+                zero_point=zero_point_val,
+                dtype=output_tensor_type_str,
+            )
         else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
 
@@ -2328,12 +2390,14 @@ class OperatorConverter(object):
         input_tensors = self.get_input_tensors(op)
 
         # TFLite PAD/PADV2 only supports CONSTANT mode
-        assert (len(input_tensors) == 2 or len(input_tensors) == 3), \
-            "input tensor's length should be 2 for PAD and 3 for PADV2"
+        assert (
+            len(input_tensors) == 2 or len(input_tensors) == 3
+        ), "input tensor's length should be 2 for PAD and 3 for PADV2"
 
         if len(input_tensors) == 3:
-            assert input_tensors[0].tensor.Type() == input_tensors[2].tensor.Type(), \
-                "constant_values tensor must be of same type as input tensor"
+            assert (
+                input_tensors[0].tensor.Type() == input_tensors[2].tensor.Type()
+            ), "constant_values tensor must be of same type as input tensor"
 
         input_tensor = input_tensors[0]
         in_expr = self.get_expr(input_tensor.tensor_idx)
@@ -2351,11 +2415,12 @@ class OperatorConverter(object):
             # Check that input and output tensor have same qnn params.
             output_tensors = self.get_output_tensors(op)
             output_tensor = output_tensors[0]
-            assert self.has_same_qnn_params(input_tensor, output_tensor), \
-                "TFLite PADV2 requires input and output scale and zero points to be equal"
+            assert self.has_same_qnn_params(
+                input_tensor, output_tensor
+            ), "TFLite PADV2 requires input and output scale and zero points to be equal"
 
             # The pad value for quantized pad is the input zero point by default.
-            pad_value = float(input_tensor.qnn_params['zero_point'].data.asnumpy())
+            pad_value = float(input_tensor.qnn_params["zero_point"].data.asnumpy())
 
         if len(input_tensors) == 3:
             pad_value = self.get_tensor_value(input_tensors[2])
@@ -2366,26 +2431,28 @@ class OperatorConverter(object):
                 pad_value = pad_value[0]
             if input_tensor.qnn_params:
                 # Check that input tensor and constant_values have same qnn params.
-                assert self.has_same_qnn_params(input_tensor, input_tensors[2]), \
-                    "TFLite PADV2 requires input and constant_values tensors' \
+                assert self.has_same_qnn_params(
+                    input_tensor, input_tensors[2]
+                ), "TFLite PADV2 requires input and constant_values tensors' \
                         scale and zero points to be equal"
 
         out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value)
         return out
 
-
     def convert_floor_div(self, op):
         """Convert TFLite FLOOR_DIV"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized FLOOR DIV operator is not supported yet.')
+                "TFlite quantized FLOOR DIV operator is not supported yet."
+            )
         return self._convert_elemwise(_op.floor_divide, op)
 
     def convert_floor_mod(self, op):
         """Convert TFLite FLOOR_MOD"""
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized FLOOR MOD operator is not supported yet.')
+                "TFlite quantized FLOOR MOD operator is not supported yet."
+            )
         return self._convert_elemwise(_op.floor_mod, op)
 
     def convert_mirror_pad(self, op):
@@ -2399,7 +2466,8 @@ class OperatorConverter(object):
         # the quantized form MirrorPad is not yet implemented in TFLite.
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
-                'TFlite quantized MIRROR_PAD operator is not supported yet.')
+                "TFlite quantized MIRROR_PAD operator is not supported yet."
+            )
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 2, "input tensors length should be 2"
@@ -2482,12 +2550,13 @@ class OperatorConverter(object):
             if isinstance(squeezed, _expr.TupleWrapper):
                 squeezed = squeezed[0]
         else:
-            splitted = _op.split(in_expr,
-                                 indices_or_sections=num_unpacks,
-                                 axis=unpack_axis)
+            splitted = _op.split(in_expr, indices_or_sections=num_unpacks, axis=unpack_axis)
             squeezed = _expr.TupleWrapper(
-                _expr.Tuple([_op.squeeze(split_item, axis=squeeze_axis) \
-                             for split_item in splitted]), len(splitted))
+                _expr.Tuple(
+                    [_op.squeeze(split_item, axis=squeeze_axis) for split_item in splitted]
+                ),
+                len(splitted),
+            )
 
         return squeezed
 
@@ -2515,8 +2584,11 @@ class OperatorConverter(object):
         reshaped = _op.reshape(in_expr, newshape=shape1)
 
         # Permute dimensions of reshaped to produce permuted of shape
-        axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \
-            list(range(2 * M + 1, len(shape1)))
+        axes = (
+            [M]
+            + [axis for i in range(M) for axis in [M + i + 1, i]]
+            + list(range(2 * M + 1, len(shape1)))
+        )
         permuted = _op.transpose(reshaped, axes=axes)
 
         # Reshape permuted to produce reshaped_permuted of shape
@@ -2533,7 +2605,7 @@ class OperatorConverter(object):
                 indices = _op.arange(
                     _expr.const(crop[0]),
                     _expr.const(reshaped_permuted_shape[axis] - crop[1]),
-                    dtype='int32'
+                    dtype="int32",
                 )
                 cropped = _op.take(cropped, indices=indices, axis=axis)
 
@@ -2578,14 +2650,18 @@ class OperatorConverter(object):
         reshaped_padded = _op.reshape(padded, newshape=shape1)
 
         # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
-        axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
-            list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
+        axes = (
+            [2 * i + 2 for i in range(M)]
+            + [0]
+            + [2 * i + 1 for i in range(M)]
+            + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
+        )
         permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes)
         permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)
 
         # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
         # producing an output tensor of shape:
-        shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
+        shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :]
         reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2)
 
         return reshaped_permuted_reshaped_padded
@@ -2609,7 +2685,7 @@ class OperatorConverter(object):
         depth_to_space_options = DepthToSpaceOptions()
         depth_to_space_options.Init(op_options.Bytes, op_options.Pos)
         block_size = depth_to_space_options.BlockSize()
-        out = _op.nn.depth_to_space(in_expr, block_size, layout='NHWC')
+        out = _op.nn.depth_to_space(in_expr, block_size, layout="NHWC")
 
         return out
 
@@ -2632,7 +2708,7 @@ class OperatorConverter(object):
         space_to_depth_options = SpaceToDepthOptions()
         space_to_depth_options.Init(op_options.Bytes, op_options.Pos)
         block_size = space_to_depth_options.BlockSize()
-        out = _op.nn.space_to_depth(in_expr, block_size, layout='NHWC')
+        out = _op.nn.space_to_depth(in_expr, block_size, layout="NHWC")
 
         return out
 
@@ -2661,7 +2737,7 @@ class OperatorConverter(object):
             self.get_tensor_expr(indices),
             list(self.get_tensor_value(output_shape)),
             self.get_tensor_expr(values),
-            self.get_tensor_expr(default_value)
+            self.get_tensor_expr(default_value),
         )
 
         return out
@@ -2675,8 +2751,9 @@ class OperatorConverter(object):
         alpha_tensor = input_tensors[1]
         alpha_tensor_type = alpha_tensor.tensor.Type()
         alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
-        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor).flatten(),
-                                            dtype=alpha_tensor_type_str)
+        alpha_expr = self.exp_tab.new_const(
+            self.get_tensor_value(alpha_tensor).flatten(), dtype=alpha_tensor_type_str
+        )
         in_expr = self.get_expr(input_tensor.tensor_idx)
         out = _op.nn.prelu(in_expr, alpha_expr, axis=3)
 
@@ -2701,8 +2778,9 @@ class OperatorConverter(object):
         # Weights tensor. TFLite uses OHWI layout
         weights_tensor = input_tensors[1]
         out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
-        assert input_c == in_channels, \
-            "Input channel in the filter should match to channel in the input"
+        assert (
+            input_c == in_channels
+        ), "Input channel in the filter should match to channel in the input"
         # output_shape Tensor. NHWC layout
         output_shape_tensor = input_tensors[0]
 
@@ -2720,8 +2798,10 @@ class OperatorConverter(object):
         padding = deconv_options.Padding()
         stride_h = deconv_options.StrideH()
         stride_w = deconv_options.StrideW()
-        assert padding in (Padding.VALID, Padding.SAME), \
-            'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)
+        assert padding in (
+            Padding.VALID,
+            Padding.SAME,
+        ), "Padding format {} is not supported for operator TRANSPOSE_CONV".format(padding)
 
         # Data
         in_expr = self.get_expr(input_tensor.tensor_idx)
@@ -2740,8 +2820,9 @@ class OperatorConverter(object):
         # Output shape value
         output_shape_value = self.get_tensor_value(output_shape_tensor)
         # Relay expects filter output channel to match to output tensor channel.
-        assert out_channels == output_shape_value[3], \
-            "Output channel in the filter should match to channel in the output_shape"
+        assert (
+            out_channels == output_shape_value[3]
+        ), "Output channel in the filter should match to channel in the output_shape"
 
         if padding == Padding.SAME:
             pad_top, pad_bottom = get_pad_value(input_h, kernel_h, stride_h)
@@ -2750,14 +2831,17 @@ class OperatorConverter(object):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(in_expr, weight_expr_iohw,
-                                      strides=(stride_h, stride_w),
-                                      padding=padding,
-                                      channels=int(out_channels),
-                                      kernel_size=(int(kernel_h), int(kernel_w)),
-                                      data_layout="NHWC",
-                                      kernel_layout="OIHW",
-                                      out_dtype=output_tensor_type_str)
+        out = _op.nn.conv2d_transpose(
+            in_expr,
+            weight_expr_iohw,
+            strides=(stride_h, stride_w),
+            padding=padding,
+            channels=int(out_channels),
+            kernel_size=(int(kernel_h), int(kernel_w)),
+            data_layout="NHWC",
+            kernel_layout="OIHW",
+            out_dtype=output_tensor_type_str,
+        )
 
         return out
 
@@ -2782,12 +2866,14 @@ class OperatorConverter(object):
         if input_tensor_type_str == "float32":
             out = self.quantize(in_expr, output_tensor)
         else:
-            out = _qnn.op.requantize(in_expr,
-                                     input_scale=input_tensor.qnn_params['scale'],
-                                     input_zero_point=input_tensor.qnn_params['zero_point'],
-                                     output_scale=output_tensor.qnn_params['scale'],
-                                     output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+            out = _qnn.op.requantize(
+                in_expr,
+                input_scale=input_tensor.qnn_params["scale"],
+                input_zero_point=input_tensor.qnn_params["zero_point"],
+                output_scale=output_tensor.qnn_params["scale"],
+                output_zero_point=output_tensor.qnn_params["zero_point"],
+                out_dtype=output_tensor_type_str,
+            )
         return out
 
     def convert_dequantize(self, op):
@@ -2813,8 +2899,9 @@ class OperatorConverter(object):
         if "use_regular_nms" in custom_options:
             if custom_options["use_regular_nms"]:
                 raise tvm.error.OpAttributeUnImplemented(
-                    "use_regular_nms=True is not yet supported for operator {}."
-                    .format("TFLite_Detection_PostProcess")
+                    "use_regular_nms=True is not yet supported for operator {}.".format(
+                        "TFLite_Detection_PostProcess"
+                    )
                 )
 
         inputs = self.get_input_tensors(op)
@@ -2828,17 +2915,23 @@ class OperatorConverter(object):
         anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
 
         if inputs[0].qnn_params:
-            loc_prob = _qnn.op.dequantize(data=loc_prob,
-                                          input_scale=inputs[0].qnn_params['scale'],
-                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+            loc_prob = _qnn.op.dequantize(
+                data=loc_prob,
+                input_scale=inputs[0].qnn_params["scale"],
+                input_zero_point=inputs[0].qnn_params["zero_point"],
+            )
         if inputs[1].qnn_params:
-            cls_pred = _qnn.op.dequantize(data=cls_pred,
-                                          input_scale=inputs[1].qnn_params['scale'],
-                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+            cls_pred = _qnn.op.dequantize(
+                data=cls_pred,
+                input_scale=inputs[1].qnn_params["scale"],
+                input_zero_point=inputs[1].qnn_params["zero_point"],
+            )
         if inputs[2].qnn_params:
-            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
-                                             input_scale=inputs[2].qnn_params['scale'],
-                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+            anchor_expr = _qnn.op.dequantize(
+                data=anchor_expr,
+                input_scale=inputs[2].qnn_params["scale"],
+                input_zero_point=inputs[2].qnn_params["zero_point"],
+            )
 
         # reshape the cls_pred and loc_prob tensors so
         # they can be consumed by multibox_transform_loc
@@ -2849,7 +2942,7 @@ class OperatorConverter(object):
         loc_prob = _op.concatenate(
             [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
         )
-        loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4])
+        loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes * 4])
 
         # anchor coords are in yxhw format
         # need to convert to ltrb
@@ -2858,8 +2951,8 @@ class OperatorConverter(object):
         anchor_x = anchor_coords[1]
         anchor_h = anchor_coords[2]
         anchor_w = anchor_coords[3]
-        plus_half = _expr.const(0.5, dtype='float32')
-        minus_half = _expr.const(-0.5, dtype='float32')
+        plus_half = _expr.const(0.5, dtype="float32")
+        minus_half = _expr.const(-0.5, dtype="float32")
         anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
         anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
         anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
@@ -2887,15 +2980,16 @@ class OperatorConverter(object):
         non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
         non_max_suppression_attrs["invalid_to_bottom"] = False
 
-        ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
-                                                anchor_expr, **multibox_transform_loc_attrs)
+        ret = _op.vision.multibox_transform_loc(
+            cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs
+        )
         ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
         ret = _op.vision.get_valid_counts(ret, 0)
         valid_count = ret[0]
         # keep only the top 'max_detections' rows
-        ret = _op.strided_slice(ret[1],
-                                [0, 0, 0],
-                                [batch_size, custom_options["max_detections"], anchor_boxes])
+        ret = _op.strided_slice(
+            ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], anchor_boxes]
+        )
         # the output needs some reshaping to match tflite
         ret = _op.split(ret, 6, axis=2)
         cls_ids = _op.reshape(ret[0], [batch_size, -1])
@@ -2912,8 +3006,9 @@ class OperatorConverter(object):
         if input_tensors[0].qnn_params:
             # Check that input and output tensor have same qnn params.
             output_tensors = self.get_output_tensors(op)
-            assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
-                "TFLite EXPAND_DIMS requires input and output tensors' \
+            assert self.has_same_qnn_params(
+                input_tensors[0], output_tensors[0]
+            ), "TFLite EXPAND_DIMS requires input and output tensors' \
                     scale and zero points to be equal"
 
         input_expr = self.get_tensor_expr(input_tensors[0])
@@ -2923,7 +3018,7 @@ class OperatorConverter(object):
             axis = int(axis)
 
         ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
-        assert (-1-ndims <= axis <= ndims), "axis out of range"
+        assert -1 - ndims <= axis <= ndims, "axis out of range"
 
         out = _op.expand_dims(input_expr, axis, 1)
 
@@ -2941,8 +3036,7 @@ class OperatorConverter(object):
         assert len(input_tensors) == 4, "Input tensor's length should be 4"
 
         # Ensuring input isn't quantized
-        assert all(not i.qnn_params for i in input_tensors), \
-            "Quantized input is not expected."
+        assert all(not i.qnn_params for i in input_tensors), "Quantized input is not expected."
 
         # TFlite ONE_HOT requires both on_value
         # and off_value, making dtype redundant.
@@ -2951,8 +3045,9 @@ class OperatorConverter(object):
         on_value = input_tensors[2]
         off_value = input_tensors[3]
 
-        assert on_value.tensor.Type() == off_value.tensor.Type(), \
-            "on_value and off_value should be the same type"
+        assert (
+            on_value.tensor.Type() == off_value.tensor.Type()
+        ), "on_value and off_value should be the same type"
 
         # Getting relay expr
         indices_expr = self.get_expr(indices.tensor_idx)
@@ -3000,19 +3095,22 @@ class OperatorConverter(object):
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 2, "input tensor's length should be 2"
 
-        assert input_tensors[0].tensor.Type() == input_tensors[1].tensor.Type(), \
-            "input and diagonal should be the same type of tensors"
+        assert (
+            input_tensors[0].tensor.Type() == input_tensors[1].tensor.Type()
+        ), "input and diagonal should be the same type of tensors"
 
         if input_tensors[0].qnn_params:
             # Check that input and output tensor have same qnn params.
             output_tensors = self.get_output_tensors(op)
-            assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
-                "TFLite MATRIX_SET_DIAG requires input and output tensors' \
+            assert self.has_same_qnn_params(
+                input_tensors[0], output_tensors[0]
+            ), "TFLite MATRIX_SET_DIAG requires input and output tensors' \
                     scale and zero points to be equal"
 
             # Check that input and diagonal tensor have same qnn params.
-            assert self.has_same_qnn_params(input_tensors[0], input_tensors[1]), \
-                "TFLite MATRIX_SET_DIAG requires input and diagonal tensors' \
+            assert self.has_same_qnn_params(
+                input_tensors[0], input_tensors[1]
+            ), "TFLite MATRIX_SET_DIAG requires input and diagonal tensors' \
                     scale and zero points to be equal"
 
         input_expr = self.get_tensor_expr(input_tensors[0])
@@ -3031,8 +3129,9 @@ class OperatorConverter(object):
         if diagonal.qnn_params:
             # Check that diagonal and output tensor have same qnn params.
             output_tensors = self.get_output_tensors(op)
-            assert self.has_same_qnn_params(diagonal, output_tensors[0]), \
-                "TFLite MATRIX_DIAG requires diagonal and output tensors' \
+            assert self.has_same_qnn_params(
+                diagonal, output_tensors[0]
+            ), "TFLite MATRIX_DIAG requires diagonal and output tensors' \
                     scale and zero points to be equal"
 
         shape = diagonal.tensor.ShapeAsNumpy()
@@ -3045,7 +3144,6 @@ class OperatorConverter(object):
         out = _op.matrix_set_diag(input_expr, diagonal_expr)
         return out
 
-
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
@@ -3064,21 +3162,26 @@ class OperatorConverter(object):
 
 def get_scalar_from_constant(expr):
     """ Returns scalar value from Relay constant scalar. """
-    assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
-        "Expr is not a constant scalar."
+    assert (
+        isinstance(expr, _expr.Constant) and not expr.data.shape
+    ), "Expr is not a constant scalar."
     value = expr.data.asnumpy()
-    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
-        "value must be float32/int32"
+    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(
+        np.float32
+    ), "value must be float32/int32"
     return np.asscalar(value)
 
+
 def get_tensor_from_constant(expr):
     """ Returns tensor of values from Relay constant node. """
     assert isinstance(expr, _expr.Constant)
     value = expr.data.asnumpy()
-    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
-        "value must be float32/int32"
+    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(
+        np.float32
+    ), "value must be float32/int32"
     return value
 
+
 def build_str_map(obj):
     """Build string map of TFLite enum int value
 
@@ -3093,12 +3196,13 @@ def build_str_map(obj):
     """
     ret = {}
     for field_name in dir(obj):
-        if not field_name.startswith('_'):
+        if not field_name.startswith("_"):
             field_value = getattr(obj, field_name)
             if isinstance(field_value, int):
                 ret[field_value] = field_name
     return ret
 
+
 # SAME padding: https://www.tensorflow.org/api_guides/python/nn
 def get_pad_value(data, kernel, stride):
     """Get the pad tuple of value for SAME padding
@@ -3175,9 +3279,11 @@ def from_tflite(model, shape_dict, dtype_dict):
     # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
     try:
         import tflite
+
         assert isinstance(model, tflite.Model)
     except TypeError:
         import tflite.Model
+
         assert isinstance(model, tflite.Model.Model)
 
     # keep the same as tflite
@@ -3201,7 +3307,7 @@ def from_tflite(model, shape_dict, dtype_dict):
     op_converter.convert_op_to_relay()
 
     # params and outputs
-    params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
+    params = {k: _nd.array(np.array(v)) for k, v in exp_tab.params.items()}
     outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
     outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
     func = _function.Function(analysis.free_vars(outputs), outputs)
index d08570b..7349082 100644 (file)
 import struct
 from enum import IntEnum
 
+
 class BitWidth(IntEnum):
     """Flexbuffer bit width schema from flexbuffers.h"""
+
     BIT_WIDTH_8 = 0
     BIT_WIDTH_16 = 1
     BIT_WIDTH_32 = 2
     BIT_WIDTH_64 = 3
 
+
 class FlexBufferType(IntEnum):
     """Flexbuffer type schema from flexbuffers.h"""
+
     FBT_NULL = 0
     FBT_INT = 1
     FBT_UINT = 2
@@ -40,24 +44,24 @@ class FlexBufferType(IntEnum):
     FBT_INDIRECT_UINT = 7
     FBT_INDIRECT_FLOAT = 8
     FBT_MAP = 9
-    FBT_VECTOR = 10 # Untyped.
-    FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
+    FBT_VECTOR = 10  # Untyped.
+    FBT_VECTOR_INT = 11  # Typed any size (stores no type table).
     FBT_VECTOR_UINT = 12
     FBT_VECTOR_FLOAT = 13
     FBT_VECTOR_KEY = 14
     FBT_VECTOR_STRING = 15
-    FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
+    FBT_VECTOR_INT2 = 16  # Typed tuple (no type table, no size field).
     FBT_VECTOR_UINT2 = 17
     FBT_VECTOR_FLOAT2 = 18
-    FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
+    FBT_VECTOR_INT3 = 19  # Typed triple (no type table, no size field).
     FBT_VECTOR_UINT3 = 20
     FBT_VECTOR_FLOAT3 = 21
-    FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
+    FBT_VECTOR_INT4 = 22  # Typed quad (no type table, no size field).
     FBT_VECTOR_UINT4 = 23
     FBT_VECTOR_FLOAT4 = 24
     FBT_BLOB = 25
     FBT_BOOL = 26
-    FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
+    FBT_VECTOR_BOOL = 36  # To Allow the same type of conversion of type to vector type
 
 
 class FlexBufferDecoder(object):
@@ -79,8 +83,7 @@ class FlexBufferDecoder(object):
         elif byte_width == 4:
             unpack_str = "<i"
         assert unpack_str != ""
-        back_jump = struct.unpack(unpack_str,
-                                  self.buffer[offset: offset + byte_width])[0]
+        back_jump = struct.unpack(unpack_str, self.buffer[offset : offset + byte_width])[0]
         return offset - back_jump
 
     def decode_keys(self, end, size, byte_width):
@@ -94,7 +97,7 @@ class FlexBufferDecoder(object):
             start_index = self.indirect_jump(offset_pos, byte_width)
             str_size = self.buffer[start_index:].find(b"\0")
             assert str_size != -1
-            s = self.buffer[start_index: start_index + str_size].decode("utf-8")
+            s = self.buffer[start_index : start_index + str_size].decode("utf-8")
             keys.append(s)
         return keys
 
@@ -107,7 +110,7 @@ class FlexBufferDecoder(object):
         for i in range(0, size):
             value_type_pos = end + size * byte_width + i
             value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
-            value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width]
+            value_bytes = self.buffer[end + i * byte_width : end + (i + 1) * byte_width]
             if value_type == FlexBufferType.FBT_BOOL:
                 value = bool(value_bytes[0])
             elif value_type == FlexBufferType.FBT_INT:
@@ -124,7 +127,7 @@ class FlexBufferDecoder(object):
     def decode_map(self, end, byte_width, parent_byte_width):
         """ Decodes the flexbuffer map and returns a dict """
         mid_loc = self.indirect_jump(end, parent_byte_width)
-        map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width:mid_loc])[0]
+        map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width : mid_loc])[0]
 
         # Find keys
         keys_offset = mid_loc - byte_width * 3
index 568dd41..f889f1e 100644 (file)
@@ -25,6 +25,7 @@ from tvm.ir import BaseFunc
 from .expr import Call
 from . import _ffi_api
 
+
 @tvm._ffi.register_object("relay.Function")
 class Function(BaseFunc):
     """A function declaration expression.
@@ -44,17 +45,14 @@ class Function(BaseFunc):
         The additional type parameters, this is only
         used in advanced usecase of template functions.
     """
-    def __init__(self,
-                 params,
-                 body,
-                 ret_type=None,
-                 type_params=None,
-                 attrs=None):
+
+    def __init__(self, params, body, ret_type=None, type_params=None, attrs=None):
         if type_params is None:
             type_params = convert([])
 
         self.__init_handle_by_constructor__(
-            _ffi_api.Function, params, body, ret_type, type_params, attrs)
+            _ffi_api.Function, params, body, ret_type, type_params, attrs
+        )
 
     def __call__(self, *args):
         """Invoke the global function.
index 9af6811..6c2ab2e 100644 (file)
@@ -22,6 +22,7 @@ from .scope_builder import ScopeBuilder
 from . import expr as _expr
 from . import function as _function
 
+
 def while_loop(cond, loop_vars, loop_bodies):
     """
     Construct a while loop.
index 011042b..f6afa44 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=wildcard-import, redefined-builtin
+# pylint: disable=wildcard-import, redefined-builtin
 """Relay core operators."""
 # operator defs
-from .op import get, register_compute, register_gradient, \
-    register_pattern, register_alter_op_layout, register_legalize, \
-    OpPattern, OpStrategy, debug, register_external_compiler
+from .op import (
+    get,
+    register_compute,
+    register_gradient,
+    register_pattern,
+    register_alter_op_layout,
+    register_legalize,
+    OpPattern,
+    OpStrategy,
+    debug,
+    register_external_compiler,
+)
 from . import strategy
 
 # Operators
@@ -48,6 +57,8 @@ def _register_op_make():
     # pylint: disable=import-outside-toplevel
     from . import _make
     from .. import expr
+
     expr._op_make = _make
 
+
 _register_op_make()
index cded2e1..732d501 100644 (file)
@@ -34,6 +34,7 @@ register_pattern("argsort", OpPattern.OPAQUE)
 register_strategy("topk", strategy.topk_strategy)
 register_pattern("topk", OpPattern.OPAQUE)
 
+
 @script
 def _topk_shape_func_input_shape(data_shape, k, axis):
     ndim = data_shape.shape[0]
@@ -53,6 +54,7 @@ def _topk_shape_func_input_shape(data_shape, k, axis):
                 indices_out[i] = int64(k)
     return val_out, indices_out
 
+
 @_reg.register_shape_func("topk", False)
 def topk_shape_func(attrs, inputs, _):
     """
@@ -61,8 +63,7 @@ def topk_shape_func(attrs, inputs, _):
     axis = attrs.axis
     if axis < 0:
         axis += inputs[0].shape[0]
-    val_out, indices_out = \
-        _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
+    val_out, indices_out = _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
     ret_type = attrs.ret_type
     if ret_type == "both":
         ret = [val_out, indices_out]
index 015f5ad..604098f 100644 (file)
@@ -33,6 +33,7 @@ _reg.register_reduce_schedule("prod")
 _reg.register_reduce_schedule("mean")
 _reg.register_reduce_schedule("variance")
 
+
 def _create_axis_record(attrs, inputs):
     axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
     exclude = get_const_int(attrs.exclude) > 0
@@ -79,6 +80,7 @@ def _reduce_shape_func(data_shape, axis_record):
 
     return out
 
+
 def reduce_shape_func(attrs, inputs, _):
     """
     Shape function for reduce op.
@@ -86,6 +88,7 @@ def reduce_shape_func(attrs, inputs, _):
     axis_record = _create_axis_record(attrs, inputs)
     return [_reduce_shape_func(inputs[0], convert(axis_record))]
 
+
 _reg.register_shape_func("argmax", False, reduce_shape_func)
 _reg.register_shape_func("argmin", False, reduce_shape_func)
 _reg.register_shape_func("all", False, reduce_shape_func)
index c81d4c5..6b7f139 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument, len-as-condition
+# pylint: disable=invalid-name, unused-argument, len-as-condition
 """Backend compiler related feature registration"""
 
 from tvm.te.hybrid import script
@@ -96,6 +96,7 @@ def zeros_compute(attrs, inputs, output_type):
     assert not inputs
     return [topi.full(output_type.shape, output_type.dtype, 0.0)]
 
+
 register_broadcast_schedule("zeros")
 register_pattern("zeros", OpPattern.ELEMWISE)
 
@@ -105,6 +106,7 @@ def zeros_like_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full_like(inputs[0], 0.0)]
 
+
 register_broadcast_schedule("zeros_like")
 
 # ones
@@ -113,6 +115,7 @@ def ones_compute(attrs, inputs, output_type):
     assert not inputs
     return [topi.full(output_type.shape, output_type.dtype, 1.0)]
 
+
 register_broadcast_schedule("ones")
 register_pattern("ones", OpPattern.ELEMWISE)
 
@@ -122,6 +125,7 @@ def ones_like_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full_like(inputs[0], 1.0)]
 
+
 register_broadcast_schedule("ones_like")
 
 # clip
@@ -130,6 +134,7 @@ def clip_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
 
+
 register_injective_schedule("clip")
 
 # fixed point multiply
@@ -138,6 +143,7 @@ def fixed_point_multiply_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)]
 
+
 register_injective_schedule("fixed_point_multiply")
 
 # full
@@ -149,18 +155,21 @@ def _full_shape_func(shape):
         out[i] = int64(shape[i])
     return out
 
+
 def full_shape_func(attrs, inputs, out_ndims):
     """
     Shape func for full.
     """
     return [_full_shape_func(inputs[1])]
 
+
 def no_data_full_shape_func(attrs, inputs, out_ndims):
     """
     Shape func for zeros and ones.
     """
     return [_full_shape_func(inputs[0])]
 
+
 @script
 def _broadcast_shape_func(x, y, ndim):
     out = output_tensor((ndim,), "int64")
@@ -173,34 +182,39 @@ def _broadcast_shape_func(x, y, ndim):
     else:
         ndim1 = x.shape[0]
         ndim2 = y.shape[0]
-        for i in const_range(1, min(ndim1, ndim2)+1):
-            if x[ndim1-i] == y[ndim2-i]:
-                out[ndim-i] = x[ndim1-i]
-            elif x[ndim1-i] == 1:
-                out[ndim-i] = y[ndim2-i]
+        for i in const_range(1, min(ndim1, ndim2) + 1):
+            if x[ndim1 - i] == y[ndim2 - i]:
+                out[ndim - i] = x[ndim1 - i]
+            elif x[ndim1 - i] == 1:
+                out[ndim - i] = y[ndim2 - i]
             else:
                 assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
-                    x[ndim1-i], y[ndim2-i])
-                out[ndim-i] = x[ndim1-i]
-        for i in const_range(min(ndim1, ndim2)+1, ndim+1):
+                    x[ndim1 - i],
+                    y[ndim2 - i],
+                )
+                out[ndim - i] = x[ndim1 - i]
+        for i in const_range(min(ndim1, ndim2) + 1, ndim + 1):
             if ndim1 >= ndim2:
-                out[ndim-i] = x[ndim1-i]
+                out[ndim - i] = x[ndim1 - i]
             else:
-                out[ndim-i] = y[ndim2-i]
+                out[ndim - i] = y[ndim2 - i]
     return out
 
+
 def broadcast_shape_func(attrs, inputs, out_ndims):
     """
     Shape function for broadcast op.
     """
     return [_broadcast_shape_func(*inputs, out_ndims[0])]
 
+
 def elemwise_shape_func(attrs, inputs, _):
     """
     Shape function for elemwise op.
     """
     return [topi.math.identity(inputs[0])]
 
+
 register_shape_func("cast", False, elemwise_shape_func)
 register_shape_func("zeros", False, full_shape_func)
 register_shape_func("zeros_like", False, elemwise_shape_func)
index 5069f79..85168a5 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -39,7 +39,8 @@ from .tensor import (
     zeros_like,
     equal,
     shape_of,
-    log)
+    log,
+)
 from .transform import (
     broadcast_to_like,
     collapse_sum_like,
@@ -53,7 +54,7 @@ from .transform import (
     where,
     repeat,
     expand_dims,
-    full_like
+    full_like,
 )
 
 
@@ -204,31 +205,27 @@ def relu_grad(orig, grad):
 @register_gradient("add")
 def add_grad(orig, grad):
     """Returns [grad, grad]"""
-    return [collapse_sum_like(grad, orig.args[0]),
-            collapse_sum_like(grad, orig.args[1])]
+    return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])]
 
 
 @register_gradient("subtract")
 def subtract_grad(orig, grad):
     """Returns [grad, -grad]"""
-    return [collapse_sum_like(grad, orig.args[0]),
-            collapse_sum_like(negative(grad), orig.args[1])]
+    return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(negative(grad), orig.args[1])]
 
 
 @register_gradient("multiply")
 def multiply_grad(orig, grad):
     """Returns [grad * y, grad * x]"""
     x, y = orig.args
-    return [collapse_sum_like(grad * y, x),
-            collapse_sum_like(grad * x, y)]
+    return [collapse_sum_like(grad * y, x), collapse_sum_like(grad * x, y)]
 
 
 @register_gradient("divide")
 def divide_grad(orig, grad):
     """Returns [grad / y,  - grad * (x / y) / y]"""
     x, y = orig.args
-    return [collapse_sum_like(grad / y, x),
-            collapse_sum_like(- (grad * orig / y), y)]
+    return [collapse_sum_like(grad / y, x), collapse_sum_like(-(grad * orig / y), y)]
 
 
 @register_gradient("zeros")
@@ -281,9 +278,9 @@ def abs_grad(orig, grad):
 @register_gradient("erf")
 def erf_grad(orig, grad):
     # c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi)
-    inp, = orig.args
+    (inp,) = orig.args
     c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype)
-    return [c_2_div_sqrt_pi * exp(- inp * inp) * grad]
+    return [c_2_div_sqrt_pi * exp(-inp * inp) * grad]
 
 
 @register_gradient("clip")
@@ -303,9 +300,15 @@ def clip_grad(orig, grad):
 def max_pool2d_grad(orig, grad):
     """Returns the gradient of max_pool2d."""
     attrs = orig.attrs
-    pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
-                                    strides=attrs.strides, padding=attrs.padding,
-                                    layout=attrs.layout, ceil_mode=attrs.ceil_mode)
+    pool_grad = _nn.max_pool2d_grad(
+        grad,
+        orig.args[0],
+        pool_size=attrs.pool_size,
+        strides=attrs.strides,
+        padding=attrs.padding,
+        layout=attrs.layout,
+        ceil_mode=attrs.ceil_mode,
+    )
     return [pool_grad]
 
 
@@ -313,10 +316,16 @@ def max_pool2d_grad(orig, grad):
 def avg_pool2d_grad(orig, grad):
     """Returns the gradient of avg_pool2d."""
     attrs = orig.attrs
-    pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
-                                    strides=attrs.strides, padding=attrs.padding,
-                                    layout=attrs.layout, ceil_mode=attrs.ceil_mode,
-                                    count_include_pad=attrs.count_include_pad)
+    pool_grad = _nn.avg_pool2d_grad(
+        grad,
+        orig.args[0],
+        pool_size=attrs.pool_size,
+        strides=attrs.strides,
+        padding=attrs.padding,
+        layout=attrs.layout,
+        ceil_mode=attrs.ceil_mode,
+        count_include_pad=attrs.count_include_pad,
+    )
     return [pool_grad]
 
 
@@ -334,9 +343,9 @@ def global_avg_pool2d_grad(orig, grad):
     elif layout == "NHWC":
         pool_size = shape[1], shape[2]
 
-    pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
-                                    strides=(1, 1), padding=(0, 0),
-                                    layout=layout)
+    pool_grad = _nn.avg_pool2d_grad(
+        grad, data, pool_size=pool_size, strides=(1, 1), padding=(0, 0), layout=layout
+    )
     return [pool_grad]
 
 
@@ -364,52 +373,68 @@ def conv2d_grad(orig, grad):
     out_channel, _, filter_h, filter_w = weight_shape
 
     # infer output_padding
-    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(get_const_tuple(attrs.padding),
-                                                                 (filter_h, filter_w))
+    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
+        get_const_tuple(attrs.padding), (filter_h, filter_w)
+    )
     stride_h, stride_w = get_const_tuple(attrs.strides)
     dilation_h, dilation_w = get_const_tuple(attrs.dilation)
     out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
     out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
     output_padding = (in_h - out_h, in_w - out_w)
 
-    assert attrs.data_layout == 'NCHW', 'only support NCHW data layout'
-    assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout'
-    assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout'
-
-
-    backward_data = _nn.conv2d_transpose(grad, weight,
-                                         strides=attrs.strides,
-                                         padding=attrs.padding,
-                                         dilation=attrs.dilation,
-                                         groups=attrs.groups,
-                                         output_padding=output_padding)
+    assert attrs.data_layout == "NCHW", "only support NCHW data layout"
+    assert attrs.kernel_layout == "OIHW", "only support OIHW kernel layout"
+    assert attrs.out_layout in ["", "NCHW"], "only support NCHW output layout"
+
+    backward_data = _nn.conv2d_transpose(
+        grad,
+        weight,
+        strides=attrs.strides,
+        padding=attrs.padding,
+        dilation=attrs.dilation,
+        groups=attrs.groups,
+        output_padding=output_padding,
+    )
     grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
     grad = reshape(grad, [-1, 1, 0, 0])  # batch * oc * ic // groups, 1, oh, ow
     data = reshape(data, [1, -1, 0, 0])  # 1, batch * ic, ih, iw
 
-    backward_weight = _nn.conv2d(data, grad,
-                                 strides=attrs.dilation,
-                                 padding=attrs.padding,
-                                 dilation=attrs.strides,
-                                 groups=in_channel * batch)
+    backward_weight = _nn.conv2d(
+        data,
+        grad,
+        strides=attrs.dilation,
+        padding=attrs.padding,
+        dilation=attrs.strides,
+        groups=in_channel * batch,
+    )
     # infer shape of backward_weight
-    padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \
-                           // dilation_h + 1
-    padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \
-                           // dilation_w + 1
-    backward_weight = reshape(backward_weight,
-                              [batch, in_channel // attrs.groups, out_channel,
-                               padded_weight_grad_h, padded_weight_grad_w])
+    padded_weight_grad_h = (
+        in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
+    ) // dilation_h + 1
+    padded_weight_grad_w = (
+        in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
+    ) // dilation_w + 1
+    backward_weight = reshape(
+        backward_weight,
+        [
+            batch,
+            in_channel // attrs.groups,
+            out_channel,
+            padded_weight_grad_h,
+            padded_weight_grad_w,
+        ],
+    )
     backward_weight = _sum(backward_weight, axis=0)
     backward_weight = transpose(backward_weight, [1, 0, 2, 3])
 
     assert padded_weight_grad_h >= filter_h
     assert padded_weight_grad_w >= filter_w
     if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
-        backward_weight = strided_slice(backward_weight,
-                                        begin=[0, 0, 0, 0],
-                                        end=[out_channel, in_channel // attrs.groups,
-                                             filter_h, filter_w])
+        backward_weight = strided_slice(
+            backward_weight,
+            begin=[0, 0, 0, 0],
+            end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
+        )
 
     return [backward_data, backward_weight]
 
@@ -482,30 +507,39 @@ def log_softmax_grad(orig, grad):
 def bias_add_grad(orig, grad):
     """Returns gradient of bias_add"""
     data = orig.args[0]
-    return [collapse_sum_like(grad, data),
-            _sum(grad, orig.attrs.axis, keepdims=False, exclude=True)]
+    return [
+        collapse_sum_like(grad, data),
+        _sum(grad, orig.attrs.axis, keepdims=False, exclude=True),
+    ]
 
 
 @register_gradient("nn.dense")
 def dense_grad(orig, grad):
     """Returns [grad' @ weight, data @ grad']"""
     data, weight = orig.args
-    return [collapse_sum_like(_nn.dense(grad, transpose(weight),
-                                        units=weight.checked_type.shape[1]), data),
-            collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
-                                        units=data.checked_type.shape[1]), weight)]
+    return [
+        collapse_sum_like(
+            _nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), data
+        ),
+        collapse_sum_like(
+            _nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight
+        ),
+    ]
 
 
 @register_gradient("nn.batch_matmul")
 def batch_matmul_grad(orig, grad):
     """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
-       grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
-              GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
+    grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
+           GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
     """
     lhs, rhs = orig.args
-    return [collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
-            collapse_sum_like(_nn.batch_matmul(transpose(grad, [0, 2, 1]),
-                                               transpose(lhs, [0, 2, 1])), rhs)]
+    return [
+        collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
+        collapse_sum_like(
+            _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
+        ),
+    ]
 
 
 @register_gradient("reshape")
@@ -604,8 +638,10 @@ def variance_grad(orig, grad):
         mult2 = mult2 * count / (count - 1)
         count -= 1
     mult1 /= count
-    return [(grad * const(mult1, dtype=data.checked_type.dtype)) * data,
-            const(mult2, dtype=data.checked_type.dtype) * grad * data_mean]
+    return [
+        (grad * const(mult1, dtype=data.checked_type.dtype)) * data,
+        const(mult2, dtype=data.checked_type.dtype) * grad * data_mean,
+    ]
 
 
 @register_gradient("copy")
@@ -617,7 +653,7 @@ def copy_grad(orig, grad):
 def cross_entropy_grad(orig, grad):
     x, y = orig.args
     shape = shape_of(x)
-    batch_size = take(shape, const(0, dtype='int32'), axis=0)
+    batch_size = take(shape, const(0, dtype="int32"), axis=0)
     grad = grad / batch_size.astype(x.checked_type.dtype)
     return [-grad * y / x, -grad * log(x)]
 
@@ -626,6 +662,6 @@ def cross_entropy_grad(orig, grad):
 def cross_entropy_with_logits_grad(orig, grad):
     x, y = orig.args
     shape = shape_of(x)
-    batch_size = take(shape, const(0, dtype='int32'), axis=0)
+    batch_size = take(shape, const(0, dtype="int32"), axis=0)
     grad = grad / batch_size.astype(x.checked_type.dtype)
     return [-grad * y, -grad * x]
index 98ff0b3..dca3e9a 100644 (file)
@@ -73,6 +73,7 @@ def compute_strided_set(attrs, inputs, output_type):
     """Compute definition of strided_set"""
     return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
 
+
 _reg.register_injective_schedule("strided_set")
 
 # layout_transform
@@ -93,6 +94,7 @@ def compute_argwhere(attrs, inputs, output_type):
     new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
     return [topi.argwhere(new_output_type, inputs[0])]
 
+
 _reg.register_schedule("argwhere", strategy.schedule_argwhere)
 
 # scatter
@@ -101,6 +103,7 @@ def compute_scatter(attrs, inputs, output_type):
     """Compute definition of scatter"""
     return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
 
+
 _reg.register_schedule("scatter", strategy.schedule_scatter)
 
 # scatter_add
@@ -109,18 +112,21 @@ def compute_scatter_add(attrs, inputs, output_type):
     """Compute definition of scatter_add"""
     return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]
 
+
 _reg.register_schedule("scatter_add", strategy.schedule_scatter_add)
 
 #####################
 #  Shape functions  #
 #####################
 
+
 @script
 def _arange_shape_func(start, stop, step):
     out = output_tensor((1,), "int64")
     out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
     return out
 
+
 @_reg.register_shape_func("arange", True)
 def arange_shape_func(attrs, inputs, _):
     """
@@ -128,6 +134,7 @@ def arange_shape_func(attrs, inputs, _):
     """
     return [_arange_shape_func(*inputs)]
 
+
 @script
 def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
     ndim = data_shape.shape[0]
@@ -175,8 +182,12 @@ def strided_slice_shape_func(attrs, inputs, _):
     Shape func for strided_slice
     """
     slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
-    return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
-                                                  attrs.strides, slice_mode)]
+    return [
+        _strided_slice_shape_func_input_shape(
+            inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode
+        )
+    ]
+
 
 @script
 def _concatenate_shape_func(inputs, axis):
@@ -186,14 +197,14 @@ def _concatenate_shape_func(inputs, axis):
         if i != axis:
             out[i] = inputs[0][i]
             for j in const_range(1, len(inputs)):
-                assert out[i] == inputs[j][i], \
-                    "Dims mismatch in the inputs of concatenate."
+                assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate."
         else:
             out[i] = int64(0)
             for j in const_range(len(inputs)):
                 out[i] += inputs[j][i]
     return out
 
+
 @_reg.register_shape_func("concatenate", False)
 def concatenate_shape_func(attrs, inputs, _):
     axis = get_const_int(attrs.axis)
@@ -201,6 +212,7 @@ def concatenate_shape_func(attrs, inputs, _):
         axis += inputs[0].shape[0]
     return [_concatenate_shape_func(inputs, convert(axis))]
 
+
 @script
 def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
     out = output_tensor((ndim,), "int64")
@@ -228,25 +240,25 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
         elif newshape[i] == -2:
             copy = True
         elif newshape[i] == -3:
-            assert data_shape.shape[0] - src_idx > 1, \
-                "Not enough dims in input shape for -3"
-            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+            assert data_shape.shape[0] - src_idx > 1, "Not enough dims in input shape for -3"
+            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
             src_idx += 2
             dst_idx += 1
         elif newshape[i] == -4:
             assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
-            if newshape[i+1] == -1:
-                assert newshape[i+2] != -1, "Split dims cannot both be -1."
-                out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
-                out[dst_idx+1] = int64(newshape[i+2])
+            if newshape[i + 1] == -1:
+                assert newshape[i + 2] != -1, "Split dims cannot both be -1."
+                out[dst_idx] = data_shape[src_idx] // int64(newshape[i + 2])
+                out[dst_idx + 1] = int64(newshape[i + 2])
             else:
-                out[dst_idx] = int64(newshape[i+1])
-                if newshape[i+2] == -1:
-                    out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
+                out[dst_idx] = int64(newshape[i + 1])
+                if newshape[i + 2] == -1:
+                    out[dst_idx + 1] = data_shape[src_idx] // int64(newshape[i + 1])
                 else:
-                    out[dst_idx+1] = int64(newshape[i+2])
-            assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
-                "Product of split dims doesn't match to input dim"
+                    out[dst_idx + 1] = int64(newshape[i + 2])
+            assert (
+                data_shape[src_idx] == out[dst_idx] * out[dst_idx + 1]
+            ), "Product of split dims doesn't match to input dim"
             src_idx += 1
             dst_idx += 2
             skip = 2
@@ -268,12 +280,12 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
             out[infer_idx] = old_size // new_size
     return out
 
+
 @_reg.register_shape_func("reshape", False)
 def reshape_shape_func(attrs, inputs, out_ndims):
     newshape = get_const_tuple(attrs.newshape)
-    return [_reshape_shape_func_input_shape(inputs[0],
-                                            convert(newshape),
-                                            out_ndims[0])]
+    return [_reshape_shape_func_input_shape(inputs[0], convert(newshape), out_ndims[0])]
+
 
 @script
 def _take_no_axis_shape_func(indices_shape, out_ndim):
@@ -282,6 +294,7 @@ def _take_no_axis_shape_func(indices_shape, out_ndim):
         out[i] = indices_shape[i]
     return out
 
+
 @script
 def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
     out = output_tensor((out_ndim,), "int64")
@@ -289,15 +302,16 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
         out[i] = data_shape[i]
     if len(indices_shape.shape) == 0:
         # indices is constant
-        for i in const_range(axis+1, len(data_shape)):
-            out[i-1] = data_shape[i]
+        for i in const_range(axis + 1, len(data_shape)):
+            out[i - 1] = data_shape[i]
     else:
         for i in const_range(len(indices_shape)):
-            out[axis+i] = indices_shape[i]
-        for i in const_range(axis+1, len(data_shape)):
-            out[len(indices_shape)+i-1] = data_shape[i]
+            out[axis + i] = indices_shape[i]
+        for i in const_range(axis + 1, len(data_shape)):
+            out[len(indices_shape) + i - 1] = data_shape[i]
     return out
 
+
 @_reg.register_shape_func("take", False)
 def take_shape_func(attrs, inputs, out_ndims):
     """
@@ -312,9 +326,10 @@ def take_shape_func(attrs, inputs, out_ndims):
     assert 0 <= axis < data_ndim
     return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
 
+
 @script
 def _argwhere_shape_func_1d(condition):
-    out = output_tensor((2, ), "int64")
+    out = output_tensor((2,), "int64")
     out[0] = int64(0)
     out[1] = int64(1)
     for i1 in range(condition.shape[0]):
@@ -322,9 +337,10 @@ def _argwhere_shape_func_1d(condition):
             out[0] += int64(1)
     return out
 
+
 @script
 def _argwhere_shape_func_2d(condition):
-    out = output_tensor((2, ), "int64")
+    out = output_tensor((2,), "int64")
     out[0] = int64(0)
     out[1] = int64(2)
     for i1 in range(condition.shape[0]):
@@ -333,9 +349,10 @@ def _argwhere_shape_func_2d(condition):
                 out[0] += int64(1)
     return out
 
+
 @script
 def _argwhere_shape_func_3d(condition):
-    out = output_tensor((2, ), "int64")
+    out = output_tensor((2,), "int64")
     out[0] = int64(0)
     out[1] = int64(3)
     for i1 in range(condition.shape[0]):
@@ -345,9 +362,10 @@ def _argwhere_shape_func_3d(condition):
                     out[0] += int64(1)
     return out
 
+
 @script
 def _argwhere_shape_func_4d(condition):
-    out = output_tensor((2, ), "int64")
+    out = output_tensor((2,), "int64")
     out[0] = int64(0)
     out[1] = int64(4)
     for i1 in range(condition.shape[0]):
@@ -358,9 +376,10 @@ def _argwhere_shape_func_4d(condition):
                         out[0] += int64(1)
     return out
 
+
 @script
 def _argwhere_shape_func_5d(condition):
-    out = output_tensor((2, ), "int64")
+    out = output_tensor((2,), "int64")
     out[0] = int64(0)
     out[1] = int64(5)
     for i1 in range(condition.shape[0]):
@@ -372,6 +391,7 @@ def _argwhere_shape_func_5d(condition):
                             out[0] += int64(1)
     return out
 
+
 @_reg.register_shape_func("argwhere", True)
 def argwhere_shape_func(attrs, inputs, out_ndims):
     """
@@ -389,38 +409,38 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
         return [_argwhere_shape_func_5d(inputs[0])]
     return ValueError("Does not support rank higher than 5 in argwhere")
 
+
 _reg.register_shape_func("scatter", False, elemwise_shape_func)
 _reg.register_shape_func("scatter_add", False, elemwise_shape_func)
 
+
 @script
-def _layout_transform_shape_func(data_shape,
-                                 out_layout_len,
-                                 dst_equal_list,
-                                 dst_mul_list,
-                                 dst_div_list,
-                                 dst_mix_list):
+def _layout_transform_shape_func(
+    data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list
+):
     out = output_tensor((out_layout_len,), "int64")
     for i in const_range(len(dst_equal_list)):
         out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
     for i in const_range(len(dst_mul_list)):
-        out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
-                                  data_shape[dst_mul_list[i][2]]
+        out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * data_shape[dst_mul_list[i][2]]
     for i in const_range(len(dst_div_list)):
-        out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
-                                  // dst_div_list[i][3]
+        out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] // dst_div_list[i][3]
         out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
     for i in const_range(len(dst_mix_list)):
-        out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
-                                  dst_mix_list[i][2] // dst_mix_list[i][4]
+        out[dst_mix_list[i][0]] = (
+            data_shape[dst_mix_list[i][1]] * dst_mix_list[i][2] // dst_mix_list[i][4]
+        )
         out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])
 
     return out
 
+
 @_reg.register_shape_func("layout_transform", False)
 def layout_transform_shape_func(attrs, inputs, _):
     """
     Shape function for layout_transform op.
     """
+
     def _fetch_axis(layout):
         major_axes = []
         minor_axes = {}
@@ -455,31 +475,47 @@ def layout_transform_shape_func(attrs, inputs, _):
     for key in dst_major_axes:
         if key.lower() not in dst_minor_axes:
             if key.lower() not in src_minor_axes:
-                dst_equal_list.append((dst_letter_list.index(key),
-                                       src_letter_list.index(key)))
+                dst_equal_list.append((dst_letter_list.index(key), src_letter_list.index(key)))
             else:
-                dst_mul_list.append((dst_letter_list.index(key),
-                                     src_letter_list.index(key),
-                                     src_letter_list.index(key.lower())))
+                dst_mul_list.append(
+                    (
+                        dst_letter_list.index(key),
+                        src_letter_list.index(key),
+                        src_letter_list.index(key.lower()),
+                    )
+                )
         else:
             if key.lower() not in src_minor_axes:
-                dst_div_list.append((dst_letter_list.index(key),
-                                     src_letter_list.index(key),
-                                     dst_letter_list.index(key.lower()),
-                                     dst_minor_axes[key.lower()]))
+                dst_div_list.append(
+                    (
+                        dst_letter_list.index(key),
+                        src_letter_list.index(key),
+                        dst_letter_list.index(key.lower()),
+                        dst_minor_axes[key.lower()],
+                    )
+                )
             else:
-                dst_mix_list.append((dst_letter_list.index(key),
-                                     src_letter_list.index(key),
-                                     src_minor_axes[key.lower()],
-                                     dst_letter_list.index(key.lower()),
-                                     dst_minor_axes[key.lower()]))
-
-    return [_layout_transform_shape_func(inputs[0],
-                                         convert(out_layout_len),
-                                         convert(dst_equal_list),
-                                         convert(dst_mul_list),
-                                         convert(dst_div_list),
-                                         convert(dst_mix_list))]
+                dst_mix_list.append(
+                    (
+                        dst_letter_list.index(key),
+                        src_letter_list.index(key),
+                        src_minor_axes[key.lower()],
+                        dst_letter_list.index(key.lower()),
+                        dst_minor_axes[key.lower()],
+                    )
+                )
+
+    return [
+        _layout_transform_shape_func(
+            inputs[0],
+            convert(out_layout_len),
+            convert(dst_equal_list),
+            convert(dst_mul_list),
+            convert(dst_div_list),
+            convert(dst_mix_list),
+        )
+    ]
+
 
 @script
 def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
@@ -494,6 +530,7 @@ def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
 
     return out
 
+
 @_reg.register_shape_func("expand_dims", False)
 def expand_dim_shape_func(attrs, inputs, _):
     """
@@ -504,10 +541,8 @@ def expand_dim_shape_func(attrs, inputs, _):
     if axis < 0:
         axis = inputs[0].shape[0] + axis + 1
     ndim = inputs[0].shape[0] if inputs[0].shape else 0
-    return [_expand_dim_shape_func(inputs[0],
-                                   convert(ndim),
-                                   convert(axis),
-                                   convert(num_newaxis))]
+    return [_expand_dim_shape_func(inputs[0], convert(ndim), convert(axis), convert(num_newaxis))]
+
 
 @script
 def _transpose_shape_func(data_shape, axes):
@@ -517,6 +552,7 @@ def _transpose_shape_func(data_shape, axes):
 
     return out
 
+
 @_reg.register_shape_func("transpose", False)
 def transpose_shape_func(attrs, inputs, _):
     """
@@ -532,6 +568,7 @@ def transpose_shape_func(attrs, inputs, _):
             axes[i] = inputs[0].shape[0] + axis
     return [_transpose_shape_func(inputs[0], convert(axes))]
 
+
 @script
 def _squeeze_shape_func(data_shape, keep_axes):
     out = output_tensor((len(keep_axes),), "int64")
@@ -540,6 +577,7 @@ def _squeeze_shape_func(data_shape, keep_axes):
 
     return out
 
+
 @_reg.register_shape_func("squeeze", False)
 def squeeze_shape_func(attrs, inputs, _):
     """
@@ -563,6 +601,7 @@ def squeeze_shape_func(attrs, inputs, _):
         out = te.compute((), lambda *indices: 0)
     return [out]
 
+
 @script
 def _reshape_like_shape_func(target_shape):
     out = output_tensor((target_shape.shape[0],), "int64")
@@ -571,6 +610,7 @@ def _reshape_like_shape_func(target_shape):
 
     return out
 
+
 @_reg.register_shape_func("reshape_like", False)
 def reshape_like_shape_func(attrs, inputs, _):
     """
@@ -578,6 +618,7 @@ def reshape_like_shape_func(attrs, inputs, _):
     """
     return [_reshape_like_shape_func(inputs[1])]
 
+
 @script
 def _tile_shape_func(data, reps, ndim, tndim, rndim):
     out = output_tensor((tndim,), "int64")
@@ -601,6 +642,7 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim):
                 out[i] = int64(reps[i]) * data[i - rgap]
     return out
 
+
 @_reg.register_shape_func("tile", False)
 def tile_shape_func(attrs, inputs, _):
     """
@@ -610,8 +652,10 @@ def tile_shape_func(attrs, inputs, _):
     ndim = inputs[0].shape[0].value
     rndim = len(reps)
     tndim = ndim if ndim > rndim else rndim
-    return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
-                             convert(tndim), convert(rndim))]
+    return [
+        _tile_shape_func(inputs[0], convert(reps), convert(ndim), convert(tndim), convert(rndim))
+    ]
+
 
 @script
 def _split_shape_func(data_shape, index, indices_or_sections, axis):
@@ -619,8 +663,9 @@ def _split_shape_func(data_shape, index, indices_or_sections, axis):
     if len(indices_or_sections) == 1:
         for i in const_range(data_shape.shape[0]):
             if i == axis:
-                assert data_shape[axis] % indices_or_sections[0] == 0, \
-                    "num_sections must be an integer factor of the size of axis"
+                assert (
+                    data_shape[axis] % indices_or_sections[0] == 0
+                ), "num_sections must be an integer factor of the size of axis"
                 out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
             else:
                 out[i] = data_shape[i]
@@ -638,6 +683,7 @@ def _split_shape_func(data_shape, index, indices_or_sections, axis):
                 out[i] = data_shape[i]
     return out
 
+
 @_reg.register_shape_func("split", False)
 def split_shape_func(attrs, inputs, _):
     """
@@ -648,20 +694,24 @@ def split_shape_func(attrs, inputs, _):
         assert indices_or_sections > 0, "Slice count must be > 0"
     else:
         indices_or_sections = list(get_const_tuple(attrs.indices_or_sections))
-        assert sorted(indices_or_sections)[0] > 0 and \
-               indices_or_sections == sorted(indices_or_sections), \
-            "split_indices must be sorted"
+        assert sorted(indices_or_sections)[0] > 0 and indices_or_sections == sorted(
+            indices_or_sections
+        ), "split_indices must be sorted"
 
     axis = get_const_int(attrs.axis)
 
-    num_out = indices_or_sections if isinstance(indices_or_sections, int) \
+    num_out = (
+        indices_or_sections
+        if isinstance(indices_or_sections, int)
         else len(indices_or_sections) + 1
+    )
     if isinstance(indices_or_sections, int):
         indices_or_sections = [indices_or_sections]
-    return [_split_shape_func(inputs[0],
-                              convert(i),
-                              convert(indices_or_sections),
-                              convert(axis)) for i in range(num_out)]
+    return [
+        _split_shape_func(inputs[0], convert(i), convert(indices_or_sections), convert(axis))
+        for i in range(num_out)
+    ]
+
 
 @script
 def _adv_index_shape_func(inputs):
@@ -687,6 +737,7 @@ def _adv_index_shape_func(inputs):
 
     return out
 
+
 @_reg.register_shape_func("adv_index", False)
 def adv_index_shape_func(attrs, inputs, _):
     """
index f3c35b8..e055054 100644 (file)
@@ -21,6 +21,7 @@ from . import _make
 from .dyn import _make as _dyn_make
 from ..expr import TupleWrapper, Expr, Constant
 
+
 def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
@@ -50,8 +51,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     return _make.argsort(data, axis, is_ascend, dtype)
 
 
-def topk(data, k=1, axis=-1, ret_type="both",
-         is_ascend=False, dtype="int32"):
+def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
     """Get the top k elements in an input tensor along the given axis.
 
     ret_type specifies the return type, can be one of ("both", "values", "indices").
index 7bd5262..0ab1a0b 100644 (file)
@@ -43,8 +43,10 @@ def on_device(data, device):
     elif isinstance(device, str):
         device = _nd.context(device).device_type
     else:
-        raise ValueError("device is expected to be the type of TVMContext or "
-                         "str, but received %s" % (type(device)))
+        raise ValueError(
+            "device is expected to be the type of TVMContext or "
+            "str, but received %s" % (type(device))
+        )
     return _make.on_device(data, device)
 
 
@@ -79,6 +81,7 @@ def checkpoint(data):
     """
     return _make.checkpoint(data)
 
+
 reg.register_injective_schedule("annotation.checkpoint")
 
 
index adeeeb1..0c7df75 100644 (file)
@@ -55,11 +55,15 @@ def partition_for_arm_compute_lib(mod, params=None):
     ret : annotated and partitioned module.
     """
     if params:
-        mod['main'] = bind_params_by_name(mod['main'], params)
+        mod["main"] = bind_params_by_name(mod["main"], params)
 
-    seq = tvm.transform.Sequential([transform.MergeComposite(arm_compute_lib_pattern_table()),
-                                    transform.AnnotateTarget('arm_compute_lib'),
-                                    transform.PartitionGraph()])
+    seq = tvm.transform.Sequential(
+        [
+            transform.MergeComposite(arm_compute_lib_pattern_table()),
+            transform.AnnotateTarget("arm_compute_lib"),
+            transform.PartitionGraph(),
+        ]
+    )
 
     return seq(mod)
 
@@ -76,10 +80,10 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('nn.pad')(wildcard()) | wildcard()
-        pattern = is_op('nn.conv2d')(pattern, is_constant())
-        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
-        pattern = pattern.optional(is_op('nn.relu'))
+        pattern = is_op("nn.pad")(wildcard()) | wildcard()
+        pattern = is_op("nn.conv2d")(pattern, is_constant())
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
+        pattern = pattern.optional(is_op("nn.relu"))
         return pattern
 
     def qnn_conv_pattern():
@@ -90,13 +94,15 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('nn.pad')(wildcard()) | wildcard()
-        pattern = is_op('qnn.conv2d')(
-            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
-        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
-        pattern = pattern.optional(is_op('nn.relu'))
-        pattern = is_op('qnn.requantize')(
-            pattern, wildcard(), wildcard(), is_constant(), is_constant())
+        pattern = is_op("nn.pad")(wildcard()) | wildcard()
+        pattern = is_op("qnn.conv2d")(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
+        pattern = pattern.optional(is_op("nn.relu"))
+        pattern = is_op("qnn.requantize")(
+            pattern, wildcard(), wildcard(), is_constant(), is_constant()
+        )
         return pattern
 
     def dense_pattern():
@@ -107,8 +113,8 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('nn.dense')(wildcard(), is_constant())
-        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
+        pattern = is_op("nn.dense")(wildcard(), is_constant())
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
         return pattern
 
     def qnn_dense_pattern():
@@ -119,11 +125,13 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('qnn.dense')(
-            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
-        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
-        pattern = is_op('qnn.requantize')(
-            pattern, wildcard(), wildcard(), is_constant(), is_constant())
+        pattern = is_op("qnn.dense")(
+            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
+        pattern = is_op("qnn.requantize")(
+            pattern, wildcard(), wildcard(), is_constant(), is_constant()
+        )
         return pattern
 
     def avg_pool2d_pattern():
@@ -135,9 +143,9 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('cast')(wildcard())
-        pattern = is_op('nn.avg_pool2d')(pattern) | is_op('nn.global_avg_pool2d')(pattern)
-        pattern = is_op('cast')(pattern)
+        pattern = is_op("cast")(wildcard())
+        pattern = is_op("nn.avg_pool2d")(pattern) | is_op("nn.global_avg_pool2d")(pattern)
+        pattern = is_op("cast")(pattern)
         return pattern
 
     def l2_pool2d_pattern():
@@ -148,9 +156,9 @@ def arm_compute_lib_pattern_table():
         pattern : dataflow_pattern.AltPattern
             Denotes the convolution pattern.
         """
-        pattern = is_op('power')(wildcard(), is_expr(const(2.0)))
-        pattern = is_op('nn.avg_pool2d')(pattern)
-        pattern = is_op('sqrt')(pattern)
+        pattern = is_op("power")(wildcard(), is_expr(const(2.0)))
+        pattern = is_op("nn.avg_pool2d")(pattern)
+        pattern = is_op("sqrt")(pattern)
         return pattern
 
     def check_conv(extract):
@@ -199,13 +207,15 @@ def arm_compute_lib_pattern_table():
         pool = extract.args[0]
         return avg_pool2d(pool.attrs, pool.args)
 
-    return [('arm_compute_lib.conv2d', conv_pattern(), check_conv),
-            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
-            ('arm_compute_lib.dense', dense_pattern(), check_dense),
-            ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense),
-            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
-            ('arm_compute_lib.avg_pool2d', avg_pool2d_pattern(), check_avg_pool2d),
-            ('arm_compute_lib.l2_pool2d', l2_pool2d_pattern(), check_l2_pool2d)]
+    return [
+        ("arm_compute_lib.conv2d", conv_pattern(), check_conv),
+        ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
+        ("arm_compute_lib.dense", dense_pattern(), check_dense),
+        ("arm_compute_lib.qnn_dense", qnn_dense_pattern(), check_qnn_dense),
+        ("arm_compute_lib.qnn_conv2d", qnn_conv_pattern(), check_qnn_conv),
+        ("arm_compute_lib.avg_pool2d", avg_pool2d_pattern(), check_avg_pool2d),
+        ("arm_compute_lib.l2_pool2d", l2_pool2d_pattern(), check_l2_pool2d),
+    ]
 
 
 def _register_external_op_helper(op_name, supported=True):
index dc14c2a..105009a 100644 (file)
@@ -30,11 +30,12 @@ def _register_coreml_op(op_name):
         The name of operator that will be registered.
 
     """
+
     def _check_supported(attrs, args):
-        if op_name == 'nn.conv2d':
+        if op_name == "nn.conv2d":
             if not isinstance(args[1], Constant):
                 return False
-            if attrs['kernel_layout'] not in ['HWIO', 'OIHW']:
+            if attrs["kernel_layout"] not in ["HWIO", "OIHW"]:
                 return False
         return True
 
index 27574a8..816cb38 100644 (file)
@@ -51,6 +51,7 @@ def _register_external_op_helper(op_name, supported=True):
     f : callable
         A function that returns if the operator is supported by DNNL.
     """
+
     @tvm.ir.register_op_attr(op_name, "target.dnnl")
     def _func_wrapper(attrs, args):
         return supported
@@ -71,12 +72,12 @@ def make_pattern(with_bias=True):
     data = wildcard()
     weight = wildcard()
     bias = wildcard()
-    conv = is_op('nn.conv2d')(data, weight)
+    conv = is_op("nn.conv2d")(data, weight)
     if with_bias:
-        conv_out = is_op('add')(conv, bias)
+        conv_out = is_op("add")(conv, bias)
     else:
         conv_out = conv
-    return is_op('nn.relu')(conv_out)
+    return is_op("nn.relu")(conv_out)
 
 
 @register_pattern_table("dnnl")
index a93b0e5..213f4d3 100644 (file)
@@ -45,13 +45,16 @@ def ethosn_available():
 @register_pattern_table("ethos-n")
 def pattern_table():
     """Get the Ethos-N compiler pattern table."""
+
     def qnn_conv_pattern():
-        pattern = is_op('nn.pad')(wildcard()) | wildcard()
-        pattern = is_op('qnn.conv2d')(
-            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
-        pattern = is_op('nn.bias_add')(pattern, is_constant())
-        pattern = is_op('qnn.requantize')(
-            pattern, is_constant(), is_constant(), is_constant(), is_constant())
+        pattern = is_op("nn.pad")(wildcard()) | wildcard()
+        pattern = is_op("qnn.conv2d")(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        pattern = is_op("nn.bias_add")(pattern, is_constant())
+        pattern = is_op("qnn.requantize")(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant()
+        )
         return pattern
 
     def check_conv2d(extract):
@@ -88,7 +91,7 @@ def qnn_concatenate(attrs, args):
         qnn_params.append((scale, zero_point))
 
     scale = (max_range - min_range) / 255
-    zero_point = int(-min_range/scale)
+    zero_point = int(-min_range / scale)
     if (scale, zero_point) in qnn_params:
         return True
 
@@ -102,13 +105,13 @@ def split(attrs, args):
         return False
 
     if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm):
-        sp = tvm.relay.split(*args,
-                             indices_or_sections=attrs["indices_or_sections"].value,
-                             axis=attrs["axis"])
+        sp = tvm.relay.split(
+            *args, indices_or_sections=attrs["indices_or_sections"].value, axis=attrs["axis"]
+        )
     else:
-        sp = tvm.relay.split(*args,
-                             indices_or_sections=attrs["indices_or_sections"],
-                             axis=attrs["axis"])
+        sp = tvm.relay.split(
+            *args, indices_or_sections=attrs["indices_or_sections"], axis=attrs["axis"]
+        )
     if not support.split(sp.astuple()):
         return False
 
index b82abdb..278a311 100644 (file)
@@ -37,10 +37,12 @@ def register_pattern_table(compiler, table=None):
     fregister : function
         Register function if value is not specified.
     """
+
     def _register(t):
         """internal register function"""
         _PATTERN_TABLES[compiler] = t()
         return t
+
     return _register(table) if table is not None else _register
 
 
index c6dbca3..45bab2b 100644 (file)
@@ -21,4 +21,4 @@ from . import _algorithm
 from . import _transform
 from . import _tensor
 
-from .import image
+from . import image
index b98b775..ba903e6 100644 (file)
@@ -30,6 +30,7 @@ from ..op import register_strategy
 register_strategy("dyn.topk", strategy.topk_strategy)
 register_pattern("dyn.topk", OpPattern.OPAQUE)
 
+
 @script
 def _topk_shape_func_input_data(data, k, axis):
     ndim = len(data.shape)
@@ -49,6 +50,7 @@ def _topk_shape_func_input_data(data, k, axis):
                 indices_out[i] = int64(k[0])
     return val_out, indices_out
 
+
 @_reg.register_shape_func("dyn.topk", True)
 def topk_shape_func(attrs, inputs, _):
     """
@@ -57,8 +59,7 @@ def topk_shape_func(attrs, inputs, _):
     axis = attrs.axis
     if axis < 0:
         axis += len(inputs[0].shape)
-    val_out, indices_out = \
-        _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+    val_out, indices_out = _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
 
     ret_type = attrs.ret_type
     if ret_type == "both":
index cd53641..5d5d555 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument, len-as-condition
+# pylint: disable=invalid-name, unused-argument, len-as-condition
 """Backend compiler related feature registration for dynamic ops"""
 
 from tvm import topi
@@ -30,14 +30,17 @@ def ones_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full(output_type.shape, output_type.dtype, 1.0)]
 
+
 register_broadcast_schedule("dyn.ones")
 register_pattern("dyn.ones", OpPattern.ELEMWISE)
 
+
 @register_compute("dyn.zeros")
 def zeros_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full(output_type.shape, output_type.dtype, 0.0)]
 
+
 register_broadcast_schedule("dyn.zeros")
 register_pattern("dyn.zeros", OpPattern.ELEMWISE)
 
index 6bf02ec..dedd3df 100644 (file)
@@ -29,10 +29,11 @@ _reg.register_injective_schedule("dyn.one_hot")
 _reg.register_injective_schedule("dyn.full")
 _reg.register_injective_schedule("dyn.strided_slice")
 
+
 @script
 def _reshape_shape_func_input_data(data, newshape, ndim):
-    out = output_tensor((ndim, ), "int64")
-    data_shape = allocate((len(data.shape), ), "int64")
+    out = output_tensor((ndim,), "int64")
+    data_shape = allocate((len(data.shape),), "int64")
     for x in const_range(len(data.shape)):
         data_shape[x] = int64(data.shape[x])
     src_idx = 0
@@ -60,8 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
         elif newshape[i] == -2:
             assert False, "Value -2 is not valid in newshape argument of dynamic reshape"
         elif newshape[i] == -3:
-            assert data_shape.shape[0] - src_idx > 1, \
-                "Not enough dims in input shape for -3"
+            assert data_shape.shape[0] - src_idx > 1, "Not enough dims in input shape for -3"
             out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
             src_idx += 2
             dst_idx += 1
@@ -93,7 +93,7 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
 
 @script
 def _tile_shape_func(data, reps, ndim, tndim, rndim):
-    out = output_tensor((tndim, ), "int64")
+    out = output_tensor((tndim,), "int64")
 
     if ndim == rndim:
         for i in const_range(tndim):
@@ -130,7 +130,7 @@ def tile_shape_func(attrs, inputs, _):
 @script
 def _onehot_shape_func(dshape, k, axis):
     ndim = len(dshape) + 1
-    out = output_tensor((ndim, ), "int64")
+    out = output_tensor((ndim,), "int64")
     for i in const_range(axis):
         out[i] = int64(dshape[i])
     out[axis] = int64(k[0])
@@ -149,8 +149,7 @@ def one_hot_shape_func(attrs, inputs, _):
 
 
 @script
-def _strided_slice_shape_func_input_data(data, begin, end, strides,
-                                         slice_mode):
+def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
     ndim = len(data.shape)
     out = output_tensor((ndim,), "int64")
     for i in const_range(ndim):
@@ -189,6 +188,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides,
         out[i] = int64(ceil_div(slice_range, step))
     return out
 
+
 @_reg.register_shape_func("dyn.strided_slice", True)
 def strided_slice_shape_func(attrs, inputs, _):
     """
index 2d36708..cc00998 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -33,22 +33,25 @@ def compute_resize(attrs, inputs, out_type):
     coord_trans = attrs.coordinate_transformation_mode
     out_dtype = attrs.out_dtype
     return [
-        tvm.topi.image.resize(inputs[0], inputs[1], layout, method, coord_trans, out_dtype,
-                              out_type.shape)
+        tvm.topi.image.resize(
+            inputs[0], inputs[1], layout, method, coord_trans, out_dtype, out_type.shape
+        )
     ]
 
 
 reg.register_injective_schedule("dyn.image.resize")
 
+
 @script
 def _resize_shape_func(dshape, size, ndim, height_axis, width_axis):
-    out = output_tensor((ndim, ), "int64")
+    out = output_tensor((ndim,), "int64")
     for i in const_range(ndim):
         out[i] = int64(dshape[i])
     out[height_axis] = int64(size[0])
     out[width_axis] = int64(size[1])
     return out
 
+
 @reg.register_shape_func("dyn.image.resize", True)
 def resize_shape_func(attrs, inputs, _):
     """
@@ -56,8 +59,11 @@ def resize_shape_func(attrs, inputs, _):
     """
     layout = attrs.layout
     if nchw_pack_layout(layout) or nchw_xc_layout(layout):
-        out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
-                                  convert(2), convert(3))]
+        out = [
+            _resize_shape_func(
+                inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), convert(2), convert(3)
+            )
+        ]
     else:
         height_axis = width_axis = 1
         for i, letter in enumerate(layout):
@@ -65,6 +71,13 @@ def resize_shape_func(attrs, inputs, _):
                 height_axis = i
             if letter == "W":
                 width_axis = i
-        out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
-                                  convert(height_axis), convert(width_axis))]
+        out = [
+            _resize_shape_func(
+                inputs[0].shape,
+                inputs[1],
+                convert(len(inputs[0].shape)),
+                convert(height_axis),
+                convert(width_axis),
+            )
+        ]
     return out
index 0cbc07e..7277151 100644 (file)
@@ -35,8 +35,10 @@ def compute_upsampling(attrs, inputs, out_dtype):
     layout = attrs.layout
     method = attrs.method
     align_corners = attrs.align_corners
-    return [topi.nn.upsampling(data, scale_h, scale_w, layout,
-                               method, align_corners, out_dtype.shape)]
+    return [
+        topi.nn.upsampling(data, scale_h, scale_w, layout, method, align_corners, out_dtype.shape)
+    ]
+
 
 # upsampling3d
 @register_compute("dyn.nn.upsampling3d")
@@ -48,8 +50,19 @@ def compute_upsampling3d(attrs, inputs, out_dtype):
     layout = attrs.layout
     method = attrs.method
     coordinate_transformation_mode = attrs.coordinate_transformation_mode
-    return [topi.nn.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,\
-        coordinate_transformation_mode, out_dtype.shape)]
+    return [
+        topi.nn.upsampling3d(
+            data,
+            scale_d,
+            scale_h,
+            scale_w,
+            layout,
+            method,
+            coordinate_transformation_mode,
+            out_dtype.shape,
+        )
+    ]
+
 
 register_injective_schedule("dyn.nn.upsampling")
 register_injective_schedule("dyn.nn.upsampling3d")
@@ -69,6 +82,7 @@ def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis):
     out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
     return out
 
+
 @register_shape_func("dyn.nn.upsampling", True)
 def upsampling_shape_func(attrs, inputs, _):
     """Shape function for upsampling. Supports NCHW and NHWC layouts."""
@@ -79,13 +93,18 @@ def upsampling_shape_func(attrs, inputs, _):
             height_axis = i
         if letter == "W":
             width_axis = i
-    return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2],
-                                   convert(height_axis), convert(width_axis))]
+    return [
+        _upsampling_shape_func(
+            inputs[0].shape, inputs[1], inputs[2], convert(height_axis), convert(width_axis)
+        )
+    ]
+
 
 # upsampling3d
 @script
-def _upsampling3d_shape_func(dshape, scale_d, scale_h, scale_w,
-                             depth_axis, height_axis, width_axis):
+def _upsampling3d_shape_func(
+    dshape, scale_d, scale_h, scale_w, depth_axis, height_axis, width_axis
+):
     out = output_tensor((5,), "int64")
     for i in const_range(5):
         out[i] = int64(dshape[i])
@@ -107,10 +126,18 @@ def upsampling3d_shape_func(attrs, inputs, _):
             height_axis = i
         if letter == "W":
             width_axis = i
-    return [_upsampling3d_shape_func(inputs[0].shape, inputs[1], inputs[2],
-                                     inputs[3], convert(depth_axis),
-                                     convert(height_axis),
-                                     convert(width_axis))]
+    return [
+        _upsampling3d_shape_func(
+            inputs[0].shape,
+            inputs[1],
+            inputs[2],
+            inputs[3],
+            convert(depth_axis),
+            convert(height_axis),
+            convert(width_axis),
+        )
+    ]
+
 
 # pad
 @script
@@ -121,6 +148,7 @@ def _dyn_pad_shape_func(data, pad_width):
         out[i] = int64(pad_width[i, 0] + pad_width[i, 1] + data.shape[i])
     return out
 
+
 @register_shape_func("dyn.nn.pad", True)
 def pad_shape_func(attrs, inputs, data):
     """
index 2cc3588..adbed84 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -38,6 +38,7 @@ def compute_resize(attrs, inputs, out_type):
     out_dtype = attrs.out_dtype
     return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
 
+
 reg.register_injective_schedule("image.resize")
 
 
@@ -50,6 +51,7 @@ def compute_resize3d(attrs, inputs, out_type):
     out_dtype = attrs.out_dtype
     return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)]
 
+
 reg.register_injective_schedule("image.resize3d")
 
 
@@ -61,15 +63,27 @@ def compute_crop_and_resize(attrs, inputs, out_type):
     method = attrs.method
     extrapolation_value = attrs.extrapolation_value
     out_dtype = attrs.out_dtype
-    return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
-                                       crop_size, layout, method,
-                                       extrapolation_value, out_dtype)]
+    return [
+        topi.image.crop_and_resize(
+            inputs[0],
+            inputs[1],
+            inputs[2],
+            crop_size,
+            layout,
+            method,
+            extrapolation_value,
+            out_dtype,
+        )
+    ]
+
 
 reg.register_injective_schedule("image.crop_and_resize")
 
+
 @script
-def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
-                          height_axis, width_axis, channel_axis):
+def _crop_and_resize_func(
+    image_shape, boxes_shape, crop_size, height_axis, width_axis, channel_axis
+):
     out = output_tensor((4,), "int64")
     out[0] = boxes_shape[0]
     out[height_axis] = int64(crop_size[0])
@@ -77,6 +91,7 @@ def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
     out[channel_axis] = image_shape[channel_axis]
     return out
 
+
 @reg.register_shape_func("image.crop_and_resize", False)
 def crop_and_resize_func(attrs, inputs, _):
     """
@@ -92,8 +107,16 @@ def crop_and_resize_func(attrs, inputs, _):
         if letter == "C":
             channel_axis = i
     crop_size = get_const_tuple(attrs.crop_size)
-    return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size),
-                                  convert(height_axis), convert(width_axis), convert(channel_axis))]
+    return [
+        _crop_and_resize_func(
+            inputs[0],
+            inputs[1],
+            convert(crop_size),
+            convert(height_axis),
+            convert(width_axis),
+            convert(channel_axis),
+        )
+    ]
 
 
 # dilation2d
@@ -107,6 +130,7 @@ def compute_affine_grid(attrs, inputs, out_dtype):
     target_shape = get_const_tuple(attrs.target_shape)
     return [topi.image.affine_grid(inputs[0], target_shape)]
 
+
 reg.register_injective_schedule("image.affine_grid")
 
 
@@ -117,4 +141,5 @@ def compute_grid_sample(attrs, inputs, out_dtype):
     layout = attrs.layout
     return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)]
 
+
 reg.register_injective_schedule("image.grid_sample")
index 607e1d3..a3f3a3e 100644 (file)
@@ -20,12 +20,14 @@ from ..dyn.image import _make as _dyn_make
 from ...expr import Expr
 
 
-def resize(data,
-           size,
-           layout="NCHW",
-           method="bilinear",
-           coordinate_transformation_mode="half_pixel",
-           out_dtype=None):
+def resize(
+    data,
+    size,
+    layout="NCHW",
+    method="bilinear",
+    coordinate_transformation_mode="half_pixel",
+    out_dtype=None,
+):
     """Image resize operator.
 
     This operator takes data as input and does 2D scaling to the given scale factor.
@@ -65,17 +67,20 @@ def resize(data,
         The resized result.
     """
     if isinstance(size, Expr):
-        return _dyn_make.resize(data, size, layout, method, coordinate_transformation_mode,
-                                out_dtype)
+        return _dyn_make.resize(
+            data, size, layout, method, coordinate_transformation_mode, out_dtype
+        )
     return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
 
 
-def resize3d(data,
-             size,
-             layout="NCDHW",
-             method="trilinear",
-             coordinate_transformation_mode="half_pixel",
-             out_dtype=None):
+def resize3d(
+    data,
+    size,
+    layout="NCDHW",
+    method="trilinear",
+    coordinate_transformation_mode="half_pixel",
+    out_dtype=None,
+):
     """Image resize 3D operator.
 
     This operator takes data as input and does 3D scaling to the given scale factor.
@@ -116,14 +121,16 @@ def resize3d(data,
     return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype)
 
 
-def crop_and_resize(data,
-                    boxes,
-                    box_indices,
-                    crop_size,
-                    layout,
-                    method="bilinear",
-                    extrapolation_value=0,
-                    out_dtype=None):
+def crop_and_resize(
+    data,
+    boxes,
+    box_indices,
+    crop_size,
+    layout,
+    method="bilinear",
+    extrapolation_value=0,
+    out_dtype=None,
+):
     """Crop input images and resize them.
 
     method indicates the algorithm to be used while calculating the out value
@@ -162,18 +169,21 @@ def crop_and_resize(data,
     result: relay.Expr
         The computed result.
     """
-    return _make.crop_and_resize(data, boxes, box_indices, crop_size, layout, method,
-                                 extrapolation_value, out_dtype)
-
-
-def dilation2d(data,
-               weight,
-               strides=(1, 1),
-               padding=(0, 0),
-               dilations=(1, 1),
-               data_layout="NCHW",
-               kernel_layout="IHW",
-               out_dtype=""):
+    return _make.crop_and_resize(
+        data, boxes, box_indices, crop_size, layout, method, extrapolation_value, out_dtype
+    )
+
+
+def dilation2d(
+    data,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilations=(1, 1),
+    data_layout="NCHW",
+    kernel_layout="IHW",
+    out_dtype="",
+):
     r"""Morphological Dilation 2D.
     This operator takes the weight as the dilation kernel and dilates it with
     data to produce an output. In the default case, where the data_layout is `NCHW`
@@ -219,8 +229,9 @@ def dilation2d(data,
         The computed result.
     """
 
-    return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout,
-                            out_dtype)
+    return _make.dilation2d(
+        data, weight, strides, padding, dilations, data_layout, kernel_layout, out_dtype
+    )
 
 
 def affine_grid(data, target_shape=None):
@@ -246,7 +257,7 @@ def affine_grid(data, target_shape=None):
     return _make.affine_grid(data, target_shape)
 
 
-def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+def grid_sample(data, grid, method="bilinear", layout="NCHW"):
     """Applies bilinear sampling to input feature map.
 
     Given :math:`data` and :math:`grid`, then the output is computed by
index b426a0e..57aa7e4 100644 (file)
@@ -19,7 +19,8 @@
 from __future__ import absolute_import as _abs
 from . import _make
 
-def alloc_tensor(storage, offset, shape, dtype='float32', assert_shape=None):
+
+def alloc_tensor(storage, offset, shape, dtype="float32", assert_shape=None):
     """Allocate a tensor with the provided shape, and dtype.
 
     Parameters
@@ -45,7 +46,8 @@ def alloc_tensor(storage, offset, shape, dtype='float32', assert_shape=None):
     """
     return _make.alloc_tensor(storage, offset, shape, dtype, assert_shape)
 
-def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
+
+def alloc_storage(size, alignment, ctx, dtype_hint="float32"):
     """Allocate a piece of tensor storage.
 
     Parameters
@@ -64,6 +66,7 @@ def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
     """
     return _make.alloc_storage(size, alignment, ctx, dtype_hint)
 
+
 def flatten_tuple_type(ty):
     """Return a sequence of the types contained in the tuple type in order.
 
@@ -79,6 +82,7 @@ def flatten_tuple_type(ty):
     """
     return _make.FlattenTupleType(ty)
 
+
 def from_tuple_type(ty, expr):
     """Convert an expression with the given type into a sequence of expressions.
        Each expression maps to a field of the tuple or nested tuples in linear
@@ -99,6 +103,7 @@ def from_tuple_type(ty, expr):
     """
     return _make.FromTupleType(ty, expr)
 
+
 def to_tuple_type(ty, exprs):
     """Pack the sequence of expressions into the nested tuple type.
 
index 02cf78d..df29d88 100644 (file)
@@ -50,9 +50,10 @@ reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # fifo_buffer
-@reg.register_compute('nn.fifo_buffer')
+@reg.register_compute("nn.fifo_buffer")
 def compute_fifo_buffer(attrs, inputs, out_type):
-    return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]
+    return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int("axis"))]
+
 
 reg.register_injective_schedule("nn.fifo_buffer")
 reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
@@ -69,6 +70,7 @@ def compute_sparse_dense(attrs, inputs, out_type):
     """Compute definition of sparse_dense"""
     return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
 
+
 reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
 reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
@@ -79,6 +81,7 @@ def compute_sparse_transpose(attrs, inputs, out_type):
     """Compute definition of sparse_transpose"""
     return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])
 
+
 reg.register_schedule("nn.sparse_transpose", strategy.schedule_sparse_transpose)
 reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
@@ -92,11 +95,13 @@ reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)
 reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)
 reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_alter_op_layout("nn.conv2d")
 def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
     """Alternate the layout of conv2d"""
     return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
+
 @reg.register_legalize("nn.conv2d")
 def legalize_conv2d(attrs, inputs, types):
     """Legalize conv2d op.
@@ -117,6 +122,7 @@ def legalize_conv2d(attrs, inputs, types):
     """
     return topi.nn.conv2d_legalize(attrs, inputs, types)
 
+
 @reg.register_convert_op_layout("nn.conv2d")
 def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv2d op.
@@ -140,30 +146,35 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
     """
     # pylint: disable=import-outside-toplevel
     from tvm import relay
+
     data, weight = inputs
     new_attrs = dict(attrs)
     assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
     desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
     assert desired_data_layout != "default", "Data layout cannot be default"
-    new_attrs['data_layout'] = desired_data_layout
+    new_attrs["data_layout"] = desired_data_layout
 
     if desired_kernel_layout != "default":
-        new_attrs['kernel_layout'] = desired_kernel_layout
+        new_attrs["kernel_layout"] = desired_kernel_layout
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     # Handle default kernel layouts
-    if desired_data_layout == 'NCHW':
-        new_attrs['kernel_layout'] = 'OIHW'
+    if desired_data_layout == "NCHW":
+        new_attrs["kernel_layout"] = "OIHW"
         return relay.nn.conv2d(data, weight, **new_attrs)
-    elif desired_data_layout == 'NHWC':
+    elif desired_data_layout == "NHWC":
         # Check for depthwise convolution.
         data_info, weight_info = tinfos
-        if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
-                               weight_info.shape, attrs['kernel_layout'],
-                               attrs['groups']):
-            new_attrs['kernel_layout'] = 'HWOI'
+        if is_depthwise_conv2d(
+            data_info.shape,
+            attrs["data_layout"],
+            weight_info.shape,
+            attrs["kernel_layout"],
+            attrs["groups"],
+        ):
+            new_attrs["kernel_layout"] = "HWOI"
         else:
-            new_attrs['kernel_layout'] = 'HWIO'
+            new_attrs["kernel_layout"] = "HWIO"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     raise ValueError("Layout %s is not yet supported." % desired_data_layout)
@@ -194,6 +205,7 @@ def legalize_conv2d_transpose(attrs, inputs, types):
     """
     return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
 
+
 @reg.register_convert_op_layout("nn.conv2d_transpose")
 def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv2d_transpose op.
@@ -217,31 +229,34 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
     """
     # pylint: disable=import-outside-toplevel
     from tvm import relay
+
     data, weight = inputs
     new_attrs = dict(attrs)
     assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
     desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
     assert desired_data_layout != "default", "Data layout cannot be default"
-    new_attrs['data_layout'] = desired_data_layout
+    new_attrs["data_layout"] = desired_data_layout
 
     if desired_kernel_layout != "default":
-        new_attrs['kernel_layout'] = desired_kernel_layout
+        new_attrs["kernel_layout"] = desired_kernel_layout
         return relay.nn.conv2d_transpose(data, weight, **new_attrs)
 
     # Handle default kernel layouts
-    if desired_data_layout == 'NCHW':
-        new_attrs['kernel_layout'] = 'OIHW'
+    if desired_data_layout == "NCHW":
+        new_attrs["kernel_layout"] = "OIHW"
         return relay.nn.conv2d_transpose(data, weight, **new_attrs)
-    elif desired_data_layout == 'NHWC':
-        new_attrs['kernel_layout'] = 'HWIO'
+    elif desired_data_layout == "NHWC":
+        new_attrs["kernel_layout"] = "HWIO"
         return relay.nn.conv2d_transpose(data, weight, **new_attrs)
 
     raise ValueError("Layout %s is not yet supported." % desired_data_layout)
 
+
 # conv3d_transpose
 reg.register_strategy("nn.conv3d_transpose", strategy.conv3d_transpose_strategy)
 reg.register_pattern("nn.conv3d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_legalize("nn.conv3d_transpose")
 def legalize_conv3d_transpose(attrs, inputs, types):
     """Legalize conv3d_transpose op.
@@ -267,11 +282,13 @@ def legalize_conv3d_transpose(attrs, inputs, types):
 reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
 reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_alter_op_layout("nn.conv3d")
 def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
     """Alternate the layout of conv3d"""
     return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
 
+
 @reg.register_convert_op_layout("nn.conv3d")
 def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv3d op.
@@ -295,45 +312,51 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
     """
     # pylint: disable=import-outside-toplevel
     from tvm import relay
+
     data, weight = inputs
     new_attrs = dict(attrs)
     assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs"
     desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
     assert desired_data_layout != "default", "Data layout cannot be default"
-    new_attrs['data_layout'] = desired_data_layout
+    new_attrs["data_layout"] = desired_data_layout
 
     if desired_kernel_layout != "default":
-        new_attrs['kernel_layout'] = desired_kernel_layout
+        new_attrs["kernel_layout"] = desired_kernel_layout
         return relay.nn.conv3d(data, weight, **new_attrs)
 
     # Handle default kernel layouts
-    if desired_data_layout == 'NCDHW':
-        new_attrs['kernel_layout'] = 'OIDHW'
+    if desired_data_layout == "NCDHW":
+        new_attrs["kernel_layout"] = "OIDHW"
         return relay.nn.conv3d(data, weight, **new_attrs)
     elif desired_data_layout == "NDHWC":
-        new_attrs['kernel_layout'] = 'DHWIO'
+        new_attrs["kernel_layout"] = "DHWIO"
         return relay.nn.conv3d(data, weight, **new_attrs)
 
     raise ValueError("Layout %s is not yet supported" % desired_data_layout)
 
 
 # conv3d_winograd related operators
-reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
-                      strategy.conv3d_winograd_without_weight_transfrom_strategy)
-reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_strategy(
+    "nn.contrib_conv3d_winograd_without_weight_transform",
+    strategy.conv3d_winograd_without_weight_transfrom_strategy,
+)
+reg.register_pattern(
+    "nn.contrib_conv3d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE
+)
+
 
 @reg.register_compute("nn.contrib_conv3d_winograd_weight_transform")
 def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv3d_winograd_weight_transform"""
-    out = topi.nn.conv3d_winograd_weight_transform(
-        inputs[0], attrs.get_int('tile_size'))
+    out = topi.nn.conv3d_winograd_weight_transform(inputs[0], attrs.get_int("tile_size"))
     return [out]
 
-reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform",
-                      strategy.schedule_conv3d_winograd_weight_transform)
-reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
+reg.register_schedule(
+    "nn.contrib_conv3d_winograd_weight_transform",
+    strategy.schedule_conv3d_winograd_weight_transform,
+)
+reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # conv1d_transpose
@@ -436,8 +459,8 @@ reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE)
 def compute_lrn(attrs, inputs, out_dtype):
     """Compute definition of lrn"""
     assert len(inputs) == 1
-    return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
-                        attrs.alpha, attrs.beta, attrs.bias)]
+    return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, attrs.alpha, attrs.beta, attrs.bias)]
+
 
 reg.register_schedule("nn.lrn", strategy.schedule_lrn)
 reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
@@ -453,6 +476,7 @@ def compute_upsampling(attrs, inputs, out_dtype):
     align_corners = attrs.align_corners
     return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
 
+
 reg.register_injective_schedule("nn.upsampling")
 
 
@@ -465,8 +489,12 @@ def compute_upsampling3d(attrs, inputs, out_dtype):
     layout = attrs.layout
     method = attrs.method
     coordinate_transformation_mode = attrs.coordinate_transformation_mode
-    return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\
-        coordinate_transformation_mode)]
+    return [
+        topi.nn.upsampling3d(
+            inputs[0], scale_d, scale_h, scale_w, layout, method, coordinate_transformation_mode
+        )
+    ]
+
 
 reg.register_injective_schedule("nn.upsampling3d")
 
@@ -483,6 +511,7 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
     out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)
     return [out]
 
+
 reg.register_broadcast_schedule("nn.mirror_pad")
 
 
@@ -493,6 +522,7 @@ def _mirror_pad_func(data_shape, pad_width):
         out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1])
     return out
 
+
 @reg.register_shape_func("nn.mirror_pad", False)
 def mirror_pad_func(attrs, inputs, _):
     pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width]
@@ -500,65 +530,75 @@ def mirror_pad_func(attrs, inputs, _):
 
 
 # conv2d_winograd related operators
-reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
-                      strategy.conv2d_winograd_without_weight_transfrom_strategy)
-reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_strategy(
+    "nn.contrib_conv2d_winograd_without_weight_transform",
+    strategy.conv2d_winograd_without_weight_transfrom_strategy,
+)
+reg.register_pattern(
+    "nn.contrib_conv2d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE
+)
 
 # conv2d_gemm related operators
-reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform",
-                      strategy.conv2d_gemm_without_weight_transform_strategy)
-reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_strategy(
+    "nn.contrib_conv2d_gemm_without_weight_transform",
+    strategy.conv2d_gemm_without_weight_transform_strategy,
+)
+reg.register_pattern(
+    "nn.contrib_conv2d_gemm_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE
+)
+
 
 @reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
 def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_gemm_weight_transform"""
-    out = topi.nn.conv2d_gemm_weight_transform(
-        inputs[0], attrs.tile_rows, attrs.tile_cols)
+    out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, attrs.tile_cols)
     return [out]
 
-reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform",
-                      strategy.schedule_conv2d_gemm_weight_transform)
-reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
+reg.register_schedule(
+    "nn.contrib_conv2d_gemm_weight_transform", strategy.schedule_conv2d_gemm_weight_transform
+)
+reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE)
+
 
 @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
 def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_winograd_weight_transform"""
-    out = topi.nn.conv2d_winograd_weight_transform(
-        inputs[0], attrs.get_int('tile_size'))
+    out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int("tile_size"))
     return [out]
 
-reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform",
-                      strategy.schedule_conv2d_winograd_weight_transform)
-reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
+reg.register_schedule(
+    "nn.contrib_conv2d_winograd_weight_transform",
+    strategy.schedule_conv2d_winograd_weight_transform,
+)
+reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE)
+
 
 @reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
 def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
-    convolution_algorithm = attrs.get_int('convolution_algorithm')
+    convolution_algorithm = attrs.get_int("convolution_algorithm")
     out = topi.nn.conv2d_winograd_nnpack_weight_transform(
-        inputs[0], convolution_algorithm, out_dtype)
+        inputs[0], convolution_algorithm, out_dtype
+    )
     return [out]
 
-reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform",
-                      strategy.schedule_conv2d_winograd_nnpack_weight_transform)
-reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
-                     OpPattern.OPAQUE)
+
+reg.register_schedule(
+    "nn.contrib_conv2d_winograd_nnpack_weight_transform",
+    strategy.schedule_conv2d_winograd_nnpack_weight_transform,
+)
+reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE)
 
 
 # conv2d_NCHWc
 reg.register_strategy("nn.contrib_conv2d_NCHWc", strategy.conv2d_NCHWc_strategy)
-reg.register_pattern("nn.contrib_conv2d_NCHWc",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_pattern("nn.contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # depthwise_conv2d_NCHWc
-reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc",
-                      strategy.depthwise_conv2d_NCHWc_strategy)
-reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc", strategy.depthwise_conv2d_NCHWc_strategy)
+reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # deformable_conv2d
@@ -578,6 +618,7 @@ def compute_bitpack(attrs, inputs, out_dtype):
     out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, name)
     return [out]
 
+
 reg.register_schedule("nn.bitpack", strategy.schedule_bitpack)
 reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
 
@@ -586,6 +627,7 @@ reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
 reg.register_strategy("nn.bitserial_conv2d", strategy.bitserial_conv2d_strategy)
 reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_legalize("nn.bitserial_conv2d")
 def legalize_bitserial_conv2d(attrs, inputs, types):
     """Legalize bitserial_conv2d op.
@@ -618,6 +660,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
     x, y = inputs
     return [-topi.sum(topi.log(x) * y) / x.shape[0]]
 
+
 reg.register_reduce_schedule("nn.cross_entropy")
 reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 
@@ -627,6 +670,7 @@ reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 def compute_dilate(attrs, inputs, out_dtype):
     return [topi.nn.dilate(inputs[0], attrs.strides)]
 
+
 reg.register_broadcast_schedule("nn.dilate")
 reg.register_pattern("nn.dilate", OpPattern.INJECTIVE)
 
@@ -637,6 +681,7 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
     x, y = inputs
     return [-topi.sum(x * y) / x.shape[0]]
 
+
 reg.register_reduce_schedule("nn.cross_entropy_with_logits")
 reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
 
@@ -649,6 +694,7 @@ def compute_depth_to_space(attrs, inputs, out_dtype):
     mode = attrs.mode
     return [topi.nn.depth_to_space(inputs[0], block_size, layout=layout, mode=mode)]
 
+
 reg.register_injective_schedule("nn.depth_to_space")
 reg.register_pattern("nn.depth_to_space", OpPattern.INJECTIVE)
 
@@ -660,6 +706,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
     layout = attrs.layout
     return [topi.nn.space_to_depth(inputs[0], block_size, layout=layout)]
 
+
 reg.register_injective_schedule("nn.space_to_depth")
 reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)
 
@@ -673,6 +720,7 @@ reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE)
 #  Shape functions  #
 #####################
 
+
 @script
 def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
     out = output_tensor((dshape.shape[0],), "int64")
@@ -701,6 +749,7 @@ def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
     out[4] = int64(oc_bn)
     return out
 
+
 @reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
 def conv2d_NCHWc_shape_func(attrs, inputs, _):
     """
@@ -712,13 +761,20 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _):
     out_layout = attrs.out_layout
     oc_bn = int(out_layout[4:-1])
 
-    return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
-                                     convert(strides), convert(padding),
-                                     convert(dilation), convert(oc_bn))]
+    return [
+        _conv2d_NCHWc_shape_func(
+            inputs[0],
+            inputs[1],
+            convert(strides),
+            convert(padding),
+            convert(dilation),
+            convert(oc_bn),
+        )
+    ]
+
 
 @script
-def _pool2d_shape_func(data_shape, pool_size, strides,
-                       padding, height_axis, width_axis):
+def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis):
     out = output_tensor((data_shape.shape[0],), "int64")
     for i in const_range(data_shape.shape[0]):
         if i == height_axis:
@@ -730,6 +786,7 @@ def _pool2d_shape_func(data_shape, pool_size, strides,
 
     return out
 
+
 def pool2d_shape_func(attrs, inputs, _):
     """
     Shape function for pool2d op.
@@ -745,13 +802,22 @@ def pool2d_shape_func(attrs, inputs, _):
     elif len(padding) == 2:
         padding = [padding[0], padding[1], padding[0], padding[1]]
 
-    return [_pool2d_shape_func(inputs[0], convert(pool_size),
-                               convert(strides), convert(padding),
-                               convert(height_axis), convert(width_axis))]
+    return [
+        _pool2d_shape_func(
+            inputs[0],
+            convert(pool_size),
+            convert(strides),
+            convert(padding),
+            convert(height_axis),
+            convert(width_axis),
+        )
+    ]
+
 
 reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
 reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)
 
+
 @script
 def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
     out = output_tensor((data_shape.shape[0],), "int64")
@@ -763,6 +829,7 @@ def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
 
     return out
 
+
 def global_pool2d_shape_func(attrs, inputs, _):
     """
     Shape function for global pool2d op.
@@ -776,9 +843,11 @@ def global_pool2d_shape_func(attrs, inputs, _):
             width_axis = i
     return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]
 
+
 reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
 reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)
 
+
 @script
 def _batch_flatten_shape_func(data_shape):
     out = output_tensor((2,), "int64")
@@ -789,6 +858,7 @@ def _batch_flatten_shape_func(data_shape):
 
     return out
 
+
 @reg.register_shape_func("nn.batch_flatten", False)
 def batch_flatten_shape_func(attrs, inputs, _):
     """
@@ -796,6 +866,7 @@ def batch_flatten_shape_func(attrs, inputs, _):
     """
     return [_batch_flatten_shape_func(inputs[0])]
 
+
 @script
 def _dense_shape_func(data_shape, weight_shape):
     out = output_tensor((data_shape.shape[0],), "int64")
@@ -805,6 +876,7 @@ def _dense_shape_func(data_shape, weight_shape):
 
     return out
 
+
 @reg.register_shape_func("nn.dense", False)
 def dense_shape_func(attrs, inputs, _):
     """
@@ -813,6 +885,7 @@ def dense_shape_func(attrs, inputs, _):
     ret = [_dense_shape_func(inputs[0], inputs[1])]
     return ret
 
+
 @script
 def _pad_shape_func(data_shape, pad_width):
     out = output_tensor((data_shape.shape[0],), "int64")
@@ -821,6 +894,7 @@ def _pad_shape_func(data_shape, pad_width):
 
     return out
 
+
 @reg.register_shape_func("nn.pad", False)
 def pad_shape_func(attrs, inputs, _):
     """
@@ -831,6 +905,7 @@ def pad_shape_func(attrs, inputs, _):
         pad_width.append(get_const_tuple(pair))
     return [_pad_shape_func(inputs[0], convert(pad_width))]
 
+
 @script
 def _dilate_shape_func(data_shape, strides):
     out = output_tensor((data_shape.shape[0],), "int64")
@@ -839,6 +914,7 @@ def _dilate_shape_func(data_shape, strides):
 
     return out
 
+
 @reg.register_shape_func("nn.dilate", False)
 def dilate_shape_func(attrs, inputs, _):
     """
@@ -846,6 +922,7 @@ def dilate_shape_func(attrs, inputs, _):
     """
     return [_dilate_shape_func(inputs[0], convert(attrs.strides))]
 
+
 reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
 reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
 reg.register_shape_func("nn.relu", False, elemwise_shape_func)
index 587dbc7..70dc776 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, too-many-lines
+# pylint: disable=invalid-name, too-many-lines
 """Neural network operations."""
 from tvm.relay import expr
 
@@ -24,18 +24,20 @@ from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
 from ...expr import const, Expr
 
 
-def conv1d(data,
-           weight,
-           strides=1,
-           padding=0,
-           dilation=1,
-           groups=1,
-           channels=None,
-           kernel_size=None,
-           data_layout="NCW",
-           kernel_layout="OIW",
-           out_layout="",
-           out_dtype=""):
+def conv1d(
+    data,
+    weight,
+    strides=1,
+    padding=0,
+    dilation=1,
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCW",
+    kernel_layout="OIW",
+    out_layout="",
+    out_dtype="",
+):
     r"""1D convolution.
 
     This operator takes the weight as the convolution kernel
@@ -105,29 +107,42 @@ def conv1d(data,
         The computed result.
     """
     if isinstance(kernel_size, int):
-        kernel_size = (kernel_size, )
+        kernel_size = (kernel_size,)
     if isinstance(strides, int):
-        strides = (strides, )
+        strides = (strides,)
     if isinstance(dilation, int):
-        dilation = (dilation, )
+        dilation = (dilation,)
     padding = get_pad_tuple1d(padding)
-    return _make.conv1d(data, weight, strides, padding, dilation,
-                        groups, channels, kernel_size, data_layout,
-                        kernel_layout, out_layout, out_dtype)
-
-
-def conv2d(data,
-           weight,
-           strides=(1, 1),
-           padding=(0, 0),
-           dilation=(1, 1),
-           groups=1,
-           channels=None,
-           kernel_size=None,
-           data_layout="NCHW",
-           kernel_layout="OIHW",
-           out_layout="",
-           out_dtype=""):
+    return _make.conv1d(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def conv2d(
+    data,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""2D convolution.
 
     This operator takes the weight as the convolution kernel
@@ -205,23 +220,36 @@ def conv2d(data,
     # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.conv2d(data, weight, strides, padding, dilation,
-                        groups, channels, kernel_size, data_layout,
-                        kernel_layout, out_layout, out_dtype)
-
-
-def conv3d(data,
-           weight,
-           strides=(1, 1, 1),
-           padding=(0, 0, 0),
-           dilation=(1, 1, 1),
-           groups=1,
-           channels=None,
-           kernel_size=None,
-           data_layout="NCDHW",
-           kernel_layout="OIDHW",
-           out_layout="",
-           out_dtype=""):
+    return _make.conv2d(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def conv3d(
+    data,
+    weight,
+    strides=(1, 1, 1),
+    padding=(0, 0, 0),
+    dilation=(1, 1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCDHW",
+    kernel_layout="OIDHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""3D convolution.
 
     This operator takes the weight as the convolution kernel
@@ -297,24 +325,37 @@ def conv3d(data,
     if isinstance(dilation, int):
         dilation = (dilation, dilation, dilation)
     padding = get_pad_tuple3d(padding)
-    return _make.conv3d(data, weight, strides, padding, dilation,
-                        groups, channels, kernel_size, data_layout,
-                        kernel_layout, out_layout, out_dtype)
-
-
-def contrib_conv3d_winograd_without_weight_transform(data,
-                                                     weight,
-                                                     tile_size,
-                                                     strides=(1, 1, 1),
-                                                     padding=(0, 0, 0),
-                                                     dilation=(1, 1, 1),
-                                                     groups=1,
-                                                     channels=None,
-                                                     kernel_size=None,
-                                                     data_layout="NCDHW",
-                                                     kernel_layout="OIDHW",
-                                                     out_layout="",
-                                                     out_dtype=""):
+    return _make.conv3d(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def contrib_conv3d_winograd_without_weight_transform(
+    data,
+    weight,
+    tile_size,
+    strides=(1, 1, 1),
+    padding=(0, 0, 0),
+    dilation=(1, 1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCDHW",
+    kernel_layout="OIDHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""3D convolution with winograd algorithm.
 
     The basic parameters are the same as the ones in vanilla conv3d.
@@ -369,23 +410,37 @@ def contrib_conv3d_winograd_without_weight_transform(data,
     # convert 3-way padding to 6-way padding
     padding = get_pad_tuple3d(padding)
     return _make.contrib_conv3d_winograd_without_weight_transform(
-        data, weight, tile_size, strides, padding, dilation,
-        groups, channels, kernel_size, data_layout,
-        kernel_layout, out_layout, out_dtype)
-
-def conv3d_transpose(data,
-                     weight,
-                     strides=(1, 1, 1),
-                     padding=(0, 0, 0),
-                     dilation=(1, 1, 1),
-                     groups=1,
-                     channels=None,
-                     kernel_size=None,
-                     data_layout="NCDHW",
-                     kernel_layout="OIDHW",
-                     out_layout="",
-                     output_padding=(0, 0, 0),
-                     out_dtype=""):
+        data,
+        weight,
+        tile_size,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def conv3d_transpose(
+    data,
+    weight,
+    strides=(1, 1, 1),
+    padding=(0, 0, 0),
+    dilation=(1, 1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCDHW",
+    kernel_layout="OIDHW",
+    out_layout="",
+    output_padding=(0, 0, 0),
+    out_dtype="",
+):
     r"""3D transpose convolution.
 
     Parameters
@@ -440,23 +495,38 @@ def conv3d_transpose(data,
         dilation = (dilation, dilation, dilation)
     padding = get_pad_tuple3d(padding)
 
-    return _make.conv3d_transpose(data, weight, strides, padding, dilation,
-                                  groups, channels, kernel_size, data_layout,
-                                  kernel_layout, out_layout, output_padding, out_dtype)
-
-def conv2d_transpose(data,
-                     weight,
-                     strides=(1, 1),
-                     padding=(0, 0),
-                     dilation=(1, 1),
-                     groups=1,
-                     channels=None,
-                     kernel_size=None,
-                     data_layout="NCHW",
-                     kernel_layout="OIHW",
-                     out_layout="",
-                     output_padding=(0, 0),
-                     out_dtype=""):
+    return _make.conv3d_transpose(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        output_padding,
+        out_dtype,
+    )
+
+
+def conv2d_transpose(
+    data,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    output_padding=(0, 0),
+    out_dtype="",
+):
     """Two dimensional transposed convolution operator.
 
     Parameters
@@ -507,24 +577,38 @@ def conv2d_transpose(data,
     """
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.conv2d_transpose(data, weight, strides, padding, dilation,
-                                  groups, channels, kernel_size, data_layout,
-                                  kernel_layout, out_layout, output_padding, out_dtype)
-
-
-def conv1d_transpose(data,
-                     weight,
-                     strides=(1,),
-                     padding=(0,),
-                     dilation=(1,),
-                     groups=1,
-                     channels=None,
-                     kernel_size=None,
-                     data_layout="NCW",
-                     kernel_layout="OIW",
-                     out_layout="",
-                     output_padding=(0,),
-                     out_dtype=""):
+    return _make.conv2d_transpose(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        output_padding,
+        out_dtype,
+    )
+
+
+def conv1d_transpose(
+    data,
+    weight,
+    strides=(1,),
+    padding=(0,),
+    dilation=(1,),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCW",
+    kernel_layout="OIW",
+    out_layout="",
+    output_padding=(0,),
+    out_dtype="",
+):
     """One dimensional transposed convolution operator.
 
     Parameters
@@ -573,9 +657,21 @@ def conv1d_transpose(data,
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.conv1d_transpose(data, weight, strides, padding, dilation,
-                                  groups, channels, kernel_size, data_layout,
-                                  kernel_layout, out_layout, output_padding, out_dtype)
+    return _make.conv1d_transpose(
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        output_padding,
+        out_dtype,
+    )
 
 
 def softmax(data, axis=-1):
@@ -628,12 +724,7 @@ def log_softmax(data, axis=-1):
     return _make.log_softmax(data, axis)
 
 
-def max_pool1d(data,
-               pool_size=(1,),
-               strides=(1,),
-               padding=(0,),
-               layout="NCW",
-               ceil_mode=False):
+def max_pool1d(data, pool_size=(1,), strides=(1,), padding=(0,), layout="NCW", ceil_mode=False):
     r"""1D maximum pooling operator.
 
     This operator takes data as input and does 1D max value calculation
@@ -677,16 +768,12 @@ def max_pool1d(data,
     if isinstance(strides, int):
         strides = (strides,)
     padding = get_pad_tuple1d(padding)
-    return _make.max_pool1d(data, pool_size, strides, padding,
-                            layout, ceil_mode)
+    return _make.max_pool1d(data, pool_size, strides, padding, layout, ceil_mode)
 
 
-def max_pool2d(data,
-               pool_size=(1, 1),
-               strides=(1, 1),
-               padding=(0, 0),
-               layout="NCHW",
-               ceil_mode=False):
+def max_pool2d(
+    data, pool_size=(1, 1), strides=(1, 1), padding=(0, 0), layout="NCHW", ceil_mode=False
+):
     r"""2D maximum pooling operator.
 
     This operator takes data as input and does 2D max value calculation
@@ -738,15 +825,12 @@ def max_pool2d(data,
     if isinstance(strides, int):
         strides = (strides, strides)
     padding = get_pad_tuple2d(padding)
-    return _make.max_pool2d(data, pool_size, strides, padding,
-                            layout, ceil_mode)
-
-def max_pool3d(data,
-               pool_size=(1, 1, 1),
-               strides=(1, 1, 1),
-               padding=(0, 0, 0),
-               layout="NCDHW",
-               ceil_mode=False):
+    return _make.max_pool2d(data, pool_size, strides, padding, layout, ceil_mode)
+
+
+def max_pool3d(
+    data, pool_size=(1, 1, 1), strides=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", ceil_mode=False
+):
     r"""3D maximum pooling operator.
 
     This operator takes data as input and does 3D max value calculation
@@ -791,17 +875,18 @@ def max_pool3d(data,
     if isinstance(strides, int):
         strides = (strides, strides, strides)
     padding = get_pad_tuple3d(padding)
-    return _make.max_pool3d(data, pool_size, strides, padding,
-                            layout, ceil_mode)
-
-
-def avg_pool1d(data,
-               pool_size=(1,),
-               strides=(1,),
-               padding=(0,),
-               layout="NCW",
-               ceil_mode=False,
-               count_include_pad=False):
+    return _make.max_pool3d(data, pool_size, strides, padding, layout, ceil_mode)
+
+
+def avg_pool1d(
+    data,
+    pool_size=(1,),
+    strides=(1,),
+    padding=(0,),
+    layout="NCW",
+    ceil_mode=False,
+    count_include_pad=False,
+):
     r"""1D average pooling operator.
 
     This operator takes data as input and does 1D average value calculation
@@ -848,17 +933,18 @@ def avg_pool1d(data,
     if isinstance(strides, int):
         strides = (strides,)
     padding = get_pad_tuple1d(padding)
-    return _make.avg_pool1d(data, pool_size, strides, padding,
-                            layout, ceil_mode, count_include_pad)
-
-
-def avg_pool2d(data,
-               pool_size=(1, 1),
-               strides=(1, 1),
-               padding=(0, 0),
-               layout="NCHW",
-               ceil_mode=False,
-               count_include_pad=False):
+    return _make.avg_pool1d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad)
+
+
+def avg_pool2d(
+    data,
+    pool_size=(1, 1),
+    strides=(1, 1),
+    padding=(0, 0),
+    layout="NCHW",
+    ceil_mode=False,
+    count_include_pad=False,
+):
     r"""2D average pooling operator.
 
     This operator takes data as input and does 2D average value calculation
@@ -914,16 +1000,18 @@ def avg_pool2d(data,
     if isinstance(strides, int):
         strides = (strides, strides)
     padding = get_pad_tuple2d(padding)
-    return _make.avg_pool2d(data, pool_size, strides, padding,
-                            layout, ceil_mode, count_include_pad)
-
-def avg_pool3d(data,
-               pool_size=(1, 1, 1),
-               strides=(1, 1, 1),
-               padding=(0, 0, 0),
-               layout="NCDHW",
-               ceil_mode=False,
-               count_include_pad=False):
+    return _make.avg_pool2d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad)
+
+
+def avg_pool3d(
+    data,
+    pool_size=(1, 1, 1),
+    strides=(1, 1, 1),
+    padding=(0, 0, 0),
+    layout="NCDHW",
+    ceil_mode=False,
+    count_include_pad=False,
+):
     r"""3D average pooling operator.
 
     This operator takes data as input and does 3D average value calculation
@@ -971,16 +1059,12 @@ def avg_pool3d(data,
     if isinstance(strides, int):
         strides = (strides, strides, strides)
     padding = get_pad_tuple3d(padding)
-    return _make.avg_pool3d(data, pool_size, strides, padding,
-                            layout, ceil_mode, count_include_pad)
-
-def max_pool2d_grad(out_grad,
-                    data,
-                    pool_size=(1, 1),
-                    strides=(1, 1),
-                    padding=(0, 0),
-                    layout="NCHW",
-                    ceil_mode=False):
+    return _make.avg_pool3d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad)
+
+
+def max_pool2d_grad(
+    out_grad, data, pool_size=(1, 1), strides=(1, 1), padding=(0, 0), layout="NCHW", ceil_mode=False
+):
     r"""Gradient of 2D maximum pooling operator.
 
     This operator takes out_grad and data as input and calculates gradient of max_pool2d.
@@ -1013,17 +1097,19 @@ def max_pool2d_grad(out_grad,
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding,
-                                 layout, ceil_mode)
-
-def avg_pool2d_grad(out_grad,
-                    data,
-                    pool_size=(1, 1),
-                    strides=(1, 1),
-                    padding=(0, 0),
-                    layout="NCHW",
-                    ceil_mode=False,
-                    count_include_pad=False):
+    return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding, layout, ceil_mode)
+
+
+def avg_pool2d_grad(
+    out_grad,
+    data,
+    pool_size=(1, 1),
+    strides=(1, 1),
+    padding=(0, 0),
+    layout="NCHW",
+    ceil_mode=False,
+    count_include_pad=False,
+):
     r"""Gradient of 2D average pooling operator.
 
     This operator takes out_grad and data as input and calculates gradient of avg_pool2d.
@@ -1059,11 +1145,12 @@ def avg_pool2d_grad(out_grad,
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.avg_pool2d_grad(out_grad, data, pool_size, strides, padding,
-                                 layout, ceil_mode, count_include_pad)
+    return _make.avg_pool2d_grad(
+        out_grad, data, pool_size, strides, padding, layout, ceil_mode, count_include_pad
+    )
+
 
-def global_max_pool2d(data,
-                      layout="NCHW"):
+def global_max_pool2d(data, layout="NCHW"):
     r"""2D global maximum pooling operator.
 
     This operator takes data as input and does 2D max value calculation
@@ -1096,8 +1183,8 @@ def global_max_pool2d(data,
     """
     return _make.global_max_pool2d(data, layout)
 
-def global_avg_pool2d(data,
-                      layout="NCHW"):
+
+def global_avg_pool2d(data, layout="NCHW"):
     r"""2D global average pooling operator.
 
     This operator takes data as input and does 2D average value calculation
@@ -1131,12 +1218,9 @@ def global_avg_pool2d(data,
     return _make.global_avg_pool2d(data, layout)
 
 
-def upsampling(data,
-               scale_h=1,
-               scale_w=1,
-               layout="NCHW",
-               method="nearest_neighbor",
-               align_corners=False):
+def upsampling(
+    data, scale_h=1, scale_w=1, layout="NCHW", method="nearest_neighbor", align_corners=False
+):
     """Upsampling.
 
     This operator takes data as input and does 2D scaling to the given scale factor.
@@ -1181,13 +1265,15 @@ def upsampling(data,
     return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
 
 
-def upsampling3d(data,
-                 scale_d=1,
-                 scale_h=1,
-                 scale_w=1,
-                 layout="NCDHW",
-                 method="nearest_neighbor",
-                 coordinate_transformation_mode="half_pixel"):
+def upsampling3d(
+    data,
+    scale_d=1,
+    scale_h=1,
+    scale_w=1,
+    layout="NCDHW",
+    method="nearest_neighbor",
+    coordinate_transformation_mode="half_pixel",
+):
     """3D Upsampling.
 
     This operator takes data as input and does 3D scaling to the given scale factor.
@@ -1236,10 +1322,12 @@ def upsampling3d(data,
             scale_h = const(scale_h, "float64")
         if not isinstance(scale_w, Expr):
             scale_w = const(scale_w, "float64")
-        return _dyn_make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
-                                      coordinate_transformation_mode)
-    return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
-                              coordinate_transformation_mode)
+        return _dyn_make.upsampling3d(
+            data, scale_d, scale_h, scale_w, layout, method, coordinate_transformation_mode
+        )
+    return _make.upsampling3d(
+        data, scale_d, scale_h, scale_w, layout, method, coordinate_transformation_mode
+    )
 
 
 def batch_flatten(data):
@@ -1425,10 +1513,7 @@ def prelu(data, alpha, axis=1):
     return _make.prelu(data, alpha, axis)
 
 
-def pad(data,
-        pad_width,
-        pad_value=0,
-        pad_mode='constant'):
+def pad(data, pad_width, pad_value=0, pad_mode="constant"):
     r"""Padding
 
     This operator takes in a tensor and pads each axis by the specified
@@ -1452,7 +1537,7 @@ def pad(data,
     result : tvm.relay.Expr
         The computed result.
     """
-    if (isinstance(pad_width, Expr) or (isinstance(pad_value, Expr))):
+    if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)):
         if not isinstance(pad_width, Expr):
             pad_width = const(list(pad_width))
         if not isinstance(pad_value, Expr):
@@ -1480,9 +1565,7 @@ def dilate(data, strides):
     return _make.dilate(data, strides)
 
 
-def mirror_pad(data,
-               pad_width,
-               mode="SYMMETRIC"):
+def mirror_pad(data, pad_width, mode="SYMMETRIC"):
     r"""MirrorPadding
 
     This operator takes in a tensor and pads each axis by the specified
@@ -1506,7 +1589,7 @@ def mirror_pad(data,
     return _make.mirror_pad(data, pad_width, mode)
 
 
-def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
+def lrn(data, size=5, axis=1, bias=2, alpha=0.00001, beta=0.75):
     """This operator takes data as input and does local response normalization.
 
     Normalize the input in a local region across or within feature maps.
@@ -1616,15 +1699,9 @@ def dropout_raw(data, rate=0.5):
     return _make.dropout(data, rate)
 
 
-def batch_norm(data,
-               gamma,
-               beta,
-               moving_mean,
-               moving_var,
-               axis=1,
-               epsilon=1e-5,
-               center=True,
-               scale=True):
+def batch_norm(
+    data, gamma, beta, moving_mean, moving_var, axis=1, epsilon=1e-5, center=True, scale=True
+):
     r"""
     Batch normalization layer (Ioffe and Szegedy, 2014).
     Normalizes the input at each batch, i.e. applies a transformation
@@ -1704,25 +1781,13 @@ def batch_norm(data,
         new running mean (k-length vector),
         and new running variance (k-length vector)
     """
-    result = _make.batch_norm(data,
-                              gamma,
-                              beta,
-                              moving_mean,
-                              moving_var,
-                              axis,
-                              epsilon,
-                              center,
-                              scale)
+    result = _make.batch_norm(
+        data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale
+    )
     return expr.TupleWrapper(result, 3)
 
 
-def instance_norm(data,
-                  gamma,
-                  beta,
-                  axis=1,
-                  epsilon=1e-5,
-                  center=True,
-                  scale=True):
+def instance_norm(data, gamma, beta, axis=1, epsilon=1e-5, center=True, scale=True):
     r"""
     Instance Normalization (Ulyanov and et al., 2016)
     Applies instance normalization to the n-dimensional input array.
@@ -1783,13 +1848,7 @@ def instance_norm(data,
     return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)
 
 
-def layer_norm(data,
-               gamma,
-               beta,
-               axis=-1,
-               epsilon=1e-5,
-               center=True,
-               scale=True):
+def layer_norm(data, gamma, beta, axis=-1, epsilon=1e-5, center=True, scale=True):
     r"""
     Layer normalization (Lei Ba and et al., 2016).
     Applies layer normalization to the n-dimensional input array.
@@ -1841,14 +1900,7 @@ def layer_norm(data,
     return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)
 
 
-def group_norm(data,
-               gamma,
-               beta,
-               num_groups,
-               axis=1,
-               epsilon=1e-5,
-               center=True,
-               scale=True):
+def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True, scale=True):
     r"""
     Group normalization normalizes over group of channels for each training examples.
     We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
@@ -1934,6 +1986,7 @@ def batch_matmul(x, y):
     """
     return _make.batch_matmul(x, y)
 
+
 def sparse_dense(data, weight):
     r"""
     Computes the matrix multiplication of `data` and `weight`, where `data` is
@@ -1967,6 +2020,7 @@ def sparse_dense(data, weight):
     """
     return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
 
+
 def sparse_transpose(x):
     r"""
     Computes the fast matrix transpose of x,
@@ -1993,22 +2047,24 @@ def sparse_transpose(x):
         Tuple of output sparse tensor (same shape and format as input),
         i.e. if CSR then output is in ([data, indices, indptr]) form
     """
-    return expr.TupleWrapper(
-        _make.sparse_transpose(x.data, x.indices, x.indptr), 3)
-
-def contrib_conv2d_winograd_without_weight_transform(data,
-                                                     weight,
-                                                     tile_size,
-                                                     strides=(1, 1),
-                                                     padding=(0, 0),
-                                                     dilation=(1, 1),
-                                                     groups=1,
-                                                     channels=None,
-                                                     kernel_size=None,
-                                                     data_layout="NCHW",
-                                                     kernel_layout="OIHW",
-                                                     out_layout="",
-                                                     out_dtype=""):
+    return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
+
+
+def contrib_conv2d_winograd_without_weight_transform(
+    data,
+    weight,
+    tile_size,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""2D convolution with winograd algorithm.
 
     The basic parameters are the same as the ones in vanilla conv2d.
@@ -2063,23 +2119,36 @@ def contrib_conv2d_winograd_without_weight_transform(data,
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
     return _make.contrib_conv2d_winograd_without_weight_transform(
-        data, weight, tile_size, strides, padding, dilation,
-        groups, channels, kernel_size, data_layout,
-        kernel_layout, out_layout, out_dtype)
-
-
-def contrib_conv2d_gemm_without_weight_transform(data,
-                                                 weight,
-                                                 strides=(1, 1),
-                                                 padding=(0, 0),
-                                                 dilation=(1, 1),
-                                                 groups=1,
-                                                 channels=None,
-                                                 kernel_size=None,
-                                                 data_layout="NCHW",
-                                                 kernel_layout="OIHW",
-                                                 out_layout="",
-                                                 out_dtype=""):
+        data,
+        weight,
+        tile_size,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def contrib_conv2d_gemm_without_weight_transform(
+    data,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""2D convolution with gemm algorithm.
 
     The basic parameters are the same as the ones in vanilla conv2d.
@@ -2131,23 +2200,35 @@ def contrib_conv2d_gemm_without_weight_transform(data,
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
     return _make.contrib_conv2d_gemm_without_weight_transform(
-        data, weight, strides, padding, dilation,
-        groups, channels, kernel_size, data_layout,
-        kernel_layout, out_layout, out_dtype)
-
-
-def contrib_conv2d_nchwc(data,
-                         kernel,
-                         strides=(1, 1),
-                         padding=(0, 0),
-                         dilation=(1, 1),
-                         groups=1,
-                         channels=None,
-                         kernel_size=None,
-                         data_layout="NCHW8c",
-                         kernel_layout="OIHW",
-                         out_layout="",
-                         out_dtype=""):
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def contrib_conv2d_nchwc(
+    data,
+    kernel,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW8c",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""Variant of 2D convolution.
 
     This operator takes the weight as the convolution kernel
@@ -2199,22 +2280,36 @@ def contrib_conv2d_nchwc(data,
     """
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.contrib_conv2d_NCHWc(data, kernel, strides, padding, dilation,
-                                      groups, channels, kernel_size, data_layout,
-                                      kernel_layout, out_layout, out_dtype)
-
-def contrib_depthwise_conv2d_nchwc(data,
-                                   kernel,
-                                   strides=(1, 1),
-                                   padding=(0, 0),
-                                   dilation=(1, 1),
-                                   groups=1,
-                                   channels=None,
-                                   kernel_size=None,
-                                   data_layout="NCHW8c",
-                                   kernel_layout="OIHW",
-                                   out_layout="",
-                                   out_dtype=""):
+    return _make.contrib_conv2d_NCHWc(
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def contrib_depthwise_conv2d_nchwc(
+    data,
+    kernel,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW8c",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
     r"""Variant of 2D depthwise convolution.
 
     This operator takes the weight as the depthwise convolution kernel
@@ -2266,13 +2361,23 @@ def contrib_depthwise_conv2d_nchwc(data,
     """
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
-                                                groups, channels, kernel_size, data_layout,
-                                                kernel_layout, out_layout, out_dtype)
-
-
-def contrib_conv2d_winograd_weight_transform(weight,
-                                             tile_size):
+    return _make.contrib_depthwise_conv2d_NCHWc(
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def contrib_conv2d_winograd_weight_transform(weight, tile_size):
     r"""Weight Transformation part for 2D convolution with winograd algorithm.
 
     We separate this as a single op to enable pre-compute for inference.
@@ -2317,8 +2422,7 @@ def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
     return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)
 
 
-def contrib_conv3d_winograd_weight_transform(weight,
-                                             tile_size):
+def contrib_conv3d_winograd_weight_transform(weight, tile_size):
     r"""Weight Transformation part for 3D convolution with winograd algorithm.
 
     We separate this as a single op to enable pre-compute for inference.
@@ -2340,9 +2444,7 @@ def contrib_conv3d_winograd_weight_transform(weight,
     return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)
 
 
-def contrib_conv2d_winograd_nnpack_weight_transform(weight,
-                                                    convolution_algorithm,
-                                                    out_dtype=""):
+def contrib_conv2d_winograd_nnpack_weight_transform(weight, convolution_algorithm, out_dtype=""):
     r"""Weight Transformation part for 2D convolution with winograd algorithm.
 
     We separate this as a single op to enable pre-compute for inference.
@@ -2362,24 +2464,27 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
         The computed result.
     """
     return _make.contrib_conv2d_winograd_nnpack_weight_transform(
-        weight, convolution_algorithm, out_dtype)
-
-
-def deformable_conv2d(data,
-                      offset,
-                      weight,
-                      strides=(1, 1),
-                      padding=(0, 0),
-                      dilation=(1, 1),
-                      deformable_groups=1,
-                      groups=1,
-                      channels=None,
-                      kernel_size=None,
-                      data_layout='NCHW',
-                      kernel_layout='OIHW',
-                      out_layout='',
-                      out_dtype=''):
-    r""" Deformable 2d convolution.
+        weight, convolution_algorithm, out_dtype
+    )
+
+
+def deformable_conv2d(
+    data,
+    offset,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    deformable_groups=1,
+    groups=1,
+    channels=None,
+    kernel_size=None,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="",
+):
+    r"""Deformable 2d convolution.
 
     The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
 
@@ -2435,17 +2540,25 @@ def deformable_conv2d(data,
     """
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
-                                   deformable_groups, groups, channels, kernel_size, data_layout,
-                                   kernel_layout, out_layout, out_dtype)
-
-
-def bitpack(data,
-            bits=1,
-            pack_axis=1,
-            bit_axis=2,
-            pack_type="uint32",
-            name="BitPack"):
+    return _make.deformable_conv2d(
+        data,
+        offset,
+        weight,
+        strides,
+        padding,
+        dilation,
+        deformable_groups,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def bitpack(data, bits=1, pack_axis=1, bit_axis=2, pack_type="uint32", name="BitPack"):
     """Tensor packing for bitserial operations.
 
     The values along the input tensor's pack_axis are quantized
@@ -2486,19 +2599,21 @@ def bitpack(data,
     return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name)
 
 
-def bitserial_conv2d(data,
-                     weight,
-                     strides=(1, 1),
-                     padding=(0, 0),
-                     channels=None,
-                     kernel_size=(3, 3),
-                     activation_bits=1,
-                     weight_bits=1,
-                     data_layout='NCHW',
-                     kernel_layout='OIHW',
-                     pack_dtype='uint32',
-                     out_dtype='int16',
-                     unipolar=True):
+def bitserial_conv2d(
+    data,
+    weight,
+    strides=(1, 1),
+    padding=(0, 0),
+    channels=None,
+    kernel_size=(3, 3),
+    activation_bits=1,
+    weight_bits=1,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     r"""2D convolution using bitserial computation.
 
     Parameters
@@ -2546,20 +2661,33 @@ def bitserial_conv2d(data,
     """
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.bitserial_conv2d(data, weight, strides, padding, channels,
-                                  kernel_size, activation_bits, weight_bits,
-                                  data_layout, kernel_layout, pack_dtype,
-                                  out_dtype, unipolar)
-
-
-def bitserial_dense(data,
-                    weight,
-                    units=None,
-                    data_bits=1,
-                    weight_bits=1,
-                    pack_dtype='uint32',
-                    out_dtype='int16',
-                    unipolar=True):
+    return _make.bitserial_conv2d(
+        data,
+        weight,
+        strides,
+        padding,
+        channels,
+        kernel_size,
+        activation_bits,
+        weight_bits,
+        data_layout,
+        kernel_layout,
+        pack_dtype,
+        out_dtype,
+        unipolar,
+    )
+
+
+def bitserial_dense(
+    data,
+    weight,
+    units=None,
+    data_bits=1,
+    weight_bits=1,
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     """Bitserial Dense operator.
     Applies matrix multiplication of two quantized matrices
     using a fast bitserial algorithm.
@@ -2599,8 +2727,9 @@ def bitserial_dense(data,
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
-                                 pack_dtype, out_dtype, unipolar)
+    return _make.bitserial_dense(
+        data, weight, units, data_bits, weight_bits, pack_dtype, out_dtype, unipolar
+    )
 
 
 def cross_entropy(predictions, targets):
@@ -2641,7 +2770,7 @@ def cross_entropy_with_logits(predictions, targets):
     return _make.cross_entropy_with_logits(predictions, targets)
 
 
-def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
+def depth_to_space(data, block_size, layout="NCHW", mode="DCR"):
     """Convert channels into spatial blocks.
 
     Parameters
@@ -2668,7 +2797,7 @@ def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
     return _make.depth_to_space(data, block_size, layout, mode)
 
 
-def space_to_depth(data, block_size, layout='NCHW'):
+def space_to_depth(data, block_size, layout="NCHW"):
     """Convert spatial blocks into channels.
 
     Parameters
@@ -2691,9 +2820,7 @@ def space_to_depth(data, block_size, layout='NCHW'):
     return _make.space_to_depth(data, block_size, layout)
 
 
-def adaptive_max_pool2d(data,
-                        output_size=None,
-                        layout="NCHW"):
+def adaptive_max_pool2d(data, output_size=None, layout="NCHW"):
     r"""2D adaptive max pooling operator. This operator is experimental.
 
     This operator takes data as input and does 2D max value calculation
@@ -2738,9 +2865,7 @@ def adaptive_max_pool2d(data,
     return _make.adaptive_max_pool2d(data, output_size, layout)
 
 
-def adaptive_avg_pool2d(data,
-                        output_size=None,
-                        layout="NCHW"):
+def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"):
     r"""2D adaptive average pooling operator. This operator is experimental.
 
     This operator takes data as input and does 2D average value calculation
@@ -2785,9 +2910,7 @@ def adaptive_avg_pool2d(data,
     return _make.adaptive_avg_pool2d(data, output_size, layout)
 
 
-def adaptive_max_pool3d(data,
-                        output_size=None,
-                        layout="NCDHW"):
+def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"):
     r"""3D adaptive max pooling operator. This operator is experimental.
 
     This operator takes data as input and does 3D max value calculation
@@ -2831,9 +2954,7 @@ def adaptive_max_pool3d(data,
     return _make.adaptive_max_pool3d(data, output_size, layout)
 
 
-def adaptive_avg_pool3d(data,
-                        output_size=None,
-                        layout="NCDHW"):
+def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"):
     r"""3D adaptive avg pooling operator. This operator is experimental.
 
     This operator takes data as input and does 3D avg value calculation
@@ -2877,8 +2998,7 @@ def adaptive_avg_pool3d(data,
     return _make.adaptive_avg_pool3d(data, output_size, layout)
 
 
-def global_max_pool3d(data,
-                      layout="NCDHW"):
+def global_max_pool3d(data, layout="NCDHW"):
     r"""3D global maximum pooling operator.
 
     This operator takes data as input and does 3D max value calculation
@@ -2911,8 +3031,7 @@ def global_max_pool3d(data,
     return _make.adaptive_max_pool3d(data, output_size, layout)
 
 
-def global_avg_pool3d(data,
-                      layout="NCDHW"):
+def global_avg_pool3d(data, layout="NCDHW"):
     r"""3D global average pooling operator.
 
     This operator takes data as input and does 3D average value calculation
@@ -2946,8 +3065,9 @@ def global_avg_pool3d(data,
     return _make.adaptive_avg_pool3d(data, output_size, layout)
 
 
-def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, padding,
-                is_multiply, layout):
+def correlation(
+    data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, layout
+):
     r"""Applies correlation to inputs.
 
     The correlation layer performs multiplicative patch comparisons between two feature maps.
@@ -3025,5 +3145,6 @@ def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, p
     """
     if isinstance(padding, int):
         padding = (padding, padding)
-    return _make.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2,
-                             padding, is_multiply, layout)
+    return _make.correlation(
+        data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, layout
+    )
index 7fad9a2..755659a 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument,invalid-name
+# pylint: disable=unused-argument,invalid-name
 """The base node types for the Relay language."""
 import tvm._ffi
 import tvm.ir
@@ -48,6 +48,7 @@ class OpPattern(object):
     --------
     top.tag : Contains explanation of the tag type.
     """
+
     # Elementwise operator
     ELEMWISE = 0
     # Broadcast operator
@@ -67,6 +68,7 @@ class OpPattern(object):
 @tvm._ffi.register_object("relay.OpImplementation")
 class OpImplementation(Object):
     """Operator implementation"""
+
     def compute(self, attrs, inputs, out_type):
         """Call compute function.
 
@@ -118,6 +120,7 @@ class OpSpecialization(Object):
 @tvm._ffi.register_object("relay.OpStrategy")
 class OpStrategy(Object):
     """Operator strategy"""
+
     def __init__(self):
         self.__init_handle_by_constructor__(_make.OpStrategy)
 
@@ -147,6 +150,7 @@ def _wrap_default_fstrategy(compute, schedule, name):
         strategy = OpStrategy()
         strategy.add_implementation(compute, schedule, name=name)
         return strategy
+
     return _fstrategy
 
 
@@ -156,12 +160,12 @@ def _create_fstrategy_from_schedule(op_name, schedule):
     assert compute is not None, "FTVMCompute is not registered for op %s" % op_name
     fstrategy = get_native_generic_func("{}_strategy".format(op_name))
     name_pfx = schedule.__name__
-    name_pfx = name_pfx[name_pfx.index('_')+1:]
+    name_pfx = name_pfx[name_pfx.index("_") + 1 :]
     fstrategy.set_default(
-        _wrap_default_fstrategy(compute, schedule.fdefault, "%s.generic" % name_pfx))
+        _wrap_default_fstrategy(compute, schedule.fdefault, "%s.generic" % name_pfx)
+    )
     for key, sch in schedule.dispatch_dict.items():
-        fstrategy.register(
-            _wrap_default_fstrategy(compute, sch, "%s.%s" % (name_pfx, key)), [key])
+        fstrategy.register(_wrap_default_fstrategy(compute, sch, "%s.%s" % (name_pfx, key)), [key])
     return fstrategy
 
 
@@ -409,6 +413,7 @@ _schedule_reduce = None
 
 __DEBUG_COUNTER__ = 0
 
+
 def debug(expr, debug_func=None):
     """The main entry point to the debugger."""
     global __DEBUG_COUNTER__
@@ -418,8 +423,9 @@ def debug(expr, debug_func=None):
         tvm._ffi.register_func(name, debug_func)
         __DEBUG_COUNTER__ += 1
     else:
-        name = ''
+        name = ""
 
     return _make.debug(expr, name)
 
+
 tvm._ffi._init_api("relay.op", __name__)
index 7f91989..5dc2c24 100644 (file)
@@ -83,18 +83,22 @@ class FIFOBufferAttrs(Attrs):
 class UpSamplingAttrs(Attrs):
     """Attributes for nn.upsampling"""
 
+
 @tvm._ffi.register_object("relay.attrs.UpSampling3DAttrs")
 class UpSampling3DAttrs(Attrs):
     """Attributes for nn.upsampling3d"""
 
+
 @tvm._ffi.register_object("relay.attrs.PadAttrs")
 class PadAttrs(Attrs):
     """Attributes for nn.pad"""
 
+
 @tvm._ffi.register_object("relay.attrs.MirrorPadAttrs")
 class MirrorPadAttrs(Attrs):
     """Attributes for nn.mirror_pad"""
 
+
 @tvm._ffi.register_object("relay.attrs.LeakyReluAttrs")
 class LeakyReluAttrs(Attrs):
     """Attributes for nn.leaky_relu"""
@@ -189,10 +193,12 @@ class TransposeAttrs(Attrs):
 class ReshapeAttrs(Attrs):
     """Attributes for transform.reshape"""
 
+
 @tvm._ffi.register_object("relay.attrs.GatherAttrs")
 class GatherAttrs(Attrs):
     """Attributes for transform.gather"""
 
+
 @tvm._ffi.register_object("relay.attrs.TakeAttrs")
 class TakeAttrs(Attrs):
     """Attributes for transform.take"""
@@ -232,6 +238,7 @@ class TileAttrs(Attrs):
 class ReverseAttrs(Attrs):
     """Attributes used in reverse operators"""
 
+
 @tvm._ffi.register_object("relay.attrs.ReverseSequenceAttrs")
 class ReverseSequenceAttrs(Attrs):
     """Attributes used in reverse sequence operators"""
@@ -361,10 +368,12 @@ class BinaryDenseAttrs(Attrs):
 class Conv2DTransposeAttrs(Attrs):
     """Attributes used in Transposed Conv2D operators"""
 
+
 @tvm._ffi.register_object("relay.attrs.Conv3DTransposeAttrs")
 class Conv3DTransposeAttrs(Attrs):
     """Attributes used in Transposed Conv3D operators"""
 
+
 @tvm._ffi.register_object("relay.attrs.DilateAttrs")
 class DilateAttrs(Attrs):
     """Attributes used in dilate operators"""
index 99189f8..368ffb5 100644 (file)
@@ -22,6 +22,7 @@ from .tensor import sqrt, log, exp
 from .transform import squeeze
 from ..expr import Tuple, TupleWrapper
 
+
 def argmax(data, axis=None, keepdims=False, exclude=False):
     """Returns the indices of the maximum values along an axis.
 
@@ -52,6 +53,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
     axis = [axis] if isinstance(axis, int) else axis
     return _make.argmax(data, axis, keepdims, exclude)
 
+
 def argmin(data, axis=None, keepdims=False, exclude=False):
     """Returns the indices of the minimum values along an axis.
 
@@ -219,7 +221,7 @@ def any(data, axis=None, keepdims=False, exclude=False):
 
 
 def max(data, axis=None, keepdims=False, exclude=False):
-    """ Computes the max of array elements over given axes.
+    """Computes the max of array elements over given axes.
 
     Parameters
     ----------
index 0c4edbb..1833dfe 100644 (file)
@@ -24,7 +24,8 @@ from ....target import arm_isa
 from .generic import *
 from .. import op as _op
 
-logger = logging.getLogger('strategy')
+logger = logging.getLogger("strategy")
+
 
 @schedule_reduce.register("arm_cpu")
 def schedule_reduce_cpu(attrs, outs, target):
@@ -32,18 +33,21 @@ def schedule_reduce_cpu(attrs, outs, target):
     with target:
         return topi.x86.schedule_reduce(outs)
 
+
 @schedule_injective.register(["arm_cpu", "micro_dev"])
 def schedule_injective_arm_cpu(_, outs, target):
     """schedule injective ops for arm cpu"""
     with target:
         return topi.arm_cpu.schedule_injective(outs)
 
+
 @schedule_concatenate.register(["arm_cpu", "micro_dev"])
 def schedule_concatenate_arm_cpu(_, outs, target):
     """schedule concatenate for arm cpu"""
     with target:
         return topi.arm_cpu.schedule_concatenate(outs)
 
+
 @conv2d_strategy.register(["arm_cpu", "micro_dev"])
 def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d arm cpu strategy"""
@@ -67,56 +71,69 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.arm_cpu")
+                    name="conv2d_nchw_spatial_pack.arm_cpu",
+                )
 
                 # Intel x86 conv2d schedule.
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.x86.conv2d_nchw),
                     wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
-                    name="conv2d_nchw.x86")
+                    name="conv2d_nchw.x86",
+                )
 
                 # check if winograd algorithm is applicable
                 _, _, kh, kw = get_const_tuple(kernel.shape)
                 pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
-                is_winograd_applicable = "float" in data.dtype and \
-                                         "float" in kernel.dtype and \
-                                         kh == 3 and kw == 3 and \
-                                         stride_h == 1 and stride_w == 1 and \
-                                         dilation_h == 1 and dilation_w == 1
+                is_winograd_applicable = (
+                    "float" in data.dtype
+                    and "float" in kernel.dtype
+                    and kh == 3
+                    and kw == 3
+                    and stride_h == 1
+                    and stride_w == 1
+                    and dilation_h == 1
+                    and dilation_w == 1
+                )
                 if is_winograd_applicable:
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
                         wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
                         name="conv2d_nchw_winograd.arm_cpu",
-                        plevel=5)
+                        plevel=5,
+                    )
                     if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1:
                         strategy.add_implementation(
                             wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack),
                             wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack),
                             name="conv2d_nchw_winograd_nnpack.arm_cpu",
-                            plevel=15)
+                            plevel=15,
+                        )
             elif re.match(r"OIHW\d*o", kernel_layout):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.arm_cpu")
+                    name="conv2d_nchw_spatial_pack.arm_cpu",
+                )
             else:
-                raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
-                                   format(kernel_layout))
+                raise RuntimeError(
+                    "Unsupported weight layout {} for conv2d NCHW".format(kernel_layout)
+                )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             logger.warning("conv2d_hwcn is not optimized for arm cpu.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_hwcn),
                 wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
-                name="conv2d_hwcn.generic")
+                name="conv2d_hwcn.generic",
+            )
         elif layout == "NHWC":
             channels = data.shape[3]
             if "SMLAD" in isa and (channels % 4) == 0 and kernel_layout == "HWOI":
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.conv2d_direct_simd),
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd),
-                    name='conv2d_direct_simd.micro_dev')
+                    name="conv2d_direct_simd.micro_dev",
+                )
             elif kernel_layout == "HWIO":
                 is_aarch64 = "aarch64" in str(isa.target)
 
@@ -124,16 +141,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
                         wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
-                        name="conv2d_NHWC_quantized.arm_cpu")
+                        name="conv2d_NHWC_quantized.arm_cpu",
+                    )
 
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
-                    name="conv2d_nhwc_spatial_pack.arm_cpu")
+                    name="conv2d_nhwc_spatial_pack.arm_cpu",
+                )
             else:
-                raise RuntimeError("Unsupported kernel layout {} for conv2d NHWC".
-                                   format(kernel_layout))
-
+                raise RuntimeError(
+                    "Unsupported kernel layout {} for conv2d NHWC".format(kernel_layout)
+                )
 
         else:
             raise RuntimeError("Unsupported conv2d layout {} for arm cpu".format(layout))
@@ -145,7 +164,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw),
                     wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw),
-                    name="depthwise_conv2d_nchw.arm_cpu")
+                    name="depthwise_conv2d_nchw.arm_cpu",
+                )
 
             # TODO:
             # This schedule has incorrect result on some hardware platforms (like NV Jetson TX2)
@@ -164,29 +184,31 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
                     wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw),
-                    name="depthwise_conv2d_nchw.x86")
+                    name="depthwise_conv2d_nchw.x86",
+                )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.arm_cpu")
+                name="depthwise_conv2d_nhwc.arm_cpu",
+            )
         else:
-            raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".
-                               format(layout))
-    else: # group_conv2d
-        if layout == 'NCHW':
+            raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".format(layout))
+    else:  # group_conv2d
+        if layout == "NCHW":
             assert kernel_layout == "OIHW"
             logger.warning("group_conv2d with layout NCHW is not optimized for arm cpu.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
                 wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
-                name="group_conv2d_nchw.generic")
+                name="group_conv2d_nchw.generic",
+            )
         else:
-            raise RuntimeError("Unsupported group_conv2d layout {} for arm cpu".
-                               format(layout))
+            raise RuntimeError("Unsupported group_conv2d layout {} for arm cpu".format(layout))
     return strategy
 
+
 @conv2d_NCHWc_strategy.register("arm_cpu")
 def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d_NCHWc adopted from x86"""
@@ -194,9 +216,11 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
-        name="conv2d_NCHWc.x86")
+        name="conv2d_NCHWc.x86",
+    )
     return strategy
 
+
 @depthwise_conv2d_NCHWc_strategy.register("arm_cpu")
 def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
     """depthwise_conv2d_NCHWc adopted from x86"""
@@ -204,21 +228,25 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
-        name="depthwise_conv2d_NCHWc.x86")
+        name="depthwise_conv2d_NCHWc.x86",
+    )
     return strategy
 
+
 def wrap_compute_conv2d_winograd_nnpack(topi_compute):
     """wrap topi compute for conv2d_winograd NNPack"""
+
     def _compute_conv2d_nnpack(attrs, inputs, out_type):
         padding = attrs.get_int_tuple("padding")
         strides = attrs.get_int_tuple("strides")
         dilation = attrs.get_int_tuple("dilation")
         out_dtype = attrs.get_str("out_dtype")
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
-        return [topi_compute(inputs[0], inputs[1], None, strides, padding,
-                             dilation, out_dtype)]
+        return [topi_compute(inputs[0], inputs[1], None, strides, padding, dilation, out_dtype)]
+
     return _compute_conv2d_nnpack
 
+
 @conv2d_winograd_without_weight_transfrom_strategy.register("arm_cpu")
 def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom arm cpu strategy"""
@@ -241,24 +269,30 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
                 wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
-                name="conv2d_nchw_winograd.arm_cpu")
+                name="conv2d_nchw_winograd.arm_cpu",
+            )
         elif len(kernel.shape) == 4:
             # kernel must be packed by winograd nnpack
             assert "nnpack" in target.libs
             strategy.add_implementation(
                 wrap_compute_conv2d_winograd_nnpack(
-                    topi.arm_cpu.conv2d_nchw_winograd_nnpack_without_weight_transform),
+                    topi.arm_cpu.conv2d_nchw_winograd_nnpack_without_weight_transform
+                ),
                 wrap_topi_schedule(
-                    topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack_without_weight_transform),
+                    topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack_without_weight_transform
+                ),
                 name="conv2d_nchw_winograd_nnpack_withou_weight_transform.arm_cpu",
-                plevel=15)
+                plevel=15,
+            )
         else:
             raise RuntimeError("Unsupported kernel shape: {}".format(kernel.shape))
     else:
-        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
-                           format(layout))
+        raise RuntimeError(
+            "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
+        )
     return strategy
 
+
 def wrap_compute_conv2d_gemm(topi_compute):
     """wrap topi compute for conv2d_gemm"""
 
@@ -267,14 +301,18 @@ def wrap_compute_conv2d_gemm(topi_compute):
         strides = attrs.get_int_tuple("strides")
         dilation = attrs.get_int_tuple("dilation")
         out_dtype = attrs.get_str("out_dtype")
-        channels = attrs['channels']
-        kernel_size = attrs['kernel_size']
+        channels = attrs["channels"]
+        kernel_size = attrs["kernel_size"]
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
-        return [topi_compute(inputs[0], inputs[1], strides, padding,
-                             dilation, out_dtype, kernel_size, channels)]
+        return [
+            topi_compute(
+                inputs[0], inputs[1], strides, padding, dilation, out_dtype, kernel_size, channels
+            )
+        ]
 
     return _compute_conv2d_gemm
 
+
 @conv2d_gemm_without_weight_transform_strategy.register("arm_cpu")
 def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom arm cpu strategy"""
@@ -282,17 +320,20 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
     data = inputs[0]
     strategy = _op.OpStrategy()
 
-    if layout == "NHWC" and data.dtype in ['int8', 'uint8']:
+    if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
         strategy.add_implementation(
             wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
             wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
-            name="conv2d_NHWC_quantized_without_transform.arm_cpu")
+            name="conv2d_NHWC_quantized_without_transform.arm_cpu",
+        )
     else:
         raise RuntimeError(
-            "Unsupported conv2d_NHWC_quantized_without_transform layout {0} with datatype {1}".
-            format(layout, data.dtype))
+            "Unsupported conv2d_NHWC_quantized_without_transform layout {0}"
+            "with datatype {1}".format(layout, data.dtype)
+        )
     return strategy
 
+
 @conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"])
 def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d_transpose arm cpu strategy"""
@@ -306,9 +347,11 @@ def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d_transpose(topi.arm_cpu.conv2d_transpose_nchw),
         wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_transpose_nchw),
-        name="conv2d_tranpose_nchw.arm_cpu")
+        name="conv2d_tranpose_nchw.arm_cpu",
+    )
     return strategy
 
+
 @bitserial_conv2d_strategy.register("arm_cpu")
 def bitserial_conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
     """bitserial_conv2d x86 strategy"""
@@ -318,16 +361,19 @@ def bitserial_conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nchw),
             wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nchw),
-            name="bitserial_conv2d_nchw.arm_cpu")
+            name="bitserial_conv2d_nchw.arm_cpu",
+        )
     elif layout == "NHWC":
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.arm_cpu.bitserial_conv2d_nhwc),
             wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_conv2d_nhwc),
-            name="bitserial_conv2d_nhwc.arm_cpu")
+            name="bitserial_conv2d_nhwc.arm_cpu",
+        )
     else:
         raise ValueError("Data layout {} not supported.".format(layout))
     return strategy
 
+
 @bitserial_dense_strategy.register("arm_cpu")
 def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
     """bitserial_dense arm cpu strategy"""
@@ -335,5 +381,6 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_bitserial_dense(topi.arm_cpu.bitserial_dense),
         wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense),
-        name="bitserial_dense.arm_cpu")
+        name="bitserial_dense.arm_cpu",
+    )
     return strategy
index c975c36..24e68a4 100644 (file)
@@ -41,38 +41,49 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.bifrost")
+                    name="conv2d_nchw_spatial_pack.bifrost",
+                )
 
                 _, _, kh, kw = get_const_tuple(kernel.shape)
-                if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
-                   dilation_h == 1 and dilation_w == 1:
+                if (
+                    kh == 3
+                    and kw == 3
+                    and stride_h == 1
+                    and stride_w == 1
+                    and dilation_h == 1
+                    and dilation_w == 1
+                ):
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
                         wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
                         name="conv2d_nchw_winograd.bifrost",
-                        plevel=5)
+                        plevel=5,
+                    )
             elif re.match(r"OIHW\d*o", kernel_layout):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.bifrost")
+                    name="conv2d_nchw_spatial_pack.bifrost",
+                )
         else:
-            raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".
-                               format(layout))
+            raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.bifrost")
+                name="depthwise_conv2d_nchw.bifrost",
+            )
         else:
-            raise RuntimeError("Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".
-                               format(layout))
-    else: # group_conv2d
+            raise RuntimeError(
+                "Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".format(layout)
+            )
+    else:  # group_conv2d
         raise RuntimeError("group_conv2d is not supported for Mali(Bifrost)")
     return strategy
 
+
 @conv2d_winograd_without_weight_transfrom_strategy.register("bifrost")
 def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom mali(bifrost) strategy"""
@@ -88,17 +99,22 @@ def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out
         strategy.add_implementation(
             wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
             wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
-            name="conv2d_nchw_winograd.bifrost")
+            name="conv2d_nchw_winograd.bifrost",
+        )
     else:
-        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
-                           format(layout))
+        raise RuntimeError(
+            "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
+        )
     return strategy
 
+
 @dense_strategy.register("bifrost")
 def dense_strategy_bifrost(attrs, inputs, out_type, target):
     """dense mali(bifrost) strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_dense(topi.bifrost.dense),
-                                wrap_topi_schedule(topi.bifrost.schedule_dense),
-                                name="dense.bifrost")
+    strategy.add_implementation(
+        wrap_compute_dense(topi.bifrost.dense),
+        wrap_topi_schedule(topi.bifrost.schedule_dense),
+        name="dense.bifrost",
+    )
     return strategy
index 4b50937..e8132df 100644 (file)
@@ -24,42 +24,49 @@ from .generic import *
 from .. import op as _op
 from .... import get_global_func
 
+
 @schedule_injective.register(["cuda", "gpu"])
 def schedule_injective_cuda(attrs, outs, target):
     """schedule injective ops for cuda"""
     with target:
         return topi.cuda.schedule_injective(outs)
 
+
 @schedule_reduce.register(["cuda", "gpu"])
 def schedule_reduce_cuda(attrs, outs, target):
     """schedule reduction ops for cuda"""
     with target:
         return topi.cuda.schedule_reduce(outs)
 
+
 @schedule_concatenate.register(["cuda", "gpu"])
 def schedule_concatenate_cuda(attrs, outs, target):
     """schedule concatenate for cuda"""
     with target:
         return topi.cuda.schedule_injective(outs)
 
+
 @schedule_pool.register(["cuda", "gpu"])
 def schedule_pool_cuda(attrs, outs, target):
     """schedule pooling ops for cuda"""
     with target:
         return topi.cuda.schedule_pool(outs, attrs.layout)
 
+
 @schedule_pool_grad.register(["cuda", "gpu"])
 def schedule_pool_grad_cuda(attrs, outs, target):
     """schedule pooling gradient ops for cuda"""
     with target:
         return topi.cuda.schedule_pool_grad(outs)
 
+
 @schedule_adaptive_pool.register(["cuda", "gpu"])
 def schedule_adaptive_pool_cuda(attrs, outs, target):
     """schedule adaptive pooling ops for cuda"""
     with target:
         return topi.cuda.schedule_adaptive_pool(outs, attrs.layout)
 
+
 @softmax_strategy.register(["cuda", "gpu"])
 def softmax_strategy_cuda(attrs, inputs, out_type, target):
     """softmax cuda strategy"""
@@ -67,27 +74,32 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.cuda.schedule_softmax),
-        name="softmax.cuda")
+        name="softmax.cuda",
+    )
     if target.kind.name == "cuda" and "cudnn" in target.libs:
         strategy.add_implementation(
             wrap_compute_softmax(topi.cuda.softmax_cudnn),
             wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
             name="softmax.cudnn",
-            plevel=15)
+            plevel=15,
+        )
     return strategy
 
+
 @schedule_log_softmax.register(["cuda", "gpu"])
 def schedule_log_softmax_cuda(attrs, outs, target):
     """scheudle log_softmax for cuda"""
     with target:
         return topi.cuda.schedule_softmax(outs)
 
+
 @schedule_lrn.register(["cuda", "gpu"])
 def schedule_lrn_cuda(attrs, outs, target):
     """schedule LRN for cuda"""
     with target:
         return topi.cuda.schedule_lrn(outs)
 
+
 @conv2d_strategy.register(["cuda", "gpu"])
 def conv2d_strategy_cuda(attrs, inputs, out_type, target):
     """conv2d cuda strategy"""
@@ -105,73 +117,99 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
     if groups == 1:
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
-            if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
+            if data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8"):
                 assert data.dtype == kernel.dtype
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
-                    name="conv2d_nchw_int8.cuda")
+                    name="conv2d_nchw_int8.cuda",
+                )
             else:
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.cuda.conv2d_nchw),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
-                    name="conv2d_nchw.cuda")
+                    name="conv2d_nchw.cuda",
+                )
             _, _, kh, kw = get_const_tuple(kernel.shape)
-            if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
-                dilation_h == 1 and dilation_w == 1:
+            if (
+                2 < kh < 8
+                and 2 < kw < 8
+                and kh == kw
+                and stride_h == 1
+                and stride_w == 1
+                and dilation_h == 1
+                and dilation_w == 1
+            ):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
                     name="conv2d_nchw_winograd.cuda",
-                    plevel=5)
+                    plevel=5,
+                )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
-                name="conv2d_hwcn.cuda")
+                name="conv2d_hwcn.cuda",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
-                name="conv2d_nhwc.cuda")
+                name="conv2d_nhwc.cuda",
+            )
             N, H, W, _ = get_const_tuple(data.shape)
             KH, KW, CI, CO = get_const_tuple(kernel.shape)
             # Winograd shape related judgment
-            judge_winograd_tensorcore, judge_winograd_shape = winograd_judge(N, H, W, KH, KW,
-                                                                             CI, CO, padding,
-                                                                             stride_h, stride_w,
-                                                                             dilation_h, dilation_w,
-                                                                             pre_flag=False)
+            judge_winograd_tensorcore, judge_winograd_shape = winograd_judge(
+                N,
+                H,
+                W,
+                KH,
+                KW,
+                CI,
+                CO,
+                padding,
+                stride_h,
+                stride_w,
+                dilation_h,
+                dilation_w,
+                pre_flag=False,
+            )
             if judge_winograd_shape:
-                if target.kind.name == "cuda" and \
-                    nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
-                    judge_winograd_tensorcore:
+                if (
+                    target.kind.name == "cuda"
+                    and nvcc.have_tensorcore(tvm.gpu(0).compute_version)
+                    and judge_winograd_tensorcore
+                ):
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_tensorcore),
-                        wrap_topi_schedule(
-                            topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore),
+                        wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore),
                         name="conv2d_nhwc_winograd_tensorcore.cuda",
-                        plevel=5)
+                        plevel=5,
+                    )
                 else:
                     strategy.add_implementation(
-                        wrap_compute_conv2d(
-                            topi.cuda.conv2d_nhwc_winograd_direct),
-                        wrap_topi_schedule(
-                            topi.cuda.schedule_conv2d_nhwc_winograd_direct),
+                        wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
+                        wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
                         name="conv2d_nhwc_winograd_direct.cuda",
-                        plevel=5)
+                        plevel=5,
+                    )
             if target.kind.name == "cuda":
                 if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
-                    if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                            (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                            (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
+                    if (
+                        (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
+                        or (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
+                        or (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+                    ):
                         strategy.add_implementation(
                             wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
                             wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
                             name="conv2d_nhwc_tensorcore.cuda",
-                            plevel=20)
+                            plevel=20,
+                        )
         elif layout == "HWNC":
             assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"]
             _, _, N, in_channels = get_const_tuple(data.shape)
@@ -182,81 +220,91 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
             else:
                 _, _, out_channels, _ = get_const_tuple(kernel.shape)
             if topi.cuda.is_shape_tensorcore_direct_qualified(
-                    batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype):
+                batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype
+            ):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore),
                     name="conv2d_hwnc_tensorcore_direct.cuda",
-                    plevel=20)
+                    plevel=20,
+                )
             else:
-                raise RuntimeError("Unsupported shape for conv2d HWNC.\
-                                    Need to satisfy tensor core schedule.")
+                raise RuntimeError(
+                    "Unsupported shape for conv2d HWNC.\
+                                    Need to satisfy tensor core schedule."
+                )
         elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
-                name="conv2d_NCHWc_int8.cuda")
+                name="conv2d_NCHWc_int8.cuda",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
         # add cudnn implementation
         if target.kind.name == "cuda" and "cudnn" in target.libs:
-            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
-                    padding[1] == padding[3]:
+            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
                 strategy.add_implementation(
-                    wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
-                                        need_data_layout=True,
-                                        has_groups=True),
+                    wrap_compute_conv2d(
+                        topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
+                    ),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
                     name="conv2d_cudnn.cuda",
-                    plevel=25)
+                    plevel=25,
+                )
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.cuda")
+                name="depthwise_conv2d_nchw.cuda",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.cuda")
+                name="depthwise_conv2d_nhwc.cuda",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
+    else:  # group_conv2d
         # add cudnn implementation, if any
         cudnn_impl = False
         if target.kind.name == "cuda" and "cudnn" in target.libs:
-            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
-                    padding[1] == padding[3]:
+            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
                 strategy.add_implementation(
-                    wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
-                                        need_data_layout=True,
-                                        has_groups=True),
+                    wrap_compute_conv2d(
+                        topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
+                    ),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
                     name="conv2d_cudnn.cuda",
-                    plevel=25)
+                    plevel=25,
+                )
                 cudnn_impl = True
 
-        if layout == 'NCHW':
+        if layout == "NCHW":
             # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
                 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
-                name="group_conv2d_nchw.cuda")
-        elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
+                name="group_conv2d_nchw.cuda",
+            )
+        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
                 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
-                name="group_conv2d_NCHWc_int8.cuda")
+                name="group_conv2d_NCHWc_int8.cuda",
+            )
         elif not cudnn_impl:
             raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
     return strategy
 
+
 @conv2d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
 def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom cuda strategy"""
@@ -272,38 +320,57 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
     if layout == "NCHW":
         strategy.add_implementation(
             wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform),
-            wrap_topi_schedule(
-                topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform),
-            name="conv2d_nchw_winograd_without_weight_transform.cuda")
+            wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform),
+            name="conv2d_nchw_winograd_without_weight_transform.cuda",
+        )
     elif layout == "NHWC":
         N, H, W, _ = get_const_tuple(data.shape)
         alpha, _, CI, CO = get_const_tuple(kernel.shape)
         dilation_h, dilation_w = dilation
-        judge_winograd_tensorcore, _ = winograd_judge(N, H, W, alpha, alpha, CI, CO,
-                                                      padding, stride_h, stride_w,
-                                                      dilation_h, dilation_w,
-                                                      pre_flag=True)
-        if target.kind.name == "cuda" and \
-            nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
-            judge_winograd_tensorcore:
+        judge_winograd_tensorcore, _ = winograd_judge(
+            N,
+            H,
+            W,
+            alpha,
+            alpha,
+            CI,
+            CO,
+            padding,
+            stride_h,
+            stride_w,
+            dilation_h,
+            dilation_w,
+            pre_flag=True,
+        )
+        if (
+            target.kind.name == "cuda"
+            and nvcc.have_tensorcore(tvm.gpu(0).compute_version)
+            and judge_winograd_tensorcore
+        ):
             strategy.add_implementation(
                 wrap_compute_conv2d(
-                    topi.cuda.conv2d_nhwc_winograd_tensorcore_without_weight_transform),
+                    topi.cuda.conv2d_nhwc_winograd_tensorcore_without_weight_transform
+                ),
                 wrap_topi_schedule(
-                    topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform),
-                name="conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda")
+                    topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform
+                ),
+                name="conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda",
+            )
         else:
             strategy.add_implementation(
-                wrap_compute_conv2d(
-                    topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform),
+                wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform),
                 wrap_topi_schedule(
-                    topi.cuda.schedule_conv2d_nhwc_winograd_direct_without_weight_transform),
-                name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda")
+                    topi.cuda.schedule_conv2d_nhwc_winograd_direct_without_weight_transform
+                ),
+                name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
+            )
     else:
-        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
-                           format(layout))
+        raise RuntimeError(
+            "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
+        )
     return strategy
 
+
 @deformable_conv2d_strategy.register(["cuda", "gpu"])
 def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
     """deformable_conv2d cuda strategy"""
@@ -313,9 +380,11 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_deformable_conv2d(topi.cuda.deformable_conv2d_nchw),
         wrap_topi_schedule(topi.cuda.schedule_deformable_conv2d_nchw),
-        name="deformable_conv2d_nchw.cuda")
+        name="deformable_conv2d_nchw.cuda",
+    )
     return strategy
 
+
 @conv2d_transpose_strategy.register(["cuda", "gpu"])
 def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     """conv2d_transpose cuda strategy"""
@@ -329,7 +398,8 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
         wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
-        name="conv2d_transpose_nchw.cuda")
+        name="conv2d_transpose_nchw.cuda",
+    )
     return strategy
 
 
@@ -346,7 +416,8 @@ def conv3d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv3d_transpose(topi.cuda.conv3d_transpose_ncdhw),
         wrap_topi_schedule(topi.cuda.schedule_conv3d_transpose_ncdhw),
-        name="conv3d_transpose_ncdhw.cuda")
+        name="conv3d_transpose_ncdhw.cuda",
+    )
     return strategy
 
 
@@ -360,45 +431,61 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
     _, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
     assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
     if layout == "NCDHW":
-        strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
-                                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
-                                    name="conv3d_ncdhw.cuda",
-                                    plevel=10)
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
+            wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
+            name="conv3d_ncdhw.cuda",
+            plevel=10,
+        )
         _, _, _, kh, kw = get_const_tuple(kernel.shape)
-        if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \
-            stride_h == 1 and stride_w == 1 and \
-            dilation_h == 1 and dilation_w == 1:
+        if (
+            2 < kh < 8
+            and 2 < kw < 8
+            and kh == kw
+            and stride_h == 1
+            and stride_w == 1
+            and dilation_h == 1
+            and dilation_w == 1
+        ):
             strategy.add_implementation(
                 wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
                 wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
                 name="conv3d_ncdhw_winograd.cuda",
-                plevel=5)
+                plevel=5,
+            )
     else:  # layout == "NDHWC":
         strategy.add_implementation(
             wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
             wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
             name="conv3d_ndhwc.cuda",
-            plevel=10)
+            plevel=10,
+        )
         N, _, _, _, _ = get_const_tuple(data.shape)
         _, _, _, CI, CO = get_const_tuple(kernel.shape)
         if target.kind.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
-                if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
+                if (
+                    (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
+                    or (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
+                    or (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+                ):
                     strategy.add_implementation(
                         wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
                         wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
                         name="conv3d_ndhwc_tensorcore.cuda",
-                        plevel=20)
+                        plevel=20,
+                    )
 
     if target.kind.name == "cuda" and "cudnn" in target.libs:
-        strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
-                                    wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
-                                    name="conv3d_cudnn.cuda",
-                                    plevel=25)
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
+            wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
+            name="conv3d_cudnn.cuda",
+            plevel=25,
+        )
     return strategy
 
+
 @conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
 def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
     """conv3d_winograd_without_weight_transfrom cuda strategy"""
@@ -411,14 +498,16 @@ def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
     if layout == "NCDHW":
         strategy.add_implementation(
             wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform),
-            wrap_topi_schedule(
-                topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
-            name="conv3d_ncdhw_winograd_without_weight_transform.cuda")
+            wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
+            name="conv3d_ncdhw_winograd_without_weight_transform.cuda",
+        )
     else:
-        raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}".
-                           format(layout))
+        raise RuntimeError(
+            "Unsupported conv3d_winograd_without_weight_transfrom layout {}".format(layout)
+        )
     return strategy
 
+
 @conv1d_strategy.register(["cuda", "gpu"])
 def conv1d_strategy_cuda(attrs, inputs, out_type, target):
     """conv1d cuda strategy"""
@@ -428,17 +517,22 @@ def conv1d_strategy_cuda(attrs, inputs, out_type, target):
         raise ValueError("dilation should be a positive value")
     strategy = _op.OpStrategy()
     if layout == "NCW":
-        strategy.add_implementation(wrap_compute_conv1d(topi.cuda.conv1d_ncw),
-                                    wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
-                                    name="conv1d_ncw.cuda")
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.cuda.conv1d_ncw),
+            wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
+            name="conv1d_ncw.cuda",
+        )
     elif layout == "NWC":
-        strategy.add_implementation(wrap_compute_conv1d(topi.cuda.conv1d_nwc),
-                                    wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
-                                    name="conv1d_nwc.cuda")
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.cuda.conv1d_nwc),
+            wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
+            name="conv1d_nwc.cuda",
+        )
     else:
         raise ValueError("Unsupported conv1d layout {}".format(layout))
     return strategy
 
+
 @conv1d_transpose_strategy.register(["cuda", "gpu"])
 def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     """conv1d_transpose cuda strategy"""
@@ -449,11 +543,14 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     assert layout == "NCW", "conv1d_transpose ncw only supported"
     assert dilation == (1,), "conv1d_transpose dilation is not supported"
     assert groups == 1, "conv1d_transpose groups == 1 only supported"
-    strategy.add_implementation(wrap_compute_conv1d_transpose(topi.cuda.conv1d_transpose_ncw),
-                                wrap_topi_schedule(topi.cuda.schedule_conv1d_transpose_ncw),
-                                name="conv1d_transpose_ncw.cuda")
+    strategy.add_implementation(
+        wrap_compute_conv1d_transpose(topi.cuda.conv1d_transpose_ncw),
+        wrap_topi_schedule(topi.cuda.schedule_conv1d_transpose_ncw),
+        name="conv1d_transpose_ncw.cuda",
+    )
     return strategy
 
+
 @dense_strategy.register(["cuda", "gpu"])
 def dense_strategy_cuda(attrs, inputs, out_type, target):
     """dense cuda strategy"""
@@ -465,36 +562,44 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_int8),
             wrap_topi_schedule(topi.cuda.schedule_dense_int8),
-            name="dense_int8.cuda")
+            name="dense_int8.cuda",
+        )
     else:
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_small_batch),
             wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
-            name="dense_small_batch.cuda")
+            name="dense_small_batch.cuda",
+        )
         with SpecializedCondition(b >= 32):
             strategy.add_implementation(
                 wrap_compute_dense(topi.cuda.dense_large_batch),
                 wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
                 name="dense_large_batch.cuda",
-                plevel=5)
+                plevel=5,
+            )
         if target.kind.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
-                if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
-                        or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
-                        or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
+                if (
+                    (i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
+                    or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
+                    or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
+                ):
                     strategy.add_implementation(
                         wrap_compute_dense(topi.cuda.dense_tensorcore),
                         wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
                         name="dense_tensorcore.cuda",
-                        plevel=20)
+                        plevel=20,
+                    )
     if target.kind.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_cublas),
             wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
             name="dense_cublas.cuda",
-            plevel=25)
+            plevel=25,
+        )
     return strategy
 
+
 @batch_matmul_strategy.register(["cuda", "gpu"])
 def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
@@ -503,13 +608,15 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
         wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
         name="batch_matmul.cuda",
-        plevel=10)
+        plevel=10,
+    )
     if target.kind.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
             wrap_topi_schedule(topi.generic.schedule_extern),
             name="batch_matmul_cublas.cuda",
-            plevel=15)
+            plevel=15,
+        )
     return strategy
 
 
@@ -521,7 +628,8 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
         wrap_compute_sparse_dense(topi.cuda.sparse_dense),
         wrap_topi_schedule(topi.cuda.schedule_sparse_dense),
         name="sparse_dense.cuda",
-        plevel=10)
+        plevel=10,
+    )
     return strategy
 
 
@@ -532,28 +640,37 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_argsort(topi.cuda.argsort),
         wrap_topi_schedule(topi.cuda.schedule_argsort),
-        name="argsort.cuda")
+        name="argsort.cuda",
+    )
     if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
-        strategy.add_implementation(wrap_compute_argsort(topi.cuda.argsort_thrust),
-                                    wrap_topi_schedule(topi.cuda.schedule_argsort),
-                                    name="argsort_thrust.cuda",
-                                    plevel=15)
+        strategy.add_implementation(
+            wrap_compute_argsort(topi.cuda.argsort_thrust),
+            wrap_topi_schedule(topi.cuda.schedule_argsort),
+            name="argsort_thrust.cuda",
+            plevel=15,
+        )
     return strategy
 
+
 @topk_strategy.register(["cuda", "gpu"])
 def topk_strategy_cuda(attrs, inputs, out_type, target):
     """topk cuda strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_topk(topi.cuda.topk),
-                                wrap_topi_schedule(topi.cuda.schedule_topk),
-                                name="topk.cuda")
+    strategy.add_implementation(
+        wrap_compute_topk(topi.cuda.topk),
+        wrap_topi_schedule(topi.cuda.schedule_topk),
+        name="topk.cuda",
+    )
     if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
-        strategy.add_implementation(wrap_compute_topk(topi.cuda.topk_thrust),
-                                    wrap_topi_schedule(topi.cuda.schedule_topk),
-                                    name="topk_thrust.cuda",
-                                    plevel=15)
+        strategy.add_implementation(
+            wrap_compute_topk(topi.cuda.topk_thrust),
+            wrap_topi_schedule(topi.cuda.schedule_topk),
+            name="topk_thrust.cuda",
+            plevel=15,
+        )
     return strategy
 
+
 @multibox_prior_strategy.register(["cuda", "gpu"])
 def multibox_prior_strategy_cuda(attrs, inputs, out_type, target):
     """multibox_prior cuda strategy"""
@@ -561,9 +678,11 @@ def multibox_prior_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_multibox_prior(topi.cuda.multibox_prior),
         wrap_topi_schedule(topi.cuda.schedule_multibox_prior),
-        name="multibox_prior.cuda")
+        name="multibox_prior.cuda",
+    )
     return strategy
 
+
 @multibox_transform_loc_strategy.register(["cuda", "gpu"])
 def multibox_transform_loc_strategy_cuda(attrs, inputs, out_type, target):
     """multibox_transform_loc cuda strategy"""
@@ -571,9 +690,11 @@ def multibox_transform_loc_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_multibox_transform_loc(topi.cuda.multibox_transform_loc),
         wrap_topi_schedule(topi.cuda.schedule_multibox_transform_loc),
-        name="multibox_transform_loc.cuda")
+        name="multibox_transform_loc.cuda",
+    )
     return strategy
 
+
 @get_valid_counts_strategy.register(["cuda", "gpu"])
 def get_valid_counts_strategy_cuda(attrs, inputs, out_type, target):
     """get_valid_counts cuda strategy"""
@@ -581,9 +702,11 @@ def get_valid_counts_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_get_valid_counts(topi.cuda.get_valid_counts),
         wrap_topi_schedule(topi.cuda.schedule_get_valid_counts),
-        name="get_valid_counts.cuda")
+        name="get_valid_counts.cuda",
+    )
     return strategy
 
+
 @nms_strategy.register(["cuda", "gpu"])
 def nms_strategy_cuda(attrs, inputs, out_type, target):
     """nms cuda strategy"""
@@ -591,35 +714,45 @@ def nms_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_nms(topi.cuda.non_max_suppression),
         wrap_topi_schedule(topi.cuda.schedule_nms),
-        name="nms.cuda")
+        name="nms.cuda",
+    )
     return strategy
 
+
 @roi_align_strategy.register(["cuda", "gpu"])
 def roi_align_strategy_cuda(attrs, inputs, out_type, target):
     """roi_align cuda strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
-                                wrap_topi_schedule(topi.cuda.schedule_roi_align),
-                                name="roi_align_nchw.cuda")
+    strategy.add_implementation(
+        wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
+        wrap_topi_schedule(topi.cuda.schedule_roi_align),
+        name="roi_align_nchw.cuda",
+    )
     return strategy
 
+
 @schedule_roi_pool.register(["cuda", "gpu"])
 def schedule_roi_pool_cuda(attrs, outs, target):
     """schedule roi_pool for cuda"""
     with target:
         return topi.cuda.schedule_roi_pool(outs)
 
+
 @proposal_strategy.register(["cuda", "gpu"])
 def proposal_strategy_cuda(attrs, inputs, out_type, target):
     """proposal cuda strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_proposal(topi.cuda.proposal),
-                                wrap_topi_schedule(topi.cuda.schedule_proposal),
-                                name="proposal.cuda")
+    strategy.add_implementation(
+        wrap_compute_proposal(topi.cuda.proposal),
+        wrap_topi_schedule(topi.cuda.schedule_proposal),
+        name="proposal.cuda",
+    )
     return strategy
 
-def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h,
-                   stride_w, dilation_h, dilation_w, pre_flag):
+
+def winograd_judge(
+    N, H, W, KH, KW, CI, CO, padding, stride_h, stride_w, dilation_h, dilation_w, pre_flag
+):
     """Winograd judgement about tensorcore and shape"""
     if H % 8 == 0:
         tile_size = 4
@@ -633,14 +766,23 @@ def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h,
     OW = (W + pl + pr - KW) // stride_w + 1
     nH, nW = (OH + tile_size - 1) // tile_size, (OW + tile_size - 1) // tile_size
     P = N * nH * nW
-    judge_winograd_tensorcore = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                                   (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                                   (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
-    judge_winograd_shape = 2 < KH < 8 and 2 < KW < 8 and KH == KW and \
-                              stride_h == 1 and stride_w == 1 and \
-                              dilation_h == 1 and dilation_w == 1
+    judge_winograd_tensorcore = (
+        (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
+        or (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
+        or (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+    )
+    judge_winograd_shape = (
+        2 < KH < 8
+        and 2 < KW < 8
+        and KH == KW
+        and stride_h == 1
+        and stride_w == 1
+        and dilation_h == 1
+        and dilation_w == 1
+    )
     return judge_winograd_tensorcore, judge_winograd_shape
 
+
 @correlation_strategy.register(["cuda", "gpu"])
 def correlation_strategy_cuda(attrs, inputs, out_type, target):
     """correlation cuda strategy"""
@@ -650,5 +792,6 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_correlation(topi.cuda.correlation_nchw),
         wrap_topi_schedule(topi.cuda.schedule_correlation_nchw),
-        name="correlation.cuda")
+        name="correlation.cuda",
+    )
     return strategy
index 69c9bd7..070efa4 100644 (file)
@@ -24,15 +24,19 @@ from tvm.topi.util import get_const_int, get_const_float, get_const_tuple, get_f
 from .. import op as _op
 from ....target import generic_func, override_native_generic_func
 
-logger = logging.getLogger('strategy')
+logger = logging.getLogger("strategy")
+
 
 def wrap_topi_schedule(topi_schedule):
     """Wrap TOPI schedule which doesn't use attrs"""
+
     def wrapper(attrs, outs, target):
         with target:
             return topi_schedule(outs)
+
     return wrapper
 
+
 def get_conv2d_in_channels(data_shape, data_layout):
     """Get conv2d input channels"""
     data_shape = get_const_tuple(data_shape)
@@ -45,6 +49,7 @@ def get_conv2d_in_channels(data_shape, data_layout):
         return data_shape[1] * data_shape[4]
     raise ValueError("Unknown conv2d data layout {}".format(data_layout))
 
+
 def get_conv2d_out_channels(kernel_shape, kernel_layout):
     """Get conv2d output channels"""
     kernel_shape = get_const_tuple(kernel_shape)
@@ -58,23 +63,27 @@ def get_conv2d_out_channels(kernel_shape, kernel_layout):
         return kernel_shape[0] * kernel_shape[4]
     raise ValueError("Unknown conv2d kernel layout {}".format(kernel_layout))
 
+
 def is_depthwise_conv2d(data_shape, data_layout, kernel_shape, kernel_layout, groups):
     ic = get_conv2d_in_channels(data_shape, data_layout)
     oc = get_conv2d_out_channels(kernel_shape, kernel_layout)
     return ic == oc == groups
 
+
 @generic_func
 def schedule_injective(attrs, outs, target):
     """Schedule injective ops"""
     with target:
         return topi.generic.schedule_injective(outs)
 
+
 @generic_func
 def schedule_reduce(attrs, outs, target):
     """Schedule reduction ops"""
     with target:
         return topi.generic.schedule_reduce(outs)
 
+
 _op._schedule_injective = schedule_injective
 _op._schedule_reduce = schedule_reduce
 
@@ -85,6 +94,7 @@ def schedule_concatenate(attrs, outs, target):
     with target:
         return topi.generic.schedule_injective(outs)
 
+
 # pool
 @generic_func
 def schedule_pool(attrs, outs, target):
@@ -92,6 +102,7 @@ def schedule_pool(attrs, outs, target):
     with target:
         return topi.generic.schedule_pool(outs, attrs.layout)
 
+
 # pool_grad
 @generic_func
 def schedule_pool_grad(attrs, outs, target):
@@ -99,6 +110,7 @@ def schedule_pool_grad(attrs, outs, target):
     with target:
         return topi.generic.schedule_pool_grad(outs)
 
+
 # adaptive pool
 @generic_func
 def schedule_adaptive_pool(attrs, outs, target):
@@ -106,14 +118,18 @@ def schedule_adaptive_pool(attrs, outs, target):
     with target:
         return topi.generic.schedule_adaptive_pool(outs)
 
+
 # softmax
 def wrap_compute_softmax(topi_compute):
     """Wrap softmax topi compute"""
+
     def _compute_softmax(attrs, inputs, out_type):
         axis = attrs.get_int("axis")
         return [topi_compute(inputs[0], axis)]
+
     return _compute_softmax
 
+
 @override_native_generic_func("softmax_strategy")
 def softmax_strategy(attrs, inputs, out_type, target):
     """softmax generic strategy"""
@@ -121,9 +137,11 @@ def softmax_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.generic.schedule_softmax),
-        name="softmax.generic")
+        name="softmax.generic",
+    )
     return strategy
 
+
 # log_softmax
 @generic_func
 def schedule_log_softmax(attrs, outs, target):
@@ -131,6 +149,7 @@ def schedule_log_softmax(attrs, outs, target):
     with target:
         return topi.generic.schedule_softmax(outs)
 
+
 # lrn
 @generic_func
 def schedule_lrn(attrs, outs, target):
@@ -138,6 +157,7 @@ def schedule_lrn(attrs, outs, target):
     with target:
         return topi.generic.schedule_lrn(outs)
 
+
 # bitpack
 @generic_func
 def schedule_bitpack(attrs, outs, target):
@@ -145,10 +165,13 @@ def schedule_bitpack(attrs, outs, target):
     with target:
         return topi.generic.schedule_bitpack(outs)
 
+
 # conv2d
-def wrap_compute_conv2d(topi_compute, need_data_layout=False, need_out_layout=False,
-                        has_groups=False):
+def wrap_compute_conv2d(
+    topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False
+):
     """Wrap conv2d topi compute"""
+
     def _compute_conv2d(attrs, inputs, out_type):
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
@@ -156,8 +179,7 @@ def wrap_compute_conv2d(topi_compute, need_data_layout=False, need_out_layout=Fa
         data_layout = attrs.get_str("data_layout")
         out_layout = attrs.get_str("out_layout")
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         args = [inputs[0], inputs[1], strides, padding, dilation]
         if has_groups:
             args.append(attrs.groups)
@@ -167,8 +189,10 @@ def wrap_compute_conv2d(topi_compute, need_data_layout=False, need_out_layout=Fa
             args.append(out_layout)
         args.append(out_dtype)
         return [topi_compute(*args)]
+
     return _compute_conv2d
 
+
 @override_native_generic_func("conv2d_strategy")
 def conv2d_strategy(attrs, inputs, out_type, target):
     """conv2d generic strategy"""
@@ -189,19 +213,22 @@ def conv2d_strategy(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_nchw),
                 wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
-                name="conv2d_nchw.generic")
+                name="conv2d_nchw.generic",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_nhwc),
                 wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc),
-                name="conv2d_nhwc.generic")
+                name="conv2d_nhwc.generic",
+            )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_hwcn),
                 wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
-                name="conv2d_hwcn.generic")
+                name="conv2d_hwcn.generic",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {}".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
@@ -210,26 +237,30 @@ def conv2d_strategy(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.generic")
+                name="depthwise_conv2d_nchw.generic",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.generic")
+                name="depthwise_conv2d_nhwc.generic",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
-        if layout == 'NCHW':
+    else:  # group_conv2d
+        if layout == "NCHW":
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
                 wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
-                name="group_conv2d_nchw.generic")
+                name="group_conv2d_nchw.generic",
+            )
         else:
             raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
     return strategy
 
+
 # conv2d_NCHWc
 @override_native_generic_func("conv2d_NCHWc_strategy")
 def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
@@ -240,14 +271,17 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
             wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
-            name="conv2d_NCHWc_int8.generic")
+            name="conv2d_NCHWc_int8.generic",
+        )
     else:
         strategy.add_implementation(
             wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
             wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
-            name="conv2d_NCHWc.generic")
+            name="conv2d_NCHWc.generic",
+        )
     return strategy
 
+
 # depthwise_conv2d_NCHWc
 @override_native_generic_func("depthwise_conv2d_NCHWc_strategy")
 def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
@@ -257,21 +291,25 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
-        name="depthwise_conv2d_NCHWc.generic")
+        name="depthwise_conv2d_NCHWc.generic",
+    )
     return strategy
 
+
 # conv2d_winograd_without_weight_transform
 @override_native_generic_func("conv2d_winograd_without_weight_transform_strategy")
 def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom generic strategy"""
     raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
 
+
 # conv2d_gemm_without_weight_transform
 @override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
 def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
     """conv2d_gemm_without_weight_transfrom generic strategy"""
     raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")
 
+
 # conv2d_winograd_weight_transform
 @generic_func
 def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
@@ -279,6 +317,7 @@ def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
     with target:
         return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
 
+
 # conv2d_winograd_nnpack_weight_transform
 @generic_func
 def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
@@ -286,6 +325,7 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
     with target:
         return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
 
+
 # conv2d_gemm_weight_transform
 @generic_func
 def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
@@ -293,9 +333,11 @@ def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
     with target:
         return topi.generic.schedule_conv2d_gemm_weight_transform(outs)
 
+
 # deformable_conv2d
 def wrap_compute_deformable_conv2d(topi_compute):
     """wrap deformable_conv2d topi compute"""
+
     def _compute_deformable_conv2d(attrs, inputs, out_dtype):
         assert attrs.data_layout == "NCHW"
         padding = get_const_tuple(attrs.padding)
@@ -305,11 +347,22 @@ def wrap_compute_deformable_conv2d(topi_compute):
         groups = attrs.groups
         out_dtype = attrs.out_dtype
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
-        out = topi_compute(inputs[0], inputs[1], inputs[2], strides, padding,
-                           dilation, deformable_groups, groups, out_dtype)
+        out = topi_compute(
+            inputs[0],
+            inputs[1],
+            inputs[2],
+            strides,
+            padding,
+            dilation,
+            deformable_groups,
+            groups,
+            out_dtype,
+        )
         return [out]
+
     return _compute_deformable_conv2d
 
+
 @override_native_generic_func("deformable_conv2d_strategy")
 def deformable_conv2d_strategy(attrs, inputs, out_type, target):
     """deformable_conv2d generic strategy"""
@@ -320,25 +373,28 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
         wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
-        name="deformable_conv2d.generic")
+        name="deformable_conv2d.generic",
+    )
     return strategy
 
+
 # conv2d_transpose
 def wrap_compute_conv2d_transpose(topi_compute):
     """wrap conv2d_transpose topi compute"""
+
     def compute_conv2d_transpose(attrs, inputs, out_dtype):
         """Compute definition of conv2d_transpose"""
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         output_padding = get_const_tuple(attrs.output_padding)
-        out = topi_compute(
-            inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
+        out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
+
     return compute_conv2d_transpose
 
+
 @override_native_generic_func("conv2d_transpose_strategy")
 def conv2d_transpose_strategy(attrs, inputs, out_type, target):
     """conv2d_transpose generic strategy"""
@@ -353,24 +409,25 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
         wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
-        name="conv2d_transpose_nchw.generic")
+        name="conv2d_transpose_nchw.generic",
+    )
     return strategy
 
 
 # conv3d_transpose
 def wrap_compute_conv3d_transpose(topi_compute):
     """wrap conv3d_transpose topi compute"""
+
     def compute_conv3d_transpose(attrs, inputs, out_dtype):
         """Compute definition of conv3d_transpose"""
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
         output_padding = get_const_tuple(attrs.output_padding)
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
-        out = topi_compute(
-            inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
+
     return compute_conv3d_transpose
 
 
@@ -388,12 +445,15 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
         wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
-        name="conv3d_transpose_ncdhw.generic")
+        name="conv3d_transpose_ncdhw.generic",
+    )
     return strategy
 
+
 # conv3d
 def wrap_compute_conv3d(topi_compute, need_layout=False):
     """wrap conv3d topi compute"""
+
     def _compute_conv3d(attrs, inputs, out_type):
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
@@ -401,8 +461,7 @@ def wrap_compute_conv3d(topi_compute, need_layout=False):
         groups = attrs.groups
         layout = attrs.data_layout
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
 
         (dilation_d, dilation_h, dilation_w) = dilation
         if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
@@ -410,14 +469,14 @@ def wrap_compute_conv3d(topi_compute, need_layout=False):
         if groups != 1:
             raise ValueError("Not support arbitrary group number for conv3d")
         if need_layout:
-            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation,
-                               layout, out_dtype)
+            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
         else:
-            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation,
-                               out_dtype)
+            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)
         return [out]
+
     return _compute_conv3d
 
+
 @override_native_generic_func("conv3d_strategy")
 def conv3d_strategy(attrs, inputs, out_type, target):
     """conv3d generic strategy"""
@@ -428,22 +487,26 @@ def conv3d_strategy(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_conv3d(topi.nn.conv3d_ncdhw),
             wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw),
-            name="conv3d_ncdhw.generic")
+            name="conv3d_ncdhw.generic",
+        )
     elif layout == "NDHWC":
         strategy.add_implementation(
             wrap_compute_conv3d(topi.nn.conv3d_ndhwc),
             wrap_topi_schedule(topi.generic.schedule_conv3d_ndhwc),
-            name="conv3d_ndhwc.generic")
+            name="conv3d_ndhwc.generic",
+        )
     else:
         raise ValueError("Not support this layout {} yet".format(layout))
     return strategy
 
+
 # conv3d_winograd_without_weight_transform
 @override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
 def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
     """conv3d_winograd_without_weight_transfrom generic strategy"""
     raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
 
+
 # conv3d_winograd_weight_transform
 @generic_func
 def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
@@ -451,21 +514,23 @@ def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
     with target:
         return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
 
+
 # conv1d
 def wrap_compute_conv1d(topi_compute):
     """wrap conv1d topi compute"""
+
     def _compute_conv1d(attrs, inputs, out_type):
         """Compute definition of conv1d"""
         strides = get_const_tuple(attrs.strides)
         padding = get_const_tuple(attrs.padding)
         dilation = get_const_tuple(attrs.dilation)
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
-        return [topi_compute(inputs[0], inputs[1], strides, padding, dilation,
-                             out_dtype)]
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        return [topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)]
+
     return _compute_conv1d
 
+
 @override_native_generic_func("conv1d_strategy")
 def conv1d_strategy(attrs, inputs, out_type, target):
     """conv1d generic strategy"""
@@ -479,29 +544,35 @@ def conv1d_strategy(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_conv1d(topi.nn.conv1d_ncw),
             wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
-            name="conv1d_ncw.generic")
+            name="conv1d_ncw.generic",
+        )
     elif layout == "NWC":
         strategy.add_implementation(
             wrap_compute_conv1d(topi.nn.conv1d_nwc),
             wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
-            name="conv1d_nwc.generic")
+            name="conv1d_nwc.generic",
+        )
     else:
         raise ValueError("Unsupported conv1d layout {}".format(layout))
     return strategy
 
+
 # conv1d_transpose
 def wrap_compute_conv1d_transpose(topi_compute):
     """wrap conv1d_transpose topi compute"""
+
     def _compute_conv1d_tranpsoe(attrs, inputs, out_type):
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         output_padding = get_const_tuple(attrs.output_padding)
         out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
+
     return _compute_conv1d_tranpsoe
 
+
 @override_native_generic_func("conv1d_transpose_strategy")
 def conv1d_transpose_strategy(attrs, inputs, out_type, target):
     """conv1d_transpose generic strategy"""
@@ -513,28 +584,31 @@ def conv1d_transpose_strategy(attrs, inputs, out_type, target):
     assert layout == "NCW", "conv1d_transpose ncw only supported"
     assert dilation == (1,), "conv1d_transpose dilation is not supported"
     assert groups == 1, "conv1d_transpose groups == 1 only supported"
-    strategy.add_implementation(wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw),
-                                wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw),
-                                name="conv1d_transpose_ncw.generic")
+    strategy.add_implementation(
+        wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw),
+        wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw),
+        name="conv1d_transpose_ncw.generic",
+    )
     return strategy
 
 
 # dilation2d
 def wrap_compute_dilation2d(topi_compute, need_data_layout=False):
     """Wrap dilation2d topi compute"""
+
     def _compute_dilation2d(attrs, inputs, out_type):
         padding = get_const_tuple(attrs.padding)
         strides = get_const_tuple(attrs.strides)
         dilations = get_const_tuple(attrs.dilations)
         data_layout = attrs.get_str("data_layout")
         out_dtype = attrs.out_dtype
-        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                     else out_dtype)
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         args = [inputs[0], inputs[1], strides, padding, dilations]
         if need_data_layout:
             args.append(data_layout)
         args.append(out_dtype)
         return [topi_compute(*args)]
+
     return _compute_dilation2d
 
 
@@ -557,13 +631,15 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_dilation2d(topi.image.dilation2d_nchw),
             wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw),
-            name="dilation2d_nchw.generic")
+            name="dilation2d_nchw.generic",
+        )
     elif layout == "NHWC":
         assert kernel_layout == "HWI"
         strategy.add_implementation(
             wrap_compute_dilation2d(topi.image.dilation2d_nhwc),
             wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc),
-            name="dilation2d_nhwc.generic")
+            name="dilation2d_nhwc.generic",
+        )
     else:
         raise RuntimeError("Unsupported dilation2d layout {}".format(layout))
     return strategy
@@ -572,57 +648,75 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
 # dense
 def wrap_compute_dense(topi_compute):
     """wrap dense topi compute"""
+
     def _compute_dense(attrs, inputs, out_type):
         """Compute definition of dense"""
         out_dtype = attrs.out_dtype
         out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
         return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
+
     return _compute_dense
 
+
 @override_native_generic_func("dense_strategy")
 def dense_strategy(attrs, inputs, out_type, target):
     """dense generic strategy"""
     logger.warning("dense is not optimized for this platform.")
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
-                                wrap_topi_schedule(topi.generic.schedule_dense),
-                                name="dense.generic")
+    strategy.add_implementation(
+        wrap_compute_dense(topi.nn.dense),
+        wrap_topi_schedule(topi.generic.schedule_dense),
+        name="dense.generic",
+    )
     return strategy
 
+
 # batch_matmul
 def wrap_compute_batch_matmul(topi_compute):
     """wrap batch_matmul topi compute"""
+
     def _compute_batch_matmul(attrs, inputs, out_type):
         return [topi_compute(inputs[0], inputs[1])]
+
     return _compute_batch_matmul
 
+
 @override_native_generic_func("batch_matmul_strategy")
 def batch_matmul_strategy(attrs, inputs, out_type, target):
     """batch_matmul generic strategy"""
     logger.warning("batch_matmul is not optimized for this platform.")
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_batch_matmul(topi.nn.batch_matmul),
-                                wrap_topi_schedule(topi.generic.schedule_batch_matmul),
-                                name="batch_matmul.generic")
+    strategy.add_implementation(
+        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_topi_schedule(topi.generic.schedule_batch_matmul),
+        name="batch_matmul.generic",
+    )
     return strategy
 
+
 # sparse dense
 def wrap_compute_sparse_dense(topi_compute):
     """wrap sparse dense topi compute"""
+
     def _compute_sparse_dense(attrs, inputs, out_type):
         return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
+
     return _compute_sparse_dense
 
+
 @override_native_generic_func("sparse_dense_strategy")
 def sparse_dense_strategy(attrs, inputs, out_type, target):
     """sparse dense generic strategy"""
     logger.warning("sparse dense is not optimized for this platform.")
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
-                                wrap_topi_schedule(topi.generic.schedule_sparse_dense),
-                                name="sparse_dense.generic")
+    strategy.add_implementation(
+        wrap_compute_sparse_dense(topi.nn.sparse_dense),
+        wrap_topi_schedule(topi.generic.schedule_sparse_dense),
+        name="sparse_dense.generic",
+    )
     return strategy
 
+
 # sparse_transpose
 @generic_func
 def schedule_sparse_transpose(attrs, outs, target):
@@ -630,28 +724,36 @@ def schedule_sparse_transpose(attrs, outs, target):
     with target:
         return topi.generic.schedule_sparse_transpose(outs)
 
+
 # argsort
 def wrap_compute_argsort(topi_compute):
     """Wrap argsort topi compute"""
+
     def _compute_argsort(attrs, inputs, _):
         axis = get_const_int(attrs.axis)
         is_ascend = bool(get_const_int(attrs.is_ascend))
         dtype = attrs.dtype
         return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
+
     return _compute_argsort
 
+
 @override_native_generic_func("argsort_strategy")
 def argsort_strategy(attrs, inputs, out_type, target):
     """argsort generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_argsort(topi.argsort),
-                                wrap_topi_schedule(topi.generic.schedule_argsort),
-                                name="argsort.generic")
+    strategy.add_implementation(
+        wrap_compute_argsort(topi.argsort),
+        wrap_topi_schedule(topi.generic.schedule_argsort),
+        name="argsort.generic",
+    )
     return strategy
 
+
 # topk
 def wrap_compute_topk(topi_compute):
     """Wrap topk compute"""
+
     def _compute_topk(attrs, inputs, out_type):
         if attrs.k is not None:
             k = attrs.k
@@ -664,20 +766,26 @@ def wrap_compute_topk(topi_compute):
         out = topi_compute(inputs[0], k, axis, ret_type, is_ascend, dtype)
         out = out if isinstance(out, list) else [out]
         return out
+
     return _compute_topk
 
+
 @override_native_generic_func("topk_strategy")
 def topk_strategy(attrs, inputs, out_type, target):
     """topk generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_topk(topi.topk),
-                                wrap_topi_schedule(topi.generic.schedule_topk),
-                                name="topk.generic")
+    strategy.add_implementation(
+        wrap_compute_topk(topi.topk),
+        wrap_topi_schedule(topi.generic.schedule_topk),
+        name="topk.generic",
+    )
     return strategy
 
+
 # multibox_prior
 def wrap_compute_multibox_prior(topi_compute):
     """Wrap multibox_prior compute"""
+
     def _compute_multibox_prior(attrs, inputs, _):
         """Compute definition of multibox_prior"""
         sizes = get_float_tuple(attrs.sizes)
@@ -686,29 +794,36 @@ def wrap_compute_multibox_prior(topi_compute):
         offsets = get_float_tuple(attrs.offsets)
         clip = bool(get_const_int(attrs.clip))
         return [topi_compute(inputs[0], sizes, ratios, steps, offsets, clip)]
+
     return _compute_multibox_prior
 
+
 @override_native_generic_func("multibox_prior_strategy")
 def multibox_prior_strategy(attrs, inputs, out_type, target):
     """multibox_prior generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior),
-                                wrap_topi_schedule(topi.generic.schedule_multibox_prior),
-                                name="multibox_prior.generic")
+    strategy.add_implementation(
+        wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior),
+        wrap_topi_schedule(topi.generic.schedule_multibox_prior),
+        name="multibox_prior.generic",
+    )
     return strategy
 
+
 # multibox_transform_loc
 def wrap_compute_multibox_transform_loc(topi_compute):
     """Wrap multibox_transform_loc compute"""
+
     def _compute_multibox_transform_loc(attrs, inputs, _):
         """Compute definition of multibox_detection"""
         clip = bool(get_const_int(attrs.clip))
         threshold = get_const_float(attrs.threshold)
         variances = get_float_tuple(attrs.variances)
-        return topi_compute(
-            inputs[0], inputs[1], inputs[2], clip, threshold, variances)
+        return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
+
     return _compute_multibox_transform_loc
 
+
 @override_native_generic_func("multibox_transform_loc_strategy")
 def multibox_transform_loc_strategy(attrs, inputs, out_type, target):
     """schedule multibox_transform_loc"""
@@ -716,31 +831,40 @@ def multibox_transform_loc_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_multibox_transform_loc(topi.vision.ssd.multibox_transform_loc),
         wrap_topi_schedule(topi.generic.schedule_multibox_transform_loc),
-        name="multibox_transform_loc.generic")
+        name="multibox_transform_loc.generic",
+    )
     return strategy
 
+
 # get_valid_counts
 def wrap_compute_get_valid_counts(topi_compute):
     """wrap get_valid_counts topi compute"""
+
     def _compute_get_valid_counts(attrs, inputs, out_type):
         score_threshold = get_const_float(attrs.score_threshold)
         id_index = get_const_int(attrs.id_index)
         score_index = get_const_int(attrs.score_index)
         return topi_compute(inputs[0], score_threshold, id_index, score_index)
+
     return _compute_get_valid_counts
 
+
 @override_native_generic_func("get_valid_counts_strategy")
 def get_valid_counts_strategy(attrs, inputs, out_type, target):
     """get_valid_counts generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_get_valid_counts(topi.vision.get_valid_counts),
-                                wrap_topi_schedule(topi.generic.schedule_get_valid_counts),
-                                name="get_valid_counts.generic")
+    strategy.add_implementation(
+        wrap_compute_get_valid_counts(topi.vision.get_valid_counts),
+        wrap_topi_schedule(topi.generic.schedule_get_valid_counts),
+        name="get_valid_counts.generic",
+    )
     return strategy
 
+
 # non-maximum suppression
 def wrap_compute_nms(topi_compute):
     """wrap nms topi compute"""
+
     def _compute_nms(attrs, inputs, out_type):
         max_output_size = inputs[3]
         if attrs.max_output_size is not None:
@@ -754,44 +878,84 @@ def wrap_compute_nms(topi_compute):
         id_index = get_const_int(attrs.id_index)
         invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
         if return_indices:
-            return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold,
-                                force_suppress, top_k, coord_start, score_index, id_index,
-                                return_indices, invalid_to_bottom)
-        return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold,
-                             force_suppress, top_k, coord_start, score_index, id_index,
-                             return_indices, invalid_to_bottom)]
+            return topi_compute(
+                inputs[0],
+                inputs[1],
+                inputs[2],
+                max_output_size,
+                iou_threshold,
+                force_suppress,
+                top_k,
+                coord_start,
+                score_index,
+                id_index,
+                return_indices,
+                invalid_to_bottom,
+            )
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                inputs[2],
+                max_output_size,
+                iou_threshold,
+                force_suppress,
+                top_k,
+                coord_start,
+                score_index,
+                id_index,
+                return_indices,
+                invalid_to_bottom,
+            )
+        ]
+
     return _compute_nms
 
+
 @override_native_generic_func("non_max_suppression_strategy")
 def nms_strategy(attrs, inputs, out_type, target):
     """nms generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_nms(topi.vision.non_max_suppression),
-                                wrap_topi_schedule(topi.generic.schedule_nms),
-                                name="nms.generic")
+    strategy.add_implementation(
+        wrap_compute_nms(topi.vision.non_max_suppression),
+        wrap_topi_schedule(topi.generic.schedule_nms),
+        name="nms.generic",
+    )
     return strategy
 
+
 # roi_align
 def wrap_compute_roi_align(topi_compute):
     """wrap roi_align topi compute"""
+
     def _compute_roi_align(attrs, inputs, out_type):
         assert attrs.layout == "NCHW"
         pooled_size = get_const_tuple(attrs.pooled_size)
-        return [topi_compute(inputs[0], inputs[1],
-                             pooled_size=pooled_size,
-                             spatial_scale=attrs.spatial_scale,
-                             sample_ratio=attrs.sample_ratio)]
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                pooled_size=pooled_size,
+                spatial_scale=attrs.spatial_scale,
+                sample_ratio=attrs.sample_ratio,
+            )
+        ]
+
     return _compute_roi_align
 
+
 @override_native_generic_func("roi_align_strategy")
 def roi_align_strategy(attrs, inputs, out_type, target):
     """roi_align generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
-                                wrap_topi_schedule(topi.generic.schedule_roi_align),
-                                name="roi_align.generic")
+    strategy.add_implementation(
+        wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
+        wrap_topi_schedule(topi.generic.schedule_roi_align),
+        name="roi_align.generic",
+    )
     return strategy
 
+
 # roi_pool
 @generic_func
 def schedule_roi_pool(attrs, outs, target):
@@ -799,9 +963,11 @@ def schedule_roi_pool(attrs, outs, target):
     with target:
         return topi.generic.schedule_roi_pool(outs)
 
+
 # proposal
 def wrap_compute_proposal(topi_compute):
     """wrap proposal topi compute"""
+
     def _compute_proposal(attrs, inputs, out_type):
         scales = get_float_tuple(attrs.scales)
         ratios = get_float_tuple(attrs.ratios)
@@ -811,20 +977,37 @@ def wrap_compute_proposal(topi_compute):
         rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
         rpn_min_size = attrs.rpn_min_size
         iou_loss = bool(get_const_int(attrs.iou_loss))
-        return [topi_compute(inputs[0], inputs[1], inputs[2], scales, ratios,
-                             feature_stride, threshold, rpn_pre_nms_top_n,
-                             rpn_post_nms_top_n, rpn_min_size, iou_loss)]
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                inputs[2],
+                scales,
+                ratios,
+                feature_stride,
+                threshold,
+                rpn_pre_nms_top_n,
+                rpn_post_nms_top_n,
+                rpn_min_size,
+                iou_loss,
+            )
+        ]
+
     return _compute_proposal
 
+
 @override_native_generic_func("proposal_strategy")
 def proposal_strategy(attrs, inputs, out_type, target):
     """proposal generic strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_proposal(topi.vision.rcnn.proposal),
-                                wrap_topi_schedule(topi.generic.schedule_proposal),
-                                name="proposal.generic")
+    strategy.add_implementation(
+        wrap_compute_proposal(topi.vision.rcnn.proposal),
+        wrap_topi_schedule(topi.generic.schedule_proposal),
+        name="proposal.generic",
+    )
     return strategy
 
+
 # argwhere
 @generic_func
 def schedule_argwhere(attrs, outs, target):
@@ -832,6 +1015,7 @@ def schedule_argwhere(attrs, outs, target):
     with target:
         return topi.generic.schedule_argwhere(outs)
 
+
 # scatter
 @generic_func
 def schedule_scatter(attrs, outs, target):
@@ -839,6 +1023,7 @@ def schedule_scatter(attrs, outs, target):
     with target:
         return topi.generic.schedule_scatter(outs)
 
+
 # scatter_add
 @generic_func
 def schedule_scatter_add(attrs, outs, target):
@@ -846,9 +1031,11 @@ def schedule_scatter_add(attrs, outs, target):
     with target:
         return topi.generic.schedule_scatter_add(outs)
 
+
 # bitserial_conv2d
 def wrap_compute_bitserial_conv2d(topi_compute):
     """wrap bitserial_conv2d topi compute"""
+
     def compute_bitserial_conv2d(attrs, inputs, out_dtype):
         """Compute definition for bitserial conv2d."""
         padding = get_const_tuple(attrs.padding)
@@ -858,10 +1045,23 @@ def wrap_compute_bitserial_conv2d(topi_compute):
         pack_dtype = attrs.pack_dtype
         out_dtype = attrs.out_dtype
         unipolar = attrs.unipolar
-        return [topi_compute(inputs[0], inputs[1], strides, padding, activation_bits,
-                             weight_bits, pack_dtype, out_dtype, unipolar)]
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                strides,
+                padding,
+                activation_bits,
+                weight_bits,
+                pack_dtype,
+                out_dtype,
+                unipolar,
+            )
+        ]
+
     return compute_bitserial_conv2d
 
+
 @override_native_generic_func("bitserial_conv2d_strategy")
 def bitserial_conv2d_strategy(attrs, inputs, out_type, target):
     """bitserial_conv2d generic strategy"""
@@ -872,19 +1072,23 @@ def bitserial_conv2d_strategy(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
             wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nchw),
-            name="bitserial_conv2d_nchw.generic")
+            name="bitserial_conv2d_nchw.generic",
+        )
     elif layout == "NHWC":
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
             wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nhwc),
-            name="bitserial_conv2d_nhwc.generic")
+            name="bitserial_conv2d_nhwc.generic",
+        )
     else:
         raise ValueError("Data layout {} not supported.".format(layout))
     return strategy
 
+
 # bitserial_dense
 def wrap_compute_bitserial_dense(topi_compute):
     """wrap bitserial_dense topi compute"""
+
     def compute_bitserial_dense(attrs, inputs, out_type):
         """Compute definition of bitserial dense"""
         data_bits = attrs.data_bits
@@ -893,10 +1097,15 @@ def wrap_compute_bitserial_dense(topi_compute):
         out_dtype = attrs.out_dtype
         out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
         unipolar = attrs.unipolar
-        return [topi_compute(inputs[0], inputs[1], data_bits, weight_bits,
-                             pack_dtype, out_dtype, unipolar)]
+        return [
+            topi_compute(
+                inputs[0], inputs[1], data_bits, weight_bits, pack_dtype, out_dtype, unipolar
+            )
+        ]
+
     return compute_bitserial_dense
 
+
 @override_native_generic_func("bitserial_dense_strategy")
 def bitserial_dense_strategy(attrs, inputs, out_type, target):
     """bitserial_dense generic strategy"""
@@ -905,12 +1114,15 @@ def bitserial_dense_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_bitserial_dense(topi.nn.bitserial_dense),
         wrap_topi_schedule(topi.generic.schedule_bitserial_dense),
-        name="bitserial_dense.generic")
+        name="bitserial_dense.generic",
+    )
     return strategy
 
+
 # correlation
 def wrap_compute_correlation(topi_compute):
     """wrap correlation topi compute"""
+
     def _compute_correlation(attrs, inputs, out_type):
         kernel_size = attrs.kernel_size
         max_displacement = attrs.max_displacement
@@ -918,10 +1130,22 @@ def wrap_compute_correlation(topi_compute):
         stride2 = attrs.stride2
         padding = get_const_tuple(attrs.padding)
         is_multiply = attrs.is_multiply
-        return [topi_compute(inputs[0], inputs[1], kernel_size, max_displacement, stride1, stride2,
-                             padding, is_multiply)]
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                kernel_size,
+                max_displacement,
+                stride1,
+                stride2,
+                padding,
+                is_multiply,
+            )
+        ]
+
     return _compute_correlation
 
+
 @override_native_generic_func("correlation_strategy")
 def correlation_strategy(attrs, inputs, out_type, target):
     """correlation generic strategy"""
@@ -932,5 +1156,6 @@ def correlation_strategy(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_correlation(topi.nn.correlation_nchw),
         wrap_topi_schedule(topi.generic.schedule_correlation_nchw),
-        name="correlation.generic")
+        name="correlation.generic",
+    )
     return strategy
index 90495fb..761ac9e 100644 (file)
@@ -20,36 +20,42 @@ from tvm import topi
 from .generic import *
 from .. import op as _op
 
+
 @schedule_injective.register("hls")
 def schedule_injective_hls(attrs, outs, target):
     """schedule injective ops for hls"""
     with target:
         return topi.hls.schedule_injective(outs)
 
+
 @schedule_reduce.register("hls")
 def schedule_reduce_hls(attrs, outs, target):
     """schedule reduction ops for hls"""
     with target:
         return topi.hls.schedule_reduce(outs)
 
+
 @schedule_concatenate.register("hls")
 def schedule_concatenate_hls(attrs, outs, target):
     """schedule concatenate for hls"""
     with target:
         return topi.hls.schedule_injective(outs)
 
+
 @schedule_pool.register("hls")
 def schedule_pool_hls(attrs, outs, target):
     """schedule pooling ops for hls"""
     with target:
         return topi.hls.schedule_pool(outs, attrs.layout)
 
+
 @schedule_adaptive_pool.register("hls")
 def schedule_adaptive_pool_hls(attrs, outs, target):
     """schedule adaptive pooling ops for hls"""
     with target:
         return topi.hls.schedule_adaptive_pool(outs)
 
+
 @softmax_strategy.register("hls")
 def softmax_strategy_hls(attrs, inputs, out_type, target):
     """softmax hls strategy"""
@@ -57,15 +63,18 @@ def softmax_strategy_hls(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.hls.schedule_softmax),
-        name="softmax.hls")
+        name="softmax.hls",
+    )
     return strategy
 
+
 @schedule_log_softmax.register("hls")
 def schedule_log_softmax_hls(attrs, inputs, out_type, target):
     """schedule log_softmax for hls"""
     with target:
         return topi.hls.schedule_softmax(outs)
 
+
 @override_native_generic_func("conv2d_strategy")
 def conv2d_strategy_hls(attrs, inputs, out_type, target):
     """conv2d hls strategy"""
@@ -85,13 +94,15 @@ def conv2d_strategy_hls(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_nchw),
                 wrap_topi_schedule(topi.hls.schedule_conv2d_nchw),
-                name="conv2d_nchw.hls")
+                name="conv2d_nchw.hls",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_nhwc),
                 wrap_topi_schedule(topi.hls.schedule_conv2d_nhwc),
-                name="conv2d_nhwc.hls")
+                name="conv2d_nhwc.hls",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {}".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
@@ -100,19 +111,22 @@ def conv2d_strategy_hls(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.hls")
+                name="depthwise_conv2d_nchw.hls",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_nhwc.hls")
+                name="depthwise_nhwc.hls",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
+    else:  # group_conv2d
         raise RuntimeError("group_conv2d is not supported for hls")
     return strategy
 
+
 @override_native_generic_func("conv2d_NCHWc_strategy")
 def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
     """conv2d_NCHWc hls strategy"""
@@ -120,9 +134,11 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
-        name="conv2d_NCHWc.hls")
+        name="conv2d_NCHWc.hls",
+    )
     return strategy
 
+
 @conv2d_transpose_strategy.register("hls")
 def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
     """conv2d_transpose hls strategy"""
@@ -136,18 +152,23 @@ def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
         wrap_topi_schedule(topi.hls.schedule_conv2d_transpose_nchw),
-        name="conv2d_transpose_nchw.hls")
+        name="conv2d_transpose_nchw.hls",
+    )
     return strategy
 
+
 @dense_strategy.register("hls")
 def dense_strategy_hls(attrs, inputs, out_type, target):
     """dense hls strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
-                                wrap_topi_schedule(topi.hls.schedule_dense),
-                                name="dense.hls")
+    strategy.add_implementation(
+        wrap_compute_dense(topi.nn.dense),
+        wrap_topi_schedule(topi.hls.schedule_dense),
+        name="dense.hls",
+    )
     return strategy
 
+
 @bitserial_conv2d_strategy.register("hls")
 def bitserial_conv2d_strategy_hls(attrs, inputs, out_type, target):
     """bitserial_conv2d hls strategy"""
@@ -157,12 +178,14 @@ def bitserial_conv2d_strategy_hls(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
             wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nchw),
-            name="bitserial_conv2d_nchw.hls")
+            name="bitserial_conv2d_nchw.hls",
+        )
     elif layout == "NHWC":
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
             wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nhwc),
-            name="bitserial_conv2d_nhwc.hls")
+            name="bitserial_conv2d_nhwc.hls",
+        )
     else:
         raise ValueError("Data layout {} not supported.".format(layout))
     return strategy
index 568cbff..a2de49c 100644 (file)
@@ -39,30 +39,33 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.intel_graphics.conv2d_nchw),
                 wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_nchw),
-                name="conv2d_nchw.intel_graphics")
+                name="conv2d_nchw.intel_graphics",
+            )
             # conv2d_NCHWc won't work without alter op layout pass
             # TODO(@Laurawly): fix this
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
                 wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
                 name="conv2d_NCHWc.intel_graphics",
-                plevel=5)
+                plevel=5,
+            )
         else:
-            raise RuntimeError("Unsupported conv2d layout {} for intel graphics".
-                               format(layout))
+            raise RuntimeError("Unsupported conv2d layout {} for intel graphics".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.intel_graphics.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.intel_graphics.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.intel_graphics")
+                name="depthwise_conv2d_nchw.intel_graphics",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
+    else:  # group_conv2d
         raise RuntimeError("group_conv2d is not supported for intel graphics")
     return strategy
 
+
 @conv2d_NCHWc_strategy.register("intel_graphics")
 def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
     """conv2d_NCHWc intel_graphics strategy"""
@@ -70,5 +73,6 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
-        name="conv2d_NCHWc.intel_graphics")
+        name="conv2d_NCHWc.intel_graphics",
+    )
     return strategy
index 84af203..f6ea911 100644 (file)
@@ -21,6 +21,7 @@ from tvm import topi
 from .generic import *
 from .. import op as _op
 
+
 @conv2d_strategy.register("mali")
 def conv2d_strategy_mali(attrs, inputs, out_type, target):
     """conv2d mali strategy"""
@@ -40,24 +41,34 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.mali")
+                    name="conv2d_nchw_spatial_pack.mali",
+                )
                 # check if winograd algorithm is applicable
                 _, _, kh, kw = get_const_tuple(kernel.shape)
-                if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
-                   dilation_h == 1 and dilation_w == 1:
+                if (
+                    kh == 3
+                    and kw == 3
+                    and stride_h == 1
+                    and stride_w == 1
+                    and dilation_h == 1
+                    and dilation_w == 1
+                ):
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
                         wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
                         name="conv2d_nchw_winograd.mali",
-                        plevel=5)
+                        plevel=5,
+                    )
             elif re.match(r"OIHW\d*o", kernel_layout):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
                     wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
-                    name="conv2d_nchw_spatial_pack.mali")
+                    name="conv2d_nchw_spatial_pack.mali",
+                )
             else:
-                raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
-                                   format(kernel_layout))
+                raise RuntimeError(
+                    "Unsupported weight layout {} for conv2d NCHW".format(kernel_layout)
+                )
         else:
             raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
@@ -66,13 +77,15 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.mali.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.mali")
+                name="depthwise_conv2d_nchw.mali",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
-    else: # group_conv2d
+    else:  # group_conv2d
         raise RuntimeError("group_conv2d is not supported for mali")
     return strategy
 
+
 @conv2d_winograd_without_weight_transfrom_strategy.register("mali")
 def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target):
     """conv2d_winograd_without_weight_transfrom mali strategy"""
@@ -90,17 +103,22 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty
         strategy.add_implementation(
             wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
             wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
-            name="conv2d_nchw_winograd.mali")
+            name="conv2d_nchw_winograd.mali",
+        )
     else:
-        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
-                           format(layout))
+        raise RuntimeError(
+            "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
+        )
     return strategy
 
+
 @dense_strategy.register("mali")
 def dense_strategy_mali(attrs, inputs, out_type, target):
     """dense mali strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_dense(topi.mali.dense),
-                                wrap_topi_schedule(topi.mali.schedule_dense),
-                                name="dense.mali")
+    strategy.add_implementation(
+        wrap_compute_dense(topi.mali.dense),
+        wrap_topi_schedule(topi.mali.schedule_dense),
+        name="dense.mali",
+    )
     return strategy
index 01cf621..2410260 100644 (file)
@@ -20,12 +20,14 @@ from tvm import topi
 from .generic import *
 from .. import op as _op
 
+
 @schedule_lrn.register("rocm")
 def schedule_lrn_rocm(attrs, outs, target):
     """schedule LRN for rocm"""
     with target:
         return topi.rocm.schedule_lrn(outs)
 
+
 @conv2d_strategy.register("rocm")
 def conv2d_strategy_rocm(attrs, inputs, out_type, target):
     """conv2d rocm strategy"""
@@ -47,21 +49,31 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_nchw),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
-                name="conv2d_nchw.cuda")
+                name="conv2d_nchw.cuda",
+            )
             _, _, kh, kw = get_const_tuple(kernel.shape)
-            if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
-                dilation_h == 1 and dilation_w == 1:
+            if (
+                2 < kh < 8
+                and 2 < kw < 8
+                and kh == kw
+                and stride_h == 1
+                and stride_w == 1
+                and dilation_h == 1
+                and dilation_w == 1
+            ):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
                     name="conv2d_nchw_winograd.cuda",
-                    plevel=5)
+                    plevel=5,
+                )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
-                name="conv2d_hwcn.cuda")
+                name="conv2d_hwcn.cuda",
+            )
         # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
         # elif layout == "NHWC":
         #     assert kernel_layout == "HWIO"
@@ -74,50 +86,61 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
-                name="conv2d_NCHWc_int8.cuda")
+                name="conv2d_NCHWc_int8.cuda",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
         # add miopen implementation
-        if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \
-            padding[1] == padding[3]:
+        if (
+            "miopen" in target.libs
+            and layout == "NCHW"
+            and padding[0] == padding[2]
+            and padding[1] == padding[3]
+        ):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
                 wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
                 name="conv2d_nchw_miopen.rocm",
-                plevel=15)
+                plevel=15,
+            )
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
-                name="depthwise_conv2d_nchw.cuda")
+                name="depthwise_conv2d_nchw.cuda",
+            )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.cuda")
+                name="depthwise_conv2d_nhwc.cuda",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
-        if layout == 'NCHW':
+    else:  # group_conv2d
+        if layout == "NCHW":
             # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
             assert kernel_layout == "OIHW"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
                 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
-                name="group_conv2d_nchw.cuda")
-        elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
+                name="group_conv2d_nchw.cuda",
+            )
+        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
                 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
-                name="group_conv2d_NCHWc_int8.cuda")
+                name="group_conv2d_NCHWc_int8.cuda",
+            )
         else:
             raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
     return strategy
 
+
 @dense_strategy.register("rocm")
 def dense_strategy_rocm(attrs, inputs, out_type, target):
     """Dense strategy for ROCM"""
@@ -126,12 +149,14 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_dense(topi.rocm.dense),
         wrap_topi_schedule(topi.rocm.schedule_dense),
-        name="dense.rocm")
+        name="dense.rocm",
+    )
     if target.kind.name == "rocm" and "rocblas" in target.libs:
         assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
         strategy.add_implementation(
             wrap_compute_dense(topi.rocm.dense_rocblas),
             wrap_topi_schedule(topi.rocm.schedule_dense_rocblas),
             name="dense_rocblas.rocm",
-            plevel=15)
+            plevel=15,
+        )
     return strategy
index d30b6a4..4d7d7e8 100644 (file)
@@ -24,41 +24,47 @@ from tvm.te import SpecializedCondition
 from .generic import *
 from .. import op as _op
 
-logger = logging.getLogger('strategy')
+logger = logging.getLogger("strategy")
 
 _NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
 _OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")
 
+
 @schedule_injective.register("cpu")
 def schedule_injective_cpu(attrs, outs, target):
     """schedule injective ops for x86"""
     with target:
         return topi.x86.schedule_injective(outs)
 
+
 @schedule_reduce.register("cpu")
 def schedule_reduce_cpu(attrs, outs, target):
     """schedule reduction ops for x86"""
     with target:
         return topi.x86.schedule_reduce(outs)
 
+
 @schedule_concatenate.register("cpu")
 def schedule_concatenate_cpu(attrs, outs, target):
     """schedule concatenate op for x86"""
     with target:
         return topi.x86.schedule_concatenate(outs)
 
+
 @schedule_pool.register("cpu")
 def schedule_pool_cpu(attrs, outs, target):
     """schedule pooling ops for x86"""
     with target:
         return topi.x86.schedule_pool(outs, attrs.layout)
 
+
 @schedule_adaptive_pool.register("cpu")
 def schedule_adaptive_pool_cpu(attrs, outs, target):
     """schedule adaptive pooling ops for x86"""
     with target:
         return topi.x86.schedule_adaptive_pool(outs)
 
+
 @softmax_strategy.register("cpu")
 def softmax_strategy_cpu(attrs, inputs, out_type, target):
     """softmax x86 strategy"""
@@ -66,15 +72,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.x86.schedule_softmax),
-        name="softmax.x86")
+        name="softmax.x86",
+    )
     return strategy
 
+
 @schedule_log_softmax.register("cpu")
 def schedule_log_softmax_cpu(attrs, outs, target):
     """schedule log_softmax op for x86"""
     with target:
         return topi.x86.schedule_softmax(outs)
 
+
 @conv2d_strategy.register("cpu")
 def conv2d_strategy_cpu(attrs, inputs, out_type, target):
     """conv2d x86 strategy"""
@@ -94,14 +103,16 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.x86.conv2d_nchw_int8),
                     wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_int8),
-                    name="conv2d_nchw_int8.x86")
+                    name="conv2d_nchw_int8.x86",
+                )
             else:
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.x86.conv2d_nchw),
                     wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
-                    name="conv2d_nchw.x86")
-        elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
-            assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
+                    name="conv2d_nchw.x86",
+                )
+        elif _NCHWc_matcher.match(layout):  # check if layout is NCHWxc
+            assert _OIHWio_matcher.match(kernel_layout)  # check if kernel is OIHWio
             return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
@@ -109,14 +120,16 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_nhwc),
                 wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
-                name="conv2d_nhwc.x86")
+                name="conv2d_nhwc.x86",
+            )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             logger.warning("conv2d HWCN layout is not optimized for x86.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_hwcn),
                 wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
-                name="conv2d_hwcn.generic")
+                name="conv2d_hwcn.generic",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
@@ -127,16 +140,20 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
                     wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw),
-                    name="depthwise_conv2d_nchw.x86")
+                    name="depthwise_conv2d_nchw.x86",
+                )
             else:
-                logger.warning("For x86 target, depthwise_conv2d with channel "
-                               "multiplier greater than 1 is not optimized")
+                logger.warning(
+                    "For x86 target, depthwise_conv2d with channel "
+                    "multiplier greater than 1 is not optimized"
+                )
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
                     wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
-                    name="depthwise_conv2d_nchw.generic")
-        elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
-            assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
+                    name="depthwise_conv2d_nchw.generic",
+                )
+        elif _NCHWc_matcher.match(layout):  # check if layout is NCHWxc
+            assert _OIHWio_matcher.match(kernel_layout)  # check if kernel is OIHWio
             return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
@@ -144,21 +161,24 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                 wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.generic")
+                name="depthwise_conv2d_nhwc.generic",
+            )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
-    else: # group_conv2d
-        if layout == 'NCHW':
+    else:  # group_conv2d
+        if layout == "NCHW":
             assert kernel_layout == "OIHW"
             logger.warning("group_conv2d is not optimized for x86.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
                 wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
-                name="group_conv2d_nchw.generic")
+                name="group_conv2d_nchw.generic",
+            )
         else:
             raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
     return strategy
 
+
 @conv2d_NCHWc_strategy.register("cpu")
 def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
     """conv2d_NCHWc x86 strategy"""
@@ -168,14 +188,17 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
             wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
-            name="conv2d_NCHWc_int8.x86")
+            name="conv2d_NCHWc_int8.x86",
+        )
     else:
         strategy.add_implementation(
             wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
             wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
-            name="conv2d_NCHWc.x86")
+            name="conv2d_NCHWc.x86",
+        )
     return strategy
 
+
 @depthwise_conv2d_NCHWc_strategy.register("cpu")
 def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
     """depthwise_conv2d x86 strategy"""
@@ -183,9 +206,11 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
         wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
-        name="depthwise_conv2d_NCHWc.x86")
+        name="depthwise_conv2d_NCHWc.x86",
+    )
     return strategy
 
+
 @conv2d_transpose_strategy.register("cpu")
 def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
     """conv2d_transpose x86 strategy"""
@@ -199,7 +224,8 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
         wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
-        name="conv2d_transpose_nchw.x86")
+        name="conv2d_transpose_nchw.x86",
+    )
     return strategy
 
 
@@ -216,7 +242,8 @@ def conv3d_transpose_strategy_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_conv3d_transpose(topi.x86.conv3d_transpose_ncdhw),
         wrap_topi_schedule(topi.x86.schedule_conv3d_transpose_ncdhw),
-        name="conv3d_transpose_ncdhw.x86")
+        name="conv3d_transpose_ncdhw.x86",
+    )
     return strategy
 
 
@@ -226,17 +253,22 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
     strategy = _op.OpStrategy()
     layout = attrs.data_layout
     if layout == "NCDHW":
-        strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ncdhw),
-                                    wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw),
-                                    name="conv3d_ncdhw.x86")
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.x86.conv3d_ncdhw),
+            wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw),
+            name="conv3d_ncdhw.x86",
+        )
     elif layout == "NDHWC":
-        strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc),
-                                    wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc),
-                                    name="conv3d_ndhwc.x86")
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.x86.conv3d_ndhwc),
+            wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc),
+            name="conv3d_ndhwc.x86",
+        )
     else:
         raise ValueError("Not support this layout {} yet".format(layout))
     return strategy
 
+
 @conv1d_strategy.register("cpu")
 def conv1d_strategy_cpu(attrs, inputs, out_type, target):
     """conv1d x86 strategy"""
@@ -246,17 +278,22 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
         raise ValueError("dilation should be a positive value")
     strategy = _op.OpStrategy()
     if layout == "NCW":
-        strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_ncw),
-                                    wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
-                                    name="conv1d_ncw.x86")
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.nn.conv1d_ncw),
+            wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
+            name="conv1d_ncw.x86",
+        )
     elif layout == "NWC":
-        strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc),
-                                    wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
-                                    name="conv1d_nwc.x86")
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.nn.conv1d_nwc),
+            wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
+            name="conv1d_nwc.x86",
+        )
     else:
         raise ValueError("Unsupported conv1d layout {}".format(layout))
     return strategy
 
+
 @dense_strategy.register("cpu")
 def dense_strategy_cpu(attrs, inputs, out_type, target):
     """dense x86 strategy"""
@@ -265,10 +302,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
     same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
     dtype = inputs[0].dtype
     u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32"
-    strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack),
-                                wrap_topi_schedule(topi.x86.schedule_dense_nopack),
-                                name="dense_nopack.x86",
-                                plevel=10)
+    strategy.add_implementation(
+        wrap_compute_dense(topi.x86.dense_nopack),
+        wrap_topi_schedule(topi.x86.schedule_dense_nopack),
+        name="dense_nopack.x86",
+        plevel=10,
+    )
     if "cblas" in target.libs:
         with SpecializedCondition(same_type and dtype in ["float32", "float64"]):
             strategy.add_implementation(
@@ -295,35 +334,45 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
             )
     with SpecializedCondition(m >= 16):
         # this implementation may not be well-optimized, so use plevel=8 for now.
-        strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack),
-                                    wrap_topi_schedule(topi.x86.schedule_dense_pack),
-                                    name="dense_pack.x86",
-                                    plevel=5)
+        strategy.add_implementation(
+            wrap_compute_dense(topi.x86.dense_pack),
+            wrap_topi_schedule(topi.x86.schedule_dense_pack),
+            name="dense_pack.x86",
+            plevel=5,
+        )
     return strategy
 
+
 @batch_matmul_strategy.register("cpu")
 def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
     """batch_matmul x86 strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul),
-                                wrap_topi_schedule(topi.x86.schedule_batch_matmul),
-                                name="batch_matmul.x86",
-                                plevel=10)
+    strategy.add_implementation(
+        wrap_compute_batch_matmul(topi.x86.batch_matmul),
+        wrap_topi_schedule(topi.x86.schedule_batch_matmul),
+        name="batch_matmul.x86",
+        plevel=10,
+    )
     if "cblas" in target.libs:
-        strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
-                                    wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas),
-                                    name="batch_matmul_cblas.x86",
-                                    plevel=15)
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
+            wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas),
+            name="batch_matmul_cblas.x86",
+            plevel=15,
+        )
     return strategy
 
+
 @sparse_dense_strategy.register("cpu")
 def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
     """sparse dense x86 strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
-                                wrap_topi_schedule(topi.x86.schedule_sparse_dense),
-                                name="sparse_dense.x86",
-                                plevel=10)
+    strategy.add_implementation(
+        wrap_compute_sparse_dense(topi.nn.sparse_dense),
+        wrap_topi_schedule(topi.x86.schedule_sparse_dense),
+        name="sparse_dense.x86",
+        plevel=10,
+    )
     return strategy
 
 
@@ -331,11 +380,14 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
 def roi_align_strategy_cpu(attrs, inputs, out_type, target):
     """roi_align x86 strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(wrap_compute_roi_align(topi.x86.roi_align_nchw),
-                                wrap_topi_schedule(topi.generic.schedule_roi_align),
-                                name="roi_align.x86")
+    strategy.add_implementation(
+        wrap_compute_roi_align(topi.x86.roi_align_nchw),
+        wrap_topi_schedule(topi.generic.schedule_roi_align),
+        name="roi_align.x86",
+    )
     return strategy
 
+
 @bitserial_conv2d_strategy.register("cpu")
 def bitserial_conv2d_strategy_cpu(attrs, inputs, out_type, target):
     """bitserial_conv2d x86 strategy"""
@@ -345,16 +397,19 @@ def bitserial_conv2d_strategy_cpu(attrs, inputs, out_type, target):
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nchw),
             wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nchw),
-            name="bitserial_conv2d_nchw.x86")
+            name="bitserial_conv2d_nchw.x86",
+        )
     elif layout == "NHWC":
         strategy.add_implementation(
             wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nhwc),
             wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nhwc),
-            name="bitserial_conv2d_nhwc.x86")
+            name="bitserial_conv2d_nhwc.x86",
+        )
     else:
         raise ValueError("Data layout {} not supported.".format(layout))
     return strategy
 
+
 @bitserial_dense_strategy.register("cpu")
 def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
     """bitserial_dense x86 strategy"""
@@ -362,5 +417,6 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
     strategy.add_implementation(
         wrap_compute_bitserial_dense(topi.x86.bitserial_dense),
         wrap_topi_schedule(topi.x86.schedule_bitserial_dense),
-        name="bitserial_dense.x86")
+        name="bitserial_dense.x86",
+    )
     return strategy
index c002c8b..832372a 100644 (file)
@@ -33,6 +33,7 @@ from ..expr import Tuple, Expr
 # - Not put too much burden on FFI to support complicated features
 #   like default value and keyword arguments
 
+
 def log(data):
     """Compute elementwise log of data.
 
@@ -48,6 +49,7 @@ def log(data):
     """
     return _make.log(data)
 
+
 def log2(data):
     """Compute elementwise log to the base 2 of data.
 
@@ -63,6 +65,7 @@ def log2(data):
     """
     return _make.log2(data)
 
+
 def log10(data):
     """Compute elementwise log to the base 10 of data.
 
@@ -78,6 +81,7 @@ def log10(data):
     """
     return _make.log10(data)
 
+
 def tan(data):
     """Compute elementwise tan of data.
 
@@ -93,6 +97,7 @@ def tan(data):
     """
     return _make.tan(data)
 
+
 def cos(data):
     """Compute elementwise cos of data.
 
@@ -108,6 +113,7 @@ def cos(data):
     """
     return _make.cos(data)
 
+
 def cosh(data):
     """Compute elementwise cosh of data.
 
@@ -123,6 +129,7 @@ def cosh(data):
     """
     return _make.cosh(data)
 
+
 def sin(data):
     """Compute elementwise sin of data.
 
@@ -138,6 +145,7 @@ def sin(data):
     """
     return _make.sin(data)
 
+
 def sinh(data):
     """Compute elementwise sinh of data.
 
@@ -153,6 +161,7 @@ def sinh(data):
     """
     return _make.sinh(data)
 
+
 def acos(data):
     """Compute elementwise acos of data.
 
@@ -168,6 +177,7 @@ def acos(data):
     """
     return _make.acos(data)
 
+
 def acosh(data):
     """Compute elementwise acosh of data.
 
@@ -183,6 +193,7 @@ def acosh(data):
     """
     return _make.acosh(data)
 
+
 def asin(data):
     """Compute elementwise asin of data.
 
@@ -198,6 +209,7 @@ def asin(data):
     """
     return _make.asin(data)
 
+
 def asinh(data):
     """Compute elementwise asinh of data.
 
@@ -213,6 +225,7 @@ def asinh(data):
     """
     return _make.asinh(data)
 
+
 def atan(data):
     """Compute elementwise atan of data.
 
@@ -228,6 +241,7 @@ def atan(data):
     """
     return _make.atan(data)
 
+
 def atanh(data):
     """Compute elementwise atanh of data.
 
@@ -243,6 +257,7 @@ def atanh(data):
     """
     return _make.atanh(data)
 
+
 def exp(data):
     """Compute elementwise exp of data.
 
@@ -406,6 +421,7 @@ def abs(data):
     """
     return _make.abs(data)
 
+
 def sign(data):
     """Compute element-wise absolute of data.
 
@@ -421,6 +437,7 @@ def sign(data):
     """
     return _make.sign(data)
 
+
 def tanh(data):
     """Compute element-wise tanh of data.
 
@@ -690,6 +707,7 @@ def logical_xor(lhs, rhs):
     """
     return _make.logical_xor(lhs, rhs)
 
+
 def bitwise_and(lhs, rhs):
     """bitwise AND with numpy-style broadcasting.
 
@@ -1034,6 +1052,7 @@ def clip(a, a_min, a_max):
     """
     return _make.clip(a, a_min, a_max)
 
+
 def fixed_point_multiply(data, multiplier, shift):
     """Fixed point multiplication between data and a fixed point
     constant expressed as multiplier * 2^(-shift), where multiplier
@@ -1145,16 +1164,20 @@ def device_copy(data, src_dev, dst_dev):
     elif isinstance(src_dev, str):
         src_dev = _nd.context(src_dev).device_type
     else:
-        raise ValueError("src_dev is expected to be the type of TVMContext or "
-                         "str, but received %s" % (type(src_dev)))
+        raise ValueError(
+            "src_dev is expected to be the type of TVMContext or "
+            "str, but received %s" % (type(src_dev))
+        )
 
     if isinstance(dst_dev, _TVMContext):
         dst_dev = dst_dev.device_type
     elif isinstance(dst_dev, str):
         dst_dev = _nd.context(dst_dev).device_type
     else:
-        raise ValueError("dst_dev is expected to be the type of TVMContext or "
-                         "str, but received %s" % (type(dst_dev)))
+        raise ValueError(
+            "dst_dev is expected to be the type of TVMContext or "
+            "str, but received %s" % (type(dst_dev))
+        )
     return _make.device_copy(data, src_dev, dst_dev)
 
 
index 0ce59ad..8ccd148 100644 (file)
@@ -42,6 +42,7 @@ def cast(data, dtype):
         The casted result.
     """
     from .. import _ffi_api as _relay_make
+
     return _relay_make.cast(data, dtype)
 
 
@@ -59,6 +60,7 @@ def cast_like(data, dtype_like):
         The casted result.
     """
     from .. import _ffi_api as _relay_make
+
     return _relay_make.cast_like(data, dtype_like)
 
 
@@ -79,6 +81,7 @@ def reinterpret(data, dtype):
         The reinterpreted result.
     """
     from .. import _make as _relay_make
+
     return _relay_make.reinterpret(data, dtype)
 
 
@@ -226,7 +229,7 @@ def reshape(data, newshape):
                 try:
                     tempshape.append(int(shape))
                 except ValueError as err:
-                    raise RuntimeError('Unrecognized shape type: %s' % err)
+                    raise RuntimeError("Unrecognized shape type: %s" % err)
         newshape = tempshape
     return _make.reshape(data, list(newshape))
 
@@ -829,15 +832,16 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         The computed result.
     """
     strides = strides or [1]
-    if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)):
+    if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr):
         if isinstance(begin, (tuple, list)):
             begin = const(list(begin))
         if isinstance(end, (tuple, list)):
             end = const(list(end))
         if isinstance(strides, (tuple, list)):
             strides = const(list(strides))
-        normalized_begin = _make.where(begin < cast_like(const(0), begin),
-                                       begin + cast_like(shape_of(data), begin), begin)
+        normalized_begin = _make.where(
+            begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
+        )
         return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
     return _make.strided_slice(data, begin, end, strides, slice_mode)
 
index 6f5097d..d20cb97 100644 (file)
@@ -31,9 +31,15 @@ reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
 def compute_roi_pool(attrs, inputs, _):
     """Compute definition of roi_pool"""
     assert attrs.layout == "NCHW"
-    return [topi.vision.rcnn.roi_pool_nchw(
-        inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
-        spatial_scale=attrs.spatial_scale)]
+    return [
+        topi.vision.rcnn.roi_pool_nchw(
+            inputs[0],
+            inputs[1],
+            pooled_size=get_const_tuple(attrs.pooled_size),
+            spatial_scale=attrs.spatial_scale,
+        )
+    ]
+
 
 reg.register_schedule("vision.roi_pool", strategy.schedule_roi_pool)
 reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE)
index c94cb5a..85bd8a2 100644 (file)
@@ -43,6 +43,7 @@ reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
 reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
 reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
 
+
 @script
 def _get_valid_counts_shape_func(data_shape):
     valid_counts_shape = output_tensor((1,), "int64")
@@ -57,10 +58,12 @@ def _get_valid_counts_shape_func(data_shape):
 
     return valid_counts_shape, out_tensor_shape, out_indices_shape
 
+
 @reg.register_shape_func("vision.get_valid_counts", False)
 def get_valid_counts_shape_func(attrs, inputs, _):
     return _get_valid_counts_shape_func(inputs[0])
 
+
 @script
 def _nms_shape_func(data_shape):
     out_shape = output_tensor((2,), "int64")
@@ -72,6 +75,7 @@ def _nms_shape_func(data_shape):
     count_shape[1] = int64(1)
     return out_shape, count_shape
 
+
 @reg.register_shape_func("vision.non_max_suppression", False)
 def nms_shape_func(attrs, inputs, _):
     if attrs.return_indices:
index c58a7a3..3c43cb2 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 from ..op import register_pattern, OpPattern
index 75daff9..c97a8c7 100644 (file)
@@ -19,12 +19,9 @@ from tvm.relay import expr
 from . import _make
 
 
-def multibox_prior(data,
-                   sizes=(1.0,),
-                   ratios=(1.0,),
-                   steps=(-1.0, -1.0),
-                   offsets=(0.5, 0.5),
-                   clip=False):
+def multibox_prior(
+    data, sizes=(1.0,), ratios=(1.0,), steps=(-1.0, -1.0), offsets=(0.5, 0.5), clip=False
+):
     """Generate prior(anchor) boxes from data, sizes and ratios.
 
     Parameters
@@ -55,12 +52,9 @@ def multibox_prior(data,
     return _make.multibox_prior(data, sizes, ratios, steps, offsets, clip)
 
 
-def multibox_transform_loc(cls_prob,
-                           loc_pred,
-                           anchor,
-                           clip=True,
-                           threshold=0.01,
-                           variances=(0.1, 0.1, 0.2, 0.2)):
+def multibox_transform_loc(
+    cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
+):
     """Location transformation for multibox detection
 
     Parameters
@@ -88,6 +82,5 @@ def multibox_transform_loc(cls_prob,
     ret : tuple of tvm.relay.Expr
     """
     return expr.TupleWrapper(
-        _make.multibox_transform_loc(cls_prob, loc_pred,
-                                     anchor, clip, threshold,
-                                     variances), 2)
+        _make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances), 2
+    )
index 60ff7a5..4366609 100644 (file)
@@ -19,10 +19,7 @@ from tvm.relay import expr
 from . import _make
 
 
-def get_valid_counts(data,
-                     score_threshold,
-                     id_index=0,
-                     score_index=1):
+def get_valid_counts(data, score_threshold, id_index=0, score_index=1):
     """Get valid count of bounding boxes given a score threshold.
     Also moves valid boxes to the top of input data.
 
@@ -52,22 +49,24 @@ def get_valid_counts(data,
         Indices in input data
     """
     return expr.TupleWrapper(
-        _make.get_valid_counts(data, score_threshold,
-                               id_index, score_index), 3)
-
-
-def non_max_suppression(data,
-                        valid_count,
-                        indices,
-                        max_output_size=-1,
-                        iou_threshold=0.5,
-                        force_suppress=False,
-                        top_k=-1,
-                        coord_start=2,
-                        score_index=1,
-                        id_index=0,
-                        return_indices=True,
-                        invalid_to_bottom=False):
+        _make.get_valid_counts(data, score_threshold, id_index, score_index), 3
+    )
+
+
+def non_max_suppression(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
     """Non-maximum suppression operator for object detection.
 
     Parameters
@@ -129,18 +128,20 @@ def non_max_suppression(data,
     """
     if isinstance(max_output_size, int):
         max_output_size = expr.const(max_output_size, "int32")
-    out = _make.non_max_suppression(data,
-                                    valid_count,
-                                    indices,
-                                    max_output_size,
-                                    iou_threshold,
-                                    force_suppress,
-                                    top_k,
-                                    coord_start,
-                                    score_index,
-                                    id_index,
-                                    return_indices,
-                                    invalid_to_bottom)
+    out = _make.non_max_suppression(
+        data,
+        valid_count,
+        indices,
+        max_output_size,
+        iou_threshold,
+        force_suppress,
+        top_k,
+        coord_start,
+        score_index,
+        id_index,
+        return_indices,
+        invalid_to_bottom,
+    )
     if return_indices:
         return expr.TupleWrapper(out, 2)
     return out
index 1798ae9..b87eb07 100644 (file)
@@ -18,7 +18,7 @@
 from . import _make
 
 
-def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'):
+def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="NCHW"):
     """ROI align operator.
 
     Parameters
@@ -48,7 +48,7 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
     return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
 
 
-def roi_pool(data, rois, pooled_size, spatial_scale, layout='NCHW'):
+def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"):
     """ROI pool operator.
 
     Parameters
@@ -75,17 +75,19 @@ def roi_pool(data, rois, pooled_size, spatial_scale, layout='NCHW'):
     return _make.roi_pool(data, rois, pooled_size, spatial_scale, layout)
 
 
-def proposal(cls_prob,
-             bbox_pred,
-             im_info,
-             scales,
-             ratios,
-             feature_stride,
-             threshold,
-             rpn_pre_nms_top_n,
-             rpn_post_nms_top_n,
-             rpn_min_size,
-             iou_loss):
+def proposal(
+    cls_prob,
+    bbox_pred,
+    im_info,
+    scales,
+    ratios,
+    feature_stride,
+    threshold,
+    rpn_pre_nms_top_n,
+    rpn_post_nms_top_n,
+    rpn_min_size,
+    iou_loss,
+):
     """Proposal operator.
 
     Parameters
@@ -131,5 +133,16 @@ def proposal(cls_prob,
         2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
         [batch_index, w_start, h_start, w_end, h_end].
     """
-    return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
-                          rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss)
+    return _make.proposal(
+        cls_prob,
+        bbox_pred,
+        im_info,
+        scales,
+        ratios,
+        feature_stride,
+        threshold,
+        rpn_pre_nms_top_n,
+        rpn_post_nms_top_n,
+        rpn_min_size,
+        iou_loss,
+    )
index 90dc3b8..f556d74 100644 (file)
@@ -17,6 +17,7 @@
 """Yolo operations."""
 from . import _make
 
+
 def yolo_reorg(data, stride):
     """Yolo reorg operation used in darknet models.
     This layer shuffles the input tensor values based on the stride value.
index b7fee8c..2a01353 100644 (file)
@@ -23,6 +23,7 @@ import tvm._ffi
 _save_param_dict = tvm._ffi.get_global_func("tvm.relay._save_param_dict")
 _load_param_dict = tvm._ffi.get_global_func("tvm.relay._load_param_dict")
 
+
 def save_param_dict(params):
     """Save parameter dictionary to binary bytes.
 
@@ -75,4 +76,4 @@ def load_param_dict(param_bytes):
     if isinstance(param_bytes, (bytes, str)):
         param_bytes = bytearray(param_bytes)
     load_arr = _load_param_dict(param_bytes)
-    return {v.name : v.array for v in load_arr}
+    return {v.name: v.array for v in load_arr}
index 893c855..81d82dd 100644 (file)
@@ -28,6 +28,7 @@ from .adt import PatternConstructor, PatternVar, PatternWildcard
 from . import op, transform
 from .analysis import free_vars
 
+
 def get_tensor_array_shape(expr, dtype, prelude):
     """Get the static shape of a tensor array if it has fixed rank shape.
 
@@ -59,8 +60,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
     ta_type_str = checked_type.args[0].func.name_hint
     static_ta_ty_start = "static_tensor_{}".format(dtype)
     if ta_type_str.startswith(static_ta_ty_start):
-        shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
-            .replace("_t", '')
+        shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), "").replace("_t", "")
         shape = []
         if "scalar" not in shape_str:
             for dim_str in shape_str.split("_"):
@@ -71,16 +71,18 @@ def get_tensor_array_shape(expr, dtype, prelude):
         return tuple(shape)
     return None
 
+
 def _get_name_static(canonical, dtype, shape):
     """Get name for static shape tensor array op corresponding
     to the canonical name"""
-    shape_str = '_'.join([str(dim) for dim in shape])
+    shape_str = "_".join([str(dim) for dim in shape])
     if len(shape_str) == 0:
         shape_str = "scalar"
-    if canonical == 'tensor_t':
-        return 'static_tensor_{}_{}_t'.format(dtype, shape_str)
+    if canonical == "tensor_t":
+        return "static_tensor_{}_{}_t".format(dtype, shape_str)
     return "{}_{}_{}".format(canonical, dtype, shape_str)
 
+
 class StaticTensorArrayOps(object):
     """Contains tensor array related ops for fixed rank tensor array"""
 
@@ -102,7 +104,7 @@ class StaticTensorArrayOps(object):
     def define_tensor_adt(self):
         """Defines the static tensor ADT, which is the container for tensors
         with fixed shapes."""
-        tensor_type_name = self.get_name('tensor_t')
+        tensor_type_name = self.get_name("tensor_t")
         # Skip register if tensor type is already registered.
         global_type_names = set()
         for g_ty_var in self.prelude.mod.get_global_type_vars():
@@ -113,17 +115,17 @@ class StaticTensorArrayOps(object):
         tensor_type_var = GlobalTypeVar(tensor_type_name)
         setattr(self.prelude, tensor_type_name, tensor_type_var)
         tensor_type = TensorType(self.shape, self.dtype)
-        tensor_constructor_name = self.get_name('tensor_constructor')
+        tensor_constructor_name = self.get_name("tensor_constructor")
 
-        tensor_nil_name = self.get_name('tensor_nil')
+        tensor_nil_name = self.get_name("tensor_nil")
         tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
         tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var)
 
         setattr(self.prelude, tensor_nil_name, tensor_nil_case)
         setattr(self.prelude, tensor_constructor_name, tensor_case)
-        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var,
-                                                     [],
-                                                     [tensor_nil_case, tensor_case])
+        self.prelude.mod[tensor_type_var] = TypeData(
+            tensor_type_var, [], [tensor_nil_case, tensor_case]
+        )
 
     def define_tensor_array(self):
         """Defines a function to create a tensor array with size n.
@@ -132,20 +134,24 @@ class StaticTensorArrayOps(object):
         tensor_array_constructor_name = self.get_name("tensor_array")
         tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name)
         setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
-        tensor_nil_var = self.get_var('tensor_nil')
-        tensor_type_var = self.get_var('tensor_t')
-        n = Var("x", scalar_type('int32'))
-        body = If(equal(n, const(0)),
-                  self.prelude.nil(),
-                  self.prelude.cons(tensor_nil_var(),
-                                    tensor_array_constructor_var(subtract(n, const(1)))))
-        self.prelude.mod[tensor_array_constructor_var] = \
-            Function([n], body, self.prelude.l(tensor_type_var()), [])
+        tensor_nil_var = self.get_var("tensor_nil")
+        tensor_type_var = self.get_var("tensor_t")
+        n = Var("x", scalar_type("int32"))
+        body = If(
+            equal(n, const(0)),
+            self.prelude.nil(),
+            self.prelude.cons(
+                tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))
+            ),
+        )
+        self.prelude.mod[tensor_array_constructor_var] = Function(
+            [n], body, self.prelude.l(tensor_type_var()), []
+        )
 
     def define_tensor_take(self):
         """Defines a function to return a range of tensor_t on axis 0.
-            tensor_take(t, lower, upper) :
-            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
+        tensor_take(t, lower, upper) :
+        tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
         """
         # We don't register take for scalar tensor.
         ndim = len(self.shape)
@@ -155,29 +161,30 @@ class StaticTensorArrayOps(object):
         take_name = self.get_name("tensor_take")
         take_var = self._create_global_var(take_name)
         setattr(self.prelude, take_name, take_var)
-        origin_tensor_constructor = self.get_var('tensor_constructor')
-
-        output_shape = [Any(),] + list(self.shape[1:])
-        tensor_type_var, tensor_constructor = \
-            self._get_adt_by_shape(output_shape)
-
-        t = Var('tensor', self.get_var('tensor_t')())
-        lower = Var('lower', scalar_type('int32'))
-        upper = Var('upper', scalar_type('int32'))
-        tvar = Var('t')
-        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
-                      tensor_constructor(op.take(tvar,
-                                                 op.arange(lower, upper, dtype='int32'),
-                                                 axis=0)))
-        self.prelude.mod[take_var] = \
-            Function([t, lower, upper],
-                     Match(t, [case], False), tensor_type_var(), [])
+        origin_tensor_constructor = self.get_var("tensor_constructor")
+
+        output_shape = [
+            Any(),
+        ] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape)
+
+        t = Var("tensor", self.get_var("tensor_t")())
+        lower = Var("lower", scalar_type("int32"))
+        upper = Var("upper", scalar_type("int32"))
+        tvar = Var("t")
+        case = Clause(
+            PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
+            tensor_constructor(op.take(tvar, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        self.prelude.mod[take_var] = Function(
+            [t, lower, upper], Match(t, [case], False), tensor_type_var(), []
+        )
 
     def define_tensor_concatenate(self):
         """Defines a function to concatenate two tensor_t on axis 0.
         tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
         """
-         # We don't register concatenate for scalar tensor.
+        # We don't register concatenate for scalar tensor.
         ndim = len(self.shape)
         if ndim == 0:
             return
@@ -185,25 +192,35 @@ class StaticTensorArrayOps(object):
         concat_name = self.get_name("tensor_concatenate")
         concat_var = self._create_global_var(concat_name)
         setattr(self.prelude, concat_name, concat_var)
-        output_shape = [Any(),] + list(self.shape[1:])
-        tensor_type_var, tensor_constructor = \
-            self._get_adt_by_shape(output_shape)
+        output_shape = [
+            Any(),
+        ] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape)
 
-        origin_tensor_constructor = self.get_var('tensor_constructor')
-        origin_tensor_type_var = self.get_var('tensor_t')
+        origin_tensor_constructor = self.get_var("tensor_constructor")
+        origin_tensor_type_var = self.get_var("tensor_t")
         x = Var("x", origin_tensor_type_var())
         y = Var("y", origin_tensor_type_var())
         t1 = Var("t1")
         t2 = Var("t2")
 
-        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
-                      Match(y,
-                            [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
-                                    tensor_constructor(op.concatenate([t1, t2], axis=0)))],
-                            False))
-
-        self.prelude.mod[concat_var] = \
-            Function([x, y], Match(x, [case], False), tensor_type_var(), [])
+        case = Clause(
+            PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
+            Match(
+                y,
+                [
+                    Clause(
+                        PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
+                        tensor_constructor(op.concatenate([t1, t2], axis=0)),
+                    )
+                ],
+                False,
+            ),
+        )
+
+        self.prelude.mod[concat_var] = Function(
+            [x, y], Match(x, [case], False), tensor_type_var(), []
+        )
 
     def define_tensor_expand_dims(self):
         """Defines a function to grow a tensor_t's rank by adding one dimension in front
@@ -213,20 +230,27 @@ class StaticTensorArrayOps(object):
         expand_dims_name = self.get_name("tensor_expand_dims")
         expand_dims_var = self._create_global_var(expand_dims_name)
         setattr(self.prelude, expand_dims_name, expand_dims_var)
-        origin_tensor_type_var = self.get_var('tensor_t')
-        origin_tensor_constructor = self.get_var('tensor_constructor')
+        origin_tensor_type_var = self.get_var("tensor_t")
+        origin_tensor_constructor = self.get_var("tensor_constructor")
         x = Var("x", origin_tensor_type_var())
 
         # Note: we set the added axis to be Any() instead of 1 due to
         # in stack op, we need to recursively concatenate.
-        tensor_type_var, tensor_constructor = \
-            self._get_adt_by_shape([Any(),] + list(self.shape))
+        tensor_type_var, tensor_constructor = self._get_adt_by_shape(
+            [
+                Any(),
+            ]
+            + list(self.shape)
+        )
         t = Var("t")
-        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
-                      tensor_constructor(op.expand_dims(t, 0, 1)))
+        case = Clause(
+            PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
+            tensor_constructor(op.expand_dims(t, 0, 1)),
+        )
 
-        self.prelude.mod[expand_dims_var] = \
-            Function([x], Match(x, [case], False), tensor_type_var(), [])
+        self.prelude.mod[expand_dims_var] = Function(
+            [x], Match(x, [case], False), tensor_type_var(), []
+        )
 
     def define_tensor_array_read(self):
         """Defines a function to get the nth element of a list. Assume the list has at least one
@@ -237,12 +261,13 @@ class StaticTensorArrayOps(object):
         read_name = self.get_name("tensor_array_read")
         read_var = self._create_global_var(read_name)
         setattr(self.prelude, read_name, read_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
 
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        n = Var("x", scalar_type('int32'))
-        self.prelude.mod[read_var] = \
-            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])
+        n = Var("x", scalar_type("int32"))
+        self.prelude.mod[read_var] = Function(
+            [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []
+        )
 
     def define_tensor_array_write(self):
         """Defines a function to update a tensor array at index n with value v.
@@ -253,13 +278,16 @@ class StaticTensorArrayOps(object):
         write_name = self.get_name("tensor_array_write")
         write_var = self._create_global_var(write_name)
         setattr(self.prelude, write_name, write_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        n = Var("x", scalar_type('int32'))
+        n = Var("x", scalar_type("int32"))
         v = Var("v", tensor_type_var())
-        self.prelude.mod[write_var] = \
-            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
-                     self.prelude.l(tensor_type_var()), [])
+        self.prelude.mod[write_var] = Function(
+            [tensor_array, n, v],
+            self.prelude.update(tensor_array, n, v),
+            self.prelude.l(tensor_type_var()),
+            [],
+        )
 
     def define_tensor_array_unstack(self):
         """Defines a function to unstack the values of a tensor_t in a tensor array.
@@ -274,28 +302,34 @@ class StaticTensorArrayOps(object):
         helper_var = self._create_global_var(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType(self.shape, self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
         tensor_var = Var("tensor", TensorType(self.shape, self.dtype))
 
-        reduced_tensor_type_var, tensor_constructor = \
-            self._get_adt_by_shape(self.shape[1:])
-        helper_body = \
-            If(equal(i, up),
-               self.prelude.nil(),
-               self.prelude.cons(tensor_constructor(op.take(tensor, i, axis=0)),
-                                 helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] = \
-            Function([i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [])
+        reduced_tensor_type_var, tensor_constructor = self._get_adt_by_shape(self.shape[1:])
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                tensor_constructor(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), []
+        )
 
         unstack_name = self.get_name("tensor_array_unstack")
         unstack_var = self._create_global_var(unstack_name)
         setattr(self.prelude, unstack_name, unstack_var)
         shape = op.shape_of(tensor_var)
         unstack_length = op.take(shape, const(0))
-        self.prelude.mod[unstack_var] = \
-            Function([tensor_var], helper_var(const(0), unstack_length, tensor_var),
-                     self.prelude.l(reduced_tensor_type_var()), [])
+        self.prelude.mod[unstack_var] = Function(
+            [tensor_var],
+            helper_var(const(0), unstack_length, tensor_var),
+            self.prelude.l(reduced_tensor_type_var()),
+            [],
+        )
 
     def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
         """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
@@ -313,33 +347,39 @@ class StaticTensorArrayOps(object):
             return
 
         tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
-        tensor_array_scatter_helper_var = \
-            self._create_global_var(tensor_array_scatter_helper_name)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
+        tensor_type_var = self.get_var("tensor_t")
         ta = Var("ta", self.prelude.l(tensor_type_var()))
-        current = Var("current", scalar_type('int32'))
-        limit = Var("limit", scalar_type('int32'))
-        indices_ = Var('indices_', TensorType(indices_shape or [Any()], 'int32'))
-        values_ = Var('values_', self.prelude.l(tensor_type_var()))
-        write_var = self.get_var('tensor_array_write')
-        read_var = self.get_var('tensor_array_read')
-        helper_body = If(equal(current, limit),
-                         ta,
-                         tensor_array_scatter_helper_var(
-                             write_var(ta, op.take(indices_, current),
-                                       read_var(values_, current)),
-                             add(current, const(1)),
-                             limit, indices_, values_))
-        self.prelude.mod[tensor_array_scatter_helper_var] = \
-            Function([ta, current, limit, indices_, values_],
-                     helper_body, self.prelude.l(tensor_type_var()), [])
+        current = Var("current", scalar_type("int32"))
+        limit = Var("limit", scalar_type("int32"))
+        indices_ = Var("indices_", TensorType(indices_shape or [Any()], "int32"))
+        values_ = Var("values_", self.prelude.l(tensor_type_var()))
+        write_var = self.get_var("tensor_array_write")
+        read_var = self.get_var("tensor_array_read")
+        helper_body = If(
+            equal(current, limit),
+            ta,
+            tensor_array_scatter_helper_var(
+                write_var(ta, op.take(indices_, current), read_var(values_, current)),
+                add(current, const(1)),
+                limit,
+                indices_,
+                values_,
+            ),
+        )
+        self.prelude.mod[tensor_array_scatter_helper_var] = Function(
+            [ta, current, limit, indices_, values_],
+            helper_body,
+            self.prelude.l(tensor_type_var()),
+            [],
+        )
 
         tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name)
         setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
 
-        indices = Var('indices', TensorType(indices_shape or [Any()], 'int32'))
-        values = Var('values', self.prelude.l(tensor_type_var()))
+        indices = Var("indices", TensorType(indices_shape or [Any()], "int32"))
+        values = Var("values", self.prelude.l(tensor_type_var()))
         if indices_shape is None:
             indices_shape = op.shape_of(indices)
             limit = op.take(indices_shape, const(0))
@@ -347,14 +387,11 @@ class StaticTensorArrayOps(object):
             limit = const(indices_shape[0])
 
         body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
-        self.prelude.mod[tensor_array_scatter_var] = \
-            Function([tensor_array, indices, values], body,
-                     self.prelude.l(tensor_type_var()), [])
-
-    def define_tensor_array_split(self,
-                                  value_shape=None,
-                                  lengths_shape=None,
-                                  force_update=False):
+        self.prelude.mod[tensor_array_scatter_var] = Function(
+            [tensor_array, indices, values], body, self.prelude.l(tensor_type_var()), []
+        )
+
+    def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_update=False):
         """Defines a function to split the values of a tensor_t into a tensor array.
         tensor_array_split(ta, value, lengths) :
             list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
@@ -374,80 +411,80 @@ class StaticTensorArrayOps(object):
         if hasattr(self.prelude, split_name) and not force_update:
             return
 
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         tensor_array_split_helper_name = self.get_name("ta_split_helper")
-        tensor_array_split_helper_var = \
-            self._create_global_var(tensor_array_split_helper_name)
+        tensor_array_split_helper_var = self._create_global_var(tensor_array_split_helper_name)
         setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
-        output_shape = [Any(),] + list(self.shape[1:])
+        output_shape = [
+            Any(),
+        ] + list(self.shape[1:])
         output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
 
         if value_shape is None:
             value_type_var = tensor_type_var
-            take_var = self.get_var('tensor_take')
+            take_var = self.get_var("tensor_take")
         else:
             value_type_var, _ = self._get_adt_by_shape(value_shape)
             # Also get static shape take operator
             origin_shape = list(self.shape)
             self.shape = value_shape
             self.define_tensor_take()
-            take_var = self.get_var('tensor_take')
+            take_var = self.get_var("tensor_take")
             self.shape = origin_shape
 
-
         ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
-        value1 = Var('value1', value_type_var())
-        offset1 = Var('offset1', scalar_type('int32'))
-        current1 = Var('current1', scalar_type('int32'))
-        limit1 = Var('limit1', scalar_type('int32'))
-        lengths1 = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+        value1 = Var("value1", value_type_var())
+        offset1 = Var("offset1", scalar_type("int32"))
+        current1 = Var("current1", scalar_type("int32"))
+        limit1 = Var("limit1", scalar_type("int32"))
+        lengths1 = Var("lengths", TensorType(lengths_shape or [Any()], "int32"))
 
         # Register write for output shape
         origin_shape = list(self.shape)
         self.shape = output_shape
         self.define_tensor_array_write()
-        write_var = self.get_var('tensor_array_write')
+        write_var = self.get_var("tensor_array_write")
         self.shape = origin_shape
-        helper1_body = If(equal(current1, limit1),
-                          ta1,
-                          write_var(
-                              tensor_array_split_helper_var(
-                                  ta1,
-                                  value1,
-                                  add(offset1, op.take(lengths1, current1)),
-                                  add(current1, const(1)),
-                                  limit1,
-                                  lengths1
-                              ),
-                              current1,
-                              take_var(value1,
-                                       offset1,
-                                       add(op.take(lengths1, current1), offset1))))
-        self.prelude.mod[tensor_array_split_helper_var] = \
-            Function([ta1, value1, offset1, current1, limit1, lengths1],
-                     helper1_body, self.prelude.l(output_tensor_type_var()), [])
+        helper1_body = If(
+            equal(current1, limit1),
+            ta1,
+            write_var(
+                tensor_array_split_helper_var(
+                    ta1,
+                    value1,
+                    add(offset1, op.take(lengths1, current1)),
+                    add(current1, const(1)),
+                    limit1,
+                    lengths1,
+                ),
+                current1,
+                take_var(value1, offset1, add(op.take(lengths1, current1), offset1)),
+            ),
+        )
+        self.prelude.mod[tensor_array_split_helper_var] = Function(
+            [ta1, value1, offset1, current1, limit1, lengths1],
+            helper1_body,
+            self.prelude.l(output_tensor_type_var()),
+            [],
+        )
         split_var = self._create_global_var(split_name)
         setattr(self.prelude, split_name, split_var)
         tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
 
-        value = Var('value', value_type_var())
-        lengths = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
+        value = Var("value", value_type_var())
+        lengths = Var("lengths", TensorType(lengths_shape or [Any()], "int32"))
         if lengths_shape is None:
             lengths_shape = op.shape_of(lengths)
             lengths_limit = op.take(lengths_shape, const(0))
         else:
             lengths_limit = const(lengths_shape[0])
         body = tensor_array_split_helper_var(
-            tensor_array,
-            value,
-            const(0),
-            const(0),
-            lengths_limit,
-            lengths)
+            tensor_array, value, const(0), const(0), lengths_limit, lengths
+        )
 
-        self.prelude.mod[split_var] = \
-            Function([tensor_array, value, lengths], body,
-                     self.prelude.l(output_tensor_type_var()), [])
+        self.prelude.mod[split_var] = Function(
+            [tensor_array, value, lengths], body, self.prelude.l(output_tensor_type_var()), []
+        )
 
     def define_tensor_array_concat(self):
         """Defines a function to return the values in the tensor array as concatenated tensor_t.
@@ -462,30 +499,37 @@ class StaticTensorArrayOps(object):
         concat_var = self._create_global_var(concat_name)
         setattr(self.prelude, concat_name, concat_var)
 
-        output_shape = [Any(),] + list(self.shape[1:])
+        output_shape = [
+            Any(),
+        ] + list(self.shape[1:])
         tensor_type_var, _ = self._get_adt_by_shape(output_shape)
 
         # Register tensor concatenate and get tensor_nil var for output shape
         origin_shape = self.shape
         self.shape = output_shape
         self.define_tensor_concatenate()
-        tensor_concat_var = self.get_var('tensor_concatenate')
-        tensor_nil_var = self.get_var('tensor_nil')
+        tensor_concat_var = self.get_var("tensor_concatenate")
+        tensor_nil_var = self.get_var("tensor_nil")
         self.shape = origin_shape
 
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
         hd = Var("hd")
         tl = Var("tl")
         nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
-        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
-                           Match(tl, [
-                               Clause(PatternConstructor(self.prelude.nil), hd),
-                               Clause(PatternWildcard(),
-                                      tensor_concat_var(hd, concat_var(tl)))
-                           ], False))
-        self.prelude.mod[concat_var] = \
-            Function([tensor_array],
-                     Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [])
+        cons_case = Clause(
+            PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
+            Match(
+                tl,
+                [
+                    Clause(PatternConstructor(self.prelude.nil), hd),
+                    Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))),
+                ],
+                False,
+            ),
+        )
+        self.prelude.mod[concat_var] = Function(
+            [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []
+        )
 
     def define_tensor_array_stack(self):
         """Defines a function to get the values in the tensor array as a stack tensor_t.
@@ -494,26 +538,30 @@ class StaticTensorArrayOps(object):
         stack_name = self.get_name("tensor_array_stack")
         stack_var = self._create_global_var(stack_name)
         setattr(self.prelude, stack_name, stack_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        expand_dims_var = self.get_var('tensor_expand_dims')
+        expand_dims_var = self.get_var("tensor_expand_dims")
 
         # Register tensor_concatenate for output_shape
         origin_shape = self.shape
-        output_shape = [Any(),] + list(self.shape)
+        output_shape = [
+            Any(),
+        ] + list(self.shape)
         self.shape = output_shape
         self.define_tensor_concatenate()
-        concat_var = self.get_var('tensor_concatenate')
+        concat_var = self.get_var("tensor_concatenate")
         self.shape = origin_shape
 
         tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
-        tensors = self.prelude.foldl(concat_var,
-                                     self.prelude.hd(tensor_array_expand_dims),
-                                     self.prelude.tl(tensor_array_expand_dims))
+        tensors = self.prelude.foldl(
+            concat_var,
+            self.prelude.hd(tensor_array_expand_dims),
+            self.prelude.tl(tensor_array_expand_dims),
+        )
         output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
-        self.prelude.mod[stack_var] = \
-            Function([tensor_array], tensors,
-                     output_tensor_type_var(), [])
+        self.prelude.mod[stack_var] = Function(
+            [tensor_array], tensors, output_tensor_type_var(), []
+        )
 
     def define_tensor_array_gather(self):
         """Defines a function to return the selected values in a tensor array as tensor_t.
@@ -522,55 +570,59 @@ class StaticTensorArrayOps(object):
         helper_name = self.get_name("tensor_array_gather_helper")
         helper_var = self._create_global_var(helper_name)
         setattr(self.prelude, helper_name, helper_var)
-        tensor_type_var = self.get_var('tensor_t')
-        output_shape = [Any(),] + list(self.shape)
+        tensor_type_var = self.get_var("tensor_t")
+        output_shape = [
+            Any(),
+        ] + list(self.shape)
         output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
-        stack_var = self.get_var('tensor_array_stack')
-        read_var = self.get_var('tensor_array_read')
+        stack_var = self.get_var("tensor_array_stack")
+        read_var = self.get_var("tensor_array_read")
         ta = Var("ta", self.prelude.l(tensor_type_var()))
         accu = Var("accu", self.prelude.l(tensor_type_var()))
-        current = Var("current", scalar_type('int32'))
-        limit = Var("limit", scalar_type('int32'))
-        indices_ = Var('indices_', TensorType([Any()], 'int32'))
-        helper_body = \
-            If(equal(current, const(0)),
-               stack_var(accu),
-               helper_var(
-                   ta,
-                   self.prelude.cons(
-                       read_var(
-                           ta, op.take(indices_, subtract(current, const(1)))), accu),
-                   subtract(current, const(1)),
-                   limit, indices_))
-        self.prelude.mod[helper_var] = \
-            Function([ta, accu, current, limit, indices_],
-                     helper_body, output_tensor_type_var(), [])
+        current = Var("current", scalar_type("int32"))
+        limit = Var("limit", scalar_type("int32"))
+        indices_ = Var("indices_", TensorType([Any()], "int32"))
+        helper_body = If(
+            equal(current, const(0)),
+            stack_var(accu),
+            helper_var(
+                ta,
+                self.prelude.cons(
+                    read_var(ta, op.take(indices_, subtract(current, const(1)))), accu
+                ),
+                subtract(current, const(1)),
+                limit,
+                indices_,
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [ta, accu, current, limit, indices_], helper_body, output_tensor_type_var(), []
+        )
         gather_name = self.get_name("tensor_array_gather")
         gather_var = self._create_global_var(gather_name)
         setattr(self.prelude, gather_name, gather_var)
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        indices = Var('indices', TensorType([Any()], 'int32'))
+        indices = Var("indices", TensorType([Any()], "int32"))
         indices_shape = op.shape_of(indices)
         limit = op.take(indices_shape, const(0))
         body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
-        self.prelude.mod[gather_var] = \
-            Function([tensor_array, indices], body, output_tensor_type_var(), [])
+        self.prelude.mod[gather_var] = Function(
+            [tensor_array, indices], body, output_tensor_type_var(), []
+        )
 
     def define_tensor_get_data(self):
-        """Defines a function to get a Tensor from tensor_t with given shape.
-        """
+        """Defines a function to get a Tensor from tensor_t with given shape."""
         tensor_get_data_name = self.get_name("tensor_get_data")
         tensor_get_data_var = self._create_global_var(tensor_get_data_name)
         setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)
-        tensor_type_var = self.get_var('tensor_t')
-        tensor_constructor = self.get_var('tensor_constructor')
-        t = Var('tensor', tensor_type_var())
-        tvar = Var('t')
-        case =\
-            Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
-        self.prelude.mod[tensor_get_data_var] = \
-            Function([t], Match(t, [case], False),
-                     TensorType(self.shape, self.dtype), [])
+        tensor_type_var = self.get_var("tensor_t")
+        tensor_constructor = self.get_var("tensor_constructor")
+        t = Var("tensor", tensor_type_var())
+        tvar = Var("t")
+        case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
+        self.prelude.mod[tensor_get_data_var] = Function(
+            [t], Match(t, [case], False), TensorType(self.shape, self.dtype), []
+        )
 
     def register(self):
         """Register all tensor array ops in Prelude"""
@@ -611,6 +663,7 @@ class StaticTensorArrayOps(object):
 
         return gvar
 
+
 class TensorArrayOps(object):
     """Contains tensor array related ops"""
 
@@ -630,7 +683,7 @@ class TensorArrayOps(object):
     def define_tensor_adt(self):
         """Defines the dynamic tensor ADT, which is the container for tensors
         with variable shapes."""
-        tensor_type_name = self.get_name('tensor_t')
+        tensor_type_name = self.get_name("tensor_t")
         tensor_type_var = GlobalTypeVar(tensor_type_name)
         setattr(self.prelude, tensor_type_name, tensor_type_var)
         tensor0_type = TensorType([], self.dtype)
@@ -640,14 +693,14 @@ class TensorArrayOps(object):
         tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype)
         tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)
         tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)
-        tensor_nil_name = self.get_name('tensor_nil')
-        tensor0_name = self.get_name('tensor0')
-        tensor1_name = self.get_name('tensor1')
-        tensor2_name = self.get_name('tensor2')
-        tensor3_name = self.get_name('tensor3')
-        tensor4_name = self.get_name('tensor4')
-        tensor5_name = self.get_name('tensor5')
-        tensor6_name = self.get_name('tensor6')
+        tensor_nil_name = self.get_name("tensor_nil")
+        tensor0_name = self.get_name("tensor0")
+        tensor1_name = self.get_name("tensor1")
+        tensor2_name = self.get_name("tensor2")
+        tensor3_name = self.get_name("tensor3")
+        tensor4_name = self.get_name("tensor4")
+        tensor5_name = self.get_name("tensor5")
+        tensor6_name = self.get_name("tensor6")
         tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
         tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var)
         tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var)
@@ -664,66 +717,86 @@ class TensorArrayOps(object):
         setattr(self.prelude, tensor4_name, tensor4_case)
         setattr(self.prelude, tensor5_name, tensor5_case)
         setattr(self.prelude, tensor6_name, tensor6_case)
-        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case,
-                                                                           tensor0_case,
-                                                                           tensor1_case,
-                                                                           tensor2_case,
-                                                                           tensor3_case,
-                                                                           tensor4_case,
-                                                                           tensor5_case,
-                                                                           tensor6_case])
+        self.prelude.mod[tensor_type_var] = TypeData(
+            tensor_type_var,
+            [],
+            [
+                tensor_nil_case,
+                tensor0_case,
+                tensor1_case,
+                tensor2_case,
+                tensor3_case,
+                tensor4_case,
+                tensor5_case,
+                tensor6_case,
+            ],
+        )
 
     def define_tensor_take(self):
         """Defines a function to return a range of tensor_t on axis 0.
-            tensor_take(t, lower, upper) :
-            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
+        tensor_take(t, lower, upper) :
+        tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
         """
         take_name = self.get_name("tensor_take")
         take_var = GlobalVar(take_name)
         setattr(self.prelude, take_name, take_var)
-        tensor_t = self.get_var('tensor_t')
-        tensor1_var = self.get_var('tensor1')
-        tensor2_var = self.get_var('tensor2')
-        tensor3_var = self.get_var('tensor3')
-        tensor4_var = self.get_var('tensor4')
-        tensor5_var = self.get_var('tensor5')
-        tensor6_var = self.get_var('tensor6')
-        t = Var('tensor', tensor_t())
-        lower = Var('lower', scalar_type('int32'))
-        upper = Var('upper', scalar_type('int32'))
-        t1 = Var('t1')
-        t2 = Var('t2')
-        t3 = Var('t3')
-        t4 = Var('t4')
-        t5 = Var('t5')
-        t6 = Var('t6')
-        tensor1_case =\
-            Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
-                   tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32'))))
-        tensor2_case =\
-            Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
-                   tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0)))
-        tensor3_case =\
-            Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
-                   tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0)))
-        tensor4_case =\
-            Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
-                   tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0)))
-        tensor5_case =\
-            Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
-                   tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0)))
-        tensor6_case =\
-            Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]),
-                   tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0)))
-        self.prelude.mod[take_var] =\
-            Function([t, lower, upper],
-                     Match(t, [tensor1_case,
-                               tensor2_case,
-                               tensor3_case,
-                               tensor4_case,
-                               tensor5_case,
-                               tensor6_case], False),
-                     tensor_t(), [])
+        tensor_t = self.get_var("tensor_t")
+        tensor1_var = self.get_var("tensor1")
+        tensor2_var = self.get_var("tensor2")
+        tensor3_var = self.get_var("tensor3")
+        tensor4_var = self.get_var("tensor4")
+        tensor5_var = self.get_var("tensor5")
+        tensor6_var = self.get_var("tensor6")
+        t = Var("tensor", tensor_t())
+        lower = Var("lower", scalar_type("int32"))
+        upper = Var("upper", scalar_type("int32"))
+        t1 = Var("t1")
+        t2 = Var("t2")
+        t3 = Var("t3")
+        t4 = Var("t4")
+        t5 = Var("t5")
+        t6 = Var("t6")
+        tensor1_case = Clause(
+            PatternConstructor(tensor1_var, [PatternVar(t1)]),
+            tensor1_var(op.take(t1, op.arange(lower, upper, dtype="int32"))),
+        )
+        tensor2_case = Clause(
+            PatternConstructor(tensor2_var, [PatternVar(t2)]),
+            tensor2_var(op.take(t2, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        tensor3_case = Clause(
+            PatternConstructor(tensor3_var, [PatternVar(t3)]),
+            tensor3_var(op.take(t3, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        tensor4_case = Clause(
+            PatternConstructor(tensor4_var, [PatternVar(t4)]),
+            tensor4_var(op.take(t4, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        tensor5_case = Clause(
+            PatternConstructor(tensor5_var, [PatternVar(t5)]),
+            tensor5_var(op.take(t5, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        tensor6_case = Clause(
+            PatternConstructor(tensor6_var, [PatternVar(t6)]),
+            tensor6_var(op.take(t6, op.arange(lower, upper, dtype="int32"), axis=0)),
+        )
+        self.prelude.mod[take_var] = Function(
+            [t, lower, upper],
+            Match(
+                t,
+                [
+                    tensor1_case,
+                    tensor2_case,
+                    tensor3_case,
+                    tensor4_case,
+                    tensor5_case,
+                    tensor6_case,
+                ],
+                False,
+            ),
+            tensor_t(),
+            [],
+        )
 
     def define_tensor_expand_dims(self):
         """Defines a function to grow a tensor_t's rank by adding one dimension in front
@@ -733,7 +806,7 @@ class TensorArrayOps(object):
         expand_dims_name = self.get_name("tensor_expand_dims")
         expand_dims_var = GlobalVar(expand_dims_name)
         setattr(self.prelude, expand_dims_name, expand_dims_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         x = Var("x", tensor_type_var())
         t0 = Var("t0")
         t1 = Var("t1")
@@ -741,33 +814,46 @@ class TensorArrayOps(object):
         t3 = Var("t3")
         t4 = Var("t4")
         t5 = Var("t5")
-        tensor0_var = self.get_var('tensor0')
-        tensor1_var = self.get_var('tensor1')
-        tensor2_var = self.get_var('tensor2')
-        tensor3_var = self.get_var('tensor3')
-        tensor4_var = self.get_var('tensor4')
-        tensor5_var = self.get_var('tensor5')
-        tensor6_var = self.get_var('tensor6')
-        tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]),
-                              tensor1_var(op.expand_dims(t0, 0, 1)))
-        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
-                              tensor2_var(op.expand_dims(t1, 0, 1)))
-        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
-                              tensor3_var(op.expand_dims(t2, 0, 1)))
-        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
-                              tensor4_var(op.expand_dims(t3, 0, 1)))
-        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
-                              tensor5_var(op.expand_dims(t4, 0, 1)))
-        tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
-                              tensor6_var(op.expand_dims(t5, 0, 1)))
-        self.prelude.mod[expand_dims_var] =\
-            Function([x],
-                     Match(x, [tensor0_case,
-                               tensor1_case,
-                               tensor2_case,
-                               tensor3_case,
-                               tensor4_case,
-                               tensor5_case], False))
+        tensor0_var = self.get_var("tensor0")
+        tensor1_var = self.get_var("tensor1")
+        tensor2_var = self.get_var("tensor2")
+        tensor3_var = self.get_var("tensor3")
+        tensor4_var = self.get_var("tensor4")
+        tensor5_var = self.get_var("tensor5")
+        tensor6_var = self.get_var("tensor6")
+        tensor0_case = Clause(
+            PatternConstructor(tensor0_var, [PatternVar(t0)]), tensor1_var(op.expand_dims(t0, 0, 1))
+        )
+        tensor1_case = Clause(
+            PatternConstructor(tensor1_var, [PatternVar(t1)]), tensor2_var(op.expand_dims(t1, 0, 1))
+        )
+        tensor2_case = Clause(
+            PatternConstructor(tensor2_var, [PatternVar(t2)]), tensor3_var(op.expand_dims(t2, 0, 1))
+        )
+        tensor3_case = Clause(
+            PatternConstructor(tensor3_var, [PatternVar(t3)]), tensor4_var(op.expand_dims(t3, 0, 1))
+        )
+        tensor4_case = Clause(
+            PatternConstructor(tensor4_var, [PatternVar(t4)]), tensor5_var(op.expand_dims(t4, 0, 1))
+        )
+        tensor5_case = Clause(
+            PatternConstructor(tensor5_var, [PatternVar(t5)]), tensor6_var(op.expand_dims(t5, 0, 1))
+        )
+        self.prelude.mod[expand_dims_var] = Function(
+            [x],
+            Match(
+                x,
+                [
+                    tensor0_case,
+                    tensor1_case,
+                    tensor2_case,
+                    tensor3_case,
+                    tensor4_case,
+                    tensor5_case,
+                ],
+                False,
+            ),
+        )
 
     def define_tensor_concat(self):
         """Defines a function to concatenate two tensor_t on the first axis
@@ -777,14 +863,14 @@ class TensorArrayOps(object):
         concat_name = self.get_name("tensor_concatenate")
         concat_var = GlobalVar(concat_name)
         setattr(self.prelude, concat_name, concat_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         x = Var("x", tensor_type_var())
         y = Var("y", tensor_type_var())
 
-        tensor1_var = self.get_var('tensor1')
-        tensor2_var = self.get_var('tensor2')
-        tensor3_var = self.get_var('tensor3')
-        tensor4_var = self.get_var('tensor4')
+        tensor1_var = self.get_var("tensor1")
+        tensor2_var = self.get_var("tensor2")
+        tensor3_var = self.get_var("tensor3")
+        tensor4_var = self.get_var("tensor4")
         t11 = Var("t11")
         t12 = Var("t12")
         t21 = Var("t21")
@@ -793,28 +879,62 @@ class TensorArrayOps(object):
         t32 = Var("t32")
         t41 = Var("t41")
         t42 = Var("t42")
-        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]),
-                              Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]),
-                                               tensor1_var(op.concatenate([t11, t12], axis=0)))],
-                                    False))
-        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]),
-                              Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]),
-                                               tensor2_var(op.concatenate([t21, t22], axis=0)))],
-                                    False))
-        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]),
-                              Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]),
-                                               tensor3_var(op.concatenate([t31, t32], axis=0)))],
-                                    False))
-        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]),
-                              Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]),
-                                               tensor4_var(op.concatenate([t41, t42], axis=0)))],
-                                    False))
+        tensor1_case = Clause(
+            PatternConstructor(tensor1_var, [PatternVar(t11)]),
+            Match(
+                y,
+                [
+                    Clause(
+                        PatternConstructor(tensor1_var, [PatternVar(t12)]),
+                        tensor1_var(op.concatenate([t11, t12], axis=0)),
+                    )
+                ],
+                False,
+            ),
+        )
+        tensor2_case = Clause(
+            PatternConstructor(tensor2_var, [PatternVar(t21)]),
+            Match(
+                y,
+                [
+                    Clause(
+                        PatternConstructor(tensor2_var, [PatternVar(t22)]),
+                        tensor2_var(op.concatenate([t21, t22], axis=0)),
+                    )
+                ],
+                False,
+            ),
+        )
+        tensor3_case = Clause(
+            PatternConstructor(tensor3_var, [PatternVar(t31)]),
+            Match(
+                y,
+                [
+                    Clause(
+                        PatternConstructor(tensor3_var, [PatternVar(t32)]),
+                        tensor3_var(op.concatenate([t31, t32], axis=0)),
+                    )
+                ],
+                False,
+            ),
+        )
+        tensor4_case = Clause(
+            PatternConstructor(tensor4_var, [PatternVar(t41)]),
+            Match(
+                y,
+                [
+                    Clause(
+                        PatternConstructor(tensor4_var, [PatternVar(t42)]),
+                        tensor4_var(op.concatenate([t41, t42], axis=0)),
+                    )
+                ],
+                False,
+            ),
+        )
         # op.concatenate does not support tensor with rank higher than 4
-        self.prelude.mod[concat_var] = \
-            Function([x, y], Match(x, [tensor1_case,
-                                       tensor2_case,
-                                       tensor3_case,
-                                       tensor4_case], False))
+        self.prelude.mod[concat_var] = Function(
+            [x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False)
+        )
 
     def define_tensor_array(self):
         """Defines a function to create a tensor array with size n.
@@ -823,15 +943,19 @@ class TensorArrayOps(object):
         tensor_array_constructor_name = self.get_name("tensor_array")
         tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name)
         setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
-        tensor_nil_var = self.get_var('tensor_nil')
-        tensor_type_var = self.get_var('tensor_t')
-        n = Var("x", scalar_type('int32'))
-        body = If(equal(n, const(0)),
-                  self.prelude.nil(),
-                  self.prelude.cons(tensor_nil_var(),
-                                    tensor_array_constructor_var(subtract(n, const(1)))))
-        self.prelude.mod[tensor_array_constructor_var] = \
-            Function([n], body, self.prelude.l(tensor_type_var()), [])
+        tensor_nil_var = self.get_var("tensor_nil")
+        tensor_type_var = self.get_var("tensor_t")
+        n = Var("x", scalar_type("int32"))
+        body = If(
+            equal(n, const(0)),
+            self.prelude.nil(),
+            self.prelude.cons(
+                tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))
+            ),
+        )
+        self.prelude.mod[tensor_array_constructor_var] = Function(
+            [n], body, self.prelude.l(tensor_type_var()), []
+        )
 
     def define_tensor_array_read(self):
         """Defines a function to get the head of a list. Assume the list has at least one
@@ -842,12 +966,13 @@ class TensorArrayOps(object):
         read_name = self.get_name("tensor_array_read")
         read_var = GlobalVar(read_name)
         setattr(self.prelude, read_name, read_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
 
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        n = Var("x", scalar_type('int32'))
-        self.prelude.mod[read_var] =\
-            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])
+        n = Var("x", scalar_type("int32"))
+        self.prelude.mod[read_var] = Function(
+            [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []
+        )
 
     def define_tensor_array_write(self):
         """Defines a function to update a tensor array at index n with value v.
@@ -857,13 +982,16 @@ class TensorArrayOps(object):
         write_name = self.get_name("tensor_array_write")
         write_var = GlobalVar(write_name)
         setattr(self.prelude, write_name, write_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        n = Var("x", scalar_type('int32'))
+        n = Var("x", scalar_type("int32"))
         v = Var("v", tensor_type_var())
-        self.prelude.mod[write_var] =\
-            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
-                     self.prelude.l(tensor_type_var()), [])
+        self.prelude.mod[write_var] = Function(
+            [tensor_array, n, v],
+            self.prelude.update(tensor_array, n, v),
+            self.prelude.l(tensor_type_var()),
+            [],
+        )
 
     def define_tensor_array_unstack_tensor1(self):
         """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array.
@@ -873,26 +1001,29 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-        tensor_type_var = self.get_var('tensor_t')
-        tensor0_var = self.get_var('tensor0')
-        helper_body =\
-            If(equal(i, up),
-               self.prelude.nil(),
-               self.prelude.cons(tensor0_var(op.take(tensor, i)),
-                                 helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+        tensor_type_var = self.get_var("tensor_t")
+        tensor0_var = self.get_var("tensor0")
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor)
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []
+        )
         unstack_name = self.get_name("tensor_array_unstack_tensor1")
         unstack_var = GlobalVar(unstack_name)
         setattr(self.prelude, unstack_name, unstack_var)
         tensor1 = Var("tensor", TensorType([Any()], self.dtype))
         shape = op.shape_of(tensor1)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[unstack_var] =\
-            Function([tensor1], helper_var(const(0), ndim, tensor1),
-                     self.prelude.l(tensor_type_var()), [])
+        self.prelude.mod[unstack_var] = Function(
+            [tensor1], helper_var(const(0), ndim, tensor1), self.prelude.l(tensor_type_var()), []
+        )
 
     def define_tensor_array_unstack_tensor2(self):
         """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array.
@@ -903,15 +1034,20 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any(), Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-
-        helper_body = If(equal(i, up),
-                         self.prelude.nil(),
-                         self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)),
-                                           helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                self.get_var("tensor1")(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
+        )
 
         tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2")
         tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name)
@@ -919,9 +1055,12 @@ class TensorArrayOps(object):
         tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype))
         shape = op.shape_of(tensor2)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[tensor_array_unstack_tensor2_var] =\
-            Function([tensor2], helper_var(const(0), ndim, tensor2),
-                     self.prelude.l(self.get_var('tensor_t')()), [])
+        self.prelude.mod[tensor_array_unstack_tensor2_var] = Function(
+            [tensor2],
+            helper_var(const(0), ndim, tensor2),
+            self.prelude.l(self.get_var("tensor_t")()),
+            [],
+        )
 
     def define_tensor_array_unstack_tensor3(self):
         """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.
@@ -932,15 +1071,20 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-
-        helper_body = If(equal(i, up),
-                         self.prelude.nil(),
-                         self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)),
-                                           helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                self.get_var("tensor2")(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
+        )
 
         tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
         tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
@@ -948,9 +1092,12 @@ class TensorArrayOps(object):
         tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
         shape = op.shape_of(tensor3)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[tensor_array_unstack_tensor3_var] =\
-            Function([tensor3], helper_var(const(0), ndim, tensor3),
-                     self.prelude.l(self.get_var('tensor_t')()), [])
+        self.prelude.mod[tensor_array_unstack_tensor3_var] = Function(
+            [tensor3],
+            helper_var(const(0), ndim, tensor3),
+            self.prelude.l(self.get_var("tensor_t")()),
+            [],
+        )
 
     def define_tensor_array_unstack_tensor4(self):
         """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.
@@ -961,15 +1108,20 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-
-        helper_body = If(equal(i, up),
-                         self.prelude.nil(),
-                         self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)),
-                                           helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                self.get_var("tensor3")(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
+        )
 
         tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
         tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
@@ -977,9 +1129,12 @@ class TensorArrayOps(object):
         tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
         shape = op.shape_of(tensor4)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[tensor_array_unstack_tensor4_var] =\
-            Function([tensor4], helper_var(const(0), ndim, tensor4),
-                     self.prelude.l(self.get_var('tensor_t')()), [])
+        self.prelude.mod[tensor_array_unstack_tensor4_var] = Function(
+            [tensor4],
+            helper_var(const(0), ndim, tensor4),
+            self.prelude.l(self.get_var("tensor_t")()),
+            [],
+        )
 
     def define_tensor_array_unstack_tensor5(self):
         """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.
@@ -990,15 +1145,20 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-
-        helper_body = If(equal(i, up),
-                         self.prelude.nil(),
-                         self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)),
-                                           helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                self.get_var("tensor4")(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
+        )
 
         tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
         tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
@@ -1006,9 +1166,12 @@ class TensorArrayOps(object):
         tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
         shape = op.shape_of(tensor5)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[tensor_array_unstack_tensor5_var] =\
-            Function([tensor5], helper_var(const(0), ndim, tensor5),
-                     self.prelude.l(self.get_var('tensor_t')()), [])
+        self.prelude.mod[tensor_array_unstack_tensor5_var] = Function(
+            [tensor5],
+            helper_var(const(0), ndim, tensor5),
+            self.prelude.l(self.get_var("tensor_t")()),
+            [],
+        )
 
     def define_tensor_array_unstack_tensor6(self):
         """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.
@@ -1019,15 +1182,20 @@ class TensorArrayOps(object):
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
         tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
-        up = Var("up", scalar_type('int32'))
-        i = Var("i", scalar_type('int32'))
-
-        helper_body = If(equal(i, up),
-                         self.prelude.nil(),
-                         self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)),
-                                           helper_var(add(i, const(1)), up, tensor)))
-        self.prelude.mod[helper_var] =\
-            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
+        up = Var("up", scalar_type("int32"))
+        i = Var("i", scalar_type("int32"))
+
+        helper_body = If(
+            equal(i, up),
+            self.prelude.nil(),
+            self.prelude.cons(
+                self.get_var("tensor5")(op.take(tensor, i, axis=0)),
+                helper_var(add(i, const(1)), up, tensor),
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
+        )
 
         tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
         tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
@@ -1035,9 +1203,12 @@ class TensorArrayOps(object):
         tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
         shape = op.shape_of(tensor6)
         ndim = op.take(shape, const(0))
-        self.prelude.mod[tensor_array_unstack_tensor6_var] =\
-            Function([tensor6], helper_var(const(0), ndim, tensor6),
-                     self.prelude.l(self.get_var('tensor_t')()), [])
+        self.prelude.mod[tensor_array_unstack_tensor6_var] = Function(
+            [tensor6],
+            helper_var(const(0), ndim, tensor6),
+            self.prelude.l(self.get_var("tensor_t")()),
+            [],
+        )
 
     def define_tensor_array_scatter(self):
         """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
@@ -1046,88 +1217,94 @@ class TensorArrayOps(object):
         """
         tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
         tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name)
-        tensor_t = self.get_var('tensor_t')
+        tensor_t = self.get_var("tensor_t")
         ta = Var("ta", self.prelude.l(tensor_t()))
-        current = Var("current", scalar_type('int32'))
-        limit = Var("limit", scalar_type('int32'))
-        indices_ = Var('indices_', TensorType([Any()], 'int32'))
-        values_ = Var('values_', self.prelude.l(tensor_t()))
-        write_var = self.get_var('tensor_array_write')
-        read_var = self.get_var('tensor_array_read')
-        helper_body = If(equal(current, limit),
-                         ta,
-                         tensor_array_scatter_helper_var(
-                             write_var(ta, op.take(indices_, current),
-                                       read_var(values_, current)),
-                             add(current, const(1)),
-                             limit, indices_, values_))
-        self.prelude.mod[tensor_array_scatter_helper_var] =\
-            Function([ta, current, limit, indices_, values_],
-                     helper_body, self.prelude.l(tensor_t()), [])
+        current = Var("current", scalar_type("int32"))
+        limit = Var("limit", scalar_type("int32"))
+        indices_ = Var("indices_", TensorType([Any()], "int32"))
+        values_ = Var("values_", self.prelude.l(tensor_t()))
+        write_var = self.get_var("tensor_array_write")
+        read_var = self.get_var("tensor_array_read")
+        helper_body = If(
+            equal(current, limit),
+            ta,
+            tensor_array_scatter_helper_var(
+                write_var(ta, op.take(indices_, current), read_var(values_, current)),
+                add(current, const(1)),
+                limit,
+                indices_,
+                values_,
+            ),
+        )
+        self.prelude.mod[tensor_array_scatter_helper_var] = Function(
+            [ta, current, limit, indices_, values_], helper_body, self.prelude.l(tensor_t()), []
+        )
         tensor_array_scatter_name = self.get_name("tensor_array_scatter")
         tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name)
         setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
         tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
-        indices = Var('indices', TensorType([Any()], 'int32'))
-        values = Var('values', self.prelude.l(tensor_t()))
+        indices = Var("indices", TensorType([Any()], "int32"))
+        values = Var("values", self.prelude.l(tensor_t()))
         indices_shape = op.shape_of(indices)
         limit = op.take(indices_shape, const(0))
         body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
-        self.prelude.mod[tensor_array_scatter_var] =\
-            Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), [])
+        self.prelude.mod[tensor_array_scatter_var] = Function(
+            [tensor_array, indices, values], body, self.prelude.l(tensor_t()), []
+        )
 
     def define_tensor_array_split(self):
         """Defines a function to split the values of a tensor_t into a tensor array.
         tensor_array_split(ta, value, lengths) :
             list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
         """
-        tensor_t = self.get_var('tensor_t')
+        tensor_t = self.get_var("tensor_t")
         tensor_array_split_helper_name = self.get_name("ta_split_helper")
         tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name)
         setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
         ta1 = Var("tensor_array", self.prelude.l(tensor_t()))
-        value1 = Var('value1', tensor_t())
-        offset1 = Var('offset1', scalar_type('int32'))
-        current1 = Var('current1', scalar_type('int32'))
-        limit1 = Var('limit1', scalar_type('int32'))
-        lengths1 = Var('lengths', TensorType([Any()], 'int32'))
-        write_var = self.get_var('tensor_array_write')
-        take_var = self.get_var('tensor_take')
-        helper1_body = If(equal(current1, limit1),
-                          ta1,
-                          write_var(
-                              tensor_array_split_helper_var(
-                                  ta1,
-                                  value1,
-                                  add(offset1, op.take(lengths1, current1)),
-                                  add(current1, const(1)),
-                                  limit1,
-                                  lengths1
-                              ),
-                              current1,
-                              take_var(value1,
-                                       offset1,
-                                       add(op.take(lengths1, current1), offset1))))
-        self.prelude.mod[tensor_array_split_helper_var] = \
-            Function([ta1, value1, offset1, current1, limit1, lengths1],
-                     helper1_body, self.prelude.l(tensor_t()), [])
+        value1 = Var("value1", tensor_t())
+        offset1 = Var("offset1", scalar_type("int32"))
+        current1 = Var("current1", scalar_type("int32"))
+        limit1 = Var("limit1", scalar_type("int32"))
+        lengths1 = Var("lengths", TensorType([Any()], "int32"))
+        write_var = self.get_var("tensor_array_write")
+        take_var = self.get_var("tensor_take")
+        helper1_body = If(
+            equal(current1, limit1),
+            ta1,
+            write_var(
+                tensor_array_split_helper_var(
+                    ta1,
+                    value1,
+                    add(offset1, op.take(lengths1, current1)),
+                    add(current1, const(1)),
+                    limit1,
+                    lengths1,
+                ),
+                current1,
+                take_var(value1, offset1, add(op.take(lengths1, current1), offset1)),
+            ),
+        )
+        self.prelude.mod[tensor_array_split_helper_var] = Function(
+            [ta1, value1, offset1, current1, limit1, lengths1],
+            helper1_body,
+            self.prelude.l(tensor_t()),
+            [],
+        )
         split_name = self.get_name("tensor_array_split")
         split_var = GlobalVar(split_name)
         setattr(self.prelude, split_name, split_var)
         tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
-        value = Var('value', tensor_t())
-        lengths = Var('lengths', TensorType([Any()], 'int32'))
+        value = Var("value", tensor_t())
+        lengths = Var("lengths", TensorType([Any()], "int32"))
         lengths_shape = op.shape_of(lengths)
         lengths_limit = op.take(lengths_shape, const(0))
         body = tensor_array_split_helper_var(
-            tensor_array,
-            value,
-            const(0),
-            const(0),
-            lengths_limit,
-            lengths)
-        self.prelude.mod[split_var] =\
-            Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [])
+            tensor_array, value, const(0), const(0), lengths_limit, lengths
+        )
+        self.prelude.mod[split_var] = Function(
+            [tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []
+        )
 
     def define_tensor_array_concat(self):
         """Defines a function to return the values in the tensor array as concatenated tensor_t.
@@ -1136,22 +1313,27 @@ class TensorArrayOps(object):
         concat_name = self.get_name("tensor_array_concat")
         concat_var = GlobalVar(concat_name)
         setattr(self.prelude, concat_name, concat_var)
-        tensor_concat_var = self.get_var('tensor_concatenate')
-        tensor_t = self.get_var('tensor_t')
-        tensor_nil_var = self.get_var('tensor_nil')
+        tensor_concat_var = self.get_var("tensor_concatenate")
+        tensor_t = self.get_var("tensor_t")
+        tensor_nil_var = self.get_var("tensor_nil")
         tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
         hd = Var("hd")
         tl = Var("tl")
         nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
-        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
-                           Match(tl, [
-                               Clause(PatternConstructor(self.prelude.nil), hd),
-                               Clause(PatternWildcard(),
-                                      tensor_concat_var(hd, concat_var(tl)))
-                           ], False))
-        self.prelude.mod[concat_var] =\
-            Function([tensor_array],
-                     Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [])
+        cons_case = Clause(
+            PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
+            Match(
+                tl,
+                [
+                    Clause(PatternConstructor(self.prelude.nil), hd),
+                    Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))),
+                ],
+                False,
+            ),
+        )
+        self.prelude.mod[concat_var] = Function(
+            [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []
+        )
 
     def define_tensor_array_gather(self):
         """Defines a function to return the selected values in a tensor array as tensor_t.
@@ -1160,36 +1342,41 @@ class TensorArrayOps(object):
         helper_name = self.get_name("tensor_array_gather_helper")
         helper_var = GlobalVar(helper_name)
         setattr(self.prelude, helper_name, helper_var)
-        tensor_type_var = self.get_var('tensor_t')
-        stack_var = self.get_var('tensor_array_stack')
-        read_var = self.get_var('tensor_array_read')
+        tensor_type_var = self.get_var("tensor_t")
+        stack_var = self.get_var("tensor_array_stack")
+        read_var = self.get_var("tensor_array_read")
         ta = Var("ta", self.prelude.l(tensor_type_var()))
         accu = Var("accu", self.prelude.l(tensor_type_var()))
-        current = Var("current", scalar_type('int32'))
-        limit = Var("limit", scalar_type('int32'))
-        indices_ = Var('indices_', TensorType([Any()], 'int32'))
-        helper_body = \
-            If(equal(current, const(0)),
-               stack_var(accu),
-               helper_var(
-                   ta,
-                   self.prelude.cons(
-                       read_var(
-                           ta, op.take(indices_, subtract(current, const(1)))), accu),
-                   subtract(current, const(1)),
-                   limit, indices_))
-        self.prelude.mod[helper_var] = \
-            Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [])
+        current = Var("current", scalar_type("int32"))
+        limit = Var("limit", scalar_type("int32"))
+        indices_ = Var("indices_", TensorType([Any()], "int32"))
+        helper_body = If(
+            equal(current, const(0)),
+            stack_var(accu),
+            helper_var(
+                ta,
+                self.prelude.cons(
+                    read_var(ta, op.take(indices_, subtract(current, const(1)))), accu
+                ),
+                subtract(current, const(1)),
+                limit,
+                indices_,
+            ),
+        )
+        self.prelude.mod[helper_var] = Function(
+            [ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []
+        )
         gather_name = self.get_name("tensor_array_gather")
         gather_var = GlobalVar(gather_name)
         setattr(self.prelude, gather_name, gather_var)
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        indices = Var('indices', TensorType([Any()], 'int32'))
+        indices = Var("indices", TensorType([Any()], "int32"))
         indices_shape = op.shape_of(indices)
         limit = op.take(indices_shape, const(0))
         body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
-        self.prelude.mod[gather_var] = \
-            Function([tensor_array, indices], body, tensor_type_var(), [])
+        self.prelude.mod[gather_var] = Function(
+            [tensor_array, indices], body, tensor_type_var(), []
+        )
 
     def define_tensor_array_stack(self):
         """Defines a function to get the values in the tensor array as a stack tensor_t.
@@ -1198,16 +1385,19 @@ class TensorArrayOps(object):
         stack_name = self.get_name("tensor_array_stack")
         stack_var = GlobalVar(stack_name)
         setattr(self.prelude, stack_name, stack_var)
-        tensor_type_var = self.get_var('tensor_t')
+        tensor_type_var = self.get_var("tensor_t")
         tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
-        expand_dims_var = self.get_var('tensor_expand_dims')
-        concat_var = self.get_var('tensor_concatenate')
+        expand_dims_var = self.get_var("tensor_expand_dims")
+        concat_var = self.get_var("tensor_concatenate")
         tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
-        tensors = self.prelude.foldl(concat_var,
-                                     self.prelude.hd(tensor_array_expand_dims),
-                                     self.prelude.tl(tensor_array_expand_dims))
-        self.prelude.mod[stack_var] = \
-            Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [])
+        tensors = self.prelude.foldl(
+            concat_var,
+            self.prelude.hd(tensor_array_expand_dims),
+            self.prelude.tl(tensor_array_expand_dims),
+        )
+        self.prelude.mod[stack_var] = Function(
+            [tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), []
+        )
 
     def register(self):
         """Register all tensor array ops in Prelude"""
@@ -1231,6 +1421,7 @@ class TensorArrayOps(object):
         # TODO(wweic): Gather fails in PartialEvaluate
         # self.define_tensor_array_gather()
 
+
 class Prelude:
     """Contains standard definitions."""
 
@@ -1242,8 +1433,8 @@ class Prelude:
 
     def get_name(self, canonical, dtype):
         """Get name corresponding to the canonical name"""
-        if canonical == 'tensor_t':
-            return 'tensor_{}_t'.format(dtype)
+        if canonical == "tensor_t":
+            return "tensor_{}_t".format(dtype)
         return "{}_{}".format(canonical, dtype)
 
     def get_var(self, canonical, dtype):
@@ -1308,14 +1499,16 @@ class Prelude:
         for global_def in GLOBAL_DEFS:
             setattr(self, global_def, self.mod.get_global_var(global_def))
 
-        for dtype in ['float32',
-                      'float16',
-                      'float64',
-                      'int32',
-                      'uint8',
-                      'int8',
-                      'int16',
-                      'uint16',
-                      'int64']:
+        for dtype in [
+            "float32",
+            "float16",
+            "float64",
+            "int32",
+            "uint8",
+            "int8",
+            "int16",
+            "uint16",
+            "int64",
+        ]:
             tensor_array_ops = TensorArrayOps(self, dtype)
             tensor_array_ops.register()
index 3d71438..4105172 100644 (file)
@@ -46,29 +46,34 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
     """
     # pylint: disable=import-outside-toplevel
     from tvm import relay
+
     assert len(desired_layouts) == 2, "A desired layout is expected for both of qnn.conv2d's inputs"
     desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
     assert desired_data_layout != "default", "Data layout cannot be default"
 
     new_attrs = dict(attrs)
-    new_attrs['data_layout'] = desired_data_layout
+    new_attrs["data_layout"] = desired_data_layout
 
     if desired_kernel_layout != "default":
-        new_attrs['kernel_layout'] = desired_kernel_layout
+        new_attrs["kernel_layout"] = desired_kernel_layout
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
 
-    if desired_data_layout == 'NCHW':
-        new_attrs['kernel_layout'] = 'OIHW'
+    if desired_data_layout == "NCHW":
+        new_attrs["kernel_layout"] = "OIHW"
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
-    if desired_data_layout == 'NHWC':
+    if desired_data_layout == "NHWC":
         # Check for depthwise convolution.
         data_info, weight_info = tinfos
-        if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
-                               weight_info.shape, attrs['kernel_layout'],
-                               attrs['groups']):
-            new_attrs['kernel_layout'] = 'HWOI'
+        if is_depthwise_conv2d(
+            data_info.shape,
+            attrs["data_layout"],
+            weight_info.shape,
+            attrs["kernel_layout"],
+            attrs["groups"],
+        ):
+            new_attrs["kernel_layout"] = "HWOI"
         else:
-            new_attrs['kernel_layout'] = 'HWIO'
+            new_attrs["kernel_layout"] = "HWIO"
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
 
-    raise ValueError('Layout %s is not yet supported' % desired_data_layout)
+    raise ValueError("Layout %s is not yet supported" % desired_data_layout)
index 62bee30..50e5a02 100644 (file)
@@ -31,11 +31,13 @@ from .. import op as reg
 def legalize_qnn_conv2d(attrs, inputs, types):
     return qnn_conv2d_legalize(attrs, inputs, types)
 
+
 # Registering QNN dense legalization function.
 @reg.register_qnn_legalize("qnn.dense")
 def legalize_qnn_dense(attrs, inputs, types):
     return qnn_dense_legalize(attrs, inputs, types)
 
+
 # Default to None. If overridden by target, this will not be run.
 # Generic QNN Conv2D legalization function.
 @tvm.target.generic_func
@@ -43,28 +45,34 @@ def qnn_conv2d_legalize(attrs, inputs, types):
     """Default legalization is None."""
     return None
 
+
 # Generic QNN Conv2D legalization function.
 @tvm.target.generic_func
 def qnn_dense_legalize(attrs, inputs, types):
     """Default legalization is None."""
     return None
 
+
 ###################
 # Helper functions.
 ###################
 
+
 def get_scalar_from_constant(expr):
     """ Returns scalar value from Relay constant scalar. """
-    assert isinstance(expr, relay.Constant) and not expr.data.shape, \
-        "Expr is not a constant scalar."
+    assert (
+        isinstance(expr, relay.Constant) and not expr.data.shape
+    ), "Expr is not a constant scalar."
     value = expr.data.asnumpy()
-    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
-        "value must be float32/int32"
+    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(
+        np.float32
+    ), "value must be float32/int32"
     return np.asscalar(value)
 
+
 # Helper function for lowering in the abscence of fast Int8 arithmetic units.
 def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
-    """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
+    """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
     not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
     much more efficiently if the convolution or dense operator input datatypes are int16 instead of
     int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
@@ -87,13 +95,16 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
     # Collect the input exprs.
     data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs
 
-    shift_data = relay.subtract(relay.cast(data, dtype='int16'),
-                                relay.cast(input_zero_point, 'int16'))
-    shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
-                                  relay.cast(kernel_zero_point, 'int16'))
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    shift_data = relay.subtract(
+        relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16")
+    )
+    shift_kernel = relay.subtract(
+        relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16")
+    )
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
     return relay_op(shift_data, shift_kernel, **new_attrs)
 
+
 # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
 def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
     """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
@@ -129,17 +140,17 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
 
     def _shift(data, zero_point, out_dtype):
         """Shifts (add/subtracts) the qnn tensor with +/-128)"""
-        if out_dtype == 'uint8':
+        if out_dtype == "uint8":
             shift = 128
-        elif out_dtype == 'int8':
+        elif out_dtype == "int8":
             shift = -128
         else:
             raise ValueError("Unsupported out dtype.")
-        data_modified = relay.cast(data, 'int32')
-        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data, "int32")
+        data_modified = relay.add(data_modified, relay.const(shift, "int32"))
         data_modified = relay.cast(data_modified, out_dtype)
         zero_point_val = get_scalar_from_constant(zero_point)
-        zero_point_modified = relay.const(zero_point_val + shift, 'int32')
+        zero_point_modified = relay.const(zero_point_val + shift, "int32")
         return (data_modified, zero_point_modified)
 
     # Collect the dtypes.
@@ -150,28 +161,29 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
     data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
 
     # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
-    if data_dtype == 'uint8' and kernel_dtype == 'int8':
+    if data_dtype == "uint8" and kernel_dtype == "int8":
         return None
 
     # Shift input if necessary.
-    if data_dtype == 'int8':
+    if data_dtype == "int8":
         # Compute (QA + 128) and (zp_a + 128)
-        data, input_zero_point = _shift(data, input_zero_point, 'uint8')
+        data, input_zero_point = _shift(data, input_zero_point, "uint8")
 
     # Shift kernel if necessary.
-    if kernel_dtype == 'uint8':
+    if kernel_dtype == "uint8":
         # Compute (QA - 128) and (zp_a - 128)
-        kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, 'int8')
+        kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8")
 
     # Call qnn.conv2d with modified inputs and zero points.
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
-    return relay_op(data, kernel,
-                    input_zero_point, kernel_zero_point,
-                    input_scale, kernel_scale, **new_attrs)
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+    return relay_op(
+        data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs
+    )
+
 
 # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
 def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
-    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
     many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
     conv2d/dense such that both the dtypes are same.
 
@@ -192,17 +204,17 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
 
     def _shift(data, zero_point, out_dtype):
         """Shifts (adds/subtracts) the qnn tensor by 128)"""
-        if out_dtype == 'uint8':
+        if out_dtype == "uint8":
             shift = 128
-        elif out_dtype == 'int8':
+        elif out_dtype == "int8":
             shift = -128
         else:
             raise ValueError("Unsupported out dtype.")
-        data_modified = relay.cast(data, 'int32')
-        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data, "int32")
+        data_modified = relay.add(data_modified, relay.const(shift, "int32"))
         data_modified = relay.cast(data_modified, out_dtype)
         zero_point_val = get_scalar_from_constant(zero_point)
-        zero_point_modified = relay.const(zero_point_val + shift, 'int32')
+        zero_point_modified = relay.const(zero_point_val + shift, "int32")
         return (data_modified, zero_point_modified)
 
     # Collect the dtypes.
@@ -215,85 +227,99 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
     # Collect the input exprs.
     data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
 
-    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
-            "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
+    assert (
+        "int8" in data_dtype and "int8" in kernel_dtype
+    ), "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
 
     # Shift input if necessary.
     data, input_zero_point = _shift(data, input_zero_point, kernel_dtype)
 
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
-    return relay_op(data, kernel,
-                    input_zero_point, kernel_zero_point,
-                    input_scale, kernel_scale, **new_attrs)
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+    return relay_op(
+        data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs
+    )
+
 
 def is_fast_int8_on_intel():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
     target = tvm.target.Target.current(allow_none=False)
-    return target.mcpu in {'skylake-avx512', 'cascadelake'}
+    return target.mcpu in {"skylake-avx512", "cascadelake"}
+
 
 def is_fast_int8_on_arm():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
     target = tvm.target.Target.current(allow_none=False)
     return "+v8.2a" in target.mattr and "+dotprod" in target.mattr
 
+
 def is_aarch64_arm():
     """ Checks whether we are compiling for an AArch64 target. """
     target = tvm.target.Target.current(allow_none=False)
-    return 'aarch64' in target.attrs.get("mtriple", "")
+    return "aarch64" in target.attrs.get("mtriple", "")
+
 
 ########################
 # ARM CPU legalizations.
 ########################
 
-@qnn_conv2d_legalize.register('arm_cpu')
+
+@qnn_conv2d_legalize.register("arm_cpu")
 def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
     # ARM prefers the dtypes to be same.
-    is_depthwise = relay.op.strategy.is_depthwise_conv2d(types[0].shape,
-                                                         attrs['data_layout'],
-                                                         types[1].shape,
-                                                         attrs['kernel_layout'],
-                                                         attrs['groups'])
+    is_depthwise = relay.op.strategy.is_depthwise_conv2d(
+        types[0].shape,
+        attrs["data_layout"],
+        types[1].shape,
+        attrs["kernel_layout"],
+        attrs["groups"],
+    )
     use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC"
     if use_int8_on_arm or is_fast_int8_on_arm():
         return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
 
 
-@qnn_dense_legalize.register('arm_cpu')
+@qnn_dense_legalize.register("arm_cpu")
 def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
     # ARM prefers the dtypes to be same.
     if is_fast_int8_on_arm():
         return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
 
+
 ##########################
 # Intel CPU legalizations.
 ##########################
 
-@qnn_conv2d_legalize.register('cpu')
+
+@qnn_conv2d_legalize.register("cpu")
 def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
     # The VNNI transformations prefer uint8 x int8 datatypes.
     if is_fast_int8_on_intel():
         return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
 
-@qnn_dense_legalize.register('cpu')
+
+@qnn_dense_legalize.register("cpu")
 def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
     # The VNNI transformations prefer uint8 x int8 datatypes.
     if is_fast_int8_on_intel():
         return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
 
+
 #####################
 # CUDA legalizations.
 #####################
 
-@qnn_conv2d_legalize.register('cuda')
+
+@qnn_conv2d_legalize.register("cuda")
 def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
     # CUDA prefers the dtypes to be same.
     return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
 
-@qnn_dense_legalize.register('cuda')
+
+@qnn_dense_legalize.register("cuda")
 def _qnn_dense_legalize_cuda(attrs, inputs, types):
     # CUDA prefers the dtypes to be same.
     return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
index 720bac4..32a6122 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument
+# pylint: disable=unused-argument
 """The register functions for the QNN dialect."""
 import tvm.ir
 
+
 def register_qnn_legalize(op_name, legal_op=None, level=10):
     """Register legal transformation function for a QNN op
 
index 14d74bf..e4a6cbf 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name
+# pylint: disable=invalid-name
 """QNN dialect operators."""
 
 from __future__ import absolute_import as _abs
@@ -22,14 +22,17 @@ from tvm.relay.expr import Tuple, TupleWrapper
 from tvm.relay.op.nn.util import get_pad_tuple2d
 from . import _make
 
-def requantize(data,
-               input_scale,
-               input_zero_point,
-               output_scale,
-               output_zero_point,
-               axis=-1,
-               rounding="UPWARD",
-               out_dtype="int8"):
+
+def requantize(
+    data,
+    input_scale,
+    input_zero_point,
+    output_scale,
+    output_zero_point,
+    axis=-1,
+    rounding="UPWARD",
+    out_dtype="int8",
+):
     r"""Requantized operator.
 
     The requantize operator converts one quantized tensor representation to
@@ -71,22 +74,20 @@ def requantize(data,
         The computed result.
     """
 
-    return _make.requantize(data,
-                            input_scale,
-                            input_zero_point,
-                            output_scale,
-                            output_zero_point,
-                            axis,
-                            rounding,
-                            out_dtype)
-
-
-def quantize(data,
-             output_scale,
-             output_zero_point,
-             axis=-1,
-             out_dtype='int8'):
-    r""" Quantize op
+    return _make.requantize(
+        data,
+        input_scale,
+        input_zero_point,
+        output_scale,
+        output_zero_point,
+        axis,
+        rounding,
+        out_dtype,
+    )
+
+
+def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
+    r"""Quantize op
     This operator takes float32 as input and produces quantized int8 or unit8 as output.
     The input tensor can be of any shape. The output shape is the same as input shape.
 
@@ -112,18 +113,11 @@ def quantize(data,
         The computed result.
     """
 
-    return _make.quantize(data,
-                          output_scale,
-                          output_zero_point,
-                          axis,
-                          out_dtype)
+    return _make.quantize(data, output_scale, output_zero_point, axis, out_dtype)
 
 
-def dequantize(data,
-               input_scale,
-               input_zero_point,
-               axis=-1):
-    r""" Dequantize op
+def dequantize(data, input_scale, input_zero_point, axis=-1):
+    r"""Dequantize op
     This operator takes quantized int8 and unit8 as input and produces
     dequantized float32 as output. The output shape is the same as input shape. The input
     tensor can be of any shape.
@@ -144,18 +138,10 @@ def dequantize(data,
         The computed result.
     """
 
-    return _make.dequantize(data,
-                            input_scale,
-                            input_zero_point,
-                            axis)
+    return _make.dequantize(data, input_scale, input_zero_point, axis)
 
 
-def concatenate(data,
-                input_scales,
-                input_zero_points,
-                output_scale,
-                output_zero_point,
-                axis):
+def concatenate(data, input_scales, input_zero_points, output_scale, output_zero_point, axis):
     """Concatenate the quantized input tensors along the given axis.
 
     Parameters
@@ -193,30 +179,29 @@ def concatenate(data,
     input_scales = list(input_scales)
     input_zero_points = list(input_zero_points)
 
-    return _make.concatenate(data,
-                             Tuple(input_scales),
-                             Tuple(input_zero_points),
-                             output_scale,
-                             output_zero_point,
-                             axis)
-
-
-def conv2d(data,
-           kernel,
-           input_zero_point,
-           kernel_zero_point,
-           input_scale,
-           kernel_scale,
-           kernel_size,
-           channels,
-           strides=(1, 1),
-           padding=(0, 0),
-           dilation=(1, 1),
-           groups=1,
-           data_layout="NCHW",
-           kernel_layout="OIHW",
-           out_layout="",
-           out_dtype="int32"):
+    return _make.concatenate(
+        data, Tuple(input_scales), Tuple(input_zero_points), output_scale, output_zero_point, axis
+    )
+
+
+def conv2d(
+    data,
+    kernel,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    kernel_size,
+    channels,
+    strides=(1, 1),
+    padding=(0, 0),
+    dilation=(1, 1),
+    groups=1,
+    data_layout="NCHW",
+    kernel_layout="OIHW",
+    out_layout="",
+    out_dtype="int32",
+):
     r"""Quantized 2D convolution.
 
     This operator convolves quantized data with quantized kernel. The scale of
@@ -289,22 +274,29 @@ def conv2d(data,
     # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-    return _make.conv2d(data, kernel,
-                        input_zero_point, kernel_zero_point,
-                        input_scale, kernel_scale,
-                        strides, padding, dilation,
-                        groups, channels, kernel_size,
-                        data_layout, kernel_layout, out_layout, out_dtype)
-
-
-def add(lhs,
-        rhs,
-        lhs_scale,
-        lhs_zero_point,
-        rhs_scale,
-        rhs_zero_point,
-        output_scale,
-        output_zero_point):
+    return _make.conv2d(
+        data,
+        kernel,
+        input_zero_point,
+        kernel_zero_point,
+        input_scale,
+        kernel_scale,
+        strides,
+        padding,
+        dilation,
+        groups,
+        channels,
+        kernel_size,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def add(
+    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
+):
     """Quantized addition with numpy-style broadcasting.
 
     Parameters
@@ -339,20 +331,28 @@ def add(lhs,
         The computed result.
 
     """
-    return _make.add(lhs, rhs,
-                     lhs_scale, lhs_zero_point,
-                     rhs_scale, rhs_zero_point,
-                     output_scale, output_zero_point)
-
-
-def dense(data,
-          weight,
-          input_zero_point,
-          kernel_zero_point,
-          input_scale,
-          kernel_scale,
-          units,
-          out_dtype="int32"):
+    return _make.add(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+    )
+
+
+def dense(
+    data,
+    weight,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    units,
+    out_dtype="int32",
+):
     """Qnn Dense operator.
     Applies a quantized linear transformation
 
@@ -388,18 +388,21 @@ def dense(data,
         The computed result.
     """
 
-    return _make.dense(data,
-                       weight,
-                       input_zero_point,
-                       kernel_zero_point,
-                       input_scale,
-                       kernel_scale,
-                       units,
-                       out_dtype)
-
-
-def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point,
-        output_scale, output_zero_point):
+    return _make.dense(
+        data,
+        weight,
+        input_zero_point,
+        kernel_zero_point,
+        input_scale,
+        kernel_scale,
+        units,
+        out_dtype,
+    )
+
+
+def mul(
+    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
+):
     """Quantized multiplication with numpy-style broadcasting.
 
     Parameters
@@ -434,20 +437,21 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point,
         The computed result.
 
     """
-    return _make.mul(lhs, rhs,
-                     lhs_scale, lhs_zero_point,
-                     rhs_scale, rhs_zero_point,
-                     output_scale, output_zero_point)
-
-
-def subtract(lhs,
-             rhs,
-             lhs_scale,
-             lhs_zero_point,
-             rhs_scale,
-             rhs_zero_point,
-             output_scale,
-             output_zero_point):
+    return _make.mul(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+    )
+
+
+def subtract(
+    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
+):
     """Quantized subtraction with numpy-style broadcasting.
 
     Parameters
@@ -482,7 +486,13 @@ def subtract(lhs,
         The computed result.
 
     """
-    return _make.subtract(lhs, rhs,
-                          lhs_scale, lhs_zero_point,
-                          rhs_scale, rhs_zero_point,
-                          output_scale, output_zero_point)
+    return _make.subtract(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+    )
index 492c739..0485cec 100644 (file)
@@ -20,6 +20,7 @@ QNN pass transformation infrastructure.
 """
 from tvm import relay
 
+
 def CanonicalizeOps():
     """Converts/Lowers an expression containing QNN ops to an expression containing only core
     (non-Dialect) Relay ops. Each QNN op is lowered to a sequence of existing Relay ops. This is a
index 09dfa8f..428c6e9 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=wildcard-import, redefined-builtin
+# pylint: disable=wildcard-import, redefined-builtin
 """Automatic quantization utilities."""
 from __future__ import absolute_import as _abs
 
index 0bccacd..329ba64 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument,inconsistent-return-statements
+# pylint: disable=unused-argument,inconsistent-return-statements
 """Internal module for registering attribute for annotation."""
 import warnings
 from tvm import topi
@@ -51,8 +51,7 @@ def simulated_quantize_compute(attrs, inputs, out_type):
 
 
 _reg.register_injective_schedule("relay.op.annotation.simulated_quantize")
-_reg.register_pattern("relay.op.annotation.simulated_quantize",
-                      _reg.OpPattern.ELEMWISE)
+_reg.register_pattern("relay.op.annotation.simulated_quantize", _reg.OpPattern.ELEMWISE)
 _reg.register_injective_schedule("annotation.cast_hint")
 
 
@@ -68,9 +67,9 @@ class QAnnotateExpr(_expr.TempExpr):
     kind: QAnnotateKind
         the kind of annotation field.
     """
+
     def __init__(self, expr, kind):
-        self.__init_handle_by_constructor__(
-            _quantize.make_annotate_expr, expr, kind)
+        self.__init_handle_by_constructor__(_quantize.make_annotate_expr, expr, kind)
 
 
 def _get_expr_kind(anno):
@@ -94,6 +93,7 @@ def register_annotate_function(op_name, frewrite=None, level=10):
     level : int, optional
         The priority level
     """
+
     def default_rewrite(ref_call, new_args, ctx):
         # recover from QAnnotateExpr
         args = [_get_expr_kind(x)[0] for x in new_args]
@@ -101,6 +101,7 @@ def register_annotate_function(op_name, frewrite=None, level=10):
 
     def _register(func):
         """internal register function"""
+
         def frewrite_with_guard(ref_call, new_args, ctx):
             if not current_qconfig().guard(ref_call):
                 return default_rewrite(ref_call, new_args, ctx)
@@ -135,20 +136,21 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
     dom_scale = _expr.var("dom_scale")
     clip_min = _expr.var("clip_min")
     clip_max = _expr.var("clip_max")
-    qnode = _quantize.simulated_quantize(
-        data, dom_scale, clip_min, clip_max, kind, sign, rounding)
+    qnode = _quantize.simulated_quantize(data, dom_scale, clip_min, clip_max, kind, sign, rounding)
     qctx.qnode_map[key] = qnode
     return qnode
 
-tvm._ffi.register_func(
-    "relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
+
+tvm._ffi.register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
 
 
 @register_annotate_function("nn.contrib_conv2d_NCHWc")
 def conv2d_nchwc_rewrite(ref_call, new_args, ctx):
-    warnings.warn("NCHWc layout Conv2D detected, please use a lower "
-                  "optimization level before applying the quantization "
-                  "pass as quantization will have no effect here...")
+    warnings.warn(
+        "NCHWc layout Conv2D detected, please use a lower "
+        "optimization level before applying the quantization "
+        "pass as quantization will have no effect here..."
+    )
 
 
 @register_annotate_function("nn.conv2d")
@@ -261,8 +263,9 @@ def add_rewrite(ref_call, new_args, ctx):
             rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
             expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
             return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
-        if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
-            (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
+        if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or (
+            lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION
+        ):
             expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
             return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
     raise ValueError()
index 8f553dd..c460dad 100644 (file)
@@ -34,14 +34,14 @@ from .kl_divergence import _find_scale_by_kl
 
 
 def _get_profile_runtime(mod):
-    func = mod['main']
+    func = mod["main"]
     func = _quantize.CreateStatsCollector(func)
 
     if tvm.target.Target.current():
         target = tvm.target.Target.current()
         ctx = tvm.context(target.kind.name)
     else:
-        target = 'llvm'
+        target = "llvm"
         ctx = tvm.context(target)
 
     with tvm.transform.PassContext(opt_level=3):
@@ -86,8 +86,8 @@ def collect_stats(mod, dataset, chunk_by=-1):
         for batch in dataset:
             runtime.set_input(**batch)
             runtime.run()
-            for j in range(i, min(i+chunk_by, num_outputs)):
-                outputs[j-i].append(runtime.get_output(j).asnumpy())
+            for j in range(i, min(i + chunk_by, num_outputs)):
+                outputs[j - i].append(runtime.get_output(j).asnumpy())
         yield [np.concatenate(output).reshape(-1) for output in outputs]
 
 
@@ -104,6 +104,7 @@ def _kl_scale(mod, dataset):
         scale = scales[func.scale_idx]
         func.scale_idx += 1
         return scale
+
     func.scale_idx = 0
 
     return func
@@ -115,7 +116,7 @@ def _set_params(mod, input_scale_func, weight_scale_func):
     const_params = {}
 
     def visit_func(expr):
-        '''visitor function for traverse'''
+        """visitor function for traverse"""
         if isinstance(expr, _expr.Call) and expr.op == quantize_op:
             _, ndom_scale, nclip_min, nclip_max = expr.args
             attrs = expr.attrs
@@ -131,19 +132,19 @@ def _set_params(mod, input_scale_func, weight_scale_func):
                 scale = input_scale_func(expr)
 
             def _make_const(val):
-                return _expr.const(val, 'float32')
+                return _expr.const(val, "float32")
 
-            valid_range = 2**valid_bit
+            valid_range = 2 ** valid_bit
             const_params[ndom_scale] = _make_const(scale / valid_range)
-            const_params[nclip_min] = _make_const(- (valid_range - 1))
+            const_params[nclip_min] = _make_const(-(valid_range - 1))
             const_params[nclip_max] = _make_const((valid_range - 1))
 
-    main_func = mod['main']
+    main_func = mod["main"]
     _analysis.post_order_visit(main_func, visit_func)
     main_func = _expr.bind(main_func, const_params)
     func_dict = {}
     for global_var, func in mod.functions.items():
-        if global_var.name_hint != 'main':
+        if global_var.name_hint != "main":
             func_dict[global_var] = func
     return IRModule.from_expr(main_func, func_dict)
 
@@ -154,7 +155,7 @@ def _power2_scale(sq_call):  # pylint: disable=unused-argument
     var = sq_call.args[0]
     assert isinstance(var, _expr.Constant)
     val = np.amax(np.abs(var.data.asnumpy()))
-    return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
+    return 2 ** np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
 
 
 def _max_scale(sq_call):
@@ -166,7 +167,7 @@ def _max_scale(sq_call):
 
 
 # input scale functions
-def _global_scale(sq_call): # pylint: disable=unused-argument
+def _global_scale(sq_call):  # pylint: disable=unused-argument
     cfg = quantize.current_qconfig()
     return cfg.global_scale
 
@@ -186,23 +187,25 @@ def calibrate(dataset=None):
     ret: Function
         The module pass function.
     """
+
     def wrapped_func(mod, _):
         """make transform.module pass happy"""
         cfg = quantize.current_qconfig()
 
-        if cfg.calibrate_mode == 'kl_divergence':
+        if cfg.calibrate_mode == "kl_divergence":
             input_scale_func = _kl_scale(mod, dataset)
-        elif cfg.calibrate_mode == 'global_scale':
+        elif cfg.calibrate_mode == "global_scale":
             input_scale_func = _global_scale
         else:
             raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))
 
-        if cfg.weight_scale == 'max':
+        if cfg.weight_scale == "max":
             weight_scale_func = _max_scale
-        elif cfg.weight_scale == 'power2':
+        elif cfg.weight_scale == "power2":
             weight_scale_func = _power2_scale
         else:
             raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale))
 
         return _set_params(mod, input_scale_func, weight_scale_func)
+
     return wrapped_func
index b72f51c..6892e86 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument,inconsistent-return-statements
+# pylint: disable=unused-argument,inconsistent-return-statements
 """Internal module for registering attribute for annotation."""
 import tvm
 from .. import expr as _expr
@@ -22,6 +22,7 @@ from .. import analysis as _analysis
 from . import _quantize
 from .quantize import _forward_op
 
+
 def register_partition_function(op_name, frewrite=None, level=10):
     return tvm.ir.register_op_attr(op_name, "FQPartitionRewrite", frewrite, level)
 
@@ -29,8 +30,7 @@ def register_partition_function(op_name, frewrite=None, level=10):
 @tvm._ffi.register_object("relay.QPartitionExpr")
 class QPartitionExpr(_expr.TempExpr):
     def __init__(self, expr):
-        self.__init_handle_by_constructor__(
-            _quantize.make_partition_expr, expr)
+        self.__init_handle_by_constructor__(_quantize.make_partition_expr, expr)
 
 
 def partition_expr_check(expr):
@@ -58,6 +58,7 @@ def identity_partition_function(ref_call, new_args, ctx):
         return QPartitionExpr(_forward_op(ref_call, [expr]))
     return None
 
+
 register_partition_function("clip", identity_partition_function)
 register_partition_function("nn.relu", identity_partition_function)
 register_partition_function("nn.max_pool2d", identity_partition_function)
@@ -121,6 +122,7 @@ def add_partition_generic(ref_call, new_args, ctx):
 
     raise ValueError
 
+
 def mul_partition_generic(ref_call, new_args, ctx):
     """Rewrite function for ewise mul for partition for generic devices"""
     lhs_cond, lhs = partition_expr_check(new_args[0])
@@ -143,8 +145,8 @@ def mul_partition_generic(ref_call, new_args, ctx):
 def add_partition_function(ref_call, new_args, ctx):
     """Rewrite function for ewise add for partition"""
     target = tvm.target.Target.current()
-    if target and 'cuda' in target.keys:
-        #TODO(wuwei/ziheng) cuda specific rules
+    if target and "cuda" in target.keys:
+        # TODO(wuwei/ziheng) cuda specific rules
         return add_partition_generic(ref_call, new_args, ctx)
     return add_partition_generic(ref_call, new_args, ctx)
 
index d1c3b59..166e864 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Utilities for partitioning input quantization and output dequantization expressions."""
 import tvm
 from tvm import relay
@@ -22,7 +22,8 @@ from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 
 # operators that are allowed in prefix/suffix partitions, because they are used
 # to quantize/dequantize
-ALLOWED_CONVERSION_OPS = ['add', 'multiply', 'right_shift', 'clip', 'round', 'cast']
+ALLOWED_CONVERSION_OPS = ["add", "multiply", "right_shift", "clip", "round", "cast"]
+
 
 def partition_conversions(mod, quantized_dtypes, ensure_fully_integral):
     """Partition mod into input quantization, core quantized inference, and output dequantization.
@@ -80,9 +81,9 @@ def partition_conversions(mod, quantized_dtypes, ensure_fully_integral):
     pre_mod, mid_mod = partition_prefix(mod, quantized_dtypes)
     mid_mod, post_mod = partition_suffix(mid_mod, quantized_dtypes)
     if ensure_fully_integral:
-        assert has_only_conversion_ops(pre_mod['main'])
-        assert relay.analysis.all_dtypes(mid_mod['main']).issubset(quantized_dtypes)
-        assert has_only_conversion_ops(post_mod['main'])
+        assert has_only_conversion_ops(pre_mod["main"])
+        assert relay.analysis.all_dtypes(mid_mod["main"]).issubset(quantized_dtypes)
+        assert has_only_conversion_ops(post_mod["main"])
     return fuse_partitions(pre_mod, mid_mod, post_mod)
 
 
@@ -109,33 +110,38 @@ def fuse_partitions(pre_mod, mid_mod, post_mod):
         Module containing the input quantization, core quantized inference,
         output dequantization, and full quantized inference functions
     """
-    pre_func = pre_mod['main']
-    mid_func = mid_mod['main']
-    post_func = post_mod['main']
+    pre_func = pre_mod["main"]
+    mid_func = mid_mod["main"]
+    post_func = post_mod["main"]
     # create a module containing the prefix, middle, and suffix partitions
-    fused_mod = tvm.IRModule(functions={
-        relay.GlobalVar('quantize_inputs'): pre_func,
-        relay.GlobalVar('quantized_main'): mid_func,
-        relay.GlobalVar('dequantize_outputs'): post_func,
-    })
+    fused_mod = tvm.IRModule(
+        functions={
+            relay.GlobalVar("quantize_inputs"): pre_func,
+            relay.GlobalVar("quantized_main"): mid_func,
+            relay.GlobalVar("dequantize_outputs"): post_func,
+        }
+    )
     # construct a `main` that strings together the partitions, such that its
     # behaviour is equivalent to `main` in an *unpartitioned* module
     scope_builder = relay.ScopeBuilder()
     fused_mod_main_params = [relay.Var(param.name_hint) for param in pre_func.params]
-    quantized_inputs = scope_builder.let('quantized_inputs', relay.Call(
-        fused_mod.get_global_var('quantize_inputs'),
-        fused_mod_main_params
-    ))
-    quantized_outputs = scope_builder.let('quantized_outputs', relay.Call(
-        fused_mod.get_global_var('quantized_main'),
-        [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))]
-    ))
-    dequantized_outputs = scope_builder.let('dequantized_outputs', relay.Call(
-        fused_mod.get_global_var('dequantize_outputs'),
-        [quantized_outputs]
-    ))
+    quantized_inputs = scope_builder.let(
+        "quantized_inputs",
+        relay.Call(fused_mod.get_global_var("quantize_inputs"), fused_mod_main_params),
+    )
+    quantized_outputs = scope_builder.let(
+        "quantized_outputs",
+        relay.Call(
+            fused_mod.get_global_var("quantized_main"),
+            [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))],
+        ),
+    )
+    dequantized_outputs = scope_builder.let(
+        "dequantized_outputs",
+        relay.Call(fused_mod.get_global_var("dequantize_outputs"), [quantized_outputs]),
+    )
     scope_builder.ret(dequantized_outputs)
-    fused_mod['main'] = relay.Function(fused_mod_main_params, scope_builder.get())
+    fused_mod["main"] = relay.Function(fused_mod_main_params, scope_builder.get())
     return fused_mod
 
 
@@ -162,7 +168,7 @@ class PrefixCutter(ExprMutator):
 
     def visit_call(self, call):
         # TODO(weberlo) use graph pattern matching?
-        if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS:
+        if not hasattr(call.op, "name") or call.op.name not in ALLOWED_CONVERSION_OPS:
             new_args = []
             for arg in call.args:
                 new_arg = self.visit(arg)
@@ -173,9 +179,7 @@ class PrefixCutter(ExprMutator):
                     param = next(iter(self.subtree_params))
                     pre_param = self.prefix_sb.let(param.name_hint, new_arg)
                     self.subtree_params.clear()
-                    mid_param = relay.Var(
-                        param.name_hint,
-                        arg.checked_type)
+                    mid_param = relay.Var(param.name_hint, arg.checked_type)
                     self.prefix_binding_map[mid_param] = pre_param
                     # return new parameter, then we can use
                     # relay.analysis.free_vars at the end of the pass to generate
@@ -206,14 +210,12 @@ def partition_prefix(mod, quantized_dtypes):
         Module containing a function with everything except for input quantization
     """
     assert len(mod.functions) == 1
-    func = mod['main']
+    func = mod["main"]
     prefix_cutter = PrefixCutter(func.params, quantized_dtypes)
     mid_body = prefix_cutter.visit(func.body)
-    assert not func.type_params, 'unimplemented'
-    assert func.attrs is None, 'unimplemented'
-    mid_func = relay.Function(
-        relay.analysis.free_vars(mid_body),
-        mid_body)
+    assert not func.type_params, "unimplemented"
+    assert func.attrs is None, "unimplemented"
+    mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body)
     mid_mod = tvm.IRModule.from_expr(mid_func)
 
     scope_builder = prefix_cutter.prefix_sb
@@ -252,9 +254,9 @@ class SuffixCutter(ExprMutator):
         self.quantized_dtypes = quantized_dtypes
 
     def visit(self, expr):
-        if hasattr(expr, 'checked_type') and expr.checked_type.dtype in self.quantized_dtypes:
+        if hasattr(expr, "checked_type") and expr.checked_type.dtype in self.quantized_dtypes:
             self.mid_body = expr
-            return relay.Var('input', expr.checked_type)
+            return relay.Var("input", expr.checked_type)
 
         return super().visit(expr)
 
@@ -279,15 +281,12 @@ def partition_suffix(mod, quantized_dtypes):
         Module containing a function with everything except for input quantization
     """
     assert len(mod.functions) == 1
-    func = mod['main']
+    func = mod["main"]
     suffix_cutter = SuffixCutter(quantized_dtypes)
     post_body = suffix_cutter.visit(func.body)
-    assert not func.type_params, 'unimplemented'
-    assert func.attrs is None, 'unimplemented'
-    post_func = relay.Function(
-        relay.analysis.free_vars(post_body),
-        post_body,
-        func.ret_type)
+    assert not func.type_params, "unimplemented"
+    assert func.attrs is None, "unimplemented"
+    post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type)
     post_mod = tvm.IRModule.from_expr(post_func)
 
     mid_body = suffix_cutter.mid_body
@@ -296,15 +295,11 @@ def partition_suffix(mod, quantized_dtypes):
         # quantization boundary in the given mod.  In this case, we use the
         # suffix mod as the middle mod and make the suffix an identity function.
         mid_mod = post_mod
-        post_body = relay.Var('input', mid_mod['main'].ret_type)
-        post_func = relay.Function(
-            [post_body],
-            post_body)
+        post_body = relay.Var("input", mid_mod["main"].ret_type)
+        post_func = relay.Function([post_body], post_body)
         post_mod = tvm.IRModule.from_expr(post_func)
     else:
-        mid_func = relay.Function(
-            func.params,
-            mid_body)
+        mid_func = relay.Function(func.params, mid_body)
         mid_mod = tvm.IRModule.from_expr(mid_func)
 
     return mid_mod, post_mod
@@ -312,12 +307,13 @@ def partition_suffix(mod, quantized_dtypes):
 
 class ConversionOpChecker(ExprVisitor):
     """A pass for checking that the visited function contains only conversion ops"""
+
     def __init__(self):
         ExprVisitor.__init__(self)
         self.valid = True
 
     def visit_call(self, call):
-        if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS:
+        if not hasattr(call.op, "name") or call.op.name not in ALLOWED_CONVERSION_OPS:
             self.valid = False
         super().visit_call(call)
 
index 7b27b7a..70f8f17 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument
+# pylint: disable=unused-argument
 """Internal module for quantization."""
 import tvm._ffi
 
index 6492750..ca6c0b6 100644 (file)
@@ -22,8 +22,7 @@ import numpy as np
 from . import _quantize
 
 
-def _find_scale_by_kl(arr, quantized_dtype='int8',
-                      num_bins=8001, num_quantized_bins=255):
+def _find_scale_by_kl(arr, quantized_dtype="int8", num_bins=8001, num_quantized_bins=255):
     """Given a tensor, find the optimal threshold for quantizing it.
     The reference distribution is `q`, and the candidate distribution is `p`.
     `q` is a truncated version of the original distribution.
@@ -36,7 +35,7 @@ def _find_scale_by_kl(arr, quantized_dtype='int8',
     max_val = np.max(arr)
     thres = max(abs(min_val), abs(max_val))
 
-    if min_val >= 0 and quantized_dtype in ['uint8']:
+    if min_val >= 0 and quantized_dtype in ["uint8"]:
         # We need to move negative bins to positive bins to fit uint8 range.
         num_quantized_bins = num_quantized_bins * 2 + 1
 
@@ -48,5 +47,6 @@ def _find_scale_by_kl(arr, quantized_dtype='int8',
     hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int)
     hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float)
 
-    return _quantize.FindScaleByKLMinimization(hist_ptr, hist_edges_ptr,
-                                               num_bins, num_quantized_bins)
+    return _quantize.FindScaleByKLMinimization(
+        hist_ptr, hist_edges_ptr, num_bins, num_quantized_bins
+    )
index 8a8c82c..3d6870f 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument, not-context-manager
+# pylint: disable=unused-argument, not-context-manager
 """Automatic quantization toolkit."""
 import tvm.ir
 import tvm
@@ -30,6 +30,7 @@ from .. import transform as _transform
 class QAnnotateKind(object):
     """Denote the kind of annotation field, corresponding
     to different nbit configure."""
+
     IDENTITY = 0
     INPUT = 1
     WEIGHT = 2
@@ -42,7 +43,7 @@ def kind2str(kind):
         QAnnotateKind.INPUT: "input",
         QAnnotateKind.WEIGHT: "weight",
         QAnnotateKind.ACTIVATION: "activation",
-        QAnnotateKind.IDENTITY: "identity"
+        QAnnotateKind.IDENTITY: "identity",
     }
     assert kind in str_map
     return str_map[kind]
@@ -50,8 +51,7 @@ def kind2str(kind):
 
 def _forward_op(ref_call, args):
     """forward the operator of ref_call with provided arguments"""
-    return _expr.Call(
-        ref_call.op, args, ref_call.attrs, ref_call.type_args)
+    return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args)
 
 
 @tvm._ffi.register_object("relay.quantize.QConfig")
@@ -112,11 +112,11 @@ class QConfig(Object):
 
     def get_nbit_by_kind(self, kind):
         name = kind2str(kind)
-        return getattr(self, 'nbit_' + name)
+        return getattr(self, "nbit_" + name)
 
     def get_dtype_by_kind(self, kind):
         name = kind2str(kind)
-        return getattr(self, 'dtype_' + name)
+        return getattr(self, "dtype_" + name)
 
     def __enter__(self):
         # pylint: disable=protected-access
@@ -128,8 +128,7 @@ class QConfig(Object):
 
     def __setattr__(self, name, value):
         if name in QConfig._node_defaults:
-            raise AttributeError(
-                "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
+            raise AttributeError("'%s' object cannot set attribute '%s'" % (str(type(self)), name))
         return super(QConfig, self).__setattr__(name, value)
 
 
@@ -197,14 +196,14 @@ def qconfig(**kwargs):
     config: QConfig
         The quantization configuration
     """
-    node_args = {k: v if k not in kwargs else kwargs[k]
-                 for k, v in QConfig._node_defaults.items()}
+    node_args = {k: v if k not in kwargs else kwargs[k] for k, v in QConfig._node_defaults.items()}
     return tvm.ir.make_node("relay.quantize.QConfig", **node_args)
 
 
 class QuantizeContext(object):
     """An internal used global context object for annotation,
     for putting some state variables like `conv2d_counter`."""
+
     Current = None
 
     def __init__(self):
@@ -222,10 +221,10 @@ class QuantizeContext(object):
             # check skip conv layers
             skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
             if self._conv2d_counter in skipped_indices:
-                if ref_call.op.name == 'nn.conv2d':
+                if ref_call.op.name == "nn.conv2d":
                     self._conv2d_counter += 1
                 return True
-            if ref_call.op.name == 'nn.conv2d':
+            if ref_call.op.name == "nn.conv2d":
                 self._conv2d_counter += 1
 
         return False
@@ -292,8 +291,7 @@ def realize():
 
 
 def _bind_params(func, params):
-    """Bind the params to the expression.
-    """
+    """Bind the params to the expression."""
     name_dict = {}
     for arg in func.params:
         name = arg.name_hint
@@ -313,25 +311,28 @@ def _bind_params(func, params):
 
 
 def prerequisite_optimize(mod, params=None):
-    """ Prerequisite optimization passes for quantization. Perform
+    """Prerequisite optimization passes for quantization. Perform
     "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
-    "CanonicalizeOps" optimization before quantization. """
+    "CanonicalizeOps" optimization before quantization."""
     optimize = tvm.transform.Sequential(
-        [_transform.SimplifyInference(),
-         _transform.FoldConstant(),
-         _transform.FoldScaleAxis(),
-         _transform.CanonicalizeOps(),
-         _transform.FoldConstant()])
+        [
+            _transform.SimplifyInference(),
+            _transform.FoldConstant(),
+            _transform.FoldScaleAxis(),
+            _transform.CanonicalizeOps(),
+            _transform.FoldConstant(),
+        ]
+    )
 
     if params:
-        mod['main'] = _bind_params(mod['main'], params)
+        mod["main"] = _bind_params(mod["main"], params)
 
     mod = optimize(mod)
     return mod
 
 
 def quantize(mod, params=None, dataset=None):
-    """ The quantization procedure. Before running the three main
+    """The quantization procedure. Before running the three main
     procedure of quantization, "annotate", "calibrate" and "realize"
     , we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
     first for optimizing.
@@ -356,27 +357,24 @@ def quantize(mod, params=None, dataset=None):
     mod = prerequisite_optimize(mod, params)
 
     calibrate_pass = tvm.transform.module_pass(
-        calibrate(dataset), opt_level=1,
-        name="QuantizeCalibrate")
-    quant_passes = [partition(),
-                    annotate(),
-                    calibrate_pass]
+        calibrate(dataset), opt_level=1, name="QuantizeCalibrate"
+    )
+    quant_passes = [partition(), annotate(), calibrate_pass]
     if not current_qconfig().do_simulation:
         quant_passes.append(realize())
     quant_passes.append(_transform.FoldConstant())
     quantize_seq = tvm.transform.Sequential(quant_passes)
-    with tvm.transform.PassContext(opt_level=3,
-                                   required_pass=["QuantizeAnnotate",
-                                                  "QuantizeCalibrate",
-                                                  "QuantizeRealize"]):
+    with tvm.transform.PassContext(
+        opt_level=3, required_pass=["QuantizeAnnotate", "QuantizeCalibrate", "QuantizeRealize"]
+    ):
         with quantize_context():
             mod = quantize_seq(mod)
 
     q_cfg = current_qconfig()
-    assert q_cfg.partition_conversions in ['disabled', 'enabled', 'fully_integral']
-    if q_cfg.partition_conversions != 'disabled':
+    assert q_cfg.partition_conversions in ["disabled", "enabled", "fully_integral"]
+    if q_cfg.partition_conversions != "disabled":
         quantized_dtypes = {q_cfg.dtype_input, q_cfg.dtype_weight, q_cfg.dtype_activation}
-        ensure_fully_integral = q_cfg.partition_conversions == 'fully_integral'
+        ensure_fully_integral = q_cfg.partition_conversions == "fully_integral"
         return partition_conversions(mod, quantized_dtypes, ensure_fully_integral)
 
     return mod
index 86ff805..726b3c6 100644 (file)
@@ -22,6 +22,7 @@ from . import ty as _ty
 from . import expr as _expr
 from .._ffi import base as _base
 
+
 class WithScope(object):
     """A wrapper for builder methods which introduce scoping.
 
@@ -43,6 +44,7 @@ class WithScope(object):
             raise value
         self._exit_cb()
 
+
 def _make_lets(bindings, ret_value):
     """Make a nested let expressions.
 
@@ -93,6 +95,7 @@ class ScopeBuilder(object):
 
         print(sb.get().astext())
     """
+
     def __init__(self):
         self._bindings = [[]]
         self._ret_values = [None]
@@ -144,12 +147,14 @@ class ScopeBuilder(object):
         The user must follows with an else scope.
         """
         self._enter_scope()
+
         def _on_exit():
             bindings, ret_value = self._exit_scope()
             if self._ret_values[-1] is not None:
                 raise RuntimeError("result already returned before if scope")
             true_branch = _make_lets(bindings, ret_value)
             self._ret_values[-1] = _expr.If(cond, true_branch, None)
+
         return WithScope(None, _on_exit)
 
     def else_scope(self):
@@ -165,17 +170,13 @@ class ScopeBuilder(object):
         def _on_exit():
             bindings, ret_value = self._exit_scope()
             partial_if = self._ret_values[-1]
-            no_else = (not isinstance(partial_if, _expr.If) or
-                       partial_if.false_branch is not None)
+            no_else = not isinstance(partial_if, _expr.If) or partial_if.false_branch is not None
             if no_else:
                 raise RuntimeError("else scope must follows")
             false_branch = _make_lets(bindings, ret_value)
-            self._ret_values[-1] = _expr.If(
-                partial_if.cond,
-                partial_if.true_branch,
-                false_branch)
-        return WithScope(None, _on_exit)
+            self._ret_values[-1] = _expr.If(partial_if.cond, partial_if.true_branch, false_branch)
 
+        return WithScope(None, _on_exit)
 
     def type_of(self, expr):
         """
index 534015f..d9e4f1e 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name
+# pylint: disable=invalid-name
 """Utilities for testing and benchmarks"""
 from __future__ import absolute_import as _abs
 import collections
@@ -66,14 +66,9 @@ def _np_randn_from_type(t, scale=1, mean=0):
     return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
 
 
-def check_grad(func,
-               inputs=None,
-               test_inputs=None,
-               eps=1e-6,
-               atol=1e-5,
-               rtol=1e-3,
-               scale=None,
-               mean=0):
+def check_grad(
+    func, inputs=None, test_inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0
+):
     """Perform numerical gradient checking given a relay function.
 
     Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
@@ -167,10 +162,12 @@ def rand(dtype, *shape):
 
 def count_ops(expr):
     """count number of times a given op is called in the graph"""
+
     class OpCounter(tvm.relay.ExprVisitor):
         """OpCounter"""
+
         def visit_call(self, call):
-            if hasattr(call, 'op'):
+            if hasattr(call, "op"):
                 self.node_counter[call.op.name] += 1
             return super().visit_call(call)
 
index 5ddbcb1..a62a91f 100644 (file)
@@ -37,11 +37,13 @@ def convert_image(image):
     imagex = np.flip(imagex, 0)
     return imagex
 
+
 def load_image_color(test_image):
     """To load the image using opencv api and do preprocessing."""
     imagex = cv2.imread(test_image)
     return convert_image(imagex)
 
+
 def _letterbox_image(img, w_in, h_in):
     """To get the image in boxed format."""
     imh, imw, imc = img.shape
@@ -60,11 +62,14 @@ def _letterbox_image(img, w_in, h_in):
     resized = convert_image(resized)
     boxed = np.full((imc, h_in, w_in), 0.5, dtype=float)
     _, resizedh, resizedw = resized.shape
-    boxed[:, int((h_in - new_h) / 2)
-          :int((h_in - new_h) / 2) + resizedh, int((w_in - new_w) / 2)
-          :int((w_in - new_w) / 2) + resizedw] = resized
+    boxed[
+        :,
+        int((h_in - new_h) / 2) : int((h_in - new_h) / 2) + resizedh,
+        int((w_in - new_w) / 2) : int((w_in - new_w) / 2) + resizedw,
+    ] = resized
     return boxed
 
+
 def load_image(img, resize_width, resize_height):
     """Load the image and convert to the darknet model format.
     The image processing of darknet is different from normal.
@@ -87,8 +92,10 @@ def load_image(img, resize_width, resize_height):
     imagex = cv2.imread(img)
     return _letterbox_image(imagex, resize_width, resize_height)
 
+
 class LAYERTYPE(object):
     """Darknet LAYERTYPE Class constant."""
+
     CONVOLUTIONAL = 0
     DECONVOLUTIONAL = 1
     CONNECTED = 2
@@ -119,8 +126,10 @@ class LAYERTYPE(object):
     L2NORM = 27
     BLANK = 28
 
+
 class ACTIVATION(object):
     """Darknet ACTIVATION Class constant."""
+
     LOGISTIC = 0
     RELU = 1
     RELIE = 2
@@ -135,9 +144,11 @@ class ACTIVATION(object):
     HARDTAN = 11
     LHTAN = 12
 
+
 __darknetffi__ = FFI()
 
-__darknetffi__.cdef("""
+__darknetffi__.cdef(
+    """
 typedef struct network network;
 typedef struct layer layer;
 
@@ -494,22 +505,36 @@ image load_image_color(char *filename, int w, int h);
 float *network_predict_image(network *net, image im);
 float *network_predict(network *net, float *input);
 network *make_network(int n);
-layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
-layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam);
+layer make_convolutional_layer(
+    int batch,
+    int h, int w, int c, int n,
+    int groups, int size, int stride, int padding,
+    ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
+layer make_connected_layer(int batch, int inputs, int outputs,
+    ACTIVATION activation, int batch_normalize, int adam);
 layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding);
 layer make_avgpool_layer(int batch, int w, int h, int c);
 layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
 layer make_batchnorm_layer(int batch, int w, int h, int c);
-layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra);
+layer make_reorg_layer(
+    int batch, int w, int h, int c,
+    int stride, int reverse, int flatten, int extra);
 layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
 layer make_softmax_layer(int batch, int inputs, int groups);
-layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam);
+layer make_rnn_layer(int batch, int inputs, int outputs,
+    int steps, ACTIVATION activation, int batch_normalize, int adam);
 layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int classes);
-layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize);
-layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
-layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
+layer make_crnn_layer(
+    int batch, int h, int w, int c,
+    int hidden_filters, int output_filters, int steps,
+    ACTIVATION activation, int batch_normalize);
+layer make_lstm_layer(
+    int batch, int inputs, int outputs, int steps,
+    int batch_normalize, int adam);
+layer make_gru_layer(int batch, int inputs,
+    int outputs, int steps, int batch_normalize, int adam);
 layer make_upsample_layer(int batch, int w, int h, int c, int stride);
 layer make_l2norm_layer(int batch, int inputs);
 void free_network(network *net);
 """
-                   )
+)
index c3db527..04429ae 100644 (file)
@@ -30,6 +30,7 @@ from tvm import relay
 from . import layers
 from .init import create_workload
 
+
 def deconv2d(data, ishape, oshape, kshape, layout, name, stride=(2, 2)):
     """a deconv layer that enlarges the feature map"""
     target_shape = (oshape[-2], oshape[-1])
@@ -46,36 +47,48 @@ def deconv2d(data, ishape, oshape, kshape, layout, name, stride=(2, 2)):
     else:
         raise ValueError("Invalid layout: " + layout)
 
-    net = layers.conv2d_transpose(data,
-                                  kernel_size=kshape,
-                                  strides=stride,
-                                  channels=oshape[0],
-                                  padding=(pad_y, pad_x),
-                                  output_padding=(adj_y, adj_x),
-                                  data_layout=layout,
-                                  kernel_layout=kernel_layout,
-                                  name=name)
+    net = layers.conv2d_transpose(
+        data,
+        kernel_size=kshape,
+        strides=stride,
+        channels=oshape[0],
+        padding=(pad_y, pad_x),
+        output_padding=(adj_y, adj_x),
+        data_layout=layout,
+        kernel_layout=kernel_layout,
+        name=name,
+    )
     return net
 
+
 def deconv2d_bn_relu(data, prefix, **kwargs):
     """a block of deconv + batch norm + relu"""
     eps = 1e-5 + 1e-12
     net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
-    bn_axis = kwargs.get('layout', "NCHW").index('C')
-    net = layers.batch_norm_infer(net, epsilon=eps, scale=False, axis=bn_axis,
-                                  name="%s_batch_norm" % prefix)
+    bn_axis = kwargs.get("layout", "NCHW").index("C")
+    net = layers.batch_norm_infer(
+        net, epsilon=eps, scale=False, axis=bn_axis, name="%s_batch_norm" % prefix
+    )
     net = relay.nn.relu(net)
     return net
 
-def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None,
-            layout='NCHW', dtype="float32"):
+
+def get_net(
+    batch_size,
+    random_len=100,
+    oshape=(3, 64, 64),
+    ngf=128,
+    code=None,
+    layout="NCHW",
+    dtype="float32",
+):
     """get net of dcgan generator"""
     assert oshape[-1] == 64, "Only support 64x64 image"
     assert oshape[-2] == 64, "Only support 64x64 image"
 
     code = relay.var("data", dtype=dtype, shape=(batch_size, random_len)) if code is None else code
     dense_weight = relay.var("dense_weight")
-    dense = relay.nn.dense(code, weight=dense_weight, units=4*4*ngf*8)
+    dense = relay.nn.dense(code, weight=dense_weight, units=4 * 4 * ngf * 8)
     relu = relay.nn.relu(dense)
     # 4 x 4
     if layout == "NCHW":
@@ -85,25 +98,50 @@ def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None,
     else:
         raise ValueError("Invalid layout: " + layout)
     # 8 x 8
-    dc8 = deconv2d_bn_relu(reshape, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4),
-                           layout=layout, prefix="g2")
+    dc8 = deconv2d_bn_relu(
+        reshape,
+        ishape=(ngf * 8, 4, 4),
+        oshape=(ngf * 4, 8, 8),
+        kshape=(4, 4),
+        layout=layout,
+        prefix="g2",
+    )
     # 16x16
-    dc16 = deconv2d_bn_relu(dc8, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4),
-                            layout=layout, prefix="g3")
+    dc16 = deconv2d_bn_relu(
+        dc8,
+        ishape=(ngf * 4, 8, 8),
+        oshape=(ngf * 2, 16, 16),
+        kshape=(4, 4),
+        layout=layout,
+        prefix="g3",
+    )
     # 32x32
-    dc32 = deconv2d_bn_relu(dc16, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4),
-                            layout=layout, prefix="g4")
+    dc32 = deconv2d_bn_relu(
+        dc16,
+        ishape=(ngf * 2, 16, 16),
+        oshape=(ngf, 32, 32),
+        kshape=(4, 4),
+        layout=layout,
+        prefix="g4",
+    )
     # 64x64
-    dc64 = deconv2d(dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4),
-                    layout=layout, name="g5_deconv")
+    dc64 = deconv2d(
+        dc32,
+        ishape=(ngf, 32, 32),
+        oshape=oshape[-3:],
+        kshape=(4, 4),
+        layout=layout,
+        name="g5_deconv",
+    )
     tanh = relay.tanh(dc64)
 
     args = relay.analysis.free_vars(tanh)
     return relay.Function(args, tanh)
 
 
-def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100,
-                 layout='NCHW', dtype="float32"):
+def get_workload(
+    batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, layout="NCHW", dtype="float32"
+):
     """Get benchmark workload for a DCGAN generator
 
     Parameters
index de140fb..1ceb626 100644 (file)
@@ -25,42 +25,54 @@ from tvm import relay
 from . import layers
 from .init import create_workload
 
+
 def _make_dense_layer(data, growth_rate, bn_size, index):
     """Single densenet layer."""
     bn1 = layers.batch_norm_infer(data, name="batch_1_%s" % index)
     relu1 = relay.nn.relu(bn1)
-    conv1 = layers.conv2d(relu1, channels=bn_size * growth_rate,
-                          kernel_size=(1, 1), name="conv2d_1_%s" % index)
+    conv1 = layers.conv2d(
+        relu1, channels=bn_size * growth_rate, kernel_size=(1, 1), name="conv2d_1_%s" % index
+    )
     bn2 = layers.batch_norm_infer(conv1, name="batch_2_" + index)
     relu2 = relay.nn.relu(bn2)
-    conv2 = layers.conv2d(relu2, channels=growth_rate, kernel_size=(3, 3),
-                          padding=(1, 1), name="conv2d_2_%s" % index)
+    conv2 = layers.conv2d(
+        relu2, channels=growth_rate, kernel_size=(3, 3), padding=(1, 1), name="conv2d_2_%s" % index
+    )
     return conv2
 
+
 def _make_dense_block(data, num_layers, bn_size, growth_rate, index):
     """Makes a block of dense layers of the specified size."""
     layer_out = data
     for i in range(num_layers):
-        layer_out = _make_dense_layer(layer_out, growth_rate, bn_size,
-                                      "%s_%s" % (index, i))
+        layer_out = _make_dense_layer(layer_out, growth_rate, bn_size, "%s_%s" % (index, i))
     return layer_out
 
+
 def _make_transition(data, num_output_features, index):
     """Transition between layers."""
     bn = layers.batch_norm_infer(data, name="batch_t_%s" % index)
     relu = relay.nn.relu(bn)
-    conv = layers.conv2d(relu, channels=num_output_features,
-                         kernel_size=(1, 1), name="conv_t_%s" % index)
+    conv = layers.conv2d(
+        relu, channels=num_output_features, kernel_size=(1, 1), name="conv_t_%s" % index
+    )
     return relay.nn.avg_pool2d(conv, pool_size=(2, 2), strides=(2, 2))
 
-def _make_dense_net(num_init_features, growth_rate, block_config,
-                    data_shape, data_dtype, bn_size=4, classes=1000):
+
+def _make_dense_net(
+    num_init_features, growth_rate, block_config, data_shape, data_dtype, bn_size=4, classes=1000
+):
     """Builds up a densenet."""
-    data = relay.Var("data", relay.TensorType(data_shape, data_dtype)) # (bn_size, 3, 224, 224)))
-    conv1 = layers.conv2d(data, channels=num_init_features,
-                          kernel_size=(7, 7), strides=(2, 2), padding=(3, 3),
-                          name='conv1')
-    bn1 = layers.batch_norm_infer(conv1, name='batch1')
+    data = relay.Var("data", relay.TensorType(data_shape, data_dtype))  # (bn_size, 3, 224, 224)))
+    conv1 = layers.conv2d(
+        data,
+        channels=num_init_features,
+        kernel_size=(7, 7),
+        strides=(2, 2),
+        padding=(3, 3),
+        name="conv1",
+    )
+    bn1 = layers.batch_norm_infer(conv1, name="batch1")
     relu1 = relay.nn.relu(bn1)
     mp = relay.nn.max_pool2d(relu1, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
 
@@ -68,21 +80,23 @@ def _make_dense_net(num_init_features, growth_rate, block_config,
     layer_out = mp
     for i, num_layers in enumerate(block_config):
         layer_out = _make_dense_block(layer_out, num_layers, growth_rate, bn_size, i)
-        num_features = num_features + num_layers*growth_rate
+        num_features = num_features + num_layers * growth_rate
         if i != len(block_config) - 1:
             layer_out = _make_transition(layer_out, num_features // 2, i)
             num_features = num_features // 2
-    bn2 = layers.batch_norm_infer(layer_out, name='batch2')
+    bn2 = layers.batch_norm_infer(layer_out, name="batch2")
     relu2 = relay.nn.relu(bn2)
     avg = relay.nn.avg_pool2d(relu2, pool_size=(7, 7))
     flat = relay.nn.batch_flatten(avg)
 
-    ret = layers.dense_add_bias(flat, units=classes, name='dense')
+    ret = layers.dense_add_bias(flat, units=classes, name="dense")
 
     return relay.Function(relay.analysis.free_vars(ret), ret)
 
-def get_workload(densenet_size=121, classes=1000, batch_size=4,
-                 image_shape=(3, 224, 224), dtype='float32'):
+
+def get_workload(
+    densenet_size=121, classes=1000, batch_size=4, image_shape=(3, 224, 224), dtype="float32"
+):
     """Gets benchmark workload for densenet.
 
     Parameters
@@ -111,13 +125,16 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4,
     params : dict of str to NDArray
         The benchmark paraeters.
     """
-    specs = {121: (64, 32, [6, 12, 24, 16]),
-             161: (96, 48, [6, 12, 36, 24]),
-             169: (69, 32, [6, 12, 32, 32]),
-             201: (64, 32, [6, 12, 48, 32])}
+    specs = {
+        121: (64, 32, [6, 12, 24, 16]),
+        161: (96, 48, [6, 12, 36, 24]),
+        169: (69, 32, [6, 12, 32, 32]),
+        201: (64, 32, [6, 12, 48, 32]),
+    }
 
     num_init_features, growth_rate, block_config = specs[densenet_size]
     data_shape = tuple([batch_size] + list(image_shape))
-    net = _make_dense_net(num_init_features, growth_rate, block_config,
-                          data_shape, dtype, batch_size, classes)
+    net = _make_dense_net(
+        num_init_features, growth_rate, block_config, data_shape, dtype, batch_size, classes
+    )
     return create_workload(net)
index 62cf7cd..dd31ab8 100644 (file)
@@ -26,31 +26,53 @@ from tvm import relay
 from . import layers
 from .init import create_workload
 
+
 def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"):
     """get symbol of nature dqn"""
     data_shape = (batch_size,) + image_shape
     data = relay.var("data", shape=data_shape, dtype=dtype)
 
-    bias_axis = layout.index('C')
+    bias_axis = layout.index("C")
 
     conv1_bias = relay.var("conv1_bias")
-    conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
-                          channels=32, name="conv1", data_layout=layout,
-                          kernel_layout=layers.conv_kernel_layout(layout))
+    conv1 = layers.conv2d(
+        data,
+        kernel_size=(8, 8),
+        strides=(4, 4),
+        padding=(0, 0),
+        channels=32,
+        name="conv1",
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout),
+    )
     conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis)
     relu1 = relay.nn.relu(conv1)
 
     conv2_bias = relay.var("conv2_bias")
-    conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
-                          channels=64, name="conv2", data_layout=layout,
-                          kernel_layout=layers.conv_kernel_layout(layout))
+    conv2 = layers.conv2d(
+        relu1,
+        kernel_size=(4, 4),
+        strides=(2, 2),
+        padding=(0, 0),
+        channels=64,
+        name="conv2",
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout),
+    )
     conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis)
     relu2 = relay.nn.relu(conv2)
 
     conv3_bias = relay.var("conv3_bias")
-    conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
-                          channels=64, name="conv3", data_layout=layout,
-                          kernel_layout=layers.conv_kernel_layout(layout))
+    conv3 = layers.conv2d(
+        relu2,
+        kernel_size=(3, 3),
+        strides=(1, 1),
+        padding=(0, 0),
+        channels=64,
+        name="conv3",
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout),
+    )
     conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis)
     relu3 = relay.nn.relu(conv3)
 
@@ -63,8 +85,9 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
     return relay.Function(args, dense2)
 
 
-def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32",
-                 layout="NCHW"):
+def get_workload(
+    batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"
+):
     """Get benchmark workload for a Deep Q Network
     Parameters
     ----------
@@ -83,6 +106,7 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
     params : dict of str to NDArray
         The parameters.
     """
-    net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype,
-                  layout=layout)
+    net = get_net(
+        batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout
+    )
     return create_workload(net)
index 8a540e5..111cbc0 100644 (file)
@@ -29,153 +29,322 @@ from tvm import relay
 from .init import create_workload
 from . import layers
 
-def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
+
+def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=""):
     conv = layers.conv2d(
         data=data,
         channels=int(num_filter),
         kernel_size=kernel,
         strides=stride,
         padding=pad,
-        name='%s%s_conv1' % (name, suffix))
+        name="%s%s_conv1" % (name, suffix),
+    )
 
-    bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, scale=False,
-                                 name='%s%s_bn' % (name, suffix))
+    bn = layers.batch_norm_infer(
+        data=conv, epsilon=2e-5, scale=False, name="%s%s_bn" % (name, suffix)
+    )
     act = relay.nn.relu(data=bn)
     return act
 
+
 def Pooling(data, kernel, stride, pad, pool_type, name):
-    if pool_type == 'max':
+    if pool_type == "max":
         return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad)
-    if pool_type == 'avg':
-        return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad,
-                                   count_include_pad=True)
+    if pool_type == "avg":
+        return relay.nn.avg_pool2d(
+            data=data, pool_size=kernel, strides=stride, padding=pad, count_include_pad=True
+        )
     raise ValueError("Invalid pooling type: " + pool_type)
 
-def Inception7A(data,
-                num_1x1,
-                num_3x3_red, num_3x3_1, num_3x3_2,
-                num_5x5_red, num_5x5,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
-    tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
-    tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name),
-                     suffix='_conv_1')
-    tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name),
-                     suffix='_conv_1')
-    tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name),
-                     suffix='_conv_2')
-    pooling = Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool,
-                      name=('%s_pool_%s_pool' % (pool, name)))
-
-    cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv')
+
+def Inception7A(
+    data, num_1x1, num_3x3_red, num_3x3_1, num_3x3_2, num_5x5_red, num_5x5, pool, proj, name
+):
+    tower_1x1 = Conv(data, num_1x1, name=("%s_conv" % name))
+    tower_5x5 = Conv(data, num_5x5_red, name=("%s_tower" % name), suffix="_conv")
+    tower_5x5 = Conv(
+        tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=("%s_tower" % name), suffix="_conv_1"
+    )
+    tower_3x3 = Conv(data, num_3x3_red, name=("%s_tower_1" % name), suffix="_conv")
+    tower_3x3 = Conv(
+        tower_3x3,
+        num_3x3_1,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_3x3 = Conv(
+        tower_3x3,
+        num_3x3_2,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    pooling = Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+
+    cproj = Conv(pooling, proj, name=("%s_tower_2" % name), suffix="_conv")
     concat = relay.concatenate((tower_1x1, tower_5x5, tower_3x3, cproj), axis=1)
     return concat
 
+
 # First Downsample
-def Inception7B(data,
-                num_3x3,
-                num_d3x3_red, num_d3x3_1, num_d3x3_2,
-                pool,
-                name):
-    tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2),
-                     name=('%s_conv' % name))
-    tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
-                      name=('%s_tower' % name), suffix='_conv_1')
-    tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2),
-                      name=('%s_tower' % name), suffix='_conv_2')
-    pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0, 0), pool_type="max",
-                      name=('max_pool_%s_pool' % name))
+def Inception7B(data, num_3x3, num_d3x3_red, num_d3x3_1, num_d3x3_2, pool, name):
+    tower_3x3 = Conv(
+        data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=("%s_conv" % name)
+    )
+    tower_d3x3 = Conv(data, num_d3x3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d3x3 = Conv(
+        tower_d3x3,
+        num_d3x3_1,
+        kernel=(3, 3),
+        pad=(1, 1),
+        stride=(1, 1),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d3x3 = Conv(
+        tower_d3x3,
+        num_d3x3_2,
+        kernel=(3, 3),
+        pad=(0, 0),
+        stride=(2, 2),
+        name=("%s_tower" % name),
+        suffix="_conv_2",
+    )
+    pooling = Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(2, 2),
+        pad=(0, 0),
+        pool_type="max",
+        name=("max_pool_%s_pool" % name),
+    )
     concat = relay.concatenate((tower_3x3, tower_d3x3, pooling), axis=1)
     return concat
 
-def Inception7C(data,
-                num_1x1,
-                num_d7_red, num_d7_1, num_d7_2,
-                num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
-    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3),
-                    name=('%s_tower' % name), suffix='_conv_1')
-    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0),
-                    name=('%s_tower' % name), suffix='_conv_2')
-    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0),
-                    name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3),
-                    name=('%s_tower_1' % name), suffix='_conv_2')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0),
-                    name=('%s_tower_1' % name), suffix='_conv_3')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3),
-                    name=('%s_tower_1' % name), suffix='_conv_4')
-    pooling = Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool,
-                      name=('%s_pool_%s_pool' % (pool, name)))
-    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1),
-                 name=('%s_tower_2' % name), suffix='_conv')
+
+def Inception7C(
+    data,
+    num_1x1,
+    num_d7_red,
+    num_d7_1,
+    num_d7_2,
+    num_q7_red,
+    num_q7_1,
+    num_q7_2,
+    num_q7_3,
+    num_q7_4,
+    pool,
+    proj,
+    name,
+):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=("%s_conv" % name))
+    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d7 = Conv(
+        data=tower_d7,
+        num_filter=num_d7_1,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d7 = Conv(
+        data=tower_d7,
+        num_filter=num_d7_2,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower" % name),
+        suffix="_conv_2",
+    )
+    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=("%s_tower_1" % name), suffix="_conv")
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_1,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_2,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_3,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_3",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_4,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_4",
+    )
+    pooling = Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+    cproj = Conv(
+        data=pooling, num_filter=proj, kernel=(1, 1), name=("%s_tower_2" % name), suffix="_conv"
+    )
     # concat
     concat = relay.concatenate((tower_1x1, tower_d7, tower_q7, cproj), axis=1)
     return concat
 
-def Inception7D(data,
-                num_3x3_red, num_3x3,
-                num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
-                pool,
-                name):
-    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name),
-                     suffix='_conv')
-    tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2),
-                     name=('%s_tower' % name), suffix='_conv_1')
-    tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name),
-                        suffix='_conv')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3),
-                        name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0),
-                        name=('%s_tower_1' % name), suffix='_conv_2')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2),
-                        name=('%s_tower_1' % name), suffix='_conv_3')
-    pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, pad=(0, 0),
-                      name=('%s_pool_%s_pool' % (pool, name)))
+
+def Inception7D(
+    data, num_3x3_red, num_3x3, num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, pool, name
+):
+    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_3x3 = Conv(
+        data=tower_3x3,
+        num_filter=num_3x3,
+        kernel=(3, 3),
+        pad=(0, 0),
+        stride=(2, 2),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d7_3x3 = Conv(
+        data=data, num_filter=num_d7_3x3_red, name=("%s_tower_1" % name), suffix="_conv"
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_1,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_2,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_3x3,
+        kernel=(3, 3),
+        stride=(2, 2),
+        name=("%s_tower_1" % name),
+        suffix="_conv_3",
+    )
+    pooling = Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(2, 2),
+        pool_type=pool,
+        pad=(0, 0),
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
     # concat
     concat = relay.concatenate((tower_3x3, tower_d7_3x3, pooling), axis=1)
     return concat
 
-def Inception7E(data,
-                num_1x1,
-                num_d3_red, num_d3_1, num_d3_2,
-                num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
-    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1),
-                      name=('%s_tower' % name), suffix='_mixed_conv')
-    tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0),
-                      name=('%s_tower' % name), suffix='_mixed_conv_1')
-    tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name),
-                        suffix='_conv')
-    tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1),
-                        name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1),
-                          name=('%s_tower_1' % name), suffix='_mixed_conv')
-    tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0),
-                          name=('%s_tower_1' % name), suffix='_mixed_conv_1')
-    pooling = Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool,
-                      name=('%s_pool_%s_pool' % (pool, name)))
-    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name),
-                 suffix='_conv')
+
+def Inception7E(
+    data,
+    num_1x1,
+    num_d3_red,
+    num_d3_1,
+    num_d3_2,
+    num_3x3_d3_red,
+    num_3x3,
+    num_3x3_d3_1,
+    num_3x3_d3_2,
+    pool,
+    proj,
+    name,
+):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=("%s_conv" % name))
+    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d3_a = Conv(
+        data=tower_d3,
+        num_filter=num_d3_1,
+        kernel=(1, 3),
+        pad=(0, 1),
+        name=("%s_tower" % name),
+        suffix="_mixed_conv",
+    )
+    tower_d3_b = Conv(
+        data=tower_d3,
+        num_filter=num_d3_2,
+        kernel=(3, 1),
+        pad=(1, 0),
+        name=("%s_tower" % name),
+        suffix="_mixed_conv_1",
+    )
+    tower_3x3_d3 = Conv(
+        data=data, num_filter=num_3x3_d3_red, name=("%s_tower_1" % name), suffix="_conv"
+    )
+    tower_3x3_d3 = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_3x3_d3_a = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3_d3_1,
+        kernel=(1, 3),
+        pad=(0, 1),
+        name=("%s_tower_1" % name),
+        suffix="_mixed_conv",
+    )
+    tower_3x3_d3_b = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3_d3_2,
+        kernel=(3, 1),
+        pad=(1, 0),
+        name=("%s_tower_1" % name),
+        suffix="_mixed_conv_1",
+    )
+    pooling = Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+    cproj = Conv(
+        data=pooling, num_filter=proj, kernel=(1, 1), name=("%s_tower_2" % name), suffix="_conv"
+    )
     # concat
     concat = relay.concatenate(
-        (tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=1)
+        (tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=1
+    )
     return concat
 
-def get_net(batch_size,
-            num_classes,
-            image_shape,
-            dtype):
+
+def get_net(batch_size, num_classes, image_shape, dtype):
     """Get network a Inception v3 network.
 
     batch_size : int
@@ -196,72 +365,42 @@ def get_net(batch_size,
         The dataflow.
     """
     data_shape = (batch_size,) + image_shape
-    data = relay.var("data",
-                     shape=data_shape,
-                     dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
 
     # stage 1
     conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
     conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
     conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
-    pool = Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0),
-                   name="pool")
+    pool = Pooling(
+        data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0), name="pool"
+    )
     # stage 2
     conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
     conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
-    pool1 = Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0),
-                    name="pool1")
+    pool1 = Pooling(
+        data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0), name="pool1"
+    )
 
     # stage 3
-    in3a = Inception7A(pool1, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 32, "mixed")
-
-    in3b = Inception7A(in3a, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 64, "mixed_1")
-    in3c = Inception7A(in3b, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 64, "mixed_2")
-    in3d = Inception7B(in3c, 384,
-                       64, 96, 96,
-                       "max", "mixed_3")
+    in3a = Inception7A(pool1, 64, 64, 96, 96, 48, 64, "avg", 32, "mixed")
+
+    in3b = Inception7A(in3a, 64, 64, 96, 96, 48, 64, "avg", 64, "mixed_1")
+    in3c = Inception7A(in3b, 64, 64, 96, 96, 48, 64, "avg", 64, "mixed_2")
+    in3d = Inception7B(in3c, 384, 64, 96, 96, "max", "mixed_3")
     # stage 4
-    in4a = Inception7C(in3d, 192,
-                       128, 128, 192,
-                       128, 128, 128, 128, 192,
-                       "avg", 192, "mixed_4")
-    in4b = Inception7C(in4a, 192,
-                       160, 160, 192,
-                       160, 160, 160, 160, 192,
-                       "avg", 192, "mixed_5")
-    in4c = Inception7C(in4b, 192,
-                       160, 160, 192,
-                       160, 160, 160, 160, 192,
-                       "avg", 192, "mixed_6")
-    in4d = Inception7C(in4c, 192,
-                       192, 192, 192,
-                       192, 192, 192, 192, 192,
-                       "avg", 192, "mixed_7")
-    in4e = Inception7D(in4d, 192, 320,
-                       192, 192, 192, 192,
-                       "max", "mixed_8")
+    in4a = Inception7C(in3d, 192, 128, 128, 192, 128, 128, 128, 128, 192, "avg", 192, "mixed_4")
+    in4b = Inception7C(in4a, 192, 160, 160, 192, 160, 160, 160, 160, 192, "avg", 192, "mixed_5")
+    in4c = Inception7C(in4b, 192, 160, 160, 192, 160, 160, 160, 160, 192, "avg", 192, "mixed_6")
+    in4d = Inception7C(in4c, 192, 192, 192, 192, 192, 192, 192, 192, 192, "avg", 192, "mixed_7")
+    in4e = Inception7D(in4d, 192, 320, 192, 192, 192, 192, "max", "mixed_8")
     # stage 5
-    in5a = Inception7E(in4e, 320,
-                       384, 384, 384,
-                       448, 384, 384, 384,
-                       "avg", 192, "mixed_9")
-    in5b = Inception7E(in5a, 320,
-                       384, 384, 384,
-                       448, 384, 384, 384,
-                       "max", 192, "mixed_10")
+    in5a = Inception7E(in4e, 320, 384, 384, 384, 448, 384, 384, 384, "avg", 192, "mixed_9")
+    in5b = Inception7E(in5a, 320, 384, 384, 384, 448, 384, 384, 384, "max", 192, "mixed_10")
 
     # pool
-    pool = Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", pad=(0, 0),
-                   name="global_pool")
+    pool = Pooling(
+        data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", pad=(0, 0), name="global_pool"
+    )
 
     flatten = relay.nn.batch_flatten(pool)
     fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
@@ -270,8 +409,8 @@ def get_net(batch_size,
     args = relay.analysis.free_vars(inception_v3)
     return relay.Function(args, inception_v3)
 
-def get_workload(batch_size=1, num_classes=1000,
-                 image_shape=(3, 299, 299), dtype="float32"):
+
+def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 299, 299), dtype="float32"):
     """Get benchmark workload for InceptionV3
 
     Parameters
index 6b8adf3..1d4d8d9 100644 (file)
@@ -24,6 +24,7 @@ from tvm import relay
 
 class Initializer(object):
     """The base class of an initializer."""
+
     def __init__(self, **kwargs):
         self._kwargs = kwargs
 
@@ -38,17 +39,17 @@ class Initializer(object):
         arr : NDArray
             The array to be initialized.
         """
-        if desc.endswith('weight'):
+        if desc.endswith("weight"):
             self._init_weight(desc, arr)
-        elif desc.endswith('bias'):
+        elif desc.endswith("bias"):
             self._init_bias(desc, arr)
-        elif desc.endswith('gamma'):
+        elif desc.endswith("gamma"):
             self._init_gamma(desc, arr)
-        elif desc.endswith('beta'):
+        elif desc.endswith("beta"):
             self._init_beta(desc, arr)
-        elif desc.endswith('mean'):
+        elif desc.endswith("mean"):
             self._init_mean(desc, arr)
-        elif desc.endswith('var'):
+        elif desc.endswith("var"):
             self._init_var(desc, arr)
         else:
             self._init_default(desc, arr)
@@ -74,10 +75,11 @@ class Initializer(object):
 
     def _init_default(self, name, _):
         raise ValueError(
-            'Unknown initialization pattern for %s. ' \
-            'Default initialization is now limited to '\
-            '"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
-            'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name)
+            "Unknown initialization pattern for %s. "
+            "Default initialization is now limited to "
+            '"weight", "bias", "gamma" (1.0), and "beta" (0.0).'
+            "Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern" % name
+        )
 
 
 class Xavier(Initializer):
@@ -94,24 +96,27 @@ class Xavier(Initializer):
     magnitude: float, optional
         Scale of random number.
     """
+
     def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
-        super(Xavier, self).__init__(rnd_type=rnd_type,
-                                     factor_type=factor_type,
-                                     magnitude=magnitude)
+        super(Xavier, self).__init__(
+            rnd_type=rnd_type, factor_type=factor_type, magnitude=magnitude
+        )
         self.rnd_type = rnd_type
         self.factor_type = factor_type
         self.magnitude = float(magnitude)
 
     def _init_weight(self, name, arr):
         shape = arr.shape
-        hw_scale = 1.
+        hw_scale = 1.0
         if len(shape) < 2:
-            raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at'
-                             ' least 2D.'.format(name))
+            raise ValueError(
+                "Xavier initializer cannot be applied to vector {0}. It requires at"
+                " least 2D.".format(name)
+            )
         if len(shape) > 2:
             hw_scale = np.prod(shape[2:])
         fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale
-        factor = 1.
+        factor = 1.0
         if self.factor_type == "avg":
             factor = (fan_in + fan_out) / 2.0
         elif self.factor_type == "in":
@@ -131,11 +136,11 @@ class Xavier(Initializer):
 
 
 class Constant(Initializer):
-    """ Constant initialization of weights. Sum of weights in the matrix is 1.
-    """
+    """Constant initialization of weights. Sum of weights in the matrix is 1."""
+
     def _init_weight(self, name, arr):
-        num_elements = reduce(lambda x, y: x*y, arr.shape)
-        arr[:] = 1./num_elements
+        num_elements = reduce(lambda x, y: x * y, arr.shape)
+        arr[:] = 1.0 / num_elements
 
 
 def create_workload(net, initializer=None, seed=0):
@@ -162,8 +167,7 @@ def create_workload(net, initializer=None, seed=0):
     """
     mod = tvm.IRModule.from_expr(net)
     mod = relay.transform.InferType()(mod)
-    shape_dict = {
-        v.name_hint : v.checked_type for v in mod["main"].params}
+    shape_dict = {v.name_hint: v.checked_type for v in mod["main"].params}
     np.random.seed(seed)
     initializer = initializer if initializer else Xavier()
     params = {}
index 5d46b32..48003f2 100644 (file)
 """Simple Layer DSL wrapper to ease creation of neural nets."""
 from tvm import relay
 
-def batch_norm_infer(data,
-                     gamma=None,
-                     beta=None,
-                     moving_mean=None,
-                     moving_var=None,
-                     **kwargs):
+
+def batch_norm_infer(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs):
     """Wrapper of batch_norm.
 
     This function automatically creates weights and return
@@ -63,12 +59,9 @@ def batch_norm_infer(data,
         moving_mean = relay.var(name + "_moving_mean")
     if not moving_var:
         moving_var = relay.var(name + "_moving_var")
-    return relay.nn.batch_norm(data,
-                               gamma=gamma,
-                               beta=beta,
-                               moving_mean=moving_mean,
-                               moving_var=moving_var,
-                               **kwargs)[0]
+    return relay.nn.batch_norm(
+        data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs
+    )[0]
 
 
 def conv2d(data, weight=None, **kwargs):
@@ -96,6 +89,7 @@ def conv2d(data, weight=None, **kwargs):
         weight = relay.var(name + "_weight")
     return relay.nn.conv2d(data, weight, **kwargs)
 
+
 def conv3d(data, weight=None, **kwargs):
     """Wrapper of conv3d which automatically creates weights if not given.
     Parameters
@@ -117,6 +111,7 @@ def conv3d(data, weight=None, **kwargs):
         weight = relay.var(name + "_weight")
     return relay.nn.conv3d(data, weight, **kwargs)
 
+
 def conv2d_transpose(data, weight=None, **kwargs):
     """Wrapper of conv2d_transpose which automatically creates weights if not given.
 
@@ -142,6 +137,7 @@ def conv2d_transpose(data, weight=None, **kwargs):
         weight = relay.var(name + "_weight")
     return relay.nn.conv2d_transpose(data, weight, **kwargs)
 
+
 def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
     """Wrapper of dense which automatically creates weights if not given.
 
@@ -174,6 +170,7 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
     data = relay.nn.bias_add(data, bias, axis=-1)
     return data
 
+
 def conv_kernel_layout(data_layout, is_depthwise=False):
     """Map the data layout to corresponding kernel layout.
 
@@ -193,12 +190,12 @@ def conv_kernel_layout(data_layout, is_depthwise=False):
         The corresponding kernel layout.
     """
     conv_layout_map = {
-        'NCHW': 'OIHW',
-        'NHWC': 'HWIO',
+        "NCHW": "OIHW",
+        "NHWC": "HWIO",
     }
     depthwise_conv_layout_map = {
-        'NCHW': 'OIHW',
-        'NHWC': 'HWOI',
+        "NCHW": "OIHW",
+        "NHWC": "HWOI",
     }
     mapping = depthwise_conv_layout_map if is_depthwise else conv_layout_map
     assert data_layout in mapping, "Unknown data layout %s" % data_layout
index 2480d15..8a97c18 100644 (file)
@@ -26,6 +26,7 @@ from tvm import relay
 from . import layers
 from .init import create_workload
 
+
 def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
     """Long-Short Term Memory (LSTM) network cell.
 
@@ -49,18 +50,15 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
     builder = relay.ScopeBuilder()
 
     input_type = relay.TensorType((batch_size, num_hidden), dtype)
-    weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
-    bias_type = relay.TensorType((4*num_hidden,), dtype)
+    weight_type = relay.TensorType((4 * num_hidden, num_hidden), dtype)
+    bias_type = relay.TensorType((4 * num_hidden,), dtype)
 
-    dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype)
-    slice_type = relay.TupleType([input_type, input_type,
-                                  input_type, input_type])
-    ret_type = relay.TupleType([input_type,
-                                relay.TupleType([input_type, input_type])])
+    dense_type = relay.TensorType((batch_size, 4 * num_hidden), dtype)
+    slice_type = relay.TupleType([input_type, input_type, input_type, input_type])
+    ret_type = relay.TupleType([input_type, relay.TupleType([input_type, input_type])])
 
     inputs = relay.Var("inputs", input_type)
-    states = relay.Var("states",
-                       relay.TupleType([input_type, input_type]))
+    states = relay.Var("states", relay.TupleType([input_type, input_type]))
 
     i2h_weight = relay.Var("i2h_weight", weight_type)
     i2h_bias = relay.Var("i2h_bias", bias_type)
@@ -68,66 +66,72 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
     h2h_weight = relay.Var("h2h_weight", weight_type)
     h2h_bias = relay.Var("h2h_bias", bias_type)
 
-    i2h = builder.let(("i2h", dense_type),
-                      layers.dense_add_bias(
-                          data=inputs,
-                          units=num_hidden * 4,
-                          weight=i2h_weight, bias=i2h_bias,
-                          name="%si2h" % name))
-    h2h = builder.let(("h2h", dense_type),
-                      layers.dense_add_bias(
-                          data=relay.TupleGetItem(states, 0),
-                          units=num_hidden * 4,
-                          weight=h2h_weight, bias=h2h_bias,
-                          name="%sh2h" % name))
+    i2h = builder.let(
+        ("i2h", dense_type),
+        layers.dense_add_bias(
+            data=inputs, units=num_hidden * 4, weight=i2h_weight, bias=i2h_bias, name="%si2h" % name
+        ),
+    )
+    h2h = builder.let(
+        ("h2h", dense_type),
+        layers.dense_add_bias(
+            data=relay.TupleGetItem(states, 0),
+            units=num_hidden * 4,
+            weight=h2h_weight,
+            bias=h2h_bias,
+            name="%sh2h" % name,
+        ),
+    )
 
     gates = builder.let(("gates", dense_type), relay.add(i2h, h2h))
-    slice_gates = builder.let(("slice_gates", slice_type),
-                              relay.split(gates,
-                                          indices_or_sections=4,
-                                          axis=1).astuple())
-
-    in_gate = builder.let(("in_gate", input_type),
-                          relay.sigmoid(relay.TupleGetItem(slice_gates, 0)))
-    forget_gate = builder.let(("forget_gate", input_type),
-                              relay.sigmoid(relay.TupleGetItem(slice_gates, 1)))
-    in_transform = builder.let(("in_transform", input_type),
-                               relay.tanh(relay.TupleGetItem(slice_gates, 2)))
-    out_gate = builder.let(("out_gate", input_type),
-                           relay.sigmoid(relay.TupleGetItem(slice_gates, 3)))
-
-    next_c = builder.let(("next_c", input_type),
-                         relay.add(relay.multiply(forget_gate,
-                                                  relay.TupleGetItem(states, 1)),
-                                   relay.multiply(in_gate, in_transform)))
-    next_h = builder.let(("next_h", input_type),
-                         relay.multiply(out_gate, relay.tanh(next_c)))
-    ret = builder.let(("ret", ret_type),
-                      relay.Tuple([next_h, relay.Tuple([next_h, next_c])]))
+    slice_gates = builder.let(
+        ("slice_gates", slice_type), relay.split(gates, indices_or_sections=4, axis=1).astuple()
+    )
+
+    in_gate = builder.let(
+        ("in_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 0))
+    )
+    forget_gate = builder.let(
+        ("forget_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 1))
+    )
+    in_transform = builder.let(
+        ("in_transform", input_type), relay.tanh(relay.TupleGetItem(slice_gates, 2))
+    )
+    out_gate = builder.let(
+        ("out_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 3))
+    )
+
+    next_c = builder.let(
+        ("next_c", input_type),
+        relay.add(
+            relay.multiply(forget_gate, relay.TupleGetItem(states, 1)),
+            relay.multiply(in_gate, in_transform),
+        ),
+    )
+    next_h = builder.let(("next_h", input_type), relay.multiply(out_gate, relay.tanh(next_c)))
+    ret = builder.let(("ret", ret_type), relay.Tuple([next_h, relay.Tuple([next_h, next_c])]))
     builder.ret(ret)
 
     body = builder.get()
 
-    return relay.Function([inputs, states, i2h_weight,
-                           i2h_bias, h2h_weight, h2h_bias],
-                          body, ret_type)
+    return relay.Function(
+        [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias], body, ret_type
+    )
 
 
 def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
-    '''Constructs an unrolled RNN with LSTM cells'''
+    """Constructs an unrolled RNN with LSTM cells"""
     input_type = relay.TensorType((batch_size, num_hidden), dtype)
-    weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
-    bias_type = relay.TensorType((4*num_hidden,), dtype)
+    weight_type = relay.TensorType((4 * num_hidden, num_hidden), dtype)
+    bias_type = relay.TensorType((4 * num_hidden,), dtype)
 
     state_type = relay.TupleType([input_type, input_type])
     cell_type = relay.TupleType([input_type, state_type])
 
     builder = relay.ScopeBuilder()
 
-    zeros = builder.let(("zeros", input_type),
-                        relay.zeros((batch_size, num_hidden), dtype))
-    init_states = builder.let(("init_states", state_type),
-                              relay.Tuple([zeros, zeros]))
+    zeros = builder.let(("zeros", input_type), relay.zeros((batch_size, num_hidden), dtype))
+    init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros]))
 
     states = init_states
     out = None
@@ -141,14 +145,12 @@ def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
 
         cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i)
 
-        call = builder.let(("call_%s" % i, cell_type),
-                           relay.Call(cell_fn,
-                                      [inputs, states, i2h_weight,
-                                       i2h_bias, h2h_weight, h2h_bias]))
-        new_out = builder.let(("out_%s" % i, input_type),
-                              relay.TupleGetItem(call, 0))
-        new_states = builder.let(("states_%s" % i, state_type),
-                                 relay.TupleGetItem(call, 1))
+        call = builder.let(
+            ("call_%s" % i, cell_type),
+            relay.Call(cell_fn, [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias]),
+        )
+        new_out = builder.let(("out_%s" % i, input_type), relay.TupleGetItem(call, 0))
+        new_states = builder.let(("states_%s" % i, state_type), relay.TupleGetItem(call, 1))
         states = new_states
         out = new_out
 
index d118731..ac2d422 100644 (file)
@@ -21,10 +21,8 @@ from __future__ import absolute_import
 from tvm import relay
 from .init import create_workload
 
-def get_net(batch_size,
-            num_classes=10,
-            image_shape=(1, 28, 28),
-            dtype="float32"):
+
+def get_net(batch_size, num_classes=10, image_shape=(1, 28, 28), dtype="float32"):
     """Get network a simple multilayer perceptron.
 
     batch_size : int
@@ -45,9 +43,7 @@ def get_net(batch_size,
         The dataflow.
     """
     data_shape = (batch_size,) + image_shape
-    data = relay.var("data",
-                     shape=data_shape,
-                     dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     data = relay.nn.batch_flatten(data)
     fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
     fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
@@ -62,10 +58,7 @@ def get_net(batch_size,
     return relay.Function(args, mlp)
 
 
-def get_workload(batch_size,
-                 num_classes=10,
-                 image_shape=(1, 28, 28),
-                 dtype="float32"):
+def get_workload(batch_size, num_classes=10, image_shape=(1, 28, 28), dtype="float32"):
     """Get benchmark workload for a simple multilayer perceptron.
 
     Parameters
index e374e1b..0b5593e 100644 (file)
@@ -24,8 +24,16 @@ from . import layers
 from .init import create_workload
 
 
-def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
-               padding=(1, 1), epsilon=1e-5, layout='NCHW'):
+def conv_block(
+    data,
+    name,
+    channels,
+    kernel_size=(3, 3),
+    strides=(1, 1),
+    padding=(1, 1),
+    epsilon=1e-5,
+    layout="NCHW",
+):
     """Helper function to construct conv_bn-relu"""
     # convolution + bn + relu
     conv = layers.conv2d(
@@ -36,15 +44,25 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
         padding=padding,
         data_layout=layout,
         kernel_layout=layers.conv_kernel_layout(layout),
-        name=name+'_conv')
-    bn = layers.batch_norm_infer(data=conv, epsilon=epsilon, name=name + '_bn')
+        name=name + "_conv",
+    )
+    bn = layers.batch_norm_infer(data=conv, epsilon=epsilon, name=name + "_bn")
     act = relay.nn.relu(data=bn)
     return act
 
 
-def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
-                         kernel_size=(3, 3), downsample=False, padding=(1, 1),
-                         epsilon=1e-5, layout='NCHW', dtype="float32"):
+def separable_conv_block(
+    data,
+    name,
+    depthwise_channels,
+    pointwise_channels,
+    kernel_size=(3, 3),
+    downsample=False,
+    padding=(1, 1),
+    epsilon=1e-5,
+    layout="NCHW",
+    dtype="float32",
+):
     """Helper function to get a separable conv block"""
     if downsample:
         strides = (2, 2)
@@ -54,11 +72,11 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
     # depthwise convolution + bn + relu
     if layout == "NCHW":
         wshape = (depthwise_channels, 1) + kernel_size
-    elif layout == 'NHWC':
+    elif layout == "NHWC":
         wshape = kernel_size + (depthwise_channels, 1)
     else:
         raise ValueError("Invalid layout: " + layout)
-    bn_axis = layout.index('C')
+    bn_axis = layout.index("C")
     weight = relay.var(name + "_weight", shape=wshape, dtype=dtype)
     conv1 = layers.conv2d(
         data=data,
@@ -70,8 +88,9 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
         padding=padding,
         data_layout=layout,
         kernel_layout=layers.conv_kernel_layout(layout, True),
-        name=name+'_depthwise_conv1')
-    bn1 = layers.batch_norm_infer(data=conv1, epsilon=epsilon, axis=bn_axis, name=name+'_bn1')
+        name=name + "_depthwise_conv1",
+    )
+    bn1 = layers.batch_norm_infer(data=conv1, epsilon=epsilon, axis=bn_axis, name=name + "_bn1")
     act1 = relay.nn.relu(data=bn1)
     # pointwise convolution + bn + relu
     conv2 = layers.conv2d(
@@ -82,66 +101,129 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
         padding=(0, 0),
         data_layout=layout,
         kernel_layout=layers.conv_kernel_layout(layout),
-        name=name + '_conv2')
-    bn2 = layers.batch_norm_infer(data=conv2, epsilon=epsilon, axis=bn_axis, name=name+'_bn2')
+        name=name + "_conv2",
+    )
+    bn2 = layers.batch_norm_infer(data=conv2, epsilon=epsilon, axis=bn_axis, name=name + "_bn2")
     act2 = relay.nn.relu(data=bn2)
     return act2
 
 
-def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
-               dtype='float32', alpha=1.0, is_shallow=False, layout='NCHW'):
+def mobile_net(
+    num_classes=1000,
+    data_shape=(1, 3, 224, 224),
+    dtype="float32",
+    alpha=1.0,
+    is_shallow=False,
+    layout="NCHW",
+):
     """Function to construct a MobileNet"""
     data = relay.var("data", shape=data_shape, dtype=dtype)
-    body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2),
-                      layout=layout)
-    body = separable_conv_block(body, 'separable_conv_block_1',
-                                int(32*alpha), int(64*alpha), layout=layout,
-                                dtype=dtype)
-    body = separable_conv_block(body, 'separable_conv_block_2',
-                                int(64*alpha), int(128*alpha), downsample=True,
-                                layout=layout, dtype=dtype)
-    body = separable_conv_block(body, 'separable_conv_block_3',
-                                int(128*alpha), int(128*alpha), layout=layout,
-                                dtype=dtype)
-    body = separable_conv_block(body, 'separable_conv_block_4',
-                                int(128*alpha), int(256*alpha), downsample=True,
-                                layout=layout, dtype=dtype)
-    body = separable_conv_block(body, 'separable_conv_block_5',
-                                int(256*alpha), int(256*alpha), layout=layout,
-                                dtype=dtype)
-    body = separable_conv_block(body, 'separable_conv_block_6',
-                                int(256*alpha), int(512*alpha), downsample=True,
-                                layout=layout, dtype=dtype)
+    body = conv_block(data, "conv_block_1", int(32 * alpha), strides=(2, 2), layout=layout)
+    body = separable_conv_block(
+        body, "separable_conv_block_1", int(32 * alpha), int(64 * alpha), layout=layout, dtype=dtype
+    )
+    body = separable_conv_block(
+        body,
+        "separable_conv_block_2",
+        int(64 * alpha),
+        int(128 * alpha),
+        downsample=True,
+        layout=layout,
+        dtype=dtype,
+    )
+    body = separable_conv_block(
+        body,
+        "separable_conv_block_3",
+        int(128 * alpha),
+        int(128 * alpha),
+        layout=layout,
+        dtype=dtype,
+    )
+    body = separable_conv_block(
+        body,
+        "separable_conv_block_4",
+        int(128 * alpha),
+        int(256 * alpha),
+        downsample=True,
+        layout=layout,
+        dtype=dtype,
+    )
+    body = separable_conv_block(
+        body,
+        "separable_conv_block_5",
+        int(256 * alpha),
+        int(256 * alpha),
+        layout=layout,
+        dtype=dtype,
+    )
+    body = separable_conv_block(
+        body,
+        "separable_conv_block_6",
+        int(256 * alpha),
+        int(512 * alpha),
+        downsample=True,
+        layout=layout,
+        dtype=dtype,
+    )
     if is_shallow:
-        body = separable_conv_block(body, 'separable_conv_block_7',
-                                    int(512*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout, dtype=dtype)
-        body = separable_conv_block(body, 'separable_conv_block_8',
-                                    int(1024*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout, dtype=dtype)
+        body = separable_conv_block(
+            body,
+            "separable_conv_block_7",
+            int(512 * alpha),
+            int(1024 * alpha),
+            downsample=True,
+            layout=layout,
+            dtype=dtype,
+        )
+        body = separable_conv_block(
+            body,
+            "separable_conv_block_8",
+            int(1024 * alpha),
+            int(1024 * alpha),
+            downsample=True,
+            layout=layout,
+            dtype=dtype,
+        )
     else:
         for i in range(7, 12):
-            body = separable_conv_block(body, 'separable_conv_block_%d' % i,
-                                        int(512*alpha), int(512*alpha),
-                                        layout=layout, dtype=dtype)
-        body = separable_conv_block(body, 'separable_conv_block_12',
-                                    int(512*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout, dtype=dtype)
-        body = separable_conv_block(body, 'separable_conv_block_13',
-                                    int(1024*alpha), int(1024*alpha),
-                                    layout=layout, dtype=dtype)
+            body = separable_conv_block(
+                body,
+                "separable_conv_block_%d" % i,
+                int(512 * alpha),
+                int(512 * alpha),
+                layout=layout,
+                dtype=dtype,
+            )
+        body = separable_conv_block(
+            body,
+            "separable_conv_block_12",
+            int(512 * alpha),
+            int(1024 * alpha),
+            downsample=True,
+            layout=layout,
+            dtype=dtype,
+        )
+        body = separable_conv_block(
+            body,
+            "separable_conv_block_13",
+            int(1024 * alpha),
+            int(1024 * alpha),
+            layout=layout,
+            dtype=dtype,
+        )
     pool = relay.nn.global_avg_pool2d(data=body, layout=layout)
     flatten = relay.nn.batch_flatten(data=pool)
-    weight = relay.var('fc_weight')
-    bias = relay.var('fc_bias')
+    weight = relay.var("fc_weight")
+    bias = relay.var("fc_bias")
     fc = relay.nn.dense(data=flatten, weight=weight, units=num_classes)
     fc = relay.nn.bias_add(fc, bias)
     softmax = relay.nn.softmax(data=fc)
     return relay.Function(relay.analysis.free_vars(softmax), softmax)
 
 
-def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224),
-                 dtype='float32', layout='NCHW'):
+def get_workload(
+    batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtype="float32", layout="NCHW"
+):
     """Get benchmark workload for mobilenet
 
     Parameters
@@ -171,7 +253,12 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224),
         The parameters.
     """
     data_shape = tuple([batch_size] + list(image_shape))
-    net = mobile_net(num_classes=num_classes, data_shape=data_shape,
-                     dtype=dtype, alpha=1.0, is_shallow=False,
-                     layout=layout)
+    net = mobile_net(
+        num_classes=num_classes,
+        data_shape=data_shape,
+        dtype=dtype,
+        alpha=1.0,
+        is_shallow=False,
+        layout=layout,
+    )
     return create_workload(net)
index 1ca456b..b694f63 100644 (file)
@@ -25,6 +25,7 @@ from tvm.relay.expr import Var, GlobalVar
 from tvm.relay.function import Function
 from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
 
+
 def define_nat_adt(prelude):
     """Defines a Peano (unary) natural number ADT.
     Zero is represented by z(). s(n) adds 1 to a nat n.
@@ -46,8 +47,9 @@ def define_nat_double(prelude):
     x = Var("x", prelude.nat())
     y = Var("y")
     z_case = Clause(PatternConstructor(prelude.z), prelude.z())
-    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
-                    prelude.s(prelude.s(prelude.double(y))))
+    s_case = Clause(
+        PatternConstructor(prelude.s, [PatternVar(y)]), prelude.s(prelude.s(prelude.double(y)))
+    )
     prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case]))
 
 
@@ -60,13 +62,13 @@ def define_nat_add(prelude):
     y = Var("y", prelude.nat())
     a = Var("a")
     z_case = Clause(PatternConstructor(prelude.z), y)
-    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]),
-                    prelude.s(prelude.add(a, y)))
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), prelude.s(prelude.add(a, y)))
     prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case]))
 
 
 # versions of prelude functions that use nats instead of scalars
 
+
 def define_nat_nth(prelude):
     """Defines a function to get the nth eleemnt of a list using
     a nat to index into the list.
@@ -80,12 +82,11 @@ def define_nat_nth(prelude):
     y = Var("y")
 
     z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x))
-    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
-                    prelude.nat_nth(prelude.tl(x), y))
+    s_case = Clause(
+        PatternConstructor(prelude.s, [PatternVar(y)]), prelude.nat_nth(prelude.tl(x), y)
+    )
 
-    prelude.mod[prelude.nat_nth] = Function([x, n],
-                                            Match(n, [z_case, s_case]),
-                                            a, [a])
+    prelude.mod[prelude.nat_nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
 
 
 def define_nat_update(prelude):
@@ -101,16 +102,15 @@ def define_nat_update(prelude):
     v = Var("v", a)
     y = Var("y")
 
-    z_case = Clause(PatternConstructor(prelude.z),
-                    prelude.cons(v, prelude.tl(l)))
-    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
-                    prelude.cons(
-                        prelude.hd(l),
-                        prelude.nat_update(prelude.tl(l), y, v)))
+    z_case = Clause(PatternConstructor(prelude.z), prelude.cons(v, prelude.tl(l)))
+    s_case = Clause(
+        PatternConstructor(prelude.s, [PatternVar(y)]),
+        prelude.cons(prelude.hd(l), prelude.nat_update(prelude.tl(l), y, v)),
+    )
 
-    prelude.mod[prelude.nat_update] = Function([l, n, v],
-                                               Match(n, [z_case, s_case]),
-                                               prelude.l(a), [a])
+    prelude.mod[prelude.nat_update] = Function(
+        [l, n, v], Match(n, [z_case, s_case]), prelude.l(a), [a]
+    )
 
 
 def define_nat_iterate(prelude):
@@ -127,13 +127,14 @@ def define_nat_iterate(prelude):
     y = Var("y", prelude.nat())
 
     z_case = Clause(PatternConstructor(prelude.z), prelude.id)
-    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
-                    prelude.compose(f, prelude.nat_iterate(f, y)))
-
-    prelude.mod[prelude.nat_iterate] = Function([f, x],
-                                                Match(x, [z_case, s_case]),
-                                                FuncType([a], a),
-                                                [a])
+    s_case = Clause(
+        PatternConstructor(prelude.s, [PatternVar(y)]),
+        prelude.compose(f, prelude.nat_iterate(f, y)),
+    )
+
+    prelude.mod[prelude.nat_iterate] = Function(
+        [f, x], Match(x, [z_case, s_case]), FuncType([a], a), [a]
+    )
 
 
 def add_nat_definitions(prelude):
index a782d83..c0dc97c 100644 (file)
@@ -27,7 +27,7 @@ from tvm.relay.expr import Expr, GlobalVar, Var
 from tvm.relay.function import Function
 from tvm.relay.expr_functor import ExprFunctor
 
-OUTPUT_VAR_NAME = '_py_out'
+OUTPUT_VAR_NAME = "_py_out"
 
 # corresponds to:
 #     import numpy
@@ -37,18 +37,19 @@ OUTPUT_VAR_NAME = '_py_out'
 #     from tvm.runtime import import container as _container
 #     from tvm.relay.backend.interpreter import RefValue, ConstructorValue
 PROLOGUE = [
-    ast.Import([alias('numpy', None)]),
-    ast.Import([alias('tvm', None)]),
-    ast.ImportFrom('tvm', [alias('relay', None)], 0),
-    ast.ImportFrom('tvm', [alias('nd', None)], 0),
-    ast.ImportFrom('tvm.runtime', [alias('container', '_container')],
-                   0),
-    ast.ImportFrom('tvm.relay.backend.interpreter',
-                   [alias('RefValue', None),
-                    alias('ConstructorValue', None)],
-                   0),
+    ast.Import([alias("numpy", None)]),
+    ast.Import([alias("tvm", None)]),
+    ast.ImportFrom("tvm", [alias("relay", None)], 0),
+    ast.ImportFrom("tvm", [alias("nd", None)], 0),
+    ast.ImportFrom("tvm.runtime", [alias("container", "_container")], 0),
+    ast.ImportFrom(
+        "tvm.relay.backend.interpreter",
+        [alias("RefValue", None), alias("ConstructorValue", None)],
+        0,
+    ),
 ]
 
+
 class PythonConverter(ExprFunctor):
     """Functor for translating Relay programs into Python ASTs."""
 
@@ -61,7 +62,6 @@ class PythonConverter(ExprFunctor):
         self.var_no = 0
         self.var_map = {}
 
-
     def convert(self, prog: Expr):
         """This method converts the passed Relay expression into a Python
         AST object with equivalent semantics.
@@ -85,7 +85,6 @@ class PythonConverter(ExprFunctor):
 
         return ast.fix_missing_locations(ast.Module(body=body))
 
-
     def optimize(self, prog: Expr):
         """Performs optimizations necessary to be able to generate code for prog."""
         # unwrap tuple wrappers (some op calls produce them)
@@ -95,34 +94,31 @@ class PythonConverter(ExprFunctor):
 
         # necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
         # and fusion (to get primitive functions)
-        opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
-                                         relay.transform.FuseOps(fuse_opt_level=0)])
+        opts = tvm.transform.Sequential(
+            [relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)]
+        )
         mod = opts(mod)
-        optimized = mod['main']
+        optimized = mod["main"]
         return optimized if isinstance(unwrapped, Function) else optimized.body
 
-
     def sanitize(self, name: str) -> str:
         """Removes any invalid characters (only underscores, numbers, and letters permitted)
         from the given name. Since we append a number and underscore to var names anyway,
         it doesn't matter if the name is the empty string."""
-        return re.sub(r'\W', '', name)
-
+        return re.sub(r"\W", "", name)
 
     def generate_var_name(self, name_hint: str) -> str:
         """Generates a unique variable name starting from the hint."""
-        name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no)
+        name = "{}_var_{}".format(self.sanitize(name_hint), self.var_no)
         self.var_no += 1
         return name
 
-
     def generate_function_name(self, name_hint: str) -> str:
         """Generates a unique function name starting from the hint."""
-        name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no)
+        name = "{}_fun_{}".format(self.sanitize(name_hint), self.fun_no)
         self.fun_no += 1
         return name
 
-
     def get_var_name(self, var: Expr) -> str:
         """Returns the var name for the given Realy variable."""
         if var in self.var_map:
@@ -131,24 +127,21 @@ class PythonConverter(ExprFunctor):
         self.var_map[var] = name
         return name
 
-
     def include_var(self, var: Expr, assign=False):
         """Returns a variable AST node for the given Relay var depending on
         whether it must appear in an assignment or not."""
         name = self.get_var_name(var)
         return Name(name, Store() if assign else Load())
 
-
     def parse_name(self, name: str):
         """Given the name of a Python method with dots (e.g., 'relay.var'),
         returns an appropriate AST object corresponding to that name."""
-        attributes = name.split('.')
+        attributes = name.split(".")
         ret = Name(attributes[0], Load())
         for i in range(len(attributes) - 1):
-            ret = ast.Attribute(ret, attributes[i+1], Load())
+            ret = ast.Attribute(ret, attributes[i + 1], Load())
         return ret
 
-
     def parse_numpy_array(self, arr):
         """Given a Numpy array, produces an appropriate Python array
         or numerical literal representing its contents."""
@@ -163,7 +156,6 @@ class PythonConverter(ExprFunctor):
             elts.append(self.parse_numpy_array(row))
         return ast.List(elts, Load())
 
-
     def convert_fields(self, fields: [Expr]):
         """Given a list of call args or tuple fields, converts
         each and returns their ASTs and their defs lists (in order)."""
@@ -175,7 +167,6 @@ class PythonConverter(ExprFunctor):
             defs += member_defs
         return (bodies, defs)
 
-
     def convert_to_thunk(self, name_hint: str, expr: Expr):
         """Wraps the passed expression in a thunk."""
         body, defs = self.visit(expr)
@@ -183,12 +174,11 @@ class PythonConverter(ExprFunctor):
         thunk = self.create_def(thunk_name, [], defs + [Return(body)])
         return (thunk, thunk_name)
 
-
     def convert_func_node(self, func: Function, name_var=None):
         """Converts the given Relay function into a Python function, with
         special for named functions (locally or globally)"""
         if name_var is None:
-            func_name = self.generate_function_name('_anon_func')
+            func_name = self.generate_function_name("_anon_func")
         if isinstance(name_var, GlobalVar):
             func_name = str(name_var.name_hint)
         if isinstance(name_var, Var):
@@ -199,7 +189,6 @@ class PythonConverter(ExprFunctor):
         ret = self.create_def(func_name, var_names, defs + [Return(body)])
         return (ret, func_name)
 
-
     def convert_module(self):
         """Converts all the global functions defined in the module and returns
         them as a list of definitions"""
@@ -215,21 +204,21 @@ class PythonConverter(ExprFunctor):
                 pass
         return defs
 
-
     def create_call(self, func_name: str, arguments):
         """Creates a simple function call."""
         return ast.Call(self.parse_name(func_name), arguments, [])
 
-
     def create_def(self, func_name: str, arguments: [str], body):
         """Wrapper over function definition AST node, whose constructor is inconvenient."""
         return ast.FunctionDef(
             func_name,
-            ast.arguments([ast.arg(argument, None)
-                           for argument in arguments],
-                          None, [], [], None, []),
-            body, [], None)
-
+            ast.arguments(
+                [ast.arg(argument, None) for argument in arguments], None, [], [], None, []
+            ),
+            body,
+            [],
+            None,
+        )
 
     def create_op_call(self, op: Function, relay_args, py_args):
         """Lowers the passed primitive function, registers it in TVM's
@@ -239,14 +228,14 @@ class PythonConverter(ExprFunctor):
         # compile the function and register globally
         cc_key = compile_engine.CCacheKey(op, self.tgt)
         func_hash = tvm.ir.structural_hash(op)
-        op_name = '_lowered_op_{}'.format(func_hash)
+        op_name = "_lowered_op_{}".format(func_hash)
         if not tvm.get_global_func(op_name, allow_missing=True):
             jitted = self.engine.jit(cc_key, self.tgt)
             tvm.register_func(op_name, jitted)
 
         def convert_input(py_input, arg_type):
             """Use the types of the function arguments to determine whether we expect
-               a tensor or tuple (returns list of inputs to the lowered op call)"""
+            a tensor or tuple (returns list of inputs to the lowered op call)"""
             # equivalent: input.data
             if isinstance(arg_type, relay.TensorType):
                 return [py_input]
@@ -255,10 +244,8 @@ class PythonConverter(ExprFunctor):
             ret = []
             for i in range(len(arg_type.fields)):
                 ret += convert_input(
-                    ast.Subscript(
-                        py_input,
-                        ast.Index(Num(i)), Load()),
-                    arg_type.fields[i])
+                    ast.Subscript(py_input, ast.Index(Num(i)), Load()), arg_type.fields[i]
+                )
             return ret
 
         def convert_output(ret_type):
@@ -266,15 +253,16 @@ class PythonConverter(ExprFunctor):
             Returns ([assignments of output vars], [extra arguments to pass to op call],
             expression collecting output)"""
             if isinstance(ret_type, relay.TensorType):
-                output_var_name = self.generate_var_name('_out')
+                output_var_name = self.generate_var_name("_out")
                 output_var = Name(output_var_name, Load())
                 shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
                 # create a new NDArray of the right shape and dtype
                 assign_output = Assign(
                     [Name(output_var_name, Store())],
-                    self.create_call('nd.array', [
-                        self.create_call('numpy.empty', [shape, Str(ret_type.dtype)])
-                    ]))
+                    self.create_call(
+                        "nd.array", [self.create_call("numpy.empty", [shape, Str(ret_type.dtype)])]
+                    ),
+                )
                 return ([assign_output], [output_var], output_var)
             assert isinstance(ret_type, relay.TupleType)
             assignments = []
@@ -286,21 +274,20 @@ class PythonConverter(ExprFunctor):
                 extra_args += inner_args
                 fields.append(inner_output)
             fields = [ast.List(fields, Load())]
-            return (assignments, extra_args, self.create_call('_container.tuple_object', fields))
+            return (assignments, extra_args, self.create_call("_container.tuple_object", fields))
 
         # create a function to wrap the call of the lowered op and return
         # a call to that function
-        wrap_name = self.generate_function_name('_{}_wrapper'.format(op_name))
-        wrap_args = [self.generate_var_name('_arg_{}'.format(i)) for i in range(len(py_args))]
+        wrap_name = self.generate_function_name("_{}_wrapper".format(op_name))
+        wrap_args = [self.generate_var_name("_arg_{}".format(i)) for i in range(len(py_args))]
 
         inner_call_args = []
         for i in range(len(py_args)):
-            inner_call_args += convert_input(Name(wrap_args[i], Load()),
-                                             relay_args[i].checked_type)
+            inner_call_args += convert_input(Name(wrap_args[i], Load()), relay_args[i].checked_type)
         output_assignments, aux_args, output = convert_output(op.checked_type.ret_type)
         # equiv: _op = tvm.get_global_func(op_name)
-        op_var = self.generate_var_name('_op')
-        op_call = self.create_call('tvm.get_global_func', [Str(op_name)])
+        op_var = self.generate_var_name("_op")
+        op_call = self.create_call("tvm.get_global_func", [Str(op_name)])
         op_assign = Assign([Name(op_var, Store())], op_call)
         # equiv: _op(args)
         inner_call = self.create_call(op_var, inner_call_args + aux_args)
@@ -308,7 +295,6 @@ class PythonConverter(ExprFunctor):
         wrap_def = self.create_def(wrap_name, wrap_args, body)
         return wrap_def, self.create_call(wrap_name, py_args)
 
-
     def create_match_check(self, pattern: Pattern, data):
         """Given an ADT match pattern and a (Python) expression pointing to
         an ADT value, this generates a Python expression that checks if the
@@ -325,9 +311,13 @@ class PythonConverter(ExprFunctor):
             # and also the matches of any nested patterns
 
             # equiv: (arg.tag == patern_constructor.tag)
-            conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
-                                     [ast.Eq()],
-                                     [ast.Num(pattern.constructor.tag)]))
+            conds.append(
+                ast.Compare(
+                    ast.Attribute(data, "tag", Load()),
+                    [ast.Eq()],
+                    [ast.Num(pattern.constructor.tag)],
+                )
+            )
 
         assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
         # now check for any nested patterns
@@ -339,8 +329,9 @@ class PythonConverter(ExprFunctor):
                 continue
 
             # index into the value corresponding to the subpattern
-            field_index = ast.Subscript(ast.Attribute(data, 'fields', Load()),
-                                        ast.Index(Num(i)), Load())
+            field_index = ast.Subscript(
+                ast.Attribute(data, "fields", Load()), ast.Index(Num(i)), Load()
+            )
             conds.append(self.create_match_check(nested_pat, field_index))
 
         # if we do not need to check nested pattern, just return the single check
@@ -349,7 +340,6 @@ class PythonConverter(ExprFunctor):
         # otherwise AND together any nested checks
         return ast.BoolOp(ast.And(), conds)
 
-
     def create_match_clause_body(self, pattern: Pattern, body: Expr):
         """Given a match clause pattern and a clause body,
         generates a Python function that when called with an ADT
@@ -377,22 +367,23 @@ class PythonConverter(ExprFunctor):
             assignments = []
             for i in range(len(pat.patterns)):
                 # we want the assignments for val.fields[i]
-                field = ast.Subscript(ast.Attribute(val, 'fields', Load()),
-                                      ast.Index(Num(i)), Load())
+                field = ast.Subscript(
+                    ast.Attribute(val, "fields", Load()), ast.Index(Num(i)), Load()
+                )
                 assignments += collect_var_assignments(pat.patterns[i], field)
             return assignments
 
-        func_name = self.generate_function_name('_match_clause_body')
-        arg_name = self.generate_var_name('_match_clause_body')
+        func_name = self.generate_function_name("_match_clause_body")
+        arg_name = self.generate_var_name("_match_clause_body")
 
         clause_body, defs = self.visit(body)
         assignments = collect_var_assignments(pattern, Name(arg_name, Load()))
 
-        func_def = self.create_def(func_name, [arg_name],
-                                   defs + assignments + [Return(clause_body)])
+        func_def = self.create_def(
+            func_name, [arg_name], defs + assignments + [Return(clause_body)]
+        )
         return (func_def, func_name)
 
-
     # Convention for the expr visitor: Each visit function returns a tuple of two members.
     #
     # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent
@@ -407,13 +398,11 @@ class PythonConverter(ExprFunctor):
     def visit_var(self, var: Expr):
         return (self.include_var(var, assign=False), [])
 
-
     def visit_global_var(self, gvar: Expr):
         # we don't need to add numbers to global var names because
         # the *names* are checked for uniqueness in the mod
         return (Name(str(gvar.name_hint), Load()), [])
 
-
     def visit_let(self, letexp: Expr):
         # To properly account for scoping and ensure that the entire node produces an expression,
         # we translate the let binding as a function that we call with the value we intend to bind.
@@ -427,9 +416,10 @@ class PythonConverter(ExprFunctor):
         """
         bind_body, bind_defs = self.visit(letexp.body)
 
-        func_name = self.generate_function_name('_let_func')
-        binding_func = self.create_def(func_name, [self.get_var_name(letexp.var)],
-                                       bind_defs + [Return(bind_body)])
+        func_name = self.generate_function_name("_let_func")
+        binding_func = self.create_def(
+            func_name, [self.get_var_name(letexp.var)], bind_defs + [Return(bind_body)]
+        )
 
         # we call the binding func with the intended value for the bound variable
 
@@ -437,27 +427,26 @@ class PythonConverter(ExprFunctor):
         # recursive by naming it after the var
         if isinstance(letexp.value, Function):
             value_def, value_name = self.convert_func_node(letexp.value, letexp.var)
-            return (self.create_call(func_name, [Name(value_name, Load())]),
-                    [value_def, binding_func])
+            return (
+                self.create_call(func_name, [Name(value_name, Load())]),
+                [value_def, binding_func],
+            )
 
         value_body, value_defs = self.visit(letexp.value)
         value_defs.append(binding_func)
         binding_call = self.create_call(func_name, [value_body])
         return (binding_call, value_defs)
 
-
     def visit_tuple(self, tup: Expr):
         fields, ret_defs = self.convert_fields(tup.fields)
         fields = [ast.List(fields, Load())]
-        return (self.create_call('_container.tuple_object', fields), ret_defs)
-
+        return (self.create_call("_container.tuple_object", fields), ret_defs)
 
     def visit_tuple_getitem(self, tgi: Expr):
         tup, tup_defs = self.visit(tgi.tuple_value)
         ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load())
         return (ret, tup_defs)
 
-
     def visit_if(self, if_block: Expr):
         cond_body, cond_defs = self.visit(if_block.cond)
         true_body, true_defs = self.visit(if_block.true_branch)
@@ -465,28 +454,27 @@ class PythonConverter(ExprFunctor):
 
         # need to get the value out of a NDArray to check the condition
         # equvialent to: val.asnumpy()
-        cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], [])
+        cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], [])
         ret = ast.IfExp(cond_check, true_body, false_body)
         return (ret, cond_defs + true_defs + false_defs)
 
-
     def visit_constant(self, constant: Expr):
         """Proceeds by converting constant value to a numpy array
         and converting it to the appropriate value in the generated
         code (whether it be a Python scalar or a Numpy array)"""
         value = constant.data.asnumpy()
-        const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
-                              [self.parse_numpy_array(value)],
-                              [ast.keyword('dtype', Str(constant.checked_type.dtype))])
-        return (self.create_call('nd.array', [const_expr]), [])
-
+        const_expr = ast.Call(
+            ast.Attribute(Name("numpy", Load()), "array", Load()),
+            [self.parse_numpy_array(value)],
+            [ast.keyword("dtype", Str(constant.checked_type.dtype))],
+        )
+        return (self.create_call("nd.array", [const_expr]), [])
 
     def visit_function(self, func: Expr):
         # Python's lambdas are very restrictive, so we do "name" inline functions
         converted_func, func_name = self.convert_func_node(func)
         return (Name(func_name, Load()), [converted_func])
 
-
     def visit_call(self, call: Expr):
         """For calls, we must distinguish between ordinary functions,
         operators, and constructor calls."""
@@ -494,15 +482,17 @@ class PythonConverter(ExprFunctor):
         fields, field_defs = self.convert_fields(call.args)
 
         if isinstance(func, tvm.ir.Op):
-            raise Exception('Operators should have been lowered and eliminated')
+            raise Exception("Operators should have been lowered and eliminated")
 
         if isinstance(func, relay.Constructor):
             # produce a constructor value
-            return (self.create_call('ConstructorValue',
-                                     [ast.Num(func.tag),
-                                      ast.List(fields, Load()),
-                                      NameConstant(None)]),
-                    field_defs)
+            return (
+                self.create_call(
+                    "ConstructorValue",
+                    [ast.Num(func.tag), ast.List(fields, Load()), NameConstant(None)],
+                ),
+                field_defs,
+            )
 
         # lowered operator: generate a call to a function that gets the PackedFunc
         # from TVM's registry
@@ -515,16 +505,13 @@ class PythonConverter(ExprFunctor):
         defs += field_defs
         return (ast.Call(converted_func, fields, []), defs)
 
-
     def visit_ref_create(self, ref: Expr):
         val, defs = self.visit(ref.value)
-        return (self.create_call('RefValue', [val]), defs)
-
+        return (self.create_call("RefValue", [val]), defs)
 
     def visit_ref_read(self, read: Expr):
         ref, defs = self.visit(read.ref)
-        return (ast.Attribute(ref, 'value', Load()), defs)
-
+        return (ast.Attribute(ref, "value", Load()), defs)
 
     def visit_ref_write(self, write: Expr):
         """For writing refs, we wrap the update in a thunk
@@ -534,16 +521,19 @@ class PythonConverter(ExprFunctor):
         in Python but expressions in Relay"""
         ref, ref_defs = self.visit(write.ref)
         val, val_defs = self.visit(write.value)
-        thunk_name = self.generate_function_name('_ref_write_thunk')
+        thunk_name = self.generate_function_name("_ref_write_thunk")
         thunk = self.create_def(
-            thunk_name, [],
-            ref_defs + val_defs + [
-                Assign([ast.Attribute(ref, 'value', Store())], val),
-                Return(self.create_call('_container.tuple_object', []))
-            ])
+            thunk_name,
+            [],
+            ref_defs
+            + val_defs
+            + [
+                Assign([ast.Attribute(ref, "value", Store())], val),
+                Return(self.create_call("_container.tuple_object", [])),
+            ],
+        )
         return (self.create_call(thunk_name, []), [thunk])
 
-
     def visit_match(self, match: Expr):
         """For matches, we wrap the entire expression in a thunk
         because it is easiest to implement them using if statements.
@@ -551,7 +541,7 @@ class PythonConverter(ExprFunctor):
         pattern matches. If yes, we call a function that assigns
         the variables appropriately and invokes the clause body."""
         data, defs = self.visit(match.data)
-        data_var = self.generate_var_name('_match_data')
+        data_var = self.generate_var_name("_match_data")
 
         # must ensure the data clause is executed exactly once
         thunk_body = [Assign([Name(data_var, Store())], data)]
@@ -561,28 +551,28 @@ class PythonConverter(ExprFunctor):
             defs.append(body_def)
 
             # equiv: if check(data): return body(data)
-            thunk_body.append(ast.If(
-                check_expr,
-                [Return(self.create_call(body_name, [Name(data_var, Load())]))],
-                []
-            ))
+            thunk_body.append(
+                ast.If(
+                    check_expr, [Return(self.create_call(body_name, [Name(data_var, Load())]))], []
+                )
+            )
 
         # finally if nothing matches we have a failed assert (should never happen)
-        thunk_body.append(ast.Assert(NameConstant(False), Str('Match was not exhaustive')))
+        thunk_body.append(ast.Assert(NameConstant(False), Str("Match was not exhaustive")))
 
-        thunk_name = self.generate_function_name('_match_thunk')
+        thunk_name = self.generate_function_name("_match_thunk")
         thunk_def = self.create_def(thunk_name, [], defs + thunk_body)
         return (self.create_call(thunk_name, []), [thunk_def])
 
-
     # these are both handled in the "call" case
     def visit_constructor(self, _):
         pass
+
     def visit_op(self, _):
         pass
 
 
-def to_python(expr: Expr, mod=None, target=tvm.target.Target('llvm')):
+def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
     """Converts the given Relay expression into a Python script (as a Python AST object).
     For easiest debugging, import the astor package and use to_source()."""
     mod = mod if mod is not None else tvm.IRModule()
@@ -590,15 +580,13 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target('llvm')):
     return converter.convert(expr)
 
 
-def run_as_python(expr: Expr, mod=None, target=tvm.target.Target('llvm')):
+def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
     """Converts the given Relay expression into a Python script and
     executes it."""
     mod = mod if mod is not None else tvm.IRModule()
     py_ast = to_python(expr, mod, target)
-    code = compile(py_ast, '<string>', 'exec')
-    var_map = {
-        OUTPUT_VAR_NAME : None
-    }
-    #pylint: disable=exec-used
+    code = compile(py_ast, "<string>", "exec")
+    var_map = {OUTPUT_VAR_NAME: None}
+    # pylint: disable=exec-used
     exec(code, var_map, var_map)
     return var_map[OUTPUT_VAR_NAME]
index ac63afd..bc5f5c4 100644 (file)
@@ -27,15 +27,17 @@ from tvm import relay
 from .init import create_workload
 from . import layers
 
-def residual_unit(data,
-                  num_filter,
-                  stride,
-                  dim_match,
-                  name,
-                  bottle_neck=True,
-                  data_layout="NCHW",
-                  kernel_layout="IOHW"
-                  ):
+
+def residual_unit(
+    data,
+    num_filter,
+    stride,
+    dim_match,
+    name,
+    bottle_neck=True,
+    data_layout="NCHW",
+    kernel_layout="IOHW",
+):
     """Return ResNet Unit symbol for building ResNet
 
     Parameters
@@ -59,74 +61,108 @@ def residual_unit(data,
     name : str
         Base name of the operators
     """
-    bn_axis = data_layout.index('C')
+    bn_axis = data_layout.index("C")
     if bottle_neck:
-        bn1 = layers.batch_norm_infer(data=data,
-                                      epsilon=2e-5,
-                                      axis=bn_axis,
-                                      name=name + '_bn1')
+        bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + "_bn1")
         act1 = relay.nn.relu(data=bn1)
         conv1 = layers.conv2d(
             data=act1,
-            channels=int(num_filter*0.25),
+            channels=int(num_filter * 0.25),
             kernel_size=(1, 1),
             strides=stride,
             padding=(0, 0),
-            name=name + '_conv1',
+            name=name + "_conv1",
             data_layout=data_layout,
-            kernel_layout=kernel_layout)
-        bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2')
+            kernel_layout=kernel_layout,
+        )
+        bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + "_bn2")
         act2 = relay.nn.relu(data=bn2)
         conv2 = layers.conv2d(
-            data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
-            strides=(1, 1), padding=(1, 1), name=name + '_conv2',
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + '_bn3')
+            data=act2,
+            channels=int(num_filter * 0.25),
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding=(1, 1),
+            name=name + "_conv2",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + "_bn3")
         act3 = relay.nn.relu(data=bn3)
         conv3 = layers.conv2d(
-            data=act3, channels=num_filter, kernel_size=(1, 1),
-            strides=(1, 1), padding=(0, 0), name=name + '_conv3',
-            data_layout=data_layout, kernel_layout=kernel_layout)
+            data=act3,
+            channels=num_filter,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding=(0, 0),
+            name=name + "_conv3",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
         if dim_match:
             shortcut = data
         else:
             shortcut = layers.conv2d(
-                data=act1, channels=num_filter, kernel_size=(1, 1),
-                strides=stride, name=name+'_sc',
-                data_layout=data_layout, kernel_layout=kernel_layout)
+                data=act1,
+                channels=num_filter,
+                kernel_size=(1, 1),
+                strides=stride,
+                name=name + "_sc",
+                data_layout=data_layout,
+                kernel_layout=kernel_layout,
+            )
         return relay.add(conv3, shortcut)
 
-    bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + '_bn1')
+    bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + "_bn1")
     act1 = relay.nn.relu(data=bn1)
     conv1 = layers.conv2d(
-        data=act1, channels=num_filter, kernel_size=(3, 3),
-        strides=stride, padding=(1, 1), name=name + '_conv1',
-        data_layout=data_layout, kernel_layout=kernel_layout)
-    bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2')
+        data=act1,
+        channels=num_filter,
+        kernel_size=(3, 3),
+        strides=stride,
+        padding=(1, 1),
+        name=name + "_conv1",
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
+    bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + "_bn2")
     act2 = relay.nn.relu(data=bn2)
     conv2 = layers.conv2d(
-        data=act2, channels=num_filter, kernel_size=(3, 3),
-        strides=(1, 1), padding=(1, 1), name=name + '_conv2',
-        data_layout=data_layout, kernel_layout=kernel_layout)
+        data=act2,
+        channels=num_filter,
+        kernel_size=(3, 3),
+        strides=(1, 1),
+        padding=(1, 1),
+        name=name + "_conv2",
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
 
     if dim_match:
         shortcut = data
     else:
         shortcut = layers.conv2d(
-            data=act1, channels=num_filter, kernel_size=(1, 1),
-            strides=stride, name=name+'_sc',
-            data_layout=data_layout, kernel_layout=kernel_layout)
+            data=act1,
+            channels=num_filter,
+            kernel_size=(1, 1),
+            strides=stride,
+            name=name + "_sc",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
     return relay.add(conv2, shortcut)
 
 
-def resnet(units,
-           num_stages,
-           filter_list,
-           num_classes,
-           data_shape,
-           bottle_neck=True,
-           layout="NCHW",
-           dtype="float32"):
+def resnet(
+    units,
+    num_stages,
+    filter_list,
+    num_classes,
+    data_shape,
+    bottle_neck=True,
+    layout="NCHW",
+    dtype="float32",
+):
     """Return ResNet Program.
 
     Parameters
@@ -158,58 +194,86 @@ def resnet(units,
 
     data_layout = layout
     kernel_layout = "OIHW" if layout == "NCHW" else "HWIO"
-    bn_axis = data_layout.index('C')
+    bn_axis = data_layout.index("C")
 
     num_unit = len(units)
     assert num_unit == num_stages
     data = relay.var("data", shape=data_shape, dtype=dtype)
-    data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False,
-                                   name='bn_data')
+    data = layers.batch_norm_infer(
+        data=data, epsilon=2e-5, axis=bn_axis, scale=False, name="bn_data"
+    )
     (_, _, height, _) = data_shape
     if layout == "NHWC":
         (_, height, _, _) = data_shape
-    if height <= 32:            # such as cifar10
+    if height <= 32:  # such as cifar10
         body = layers.conv2d(
-            data=data, channels=filter_list[0], kernel_size=(3, 3),
-            strides=(1, 1), padding=(1, 1), name="conv0",
-            data_layout=data_layout, kernel_layout=kernel_layout)
-    else:                       # often expected to be 224 such as imagenet
+            data=data,
+            channels=filter_list[0],
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding=(1, 1),
+            name="conv0",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+    else:  # often expected to be 224 such as imagenet
         body = layers.conv2d(
-            data=data, channels=filter_list[0], kernel_size=(7, 7),
-            strides=(2, 2), padding=(3, 3), name="conv0",
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn0')
+            data=data,
+            channels=filter_list[0],
+            kernel_size=(7, 7),
+            strides=(2, 2),
+            padding=(3, 3),
+            name="conv0",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name="bn0")
         body = relay.nn.relu(data=body)
-        body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
-                                   layout=data_layout)
+        body = relay.nn.max_pool2d(
+            data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout
+        )
 
     for i in range(num_stages):
         body = residual_unit(
-            body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
-            False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        for j in range(units[i]-1):
+            body,
+            filter_list[i + 1],
+            (1 if i == 0 else 2, 1 if i == 0 else 2),
+            False,
+            name="stage%d_unit%d" % (i + 1, 1),
+            bottle_neck=bottle_neck,
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        for j in range(units[i] - 1):
             body = residual_unit(
-                body, filter_list[i+1], (1, 1), True,
-                name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck,
-                data_layout=data_layout, kernel_layout=kernel_layout)
-    bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn1')
+                body,
+                filter_list[i + 1],
+                (1, 1),
+                True,
+                name="stage%d_unit%d" % (i + 1, j + 2),
+                bottle_neck=bottle_neck,
+                data_layout=data_layout,
+                kernel_layout=kernel_layout,
+            )
+    bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name="bn1")
     relu1 = relay.nn.relu(data=bn1)
     # Although kernel is not used here when global_pool=True, we should put one
     pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout)
     flat = relay.nn.batch_flatten(data=pool1)
-    fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
+    fc1 = layers.dense_add_bias(data=flat, units=num_classes, name="fc1")
     net = relay.nn.softmax(data=fc1)
     return relay.Function(relay.analysis.free_vars(net), net)
 
 
-def get_net(batch_size,
-            num_classes,
-            num_layers=50,
-            image_shape=(3, 224, 224),
-            layout="NCHW",
-            dtype="float32",
-            **kwargs):
+def get_net(
+    batch_size,
+    num_classes,
+    num_layers=50,
+    image_shape=(3, 224, 224),
+    layout="NCHW",
+    dtype="float32",
+    **kwargs,
+):
     """
     Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
     Original author Wei Wu
@@ -220,12 +284,12 @@ def get_net(batch_size,
     data_shape = (batch_size,) + image_shape
     if height <= 28:
         num_stages = 3
-        if (num_layers-2) % 9 == 0 and num_layers >= 164:
-            per_unit = [(num_layers-2)//9]
+        if (num_layers - 2) % 9 == 0 and num_layers >= 164:
+            per_unit = [(num_layers - 2) // 9]
             filter_list = [16, 64, 128, 256]
             bottle_neck = True
-        elif (num_layers-2) % 6 == 0 and num_layers < 164:
-            per_unit = [(num_layers-2)//6]
+        elif (num_layers - 2) % 6 == 0 and num_layers < 164:
+            per_unit = [(num_layers - 2) // 6]
             filter_list = [16, 16, 32, 64]
             bottle_neck = False
         else:
@@ -256,23 +320,27 @@ def get_net(batch_size,
         else:
             raise ValueError("no experiments done on num_layers {}".format(num_layers))
 
-    return resnet(units=units,
-                  num_stages=num_stages,
-                  filter_list=filter_list,
-                  num_classes=num_classes,
-                  data_shape=data_shape,
-                  bottle_neck=bottle_neck,
-                  layout=layout,
-                  dtype=dtype)
-
-
-def get_workload(batch_size=1,
-                 num_classes=1000,
-                 num_layers=18,
-                 image_shape=(3, 224, 224),
-                 layout="NCHW",
-                 dtype="float32",
-                 **kwargs):
+    return resnet(
+        units=units,
+        num_stages=num_stages,
+        filter_list=filter_list,
+        num_classes=num_classes,
+        data_shape=data_shape,
+        bottle_neck=bottle_neck,
+        layout=layout,
+        dtype=dtype,
+    )
+
+
+def get_workload(
+    batch_size=1,
+    num_classes=1000,
+    num_layers=18,
+    image_shape=(3, 224, 224),
+    layout="NCHW",
+    dtype="float32",
+    **kwargs,
+):
     """Get benchmark workload for resnet
 
     Parameters
@@ -306,11 +374,13 @@ def get_workload(batch_size=1,
     params : dict of str to NDArray
         The parameters.
     """
-    net = get_net(batch_size=batch_size,
-                  num_classes=num_classes,
-                  num_layers=num_layers,
-                  image_shape=image_shape,
-                  dtype=dtype,
-                  layout=layout,
-                  **kwargs)
+    net = get_net(
+        batch_size=batch_size,
+        num_classes=num_classes,
+        num_layers=num_layers,
+        image_shape=image_shape,
+        dtype=dtype,
+        layout=layout,
+        **kwargs,
+    )
     return create_workload(net)
index 0330d2e..484f51d 100644 (file)
@@ -25,15 +25,17 @@ from tvm import relay
 from .init import create_workload
 from . import layers
 
-def residual_unit(data,
-                  num_filter,
-                  stride,
-                  dim_match,
-                  name,
-                  bottle_neck=True,
-                  data_layout="NCDHW",
-                  kernel_layout="OIDHW"
-                  ):
+
+def residual_unit(
+    data,
+    num_filter,
+    stride,
+    dim_match,
+    name,
+    bottle_neck=True,
+    data_layout="NCDHW",
+    kernel_layout="OIDHW",
+):
     """Return ResNet Unit symbol for building ResNet
 
     Parameters
@@ -58,71 +60,106 @@ def residual_unit(data,
         Base name of the operators
     """
     if bottle_neck:
-        bn1 = layers.batch_norm_infer(data=data,
-                                      epsilon=2e-5,
-                                      name=name + '_bn1')
+        bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + "_bn1")
         act1 = relay.nn.relu(data=bn1)
         conv1 = layers.conv3d(
             data=act1,
-            channels=int(num_filter*0.25),
+            channels=int(num_filter * 0.25),
             kernel_size=(1, 1, 1),
             strides=stride,
             padding=(0, 0, 0),
-            name=name + '_conv1',
+            name=name + "_conv1",
             data_layout=data_layout,
-            kernel_layout=kernel_layout)
-        bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
+            kernel_layout=kernel_layout,
+        )
+        bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + "_bn2")
         act2 = relay.nn.relu(data=bn2)
         conv2 = layers.conv3d(
-            data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3, 3),
-            strides=(1, 1, 1), padding=(1, 1, 1), name=name + '_conv2',
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3')
+            data=act2,
+            channels=int(num_filter * 0.25),
+            kernel_size=(3, 3, 3),
+            strides=(1, 1, 1),
+            padding=(1, 1, 1),
+            name=name + "_conv2",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + "_bn3")
         act3 = relay.nn.relu(data=bn3)
         conv3 = layers.conv3d(
-            data=act3, channels=num_filter, kernel_size=(1, 1, 1),
-            strides=(1, 1, 1), padding=(0, 0, 0), name=name + '_conv3',
-            data_layout=data_layout, kernel_layout=kernel_layout)
+            data=act3,
+            channels=num_filter,
+            kernel_size=(1, 1, 1),
+            strides=(1, 1, 1),
+            padding=(0, 0, 0),
+            name=name + "_conv3",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
         if dim_match:
             shortcut = data
         else:
             shortcut = layers.conv3d(
-                data=act1, channels=num_filter, kernel_size=(1, 1, 1),
-                strides=stride, name=name+'_sc',
-                data_layout=data_layout, kernel_layout=kernel_layout)
+                data=act1,
+                channels=num_filter,
+                kernel_size=(1, 1, 1),
+                strides=stride,
+                name=name + "_sc",
+                data_layout=data_layout,
+                kernel_layout=kernel_layout,
+            )
         return relay.add(conv3, shortcut)
 
-    bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
+    bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + "_bn1")
     act1 = relay.nn.relu(data=bn1)
     conv1 = layers.conv3d(
-        data=act1, channels=num_filter, kernel_size=(3, 3, 3),
-        strides=stride, padding=(1, 1, 1), name=name + '_conv1',
-        data_layout=data_layout, kernel_layout=kernel_layout)
-    bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
+        data=act1,
+        channels=num_filter,
+        kernel_size=(3, 3, 3),
+        strides=stride,
+        padding=(1, 1, 1),
+        name=name + "_conv1",
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
+    bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + "_bn2")
     act2 = relay.nn.relu(data=bn2)
     conv2 = layers.conv3d(
-        data=act2, channels=num_filter, kernel_size=(3, 3, 3),
-        strides=(1, 1, 1), padding=(1, 1, 1), name=name + '_conv2',
-        data_layout=data_layout, kernel_layout=kernel_layout)
+        data=act2,
+        channels=num_filter,
+        kernel_size=(3, 3, 3),
+        strides=(1, 1, 1),
+        padding=(1, 1, 1),
+        name=name + "_conv2",
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
 
     if dim_match:
         shortcut = data
     else:
         shortcut = layers.conv3d(
-            data=act1, channels=num_filter, kernel_size=(1, 1, 1),
-            strides=stride, name=name+'_sc',
-            data_layout=data_layout, kernel_layout=kernel_layout)
+            data=act1,
+            channels=num_filter,
+            kernel_size=(1, 1, 1),
+            strides=stride,
+            name=name + "_sc",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
     return relay.add(conv2, shortcut)
 
 
-def resnet(units,
-           num_stages,
-           filter_list,
-           num_classes,
-           data_shape,
-           bottle_neck=True,
-           layout="NCDHW",
-           dtype="float32"):
+def resnet(
+    units,
+    num_stages,
+    filter_list,
+    num_classes,
+    data_shape,
+    bottle_neck=True,
+    layout="NCDHW",
+    dtype="float32",
+):
     """Return ResNet Program.
 
     Parameters
@@ -158,53 +195,79 @@ def resnet(units,
     num_unit = len(units)
     assert num_unit == num_stages
     data = relay.var("data", shape=data_shape, dtype=dtype)
-    data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data')
+    data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name="bn_data")
     if layout == "NCDHW":
         (_, _, _, height, _) = data_shape
     else:
         (_, _, height, _, _) = data_shape
-    if height <= 32:            # such as cifar10
+    if height <= 32:  # such as cifar10
         body = layers.conv3d(
-            data=data, channels=filter_list[0], kernel_size=(3, 3, 3),
-            strides=(1, 1, 1), padding=(1, 1, 1), name="conv0",
-            data_layout=data_layout, kernel_layout=kernel_layout)
-    else:                       # often expected to be 224 such as imagenet
+            data=data,
+            channels=filter_list[0],
+            kernel_size=(3, 3, 3),
+            strides=(1, 1, 1),
+            padding=(1, 1, 1),
+            name="conv0",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+    else:  # often expected to be 224 such as imagenet
         body = layers.conv3d(
-            data=data, channels=filter_list[0], kernel_size=(3, 7, 7),
-            strides=(1, 2, 2), padding=(1, 3, 3), name="conv0",
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0')
+            data=data,
+            channels=filter_list[0],
+            kernel_size=(3, 7, 7),
+            strides=(1, 2, 2),
+            padding=(1, 3, 3),
+            name="conv0",
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        body = layers.batch_norm_infer(data=body, epsilon=2e-5, name="bn0")
         body = relay.nn.relu(data=body)
-        #body = relay.nn.max_pool3d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
+        # body = relay.nn.max_pool3d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
         #                           layout=data_layout)
 
     for i in range(num_stages):
         body = residual_unit(
-            body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2, 1 if i == 0 else 2),
-            False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
-            data_layout=data_layout, kernel_layout=kernel_layout)
-        for j in range(units[i]-1):
+            body,
+            filter_list[i + 1],
+            (1 if i == 0 else 2, 1 if i == 0 else 2, 1 if i == 0 else 2),
+            False,
+            name="stage%d_unit%d" % (i + 1, 1),
+            bottle_neck=bottle_neck,
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+        )
+        for j in range(units[i] - 1):
             body = residual_unit(
-                body, filter_list[i+1], (1, 1, 1), True,
-                name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck,
-                data_layout=data_layout, kernel_layout=kernel_layout)
-    bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1')
+                body,
+                filter_list[i + 1],
+                (1, 1, 1),
+                True,
+                name="stage%d_unit%d" % (i + 1, j + 2),
+                bottle_neck=bottle_neck,
+                data_layout=data_layout,
+                kernel_layout=kernel_layout,
+            )
+    bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name="bn1")
     relu1 = relay.nn.relu(data=bn1)
     # Although kernel is not used here when global_pool=True, we should put one
     pool1 = relay.nn.global_avg_pool3d(data=relu1, layout=data_layout)
     flat = relay.nn.batch_flatten(data=pool1)
-    fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
+    fc1 = layers.dense_add_bias(data=flat, units=num_classes, name="fc1")
     net = relay.nn.softmax(data=fc1)
     return relay.Function(relay.analysis.free_vars(net), net)
 
 
-def get_net(batch_size,
-            num_classes,
-            num_layers=50,
-            image_shape=(3, 16, 112, 112),
-            layout="NCDHW",
-            dtype="float32",
-            **kwargs):
+def get_net(
+    batch_size,
+    num_classes,
+    num_layers=50,
+    image_shape=(3, 16, 112, 112),
+    layout="NCDHW",
+    dtype="float32",
+    **kwargs,
+):
     """
     Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
     Original author Wei Wu
@@ -216,12 +279,12 @@ def get_net(batch_size,
     data_shape = (batch_size,) + image_shape
     if height <= 28:
         num_stages = 3
-        if (num_layers-2) % 9 == 0 and num_layers >= 164:
-            per_unit = [(num_layers-2)//9]
+        if (num_layers - 2) % 9 == 0 and num_layers >= 164:
+            per_unit = [(num_layers - 2) // 9]
             filter_list = [16, 64, 128, 256]
             bottle_neck = True
-        elif (num_layers-2) % 6 == 0 and num_layers < 164:
-            per_unit = [(num_layers-2)//6]
+        elif (num_layers - 2) % 6 == 0 and num_layers < 164:
+            per_unit = [(num_layers - 2) // 6]
             filter_list = [16, 16, 32, 64]
             bottle_neck = False
         else:
@@ -252,23 +315,27 @@ def get_net(batch_size,
         else:
             raise ValueError("no experiments done on num_layers {}".format(num_layers))
 
-    return resnet(units=units,
-                  num_stages=num_stages,
-                  filter_list=filter_list,
-                  num_classes=num_classes,
-                  data_shape=data_shape,
-                  bottle_neck=bottle_neck,
-                  layout=layout,
-                  dtype=dtype)
-
-
-def get_workload(batch_size=1,
-                 num_classes=1000,
-                 num_layers=18,
-                 image_shape=(3, 16, 112, 112),
-                 layout="NCDHW",
-                 dtype="float32",
-                 **kwargs):
+    return resnet(
+        units=units,
+        num_stages=num_stages,
+        filter_list=filter_list,
+        num_classes=num_classes,
+        data_shape=data_shape,
+        bottle_neck=bottle_neck,
+        layout=layout,
+        dtype=dtype,
+    )
+
+
+def get_workload(
+    batch_size=1,
+    num_classes=1000,
+    num_layers=18,
+    image_shape=(3, 16, 112, 112),
+    layout="NCDHW",
+    dtype="float32",
+    **kwargs,
+):
     """Get benchmark workload for resnet
 
     Parameters
@@ -302,11 +369,13 @@ def get_workload(batch_size=1,
     params : dict of str to NDArray
         The parameters.
     """
-    net = get_net(batch_size=batch_size,
-                  num_classes=num_classes,
-                  num_layers=num_layers,
-                  image_shape=image_shape,
-                  dtype=dtype,
-                  layout=layout,
-                  **kwargs)
+    net = get_net(
+        batch_size=batch_size,
+        num_classes=num_classes,
+        num_layers=num_layers,
+        image_shape=image_shape,
+        dtype=dtype,
+        layout=layout,
+        **kwargs,
+    )
     return create_workload(net)
index 1a946b6..097f223 100644 (file)
@@ -40,15 +40,20 @@ def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, pr
     net = relay.concatenate((left, right), axis=1)
     return net
 
+
 def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""):
-    net = layers.conv2d(net,
-                        channels=channels,
-                        kernel_size=(kernel_size, kernel_size),
-                        padding=(padding, padding), name="%s_conv" % prefix)
+    net = layers.conv2d(
+        net,
+        channels=channels,
+        kernel_size=(kernel_size, kernel_size),
+        padding=(padding, padding),
+        name="%s_conv" % prefix,
+    )
     net = relay.nn.bias_add(net, relay.var("%s_conv_bias" % prefix))
     net = relay.nn.relu(net)
     return net
 
+
 # Net
 def get_net(batch_size, image_shape, num_classes, version, dtype):
     """Get symbol of SqueezeNet
@@ -67,17 +72,16 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
     version : str, optional
         "1.0" or "1.1" of SqueezeNet
     """
-    assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
-                                       "1.0 or 1.1 expected".format(version=version))
+    assert version in [
+        "1.0",
+        "1.1",
+    ], "Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)
     data_shape = (batch_size,) + image_shape
     net = relay.var("data", shape=data_shape, dtype=dtype)
-    if version == '1.0':
-        net = layers.conv2d(net,
-                            channels=96,
-                            kernel_size=(7, 7),
-                            strides=(2, 2),
-                            padding=(3, 3),
-                            name="conv1")
+    if version == "1.0":
+        net = layers.conv2d(
+            net, channels=96, kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), name="conv1"
+        )
         net = relay.nn.bias_add(net, relay.var("conv1_bias"))
         net = relay.nn.relu(net)
         net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
@@ -92,12 +96,9 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
         net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
         net = _make_fire(net, 64, 256, 256, "fire8")
     else:
-        net = layers.conv2d(net,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            strides=(2, 2),
-                            padding=(1, 1),
-                            name="conv1")
+        net = layers.conv2d(
+            net, channels=64, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), name="conv1"
+        )
         net = relay.nn.bias_add(net, relay.var("conv1_bias"))
         net = relay.nn.relu(net)
         net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
@@ -112,8 +113,7 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
         net = _make_fire(net, 64, 256, 256, "fire7")
         net = _make_fire(net, 64, 256, 256, "fire8")
     net = relay.nn.dropout(net, rate=0.5)
-    net = layers.conv2d(
-        net, channels=num_classes, kernel_size=(1, 1), name="conv_final")
+    net = layers.conv2d(net, channels=num_classes, kernel_size=(1, 1), name="conv_final")
     net = relay.nn.bias_add(net, relay.var("conv_final_bias"))
     net = relay.nn.relu(net)
     net = relay.nn.global_avg_pool2d(net)
@@ -123,11 +123,9 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
     return relay.Function(args, net)
 
 
-def get_workload(batch_size=1,
-                 num_classes=1000,
-                 version='1.0',
-                 image_shape=(3, 224, 224),
-                 dtype="float32"):
+def get_workload(
+    batch_size=1, num_classes=1000, version="1.0", image_shape=(3, 224, 224), dtype="float32"
+):
     """Get benchmark workload for SqueezeNet
 
     Parameters
index e8bda77..7b77789 100644 (file)
@@ -51,9 +51,7 @@ def get_net(input_shape=(1, 3, 24, 12), dtype="float32", wtype=None):
     dense = relay.nn.relu(
         relay.nn.dense(
             relay.reshape(data, dense_shape),
-            relay.var(
-                "dense_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype
-            ),
+            relay.var("dense_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype),
         )
     )
     dense = relay.reshape_like(dense, data)
@@ -73,9 +71,7 @@ def get_net(input_shape=(1, 3, 24, 12), dtype="float32", wtype=None):
     dense = relay.nn.relu(
         relay.nn.dense(
             relay.reshape(biased, dense_shape),
-            relay.var(
-                "dense2_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype
-            ),
+            relay.var("dense2_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype),
         )
     )
     dense = relay.reshape_like(dense, data)
index 67afbd5..12e3652 100644 (file)
@@ -20,10 +20,12 @@ tests."""
 
 from tvm import relay
 
+
 class TempOpAttr(object):
     """ Temporarily changes the attr of an op. """
+
     def __init__(self, op_name, attr_key, attr_value):
-        """ Saves the required info for RAII pattern usage.
+        """Saves the required info for RAII pattern usage.
 
         Parameters
         ----------
index 9715bd7..38bb30f 100644 (file)
@@ -39,6 +39,7 @@ except ImportError:
 # Some helper functions
 # ---------------------
 
+
 def ProcessGraphDefParam(graph_def):
     """Type-checks and possibly canonicalizes `graph_def`.
 
@@ -62,7 +63,7 @@ def ProcessGraphDefParam(graph_def):
             graph_def = graph_pb2.GraphDef()
             graph_def.MergeFrom(old_graph_def)
         except TypeError:
-            raise TypeError('graph_def must be a GraphDef proto.')
+            raise TypeError("graph_def must be a GraphDef proto.")
     return graph_def
 
 
@@ -71,8 +72,9 @@ def convert_to_list(x):
         x = [x]
     return x
 
+
 def AddShapesToGraphDef(session, out_node):
-    """ Add shapes attribute to nodes of the graph.
+    """Add shapes attribute to nodes of the graph.
         Input graph here is the default graph in context.
 
     Parameters
@@ -93,15 +95,14 @@ def AddShapesToGraphDef(session, out_node):
         session,
         session.graph.as_graph_def(add_shapes=True),
         convert_to_list(out_node),
-        )
+    )
     return graph_def
 
+
 class NodeLookup(object):
     """Converts integer node ID's to human readable labels."""
 
-    def __init__(self,
-                 label_lookup_path=None,
-                 uid_lookup_path=None):
+    def __init__(self, label_lookup_path=None, uid_lookup_path=None):
         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
 
     def load(self, label_lookup_path, uid_lookup_path):
@@ -122,14 +123,14 @@ class NodeLookup(object):
 
         """
         if not tf_compat_v1.gfile.Exists(uid_lookup_path):
-            tf.logging.fatal('File does not exist %s', uid_lookup_path)
+            tf.logging.fatal("File does not exist %s", uid_lookup_path)
         if not tf_compat_v1.gfile.Exists(label_lookup_path):
-            tf.logging.fatal('File does not exist %s', label_lookup_path)
+            tf.logging.fatal("File does not exist %s", label_lookup_path)
 
         # Loads mapping from string UID to human-readable string
         proto_as_ascii_lines = tf_compat_v1.gfile.GFile(uid_lookup_path).readlines()
         uid_to_human = {}
-        p = re.compile(r'[n\d]*[ \S,]*')
+        p = re.compile(r"[n\d]*[ \S,]*")
         for line in proto_as_ascii_lines:
             parsed_items = p.findall(line)
             uid = parsed_items[0]
@@ -140,17 +141,17 @@ class NodeLookup(object):
         node_id_to_uid = {}
         proto_as_ascii = tf_compat_v1.gfile.GFile(label_lookup_path).readlines()
         for line in proto_as_ascii:
-            if line.startswith('  target_class:'):
-                target_class = int(line.split(': ')[1])
-            if line.startswith('  target_class_string:'):
-                target_class_string = line.split(': ')[1]
+            if line.startswith("  target_class:"):
+                target_class = int(line.split(": ")[1])
+            if line.startswith("  target_class_string:"):
+                target_class_string = line.split(": ")[1]
                 node_id_to_uid[target_class] = target_class_string[1:-2]
 
         # Loads the final mapping of integer node ID to human-readable string
         node_id_to_name = {}
         for key, val in node_id_to_uid.items():
             if val not in uid_to_human:
-                tf.logging.fatal('Failed to locate: %s', val)
+                tf.logging.fatal("Failed to locate: %s", val)
             name = uid_to_human[val]
             node_id_to_name[key] = name
 
@@ -158,11 +159,12 @@ class NodeLookup(object):
 
     def id_to_string(self, node_id):
         if node_id not in self.node_lookup:
-            return ''
+            return ""
         return self.node_lookup[node_id]
 
+
 def get_workload_official(model_url, model_sub_path):
-    """ Import workload from tensorflow official
+    """Import workload from tensorflow official
 
     Parameters
     ----------
@@ -180,25 +182,28 @@ def get_workload_official(model_url, model_sub_path):
     """
 
     model_tar_name = os.path.basename(model_url)
-    model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
+    model_path = download_testdata(model_url, model_tar_name, module=["tf", "official"])
     dir_path = os.path.dirname(model_path)
 
     if model_path.endswith("tgz") or model_path.endswith("gz"):
         import tarfile
+
         tar = tarfile.open(model_path)
         tar.extractall(path=dir_path)
         tar.close()
     elif model_path.endswith("zip"):
         import zipfile
+
         zip_object = zipfile.ZipFile(model_path)
         zip_object.extractall(path=dir_path)
         zip_object.close()
     else:
-        raise RuntimeError('Could not decompress the file: ' + model_path)
+        raise RuntimeError("Could not decompress the file: " + model_path)
     return os.path.join(dir_path, model_sub_path)
 
+
 def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None):
-    """ Import workload from frozen protobuf
+    """Import workload from frozen protobuf
 
     Parameters
     ----------
@@ -218,15 +223,15 @@ def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None)
     if model_sub_path:
         path_model = get_workload_official(model_path, model_sub_path)
     else:
-        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
+        repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/"
         model_url = os.path.join(repo_base, model_path)
-        path_model = download_testdata(model_url, model_path, module='tf')
+        path_model = download_testdata(model_url, model_path, module="tf")
 
     # Creates graph from saved graph_def.pb.
-    with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f:
+    with tf_compat_v1.gfile.FastGFile(path_model, "rb") as f:
         graph_def = tf_compat_v1.GraphDef()
         graph_def.ParseFromString(f.read())
-        graph = tf_compat_v1.import_graph_def(graph_def, name='', input_map=inputs_dict)
+        graph = tf_compat_v1.import_graph_def(graph_def, name="", input_map=inputs_dict)
 
     if inputs_dict is not None:
         # graph is changed so generate graph_def again
@@ -235,14 +240,17 @@ def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None)
 
     return graph_def
 
+
 #######################################################################
 # PTB LSTMBlockCell Model
 # -----------------------
 
+
 class PTBSmallConfig(object):
     """Small config.
     This configurations are used when training the model
     """
+
     num_layers = 2
     num_steps = 1
     hidden_size = 200
@@ -250,38 +258,47 @@ class PTBSmallConfig(object):
     vocab_size = 10000
     init_scale = 0.1
 
+
 def get_config():
     """Configuration used for training the model"""
     return PTBSmallConfig()
 
+
 def pick_from_weight(weight, pows=1.0):
     """Identify token from Softmax output.
     This token will be mapped to word in the vocabulary.
     """
-    weight = weight**pows
+    weight = weight ** pows
     t = np.cumsum(weight)
     s = np.sum(weight)
     return int(np.searchsorted(t, 0.5 * s))
 
+
 def do_tf_sample(session, data, in_states, num_samples):
     """Sampled from the model"""
     samples = []
     sample = None
-    #Cell inputs c and h should be passed for each layer explicitly.
-    state_input_name = ['Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0',
-                        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0',
-                        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0',
-                        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0']
+    # Cell inputs c and h should be passed for each layer explicitly.
+    state_input_name = [
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0",
+    ]
     state = in_states
 
-    #Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
-    #nodes for intermediate operations (gates) in the cell during run.
-    #Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
-    fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
-                'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
-                'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
-                'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6'],
-               'Model/Softmax:0']
+    # Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
+    # nodes for intermediate operations (gates) in the cell during run.
+    # Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
+    fetches = [
+        [
+            "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1",
+            "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6",
+            "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1",
+            "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6",
+        ],
+        "Model/Softmax:0",
+    ]
 
     def _get_feed_dict(input_name, input_data):
         """Create feed dict"""
@@ -295,7 +312,7 @@ def do_tf_sample(session, data, in_states, num_samples):
 
     for x in data:
         feed_dict = _get_feed_dict(state_input_name, state)
-        feed_dict['Model/Placeholder:0'] = [[x]]
+        feed_dict["Model/Placeholder:0"] = [[x]]
         state, probs = session.run(fetches, feed_dict)
         sample = pick_from_weight(probs[0])
     if sample is not None:
@@ -306,17 +323,19 @@ def do_tf_sample(session, data, in_states, num_samples):
     k = 1
     while k < num_samples:
         feed_dict = _get_feed_dict(state_input_name, state)
-        feed_dict['Model/Placeholder:0'] = [[samples[-1]]]
+        feed_dict["Model/Placeholder:0"] = [[samples[-1]]]
         state, probs = session.run(fetches, feed_dict)
         sample = pick_from_weight(probs[0])
         samples.append(sample)
         k += 1
     return samples, state
 
+
 def _create_ptb_vocabulary(data_dir):
     """Read the PTB sample data input to create vocabulary"""
-    data_path = os.path.join(data_dir, 'simple-examples/data/')
-    file_name = 'ptb.train.txt'
+    data_path = os.path.join(data_dir, "simple-examples/data/")
+    file_name = "ptb.train.txt"
+
     def _read_words(filename):
         """Read the data for creating vocabulary"""
         with tf_compat_v1.gfile.GFile(filename, "r") as f:
@@ -329,7 +348,7 @@ def _create_ptb_vocabulary(data_dir):
         count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
         words, _ = list(zip(*count_pairs))
         word_to_id = dict(zip(words, range(len(words))))
-        #for python 3.x
+        # for python 3.x
         id_to_word = dict((v, k) for k, v in word_to_id.items())
         return word_to_id, id_to_word
 
@@ -338,10 +357,12 @@ def _create_ptb_vocabulary(data_dir):
         train_path = os.path.join(data_path, file_name)
         word_to_id, id_2_word = _build_vocab(train_path)
         return word_to_id, id_2_word
+
     return ptb_raw_data(data_path, file_name)
 
+
 def get_workload_ptb():
-    """ Import ptb workload from frozen protobuf
+    """Import ptb workload from frozen protobuf
 
     Parameters
     ----------
@@ -358,39 +379,38 @@ def get_workload_ptb():
     id_to_word : dict
         Integer id to English word mapping
     """
-    sample_repo = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
-    sample_data_file = 'simple-examples.tgz'
-    sample_url = sample_repo+sample_data_file
-    ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
+    sample_repo = "http://www.fit.vutbr.cz/~imikolov/rnnlm/"
+    sample_data_file = "simple-examples.tgz"
+    sample_url = sample_repo + sample_data_file
+    ptb_model_file = "RNN/ptb/ptb_model_with_lstmblockcell.pb"
     # pylint: disable=import-outside-toplevel
     import tarfile
-    file_path = download_testdata(sample_url, sample_data_file, module=['data', 'ptb_data'])
+
+    file_path = download_testdata(sample_url, sample_data_file, module=["data", "ptb_data"])
     dir_path = os.path.dirname(file_path)
-    t = tarfile.open(file_path, 'r')
+    t = tarfile.open(file_path, "r")
     t.extractall(dir_path)
 
     word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
-    dtype = 'float32'
+    dtype = "float32"
     shape = (1, 200)
 
     # Convert states of LSTMBlockCell to placeholder, so TVM can feed data
     state_name = [
-        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0',
-        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0',
-        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0',
-        'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0',
-        ]
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0",
+        "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0",
+    ]
 
     inputs_dict = {
-        state_name[0]:
-            tf_compat_v1.placeholder(dtype, shape, state_name[0].split(':')[0]),
-        state_name[1]:
-            tf_compat_v1.placeholder(dtype, shape, state_name[1].split(':')[0]),
-        state_name[2]:
-            tf_compat_v1.placeholder(dtype, shape, state_name[2].split(':')[0]),
-        state_name[3]:
-            tf_compat_v1.placeholder(dtype, shape, state_name[3].split(':')[0]),
+        state_name[0]: tf_compat_v1.placeholder(dtype, shape, state_name[0].split(":")[0]),
+        state_name[1]: tf_compat_v1.placeholder(dtype, shape, state_name[1].split(":")[0]),
+        state_name[2]: tf_compat_v1.placeholder(dtype, shape, state_name[2].split(":")[0]),
+        state_name[3]: tf_compat_v1.placeholder(dtype, shape, state_name[3].split(":")[0]),
     }
-    return word_to_id, id_to_word, get_workload(ptb_model_file,
-                                                inputs_dict=inputs_dict,
-                                                output='Model/Softmax')
+    return (
+        word_to_id,
+        id_to_word,
+        get_workload(ptb_model_file, inputs_dict=inputs_dict, output="Model/Softmax"),
+    )
index 686230b..339932b 100644 (file)
@@ -30,16 +30,21 @@ def get_feature(internal_layer, layers, filters, batch_norm=False):
     for i, num in enumerate(layers):
         for j in range(num):
             internal_layer = wrapper.conv2d(
-                data=internal_layer, kernel_size=(3, 3), padding=(1, 1),
-                channels=filters[i], name="conv%s_%s" % (i + 1, j + 1))
+                data=internal_layer,
+                kernel_size=(3, 3),
+                padding=(1, 1),
+                channels=filters[i],
+                name="conv%s_%s" % (i + 1, j + 1),
+            )
             internal_layer = relay.nn.bias_add(
-                internal_layer, relay.var("conv%s_%s_bias" % (i + 1, j + 1)))
+                internal_layer, relay.var("conv%s_%s_bias" % (i + 1, j + 1))
+            )
             if batch_norm:
                 internal_layer = wrapper.batch_norm_infer(
-                    data=internal_layer, name="bn%s_%s" %(i + 1, j + 1))
+                    data=internal_layer, name="bn%s_%s" % (i + 1, j + 1)
+                )
             internal_layer = relay.nn.relu(data=internal_layer)
-        internal_layer = relay.nn.max_pool2d(
-            data=internal_layer, pool_size=(2, 2), strides=(2, 2))
+        internal_layer = relay.nn.max_pool2d(data=internal_layer, pool_size=(2, 2), strides=(2, 2))
     return internal_layer
 
 
@@ -78,10 +83,12 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no
     batch_norm : bool, default False
         Use batch normalization.
     """
-    vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
-                13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
-                16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
-                19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
+    vgg_spec = {
+        11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
+        13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
+        16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
+        19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512]),
+    }
     if num_layers not in vgg_spec:
         raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
     layers, filters = vgg_spec[num_layers]
@@ -94,12 +101,14 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no
     return relay.Function(args, symbol)
 
 
-def get_workload(batch_size,
-                 num_classes=1000,
-                 image_shape=(3, 224, 224),
-                 dtype="float32",
-                 num_layers=11,
-                 batch_norm=False):
+def get_workload(
+    batch_size,
+    num_classes=1000,
+    image_shape=(3, 224, 224),
+    dtype="float32",
+    num_layers=11,
+    batch_norm=False,
+):
     """Get benchmark workload for VGG nets.
 
     Parameters
index d0a675f..c457ce3 100644 (file)
@@ -28,142 +28,173 @@ from collections import namedtuple
 from functools import cmp_to_key
 import numpy as np
 
-Box = namedtuple('Box', ['x', 'y', 'w', 'h'])
+Box = namedtuple("Box", ["x", "y", "w", "h"])
+
 
 def nms_comparator(a, b):
-    if 'sort_class' in b and b['sort_class'] >= 0:
-        diff = a['prob'][b['sort_class']] - b['prob'][b['sort_class']]
+    if "sort_class" in b and b["sort_class"] >= 0:
+        diff = a["prob"][b["sort_class"]] - b["prob"][b["sort_class"]]
     else:
-        diff = a['objectness'] - b['objectness']
+        diff = a["objectness"] - b["objectness"]
     return diff
 
+
 def _correct_boxes(dets, w, h, netw, neth, relative):
-    new_w, new_h = (netw, (h*netw)//w) if (netw/w < neth/h) else ((w*neth//h), neth)
+    new_w, new_h = (netw, (h * netw) // w) if (netw / w < neth / h) else ((w * neth // h), neth)
     for det in dets:
-        b = det['bbox']
-        b = b._replace(x=(b.x - (netw - new_w)/2/netw) / (new_w/netw))
-        b = b._replace(y=(b.y - (neth - new_h)/2/neth) / (new_h/neth))
-        b = b._replace(w=b.w * netw/new_w)
-        b = b._replace(h=b.h * neth/new_h)
+        b = det["bbox"]
+        b = b._replace(x=(b.x - (netw - new_w) / 2 / netw) / (new_w / netw))
+        b = b._replace(y=(b.y - (neth - new_h) / 2 / neth) / (new_h / neth))
+        b = b._replace(w=b.w * netw / new_w)
+        b = b._replace(h=b.h * neth / new_h)
         if not relative:
             b = b._replace(x=b.x * w)
             b = b._replace(w=b.w * w)
             b = b._replace(y=b.y * h)
             b = b._replace(h=b.h * h)
-        det['bbox'] = b
+        det["bbox"] = b
     return dets
 
+
 def _overlap(x1, w1, x2, w2):
-    l1 = x1 - w1/2
-    l2 = x2 - w2/2
+    l1 = x1 - w1 / 2
+    l2 = x2 - w2 / 2
     left = l1 if l1 > l2 else l2
-    r1 = x1 + w1/2
-    r2 = x2 + w2/2
+    r1 = x1 + w1 / 2
+    r2 = x2 + w2 / 2
     right = r1 if r1 < r2 else r2
     return right - left
 
+
 def _box_intersection(a, b):
     w = _overlap(a.x, a.w, b.x, b.w)
     h = _overlap(a.y, a.h, b.y, b.h)
     if w < 0 or h < 0:
         return 0
-    return w*h
+    return w * h
+
 
 def _box_union(a, b):
     i = _box_intersection(a, b)
-    u = a.w*a.h + b.w*b.h - i
+    u = a.w * a.h + b.w * b.h - i
     return u
 
+
 def _box_iou(a, b):
-    return _box_intersection(a, b)/_box_union(a, b)
+    return _box_intersection(a, b) / _box_union(a, b)
+
 
 def _get_box(data, biases, n, location, lw, lh, w, h):
     bx = (location[2] + data[location[0]][0][location[1]][location[2]]) / lw
     by = (location[1] + data[location[0]][1][location[1]][location[2]]) / lh
-    bw = np.exp(data[location[0]][2][location[1]][location[2]]) * biases[2*n] / w
-    bh = np.exp(data[location[0]][3][location[1]][location[2]]) * biases[2*n+1] / h
+    bw = np.exp(data[location[0]][2][location[1]][location[2]]) * biases[2 * n] / w
+    bh = np.exp(data[location[0]][3][location[1]][location[2]]) * biases[2 * n + 1] / h
     return Box(bx, by, bw, bh)
 
+
 def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets):
-    data = l['output']
+    data = l["output"]
     active_data_loc = np.asarray(np.where(data[:, 4, :, :] > thresh))
     before_correct_dets = []
     for i in range(active_data_loc.shape[1]):
         location = [active_data_loc[0][i], active_data_loc[1][i], active_data_loc[2][i]]
-        box_b = _get_box(data, l['biases'], np.asarray(l['mask'])[location[0]], location,
-                         data.shape[2], data.shape[3], net_shape[0], net_shape[1])
+        box_b = _get_box(
+            data,
+            l["biases"],
+            np.asarray(l["mask"])[location[0]],
+            location,
+            data.shape[2],
+            data.shape[3],
+            net_shape[0],
+            net_shape[1],
+        )
         objectness = data[location[0]][4][location[1]][location[2]]
-        classes = l['classes']
-        prob = objectness*data[location[0], 5:5 + 1 + classes, location[1], location[2]]
+        classes = l["classes"]
+        prob = objectness * data[location[0], 5 : 5 + 1 + classes, location[1], location[2]]
         prob[prob < thresh] = 0
         detection = {}
-        detection['bbox'] = box_b
-        detection['classes'] = classes
-        detection['prob'] = prob
-        detection['objectness'] = objectness
+        detection["bbox"] = box_b
+        detection["classes"] = classes
+        detection["prob"] = prob
+        detection["objectness"] = objectness
         before_correct_dets.append(detection)
-    dets.extend(_correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
-                               net_shape[0], net_shape[1], relative))
+    dets.extend(
+        _correct_boxes(
+            before_correct_dets, im_shape[0], im_shape[1], net_shape[0], net_shape[1], relative
+        )
+    )
+
 
 def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
-    data = l['output']
+    data = l["output"]
     before_correct_dets = []
     for row in range(data.shape[2]):
         for col in range(data.shape[3]):
             for n in range(data.shape[0]):
-                prob = [0]*l['classes']
-                scale = data[n, l['coords'], row, col] if not l['background'] else 1
+                prob = [0] * l["classes"]
+                scale = data[n, l["coords"], row, col] if not l["background"] else 1
                 location = [n, row, col]
-                box_b = _get_box(data, l['biases'], n, location,
-                                 data.shape[2], data.shape[3], data.shape[2], data.shape[3])
+                box_b = _get_box(
+                    data,
+                    l["biases"],
+                    n,
+                    location,
+                    data.shape[2],
+                    data.shape[3],
+                    data.shape[2],
+                    data.shape[3],
+                )
                 objectness = scale if scale > thresh else 0
                 if objectness:
-                    prob = scale * data[n, l['coords']+1: l['coords']+1+l['classes'],
-                                        row, col]
+                    prob = (
+                        scale * data[n, l["coords"] + 1 : l["coords"] + 1 + l["classes"], row, col]
+                    )
                     prob[prob < thresh] = 0
                 detection = {}
-                detection['bbox'] = box_b
-                detection['prob'] = prob
-                detection['objectness'] = objectness
+                detection["bbox"] = box_b
+                detection["prob"] = prob
+                detection["objectness"] = objectness
                 before_correct_dets.append(detection)
-    _correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
-                   net_shape[0], net_shape[1], relative)
+    _correct_boxes(
+        before_correct_dets, im_shape[0], im_shape[1], net_shape[0], net_shape[1], relative
+    )
     dets.extend(before_correct_dets)
 
-def fill_network_boxes(net_shape, im_shape,
-                       thresh, relative, tvm_out):
+
+def fill_network_boxes(net_shape, im_shape, thresh, relative, tvm_out):
     dets = []
     for layer in tvm_out:
-        if layer['type'] == 'Yolo':
+        if layer["type"] == "Yolo":
             _get_yolo_detections(layer, im_shape, net_shape, thresh, relative, dets)
-        elif layer['type'] == 'Region':
+        elif layer["type"] == "Region":
             _get_region_detections(layer, im_shape, net_shape, thresh, relative, dets)
     return dets
 
+
 def do_nms_sort(dets, classes, thresh):
     "Does the sorting based on the threshold values"
-    k = len(dets)-1
+    k = len(dets) - 1
     cnt = 0
     while cnt < k:
-        if dets[cnt]['objectness'] == 0:
+        if dets[cnt]["objectness"] == 0:
             dets[k], dets[cnt] = dets[cnt], dets[k]
             k = k - 1
         else:
             cnt = cnt + 1
-    total = k+1
+    total = k + 1
     for k in range(classes):
         for i in range(total):
-            dets[i]['sort_class'] = k
-        dets[0:total] = sorted(dets[0:total],
-                               key=cmp_to_key(nms_comparator), reverse=True)
+            dets[i]["sort_class"] = k
+        dets[0:total] = sorted(dets[0:total], key=cmp_to_key(nms_comparator), reverse=True)
         for i in range(total):
-            if dets[i]['prob'][k] == 0:
+            if dets[i]["prob"][k] == 0:
                 continue
-            a = dets[i]['bbox']
-            for j in range(i+1, total):
-                b = dets[j]['bbox']
+            a = dets[i]["bbox"]
+            for j in range(i + 1, total):
+                b = dets[j]["bbox"]
                 if _box_iou(a, b) > thresh:
-                    dets[j]['prob'][k] = 0
+                    dets[j]["prob"][k] = 0
+
 
 def draw_detections(font_path, im, dets, thresh, names, classes):
     "Draw the markings around the detected region"
@@ -171,44 +202,47 @@ def draw_detections(font_path, im, dets, thresh, names, classes):
         labelstr = []
         category = -1
         for j in range(classes):
-            if det['prob'][j] > thresh:
+            if det["prob"][j] > thresh:
                 if category == -1:
                     category = j
-                labelstr.append(names[j] + " " + str(round(det['prob'][j], 4)))
+                labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
         if category > -1:
             imc, imh, imw = im.shape
             width = int(imh * 0.006)
-            offset = category*123457 % classes
+            offset = category * 123457 % classes
             red = _get_color(2, offset, classes)
             green = _get_color(1, offset, classes)
             blue = _get_color(0, offset, classes)
             rgb = [red, green, blue]
-            b = det['bbox']
-            left = int((b.x-b.w/2.)*imw)
-            right = int((b.x+b.w/2.)*imw)
-            top = int((b.y-b.h/2.)*imh)
-            bot = int((b.y+b.h/2.)*imh)
+            b = det["bbox"]
+            left = int((b.x - b.w / 2.0) * imw)
+            right = int((b.x + b.w / 2.0) * imw)
+            top = int((b.y - b.h / 2.0) * imh)
+            bot = int((b.y + b.h / 2.0) * imh)
 
             if left < 0:
                 left = 0
-            if right > imw-1:
-                right = imw-1
+            if right > imw - 1:
+                right = imw - 1
             if top < 0:
                 top = 0
-            if bot > imh-1:
-                bot = imh-1
+            if bot > imh - 1:
+                bot = imh - 1
             _draw_box_width(im, left, top, right, bot, width, red, green, blue)
-            label = _get_label(font_path, ''.join(labelstr), rgb)
+            label = _get_label(font_path, "".join(labelstr), rgb)
             _draw_label(im, top + width, left, label, rgb)
 
+
 def _get_pixel(im, x, y, c):
     return im[c][y][x]
 
+
 def _set_pixel(im, x, y, c, val):
     if x < 0 or y < 0 or c < 0 or x >= im.shape[2] or y >= im.shape[1] or c >= im.shape[0]:
         return
     im[c][y][x] = val
 
+
 def _draw_label(im, r, c, label, rgb):
     w = label.shape[2]
     h = label.shape[1]
@@ -221,7 +255,8 @@ def _draw_label(im, r, c, label, rgb):
                 if i < w and (i + c) < im.shape[2]:
                     for k in range(label.shape[0]):
                         val = _get_pixel(label, i, j, k)
-                        _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
+                        _set_pixel(im, i + c, j + r, k, val)  # rgb[k] * val)
+
 
 def _get_label(font_path, labelstr, rgb):
     # pylint: disable=import-outside-toplevel
@@ -231,26 +266,29 @@ def _get_label(font_path, labelstr, rgb):
 
     text = labelstr
     colorText = "black"
-    testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
+    testDraw = ImageDraw.Draw(Image.new("RGB", (1, 1)))
     font = ImageFont.truetype(font_path, 25)
     width, height = testDraw.textsize(labelstr, font=font)
-    img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
-                                                   int(rgb[2]*255)))
+    img = Image.new(
+        "RGB", (width, height), color=(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
+    )
     d = ImageDraw.Draw(img)
     d.text((0, 0), text, fill=colorText, font=font)
     opencvImage = np.divide(np.asarray(img), 255)
     return opencvImage.transpose(2, 0, 1)
 
+
 def _get_color(c, x, max_value):
     c = int(c)
     colors = [[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]]
-    ratio = (float(x)/float(max_value)) * 5
+    ratio = (float(x) / float(max_value)) * 5
     i = int(math.floor(ratio))
     j = int(math.ceil(ratio))
     ratio -= i
-    r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
+    r = (1 - ratio) * colors[i][c] + ratio * colors[j][c]
     return r
 
+
 def _draw_box(im, x1, y1, x2, y2, r, g, b):
     y1 = int(y1)
     y2 = int(y2)
@@ -284,6 +322,7 @@ def _draw_box(im, x1, y1, x2, y2, r, g, b):
         im[2][i][x1] = b
         im[2][i][x2] = b
 
+
 def _draw_box_width(im, x1, y1, x2, y2, w, r, g, b):
     for i in range(int(w)):
-        _draw_box(im, x1+i, y1+i, x2-i, y2-i, r, g, b)
+        _draw_box(im, x1 + i, y1 + i, x2 - i, y2 - i, r, g, b)
index e6f17f9..ccc4f76 100644 (file)
@@ -30,19 +30,24 @@ from ... import DataType, register_func
 from .. import ty, expr
 from ..backend import compile_engine
 from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
-from ...import cpu
+from ... import cpu
 from ..op.memory import alloc_storage
 from ..analysis import context_analysis
 from ..._ffi.runtime_ctypes import TVMContext
 
-def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
+
+def alloc_tensor(storage, shape, dtype="float32", assert_shape=None):
     offset = expr.const(0, dtype="int64")
     return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape)
 
 
 def is_primitive(call):
-    return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
-           hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
+    return (
+        hasattr(call, "op")
+        and hasattr(call.op, "attrs")
+        and hasattr(call.op.attrs, "Primitive")
+        and int(call.op.attrs.Primitive) == 1
+    )
 
 
 def is_device_copy(func):
@@ -60,10 +65,14 @@ def is_device_copy(func):
 
 class CheckReshapeOnly(ExprVisitor):
     """A pass to check if the fused op contains only reshape ops."""
+
     def __init__(self):
         super().__init__()
-        self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
-                             op.get("dyn.reshape")]
+        self._reshape_ops = [
+            op.get("reshape"),
+            op.get("contrib_reverse_reshape"),
+            op.get("dyn.reshape"),
+        ]
         self.reshape_only = True
 
     def visit_call(self, call):
@@ -119,7 +128,7 @@ class ManifestAllocPass(ExprMutator):
         for field in tup.fields:
             field = self.visit(field)
             if isinstance(field, expr.Constant):
-                field = scope.let('const', field)
+                field = scope.let("const", field)
             new_fields.append(field)
         return expr.Tuple(new_fields)
 
@@ -159,10 +168,7 @@ class ManifestAllocPass(ExprMutator):
         size = self.compute_storage(tensor_type)
         alignment = self.compute_alignment(tensor_type.dtype)
         dtype = tensor_type.dtype
-        sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size,
-                                                                       alignment,
-                                                                       ctx,
-                                                                       dtype))
+        sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, alignment, ctx, dtype))
         # TODO(@jroesch): There is a bug with typing based on the constant shape.
         tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape)
         return scope.let("tensor_{0}".format(name_hint), tensor)
@@ -198,8 +204,7 @@ class ManifestAllocPass(ExprMutator):
             if state == 2:
                 for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)):
                     sh_of = self.visit(self.shape_of(subexp))
-                    shape_func_ins.append(
-                        scope.let("in_shape_{0}".format(input_pos + j), sh_of))
+                    shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos + j), sh_of))
                     input_pos += 1
                 is_inputs.append(0)
             # Pass Inputs
@@ -208,8 +213,7 @@ class ManifestAllocPass(ExprMutator):
                 ctx = self.get_context(arg)
                 if ctx.device_type != cpu_ctx.device_type:
                     new_arg = self.device_copy(new_arg, ctx, cpu_ctx)
-                shape_func_ins.append(
-                    scope.let("in_shape_{0}".format(input_pos), new_arg))
+                shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos), new_arg))
                 input_pos += 1
                 is_inputs.append(1)
             else:
@@ -226,9 +230,8 @@ class ManifestAllocPass(ExprMutator):
             out_shapes.append(alloc)
 
         shape_call = self.shape_func(
-            func,
-            expr.Tuple(shape_func_ins),
-            expr.Tuple(out_shapes), is_inputs)
+            func, expr.Tuple(shape_func_ins), expr.Tuple(out_shapes), is_inputs
+        )
 
         scope.let("shape_func", shape_call)
         return out_shapes
@@ -242,18 +245,15 @@ class ManifestAllocPass(ExprMutator):
         for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
             size = self.compute_storage_in_relay(out_shape, out_type.dtype)
             alignment = self.compute_alignment(out_type.dtype)
-            sto = scope.let("storage_{i}".format(i=i), alloc_storage(
-                size, alignment, func_ctx, out_type.dtype))
+            sto = scope.let(
+                "storage_{i}".format(i=i), alloc_storage(size, alignment, func_ctx, out_type.dtype)
+            )
             storages.append(sto)
 
         outs = []
         sh_ty_storage = zip(out_shapes, out_types, storages)
         for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
-            alloc = alloc_tensor(
-                storage,
-                out_shape,
-                out_type.dtype,
-                out_type.shape)
+            alloc = alloc_tensor(storage, out_shape, out_type.dtype, out_type.shape)
             alloc = scope.let("out_{i}".format(i=i), alloc)
             outs.append(alloc)
 
@@ -299,9 +299,9 @@ class ManifestAllocPass(ExprMutator):
                     attr = call.op.body.attrs
                 else:
                     attr = call.attr
-                return self.device_copy(new_args[0],
-                                        TVMContext(attr.src_dev_type, 0),
-                                        TVMContext(attr.dst_dev_type, 0))
+                return self.device_copy(
+                    new_args[0], TVMContext(attr.src_dev_type, 0), TVMContext(attr.dst_dev_type, 0)
+                )
 
             if self.is_dynamic(ret_type):
                 # Handle dynamic case.
@@ -324,6 +324,7 @@ class ManifestAllocPass(ExprMutator):
 
 def mk_analysis_annotator(results):
     """Pretty print the annotated relay program with device info"""
+
     def _annotator(exp):
         if exp in results:
             val = results[exp]
@@ -339,6 +340,7 @@ def mk_analysis_annotator(results):
 @module_pass(opt_level=0)
 class ManifestAlloc:
     """The explicit pass wrapper around ManifestAlloc."""
+
     # TODO(zhiics, jroesch) Port this pass to C++.
     def __init__(self, target_host, targets):
         self.target_host = target_host
index 248a79b..7c7685d 100644 (file)
@@ -49,6 +49,7 @@ class Region:
     The below pass groups sets of allocations into regions,
     then replaces the region with a single allocation.
     """
+
     var: expr.Var
     size: expr.Expr
     alignment: Optional[expr.Expr]
@@ -64,12 +65,15 @@ class Region:
         return Region(region_var, zero, None, None, None, {})
 
     def grow(
-            self, old_storage: expr.Var,
-            size: expr.Expr, alignment: expr.Expr,
-            ctx: TVMContext,
-            dtype: str) -> None:
+        self,
+        old_storage: expr.Var,
+        size: expr.Expr,
+        alignment: expr.Expr,
+        ctx: TVMContext,
+        dtype: str,
+    ) -> None:
         """Grow the region by a given allocation as well as track the old storage
-           for later rewriting the program to use the allocated region.
+        for later rewriting the program to use the allocated region.
         """
         if self.dtype:
             assert self.dtype == dtype, "must have matching dtypes in a region"
@@ -84,14 +88,16 @@ class Region:
             self.alignment = alignment
 
         if self.ctx:
-            assert (self.ctx.device_type == ctx.device_type and
-                    self.ctx.device_id == ctx.device_id), "must have matching context"
+            assert (
+                self.ctx.device_type == ctx.device_type and self.ctx.device_id == ctx.device_id
+            ), "must have matching context"
         else:
             assert ctx
             self.ctx = ctx
 
-        new_size = (size + self.alignment - expr.const(1, "int64")) \
-            / self.alignment * self.alignment
+        new_size = (
+            (size + self.alignment - expr.const(1, "int64")) / self.alignment * self.alignment
+        )
 
         # Record the offset at which we allocate the storage.
         offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}")
@@ -150,7 +156,6 @@ def iterative_let(let, each_binding, kont):
     return kont(bindings, let)
 
 
-
 def mk_let(bindings, body):
     for var, value in reversed(bindings):
         assert var
@@ -160,11 +165,13 @@ def mk_let(bindings, body):
 
     return body
 
+
 def const_eval(mod, exp):
     mod = IRModule.from_expr(exp, type_defs=mod.type_definitions)
     mod = transform.FoldConstant()(mod)
     return mod["main"]
 
+
 class StorageCoalesce(ExprMutator):
     """
     A pass for coalescing allocations into region/arena allocations.
@@ -237,9 +244,9 @@ class StorageCoalesce(ExprMutator):
 
         return expr.If(ite.cond, true_branch, false_branch)
 
-
     def mk_let(self, dynamic_regions):
         """Let bind the dynamic regions"""
+
         def _mk_let(bindings, body):
             for var, value in reversed(bindings):
                 assert var
@@ -255,14 +262,11 @@ class StorageCoalesce(ExprMutator):
 
     def visit_let(self, let):
         dynamic_regions = []
+
         def _each_binding(lhs, rhs):
-            if isinstance(rhs, expr.Call) and rhs.op == op.op.get(
-                    "memory.alloc_storage"
-            ):
+            if isinstance(rhs, expr.Call) and rhs.op == op.op.get("memory.alloc_storage"):
                 return self.process_alloc_storage(dynamic_regions, lhs, rhs)
-            elif isinstance(rhs, expr.Call) and rhs.op == op.op.get(
-                    "memory.alloc_tensor"
-            ):
+            elif isinstance(rhs, expr.Call) and rhs.op == op.op.get("memory.alloc_tensor"):
                 return self.process_alloc_tensor(lhs, rhs)
             else:
                 return lhs, rhs
@@ -297,16 +301,16 @@ class StorageCoalesce(ExprMutator):
         storage, old_offset, shape = call.args
         region, offset = self.new_region_and_offset(storage)
 
-        assert (
-            old_offset.data.asnumpy().item() == 0
-        ), "no offsets should yet be allocated"
+        assert old_offset.data.asnumpy().item() == 0, "no offsets should yet be allocated"
         return (
             lhs,
             expr.Call(call.op, [region.var, offset, shape], call.attrs),
         )
 
+
 class LiftConst(ExprMutator):
     """An internal pass to lift constants to the top level of function."""
+
     def __init__(self):
         self.i = 0
         self.constants = []
@@ -330,12 +334,7 @@ class LiftConst(ExprMutator):
         body = mk_let(self.constants, body)
         self.constants = outer_constant
 
-        return Function(
-            fn.params,
-            body,
-            fn.ret_type,
-            fn.type_params,
-            fn.attrs)
+        return Function(fn.params, body, fn.ret_type, fn.type_params, fn.attrs)
 
     def visit_let(self, let):
         bindings = []
@@ -348,6 +347,7 @@ class LiftConst(ExprMutator):
         new_body = self.visit(let)
         return mk_let(bindings, new_body)
 
+
 @function_pass(opt_level=0)
 class MemoryPlan:
     """An explicit pass wrapper around StorageCoalesce."""
@@ -358,8 +358,10 @@ class MemoryPlan:
         func = sc.visit(func)
         return func
 
+
 register_func("relay.transform.MemoryPlan", MemoryPlan)
 
+
 @function_pass(opt_level=0)
 class LiftConstants:
     """An explicit pass wrapper around LiftConst."""
index 60a7aa3..ade071c 100644 (file)
@@ -31,10 +31,7 @@ from tvm import relay
 from . import _ffi_api
 
 
-def build_config(opt_level=2,
-                 required_pass=None,
-                 disabled_pass=None,
-                 trace=None):
+def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None):
     """Configure the build behavior by setting config variables. This function
     will be deprecated in TVM v0.7. Instead, we should directly use
     tvm.transform.PassContext.
@@ -76,8 +73,11 @@ def build_config(opt_level=2,
     pass_context: PassContext
         The pass context for optimizations.
     """
-    warnings.warn("relay.build_config will be deprecated. Please use \
-                  tvm.transform.PassContext directly", DeprecationWarning)
+    warnings.warn(
+        "relay.build_config will be deprecated. Please use \
+                  tvm.transform.PassContext directly",
+        DeprecationWarning,
+    )
     return tvm.transform.PassContext(opt_level, required_pass, disabled_pass, trace)
 
 
@@ -133,6 +133,7 @@ def BackwardFoldScaleAxis():
     """
     return _ffi_api.BackwardFoldScaleAxis()
 
+
 def RemoveUnusedFunctions(entry_functions=None):
     """Remove unused global relay functions in a relay module.
 
@@ -147,9 +148,10 @@ def RemoveUnusedFunctions(entry_functions=None):
         The registered pass to remove unused functions.
     """
     if entry_functions is None:
-        entry_functions = ['main']
+        entry_functions = ["main"]
     return _ffi_api.RemoveUnusedFunctions(entry_functions)
 
+
 def ForwardFoldScaleAxis():
     """Fold the scaling of axis into weights of conv2d/dense.
 
@@ -180,7 +182,7 @@ def SimplifyInference():
 
 
 def FastMath():
-    """ Converts the expensive non linear functions to their fast but approximate counterparts.
+    """Converts the expensive non linear functions to their fast but approximate counterparts.
 
     Returns
     -------
@@ -218,6 +220,7 @@ def DeadCodeElimination(inline_once=False):
     """
     return _ffi_api.DeadCodeElimination(inline_once)
 
+
 def LazyGradientInit():
     """Reduces memory usage of gradient tensors
 
@@ -232,6 +235,7 @@ def LazyGradientInit():
     """
     return _ffi_api.LazyGradientInit()
 
+
 def FoldConstant():
     """Fold the constant expressions in a Relay program.
 
@@ -320,6 +324,7 @@ def CombineParallelDense(min_num_branches=3, to_batch=True):
     """
     return _ffi_api.CombineParallelDense(min_num_branches, to_batch)
 
+
 def CombineParallelBatchMatmul(min_num_branches=3):
     """Combine multiple batch matmul operators into one. For example:
 
@@ -362,11 +367,9 @@ def BatchingOps():
     ret: tvm.transform.Pass
         The sequential pass which apply batching for different operator types.
     """
-    return tvm.transform.Sequential([
-        CombineParallelConv2D(),
-        CombineParallelDense(),
-        CombineParallelBatchMatmul()
-    ])
+    return tvm.transform.Sequential(
+        [CombineParallelConv2D(), CombineParallelDense(), CombineParallelBatchMatmul()]
+    )
 
 
 def AlterOpLayout():
@@ -384,7 +387,7 @@ def AlterOpLayout():
 
 
 def ConvertLayout(desired_layouts):
-    """ Given a dest layout, this pass transforms the expr such that most of the ops input data
+    """Given a dest layout, this pass transforms the expr such that most of the ops input data
     layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
     one at the start and one at the end.
 
@@ -515,6 +518,7 @@ def ToANormalForm():
     """
     return _ffi_api.ToANormalForm()
 
+
 def ToANormalFormExpr(e):
     """ToANormalForm, but on expression level.
 
@@ -530,6 +534,7 @@ def ToANormalFormExpr(e):
     """
     return _ffi_api.ToANormalFormExpr(e)
 
+
 def ToBasicBlockNormalForm():
     """Turn an expression to Basic Block Normal Form.
     We define a block as a group of expressions implied by the scope structure.
@@ -661,7 +666,6 @@ def PartitionGraph():
     return _ffi_api.PartitionGraph()
 
 
-
 def AnnotateTarget(targets):
     """Annotate ops in an experession with a provied compiler/target and then
     use it for codegen.
@@ -706,7 +710,7 @@ def Inline():
     return _ffi_api.Inline()
 
 
-def gradient(expr, mod=None, mode='higher_order'):
+def gradient(expr, mod=None, mode="higher_order"):
     """
     Transform the input function,
     returning a function that calculate the original result,
@@ -730,11 +734,12 @@ def gradient(expr, mod=None, mode='higher_order'):
     expr : tvm.relay.Expr
       The transformed expression.
     """
-    if mode == 'first_order':
+    if mode == "first_order":
         return _ffi_api.first_order_gradient(expr, mod)
-    if mode == 'higher_order':
+    if mode == "higher_order":
         return _ffi_api.gradient(expr, mod)
-    raise Exception('unknown mode')
+    raise Exception("unknown mode")
+
 
 def Defunctionalization(func, mod):
     """
@@ -763,6 +768,7 @@ def Defunctionalization(func, mod):
     """
     return _ffi_api.Defunctionalization(func, mod)
 
+
 def to_cps(func, mod=None):
     """
     Turn expression into CPS expression.
@@ -808,8 +814,10 @@ def un_cps(func):
 
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass"""
+
     class PyFunctionPass(FunctionPass):
         """Internal wrapper class to create a class instance."""
+
         def __init__(self, *args, **kwargs):
             # initialize handle in cass pass_cls creation failed.fg
             self.handle = None
@@ -818,8 +826,8 @@ def _wrap_class_function_pass(pass_cls, pass_info):
             # avoid a cyclic dependency
             def _pass_func(func, mod, ctx):
                 return inst.transform_function(func, mod, ctx)
-            self.__init_handle_by_constructor__(
-                _ffi_api.MakeFunctionPass, _pass_func, pass_info)
+
+            self.__init_handle_by_constructor__(_ffi_api.MakeFunctionPass, _pass_func, pass_info)
             self._inst = inst
 
         def __getattr__(self, name):
@@ -916,8 +924,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
 
     required = required if required else []
     if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of " +
-                        "list/tuple.")
+        raise TypeError("Required is expected to be the type of " + "list/tuple.")
 
     def create_function_pass(pass_arg):
         """Internal function that creates a function pass"""
@@ -953,6 +960,7 @@ class ChangeBatch:
     pass: FunctionPass
       The pass.
     """
+
     def __init__(self, data, batch_size=16):
         self.data = data
         self.batch_size = batch_size
@@ -960,6 +968,7 @@ class ChangeBatch:
     def transform_function(self, func, mod, ctx):
         func = relay.Function(func.params, func.body, None, func.type_params, func.attrs)
         change_batch = self
+
         class ChangeBatchMutator(tvm.relay.ExprMutator):
             def visit_var(self, var):
                 if var in change_batch.data:
@@ -968,6 +977,7 @@ class ChangeBatch:
                     new_shape[change_batch.data[var]] = change_batch.batch_size
                     return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
                 return var
+
         return ChangeBatchMutator().visit(func)
 
 
index 84bd1ee..affd7f4 100644 (file)
@@ -25,6 +25,7 @@ from . import _ffi_api
 
 Any = _ffi_api.Any
 
+
 def is_dynamic(tensor_type):
     """Check whether type has any or symbolic variables as a shape.
 
index 7139ccb..490464b 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """The type functor of Relay."""
-from .ty import (TypeVar, IncompleteType, TensorType, FuncType,
-                 TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
+from .ty import (
+    TypeVar,
+    IncompleteType,
+    TensorType,
+    FuncType,
+    TupleType,
+    TypeRelation,
+    RefType,
+    GlobalTypeVar,
+    TypeCall,
+)
 from .adt import TypeData
 
+
 class TypeFunctor:
     """
     An abstract visitor defined over Type.
 
     Defines the default dispatch over types.
     """
+
     def __init__(self):
         # TODO(weberlo): make type vars hashable, so we can memoize
         pass
@@ -53,7 +64,7 @@ class TypeFunctor:
         elif isinstance(typ, TypeData):
             return self.visit_type_data(typ)
         else:
-            raise Exception('unhandled case: {0}'.format(type(typ)))
+            raise Exception("unhandled case: {0}".format(type(typ)))
 
     def visit_type_var(self, _):
         raise NotImplementedError()
@@ -92,6 +103,7 @@ class TypeVisitor(TypeFunctor):
 
     The default behavior recursively traverses the AST.
     """
+
     def visit_type_var(self, tv):
         pass
 
@@ -105,9 +117,9 @@ class TypeVisitor(TypeFunctor):
         for arg_type in ft.arg_types:
             self.visit(arg_type)
         self.visit(ft.ret_type)
-        for type_param in getattr(ft, 'type_params', []):
+        for type_param in getattr(ft, "type_params", []):
             self.visit(type_param)
-        for type_constraint in getattr(ft, 'type_constraints', []):
+        for type_constraint in getattr(ft, "type_constraints", []):
             self.visit(type_constraint)
 
     def visit_tuple_type(self, tt):
@@ -142,6 +154,7 @@ class TypeMutator(TypeFunctor):
     The default behavior recursively traverses the AST
     and reconstructs the AST.
     """
+
     def visit_type_var(self, tv):
         return TypeVar(tv.name_hint, tv.kind)
 
@@ -154,27 +167,17 @@ class TypeMutator(TypeFunctor):
     def visit_func_type(self, ft):
         new_arg_types = [self.visit(arg_type) for arg_type in ft.arg_types]
         new_ret_type = self.visit(ft.ret_type)
-        new_type_params = [
-            self.visit(type_param)
-            for type_param in getattr(ft, 'type_params', [])]
+        new_type_params = [self.visit(type_param) for type_param in getattr(ft, "type_params", [])]
         new_type_constraints = [
-            self.visit(type_constraint)
-            for type_constraint in getattr(ft, 'type_constraints', [])]
-        return FuncType(
-            new_arg_types,
-            new_ret_type,
-            new_type_params,
-            new_type_constraints)
+            self.visit(type_constraint) for type_constraint in getattr(ft, "type_constraints", [])
+        ]
+        return FuncType(new_arg_types, new_ret_type, new_type_params, new_type_constraints)
 
     def visit_tuple_type(self, tt):
         return TupleType([self.visit(field) for field in tt.fields])
 
     def visit_type_relation(self, tr):
-        return TypeRelation(
-            tr.func,
-            [self.visit(arg) for arg in tr.args],
-            tr.num_inputs,
-            tr.attrs)
+        return TypeRelation(tr.func, [self.visit(arg) for arg in tr.args], tr.num_inputs, tr.attrs)
 
     def visit_ref_type(self, rt):
         return RefType(self.visit(rt.value))
@@ -183,12 +186,11 @@ class TypeMutator(TypeFunctor):
         return GlobalTypeVar(gtv.name_hint, gtv.kind)
 
     def visit_type_call(self, tc):
-        return TypeCall(
-            self.visit(tc.func),
-            [self.visit(arg) for arg in tc.args])
+        return TypeCall(self.visit(tc.func), [self.visit(arg) for arg in tc.args])
 
     def visit_type_data(self, td):
         return TypeData(
             self.visit(td.header),
             [self.visit(type_var) for type_var in td.type_vars],
-            td.constructors)
+            td.constructors,
+        )
index f0e33f8..b2bfa3b 100644 (file)
@@ -28,9 +28,9 @@ import logging
 from .._ffi.base import py_str
 
 # Magic header for RPC data plane
-RPC_MAGIC = 0xff271
+RPC_MAGIC = 0xFF271
 # magic header for RPC tracker(control plane)
-RPC_TRACKER_MAGIC = 0x2f271
+RPC_TRACKER_MAGIC = 0x2F271
 # sucess response
 RPC_CODE_SUCCESS = RPC_MAGIC + 0
 # duplicate key in proxy
@@ -38,10 +38,12 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
 # cannot found matched key in server
 RPC_CODE_MISMATCH = RPC_MAGIC + 2
 
-logger = logging.getLogger('RPCServer')
+logger = logging.getLogger("RPCServer")
+
 
 class TrackerCode(object):
     """Enumeration code for the RPC tracker"""
+
     FAIL = -1
     SUCCESS = 0
     PING = 1
@@ -52,6 +54,7 @@ class TrackerCode(object):
     SUMMARY = 6
     GET_PENDING_MATCHKEYS = 7
 
+
 RPC_SESS_MASK = 128
 
 
@@ -168,8 +171,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
                 raise sock_err
             period = time.time() - tstart
             if period > timeout:
-                raise RuntimeError(
-                    "Failed to connect to server %s" % str(addr))
-            logger.warning("Cannot connect to tracker %s, retry in %g secs...",
-                           str(addr), retry_period)
+                raise RuntimeError("Failed to connect to server %s" % str(addr))
+            logger.warning(
+                "Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period
+            )
             time.sleep(retry_period)
index 60eb08d..ebdca00 100644 (file)
@@ -36,6 +36,7 @@ class RPCSession(object):
 
     Do not directly create the obhect, call connect
     """
+
     # pylint: disable=invalid-name
     def __init__(self, sess):
         self._sess = sess
@@ -112,8 +113,7 @@ class RPCSession(object):
                 target = os.path.basename(data)
 
         if "upload" not in self._remote_funcs:
-            self._remote_funcs["upload"] = self.get_function(
-                "tvm.rpc.server.upload")
+            self._remote_funcs["upload"] = self.get_function("tvm.rpc.server.upload")
         self._remote_funcs["upload"](target, blob)
 
     def download(self, path):
@@ -130,8 +130,7 @@ class RPCSession(object):
             The result blob from the file.
         """
         if "download" not in self._remote_funcs:
-            self._remote_funcs["download"] = self.get_function(
-                "tvm.rpc.server.download")
+            self._remote_funcs["download"] = self.get_function("tvm.rpc.server.download")
         return self._remote_funcs["download"](path)
 
     def remove(self, path):
@@ -143,8 +142,7 @@ class RPCSession(object):
             The relative location to remote temp folder.
         """
         if "remove" not in self._remote_funcs:
-            self._remote_funcs["remove"] = self.get_function(
-                "tvm.rpc.server.remove")
+            self._remote_funcs["remove"] = self.get_function("tvm.rpc.server.remove")
         self._remote_funcs["remove"](path)
 
     def load_module(self, path):
@@ -201,6 +199,7 @@ class LocalSession(RPCSession):
     This class can be used to implement functions that
     need to be ran both locally and remotely.
     """
+
     def __init__(self):
         self._temp = server._server_env([])
         RPCSession.__init__(self, _ffi_api.LocalSession())
@@ -235,6 +234,7 @@ class PopenSession(RPCSession):
     binary : List[Union[str, bytes]]
         The binary to be executed.
     """
+
     def __init__(self, binary):
         RPCSession.__init__(self, _popen_session(binary))
 
@@ -247,6 +247,7 @@ class TrackerSession(object):
     addr : tuple
         The address tuple
     """
+
     def __init__(self, addr):
         self._addr = addr
         self._sock = None
@@ -291,7 +292,7 @@ class TrackerSession(object):
             addr = item["addr"]
             res += addr[0] + ":" + str(addr[1]) + "\t"
             res += item["key"] + "\n"
-            key = item['key'].split(':')[1]   # 'server:rasp3b` -> 'rasp3b'
+            key = item["key"].split(":")[1]  # 'server:rasp3b` -> 'rasp3b'
             if key not in total_ct:
                 total_ct[key] = 0
             total_ct[key] += 1
@@ -299,7 +300,7 @@ class TrackerSession(object):
         res += "\n"
 
         # compute max length of device key
-        queue_info = data['queue_info']
+        queue_info = data["queue_info"]
         keys = list(queue_info.keys())
         if keys:
             keys.sort()
@@ -308,15 +309,19 @@ class TrackerSession(object):
             max_key_len = 0
 
         res += "Queue Status\n"
-        title = ("%%-%ds" % max_key_len + "   total  free  pending\n") % 'key'
-        separate_line = '-' * len(title) + '\n'
+        title = ("%%-%ds" % max_key_len + "   total  free  pending\n") % "key"
+        separate_line = "-" * len(title) + "\n"
         res += separate_line + title + separate_line
         for k in keys:
             total = total_ct.get(k, 0)
             free, pending = queue_info[k]["free"], queue_info[k]["pending"]
             if total or pending:
-                res += ("%%-%ds" % max_key_len + "   %-5d  %-4d  %-7d\n") % \
-                       (k, total, free, pending)
+                res += ("%%-%ds" % max_key_len + "   %-5d  %-4d  %-7d\n") % (
+                    k,
+                    total,
+                    free,
+                    pending,
+                )
         res += separate_line
         return res
 
@@ -344,8 +349,7 @@ class TrackerSession(object):
             try:
                 if self._sock is None:
                     self._connect()
-                base.sendjson(self._sock,
-                              [base.TrackerCode.REQUEST, key, "", priority])
+                base.sendjson(self._sock, [base.TrackerCode.REQUEST, key, "", priority])
                 value = base.recvjson(self._sock)
                 if value[0] != base.TrackerCode.SUCCESS:
                     raise RuntimeError("Invalid return value %s" % str(value))
@@ -357,15 +361,10 @@ class TrackerSession(object):
             except TVMError as err:
                 last_err = err
         raise RuntimeError(
-            "Cannot request %s after %d retry, last_error:%s" % (
-                key, max_retry, str(last_err)))
-
-    def request_and_run(self,
-                        key,
-                        func,
-                        priority=1,
-                        session_timeout=0,
-                        max_retry=2):
+            "Cannot request %s after %d retry, last_error:%s" % (key, max_retry, str(last_err))
+        )
+
+    def request_and_run(self, key, func, priority=1, session_timeout=0, max_retry=2):
         """Request a resource from tracker and run the func.
 
         This function safe-guard rare server node dropout during execution.
@@ -393,21 +392,18 @@ class TrackerSession(object):
         last_err = None
         for _ in range(max_retry):
             try:
-                sess = self.request(key,
-                                    priority=priority,
-                                    session_timeout=session_timeout)
+                sess = self.request(key, priority=priority, session_timeout=session_timeout)
                 tstart = time.time()
                 return func(sess)
             except TVMError as err:
                 duration = time.time() - tstart
                 # roughly estimate if the error is due to timeout termination
                 if session_timeout and duration >= session_timeout * 0.95:
-                    raise RuntimeError(
-                        "Session timeout when running %s" % func.__name__)
+                    raise RuntimeError("Session timeout when running %s" % func.__name__)
                 last_err = err
         raise RuntimeError(
-            "Failed to run on %s after %d retry, last_error:%s" % (
-                key, max_retry, str(last_err)))
+            "Failed to run on %s after %d retry, last_error:%s" % (key, max_retry, str(last_err))
+        )
 
 
 def connect(url, port, key="", session_timeout=0, session_constructor_args=None):
index 760c536..2c9dd29 100644 (file)
@@ -36,8 +36,7 @@ def find_minrpc_server_libpath(server="posix_popen_server"):
     curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
     source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", ".."))
 
-    path = os.path.join(
-        source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server))
+    path = os.path.join(source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server))
 
     candidates = [path]
     if not os.path.isfile(path):
@@ -45,9 +44,7 @@ def find_minrpc_server_libpath(server="posix_popen_server"):
     return path
 
 
-def with_minrpc(compile_func,
-                server="posix_popen_server",
-                runtime="libtvm"):
+def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"):
     """Attach the compiler function with minrpc related options.
 
     Parameters
@@ -67,8 +64,7 @@ def with_minrpc(compile_func,
         The return compilation.
     """
     server_path = find_minrpc_server_libpath(server)
-    runtime_path = libinfo.find_lib_path(
-        [runtime, runtime + ".so", runtime + ".dylib"])[0]
+    runtime_path = libinfo.find_lib_path([runtime, runtime + ".so", runtime + ".dylib"])[0]
 
     runtime_dir = os.path.abspath(os.path.dirname(runtime_path))
     options = ["-std=c++14"]
@@ -78,9 +74,8 @@ def with_minrpc(compile_func,
     options += ["-Wl,-rpath=" + runtime_dir]
     options += ["-I" + path for path in libinfo.find_include_path()]
     fcompile = cc.cross_compiler(
-        compile_func,
-        options=options,
-        add_files=[server_path, runtime_path])
+        compile_func, options=options, add_files=[server_path, runtime_path]
+    )
     fcompile.__name__ = "with_minrpc"
     fcompile.need_system_lib = True
     return fcompile
index 994e230..2224d50 100644 (file)
@@ -40,7 +40,8 @@ try:
     from . import tornado_util
 except ImportError as error_msg:
     raise ImportError(
-        "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
+        "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg
+    )
 
 from . import _ffi_api
 from . import base
@@ -51,6 +52,7 @@ from .._ffi.base import py_str
 
 class ForwardHandler(object):
     """Forward handler to forward the message."""
+
     def _init_handler(self):
         """Initialize handler."""
         self._init_message = bytes()
@@ -76,14 +78,14 @@ class ForwardHandler(object):
     def _init_step(self, message):
         if self._magic is None:
             assert len(message) == 4
-            self._magic = struct.unpack('<i', message)[0]
+            self._magic = struct.unpack("<i", message)[0]
             if self._magic != base.RPC_MAGIC:
                 logging.info("Invalid RPC magic from %s", self.name())
                 self.close()
             self._init_req_nbytes = 4
         elif self._rpc_key_length is None:
             assert len(message) == 4
-            self._rpc_key_length = struct.unpack('<i', message)[0]
+            self._rpc_key_length = struct.unpack("<i", message)[0]
             self._init_req_nbytes = self._rpc_key_length
         elif self.rpc_key is None:
             assert len(message) == self._rpc_key_length
@@ -143,13 +145,14 @@ class ForwardHandler(object):
 
 class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
     """Event driven TCP handler."""
+
     def __init__(self, sock, addr):
         super(TCPHandler, self).__init__(sock)
         self._init_handler()
         self.addr = addr
 
     def name(self):
-        return "TCPSocketProxy:%s:%s"  % (str(self.addr[0]), self.rpc_key)
+        return "TCPSocketProxy:%s:%s" % (str(self.addr[0]), self.rpc_key)
 
     def send_data(self, message, binary=True):
         self.write_message(message, True)
@@ -169,6 +172,7 @@ class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
 
 class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
     """Handler for websockets."""
+
     def __init__(self, *args, **kwargs):
         super(WebSocketHandler, self).__init__(*args, **kwargs)
         self._init_handler()
@@ -201,6 +205,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
 
 class RequestHandler(tornado.web.RequestHandler):
     """Handles html request."""
+
     def __init__(self, *args, **kwargs):
         file_path = kwargs.pop("file_path")
         if file_path.endswith("html"):
@@ -208,8 +213,8 @@ class RequestHandler(tornado.web.RequestHandler):
             web_port = kwargs.pop("rpc_web_port", None)
             if web_port:
                 self.page = self.page.replace(
-                    "ws://localhost:9190/ws",
-                    "ws://localhost:%d/ws" % web_port)
+                    "ws://localhost:9190/ws", "ws://localhost:%d/ws" % web_port
+                )
         else:
             self.page = open(file_path, "rb").read()
         super(RequestHandler, self).__init__(*args, **kwargs)
@@ -223,16 +228,20 @@ class RequestHandler(tornado.web.RequestHandler):
 
 class ProxyServerHandler(object):
     """Internal proxy server handler class."""
+
     current = None
-    def __init__(self,
-                 sock,
-                 listen_port,
-                 web_port,
-                 timeout_client,
-                 timeout_server,
-                 tracker_addr,
-                 index_page=None,
-                 resource_files=None):
+
+    def __init__(
+        self,
+        sock,
+        listen_port,
+        web_port,
+        timeout_client,
+        timeout_server,
+        tracker_addr,
+        index_page=None,
+        resource_files=None,
+    ):
         assert ProxyServerHandler.current is None
         ProxyServerHandler.current = self
         if web_port:
@@ -241,7 +250,8 @@ class ProxyServerHandler(object):
             ]
             if index_page:
                 handlers.append(
-                    (r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port}))
+                    (r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port})
+                )
                 logging.info("Serving RPC index html page at http://localhost:%d", web_port)
             resource_files = resource_files if resource_files else []
             for fname in resource_files:
@@ -254,10 +264,11 @@ class ProxyServerHandler(object):
         self.sock = sock
         self.sock.setblocking(0)
         self.loop = ioloop.IOLoop.current()
+
         def event_handler(_, events):
             self._on_event(events)
-        self.loop.add_handler(
-            self.sock.fileno(), event_handler, self.loop.READ)
+
+        self.loop.add_handler(self.sock.fileno(), event_handler, self.loop.READ)
         self._client_pool = {}
         self._server_pool = {}
         self.timeout_alloc = 5
@@ -272,8 +283,10 @@ class ProxyServerHandler(object):
         self.update_tracker_period = 2
         if tracker_addr:
             logging.info("Tracker address:%s", str(tracker_addr))
+
             def _callback():
                 self._update_tracker(True)
+
             self.loop.call_later(self.update_tracker_period, _callback)
         logging.info("RPCProxy: Websock port bind to %d", web_port)
 
@@ -290,12 +303,12 @@ class ProxyServerHandler(object):
         lhs.forward_proxy = rhs
         rhs.forward_proxy = lhs
 
-        lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
-        lhs.send_data(struct.pack('<i', len(rhs.rpc_key)))
+        lhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
+        lhs.send_data(struct.pack("<i", len(rhs.rpc_key)))
         lhs.send_data(rhs.rpc_key.encode("utf-8"))
 
-        rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
-        rhs.send_data(struct.pack('<i', len(lhs.rpc_key)))
+        rhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
+        rhs.send_data(struct.pack("<i", len(lhs.rpc_key)))
         rhs.send_data(lhs.rpc_key.encode("utf-8"))
         logging.info("Pairup connect %s  and %s", lhs.name(), rhs.name())
 
@@ -318,8 +331,9 @@ class ProxyServerHandler(object):
         """Update information on tracker."""
         try:
             if self._tracker_conn is None:
-                self._tracker_conn = socket.socket(base.get_addr_family(self._tracker_addr),
-                                                   socket.SOCK_STREAM)
+                self._tracker_conn = socket.socket(
+                    base.get_addr_family(self._tracker_addr), socket.SOCK_STREAM
+                )
                 self._tracker_conn.connect(self._tracker_addr)
                 self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                 magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
@@ -344,8 +358,9 @@ class ProxyServerHandler(object):
                             update_keys.append(k)
                             v.alloc_time = None
                 if update_keys:
-                    logging.info("RPCProxy: No incoming conn on %s, regenerate keys...",
-                                 str(update_keys))
+                    logging.info(
+                        "RPCProxy: No incoming conn on %s, regenerate keys...", str(update_keys)
+                    )
                     new_keys = self._regenerate_server_keys(update_keys)
                     self._tracker_pending_puts += new_keys
 
@@ -353,9 +368,9 @@ class ProxyServerHandler(object):
             # report new connections
             for key in self._tracker_pending_puts:
                 rpc_key = key.split(":")[0]
-                base.sendjson(self._tracker_conn,
-                              [TrackerCode.PUT, rpc_key,
-                               (self._listen_port, key), None])
+                base.sendjson(
+                    self._tracker_conn, [TrackerCode.PUT, rpc_key, (self._listen_port, key), None]
+                )
                 assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
                 if rpc_key not in self._key_set:
                     self._key_set.add(rpc_key)
@@ -364,21 +379,24 @@ class ProxyServerHandler(object):
             if need_update_info:
                 keylist = "[" + ",".join(self._key_set) + "]"
                 cinfo = {"key": "server:proxy" + keylist}
-                base.sendjson(self._tracker_conn,
-                              [TrackerCode.UPDATE_INFO, cinfo])
+                base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
                 assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
             self._tracker_pending_puts = []
         except (socket.error, IOError) as err:
             logging.info(
                 "Lost tracker connection: %s, try reconnect in %g sec",
-                str(err), self.update_tracker_period)
+                str(err),
+                self.update_tracker_period,
+            )
             self._tracker_conn.close()
             self._tracker_conn = None
             self._regenerate_server_keys(self._server_pool.keys())
 
         if period_update:
+
             def _callback():
                 self._update_tracker(True)
+
             self.loop.call_later(self.update_tracker_period, _callback)
 
     def _handler_ready_tracker_mode(self, handler):
@@ -393,7 +411,7 @@ class ProxyServerHandler(object):
             if handler.match_key in self._server_pool:
                 self._pair_up(self._server_pool.pop(handler.match_key), handler)
             else:
-                handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
+                handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
                 handler.signal_close()
 
     def _handler_ready_proxy_mode(self, handler):
@@ -411,18 +429,23 @@ class ProxyServerHandler(object):
             return
         if key not in pool_dst:
             pool_dst[key] = handler
+
             def cleanup():
                 """Cleanup client connection if timeout"""
                 if pool_dst.get(key, None) == handler:
-                    logging.info("Timeout client connection %s, cannot find match key=%s",
-                                 handler.name(), key)
+                    logging.info(
+                        "Timeout client connection %s, cannot find match key=%s",
+                        handler.name(),
+                        key,
+                    )
                     pool_dst.pop(key)
-                    handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
+                    handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
                     handler.signal_close()
+
             self.loop.call_later(timeout, cleanup)
         else:
             logging.info("Duplicate connection with same key=%s", key)
-            handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE))
+            handler.send_data(struct.pack("<i", base.RPC_CODE_DUPLICATE))
             handler.signal_close()
 
     def handler_ready(self, handler):
@@ -438,22 +461,26 @@ class ProxyServerHandler(object):
         ioloop.IOLoop.current().start()
 
 
-def _proxy_server(listen_sock,
-                  listen_port,
-                  web_port,
-                  timeout_client,
-                  timeout_server,
-                  tracker_addr,
-                  index_page,
-                  resource_files):
-    handler = ProxyServerHandler(listen_sock,
-                                 listen_port,
-                                 web_port,
-                                 timeout_client,
-                                 timeout_server,
-                                 tracker_addr,
-                                 index_page,
-                                 resource_files)
+def _proxy_server(
+    listen_sock,
+    listen_port,
+    web_port,
+    timeout_client,
+    timeout_server,
+    tracker_addr,
+    index_page,
+    resource_files,
+):
+    handler = ProxyServerHandler(
+        listen_sock,
+        listen_port,
+        web_port,
+        timeout_client,
+        timeout_server,
+        tracker_addr,
+        index_page,
+        resource_files,
+    )
     handler.run()
 
 
@@ -492,16 +519,19 @@ class Proxy(object):
     resource_files : str, optional
         Path to local resources that can be included in the http request
     """
-    def __init__(self,
-                 host,
-                 port=9091,
-                 port_end=9199,
-                 web_port=0,
-                 timeout_client=600,
-                 timeout_server=600,
-                 tracker_addr=None,
-                 index_page=None,
-                 resource_files=None):
+
+    def __init__(
+        self,
+        host,
+        port=9091,
+        port_end=9199,
+        web_port=0,
+        timeout_client=600,
+        timeout_server=600,
+        tracker_addr=None,
+        index_page=None,
+        resource_files=None,
+    ):
         sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
         self.port = None
         for my_port in range(port, port_end):
@@ -519,9 +549,17 @@ class Proxy(object):
         sock.listen(1)
         self.proc = multiprocessing.Process(
             target=_proxy_server,
-            args=(sock, self.port, web_port,
-                  timeout_client, timeout_server,
-                  tracker_addr, index_page, resource_files))
+            args=(
+                sock,
+                self.port,
+                web_port,
+                timeout_client,
+                timeout_server,
+                tracker_addr,
+                index_page,
+                resource_files,
+            ),
+        )
         self.proc.start()
         sock.close()
         self.host = host
@@ -548,13 +586,14 @@ def websocket_proxy_server(url, key=""):
     key : str
         The key to identify the server.
     """
+
     def create_on_message(conn):
         def _fsend(data):
             data = bytes(data)
             conn.write_message(data, binary=True)
             return len(data)
-        on_message = _ffi_api.CreateEventDrivenServer(
-            _fsend, "WebSocketProxyServer", "%toinit")
+
+        on_message = _ffi_api.CreateEventDrivenServer(_fsend, "WebSocketProxyServer", "%toinit")
         return on_message
 
     @gen.coroutine
@@ -563,13 +602,13 @@ def websocket_proxy_server(url, key=""):
         on_message = create_on_message(conn)
         temp = _server_env(None)
         # Start connecton
-        conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
+        conn.write_message(struct.pack("<i", base.RPC_MAGIC), binary=True)
         key = "server:" + key
-        conn.write_message(struct.pack('<i', len(key)), binary=True)
+        conn.write_message(struct.pack("<i", len(key)), binary=True)
         conn.write_message(key.encode("utf-8"), binary=True)
         msg = yield conn.read_message()
         assert len(msg) >= 4
-        magic = struct.unpack('<i', msg[:4])[0]
+        magic = struct.unpack("<i", msg[:4])[0]
         if magic == base.RPC_CODE_DUPLICATE:
             raise RuntimeError("key: %s has already been used in proxy" % key)
         if magic == base.RPC_CODE_MISMATCH:
@@ -594,5 +633,6 @@ def websocket_proxy_server(url, key=""):
         logging.info("WebSocketProxyServer closed...")
         temp.remove()
         ioloop.IOLoop.current().stop()
+
     ioloop.IOLoop.current().spawn_callback(_connect, key)
     ioloop.IOLoop.current().start()
index 42bcb00..bbfaf28 100644 (file)
@@ -45,9 +45,10 @@ from tvm.runtime.module import load_module as _load_module
 from tvm.contrib import util
 from . import _ffi_api
 from . import base
-from . base import TrackerCode
+from .base import TrackerCode
+
+logger = logging.getLogger("RPCServer")
 
-logger = logging.getLogger('RPCServer')
 
 def _server_env(load_library, work_path=None):
     """Server environment function return temp dir"""
@@ -78,6 +79,7 @@ def _server_env(load_library, work_path=None):
     temp.libs = libs
     return temp
 
+
 def _serve_loop(sock, addr, load_library, work_path=None):
     """Server loop"""
     sockfd = sock.fileno()
@@ -87,6 +89,7 @@ def _serve_loop(sock, addr, load_library, work_path=None):
         temp.remove()
     logger.info("Finish serving %s", addr)
 
+
 def _parse_server_opt(opts):
     # parse client options
     ret = {}
@@ -95,8 +98,10 @@ def _parse_server_opt(opts):
             ret["timeout"] = float(kv[9:])
     return ret
 
+
 def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
     """Listening loop of the server master."""
+
     def _accept_conn(listen_sock, tracker_conn, ping_period=2):
         """Accept connection from the other places.
 
@@ -115,8 +120,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
         # Report resource to tracker
         if tracker_conn:
             matchkey = base.random_key(rpc_key + ":")
-            base.sendjson(tracker_conn,
-                          [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
+            base.sendjson(tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
             assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
         else:
             matchkey = rpc_key
@@ -141,9 +145,9 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
                     if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
                         logger.info("no incoming connections, regenerate key ...")
                         matchkey = base.random_key(rpc_key + ":", old_keyset)
-                        base.sendjson(tracker_conn,
-                                      [TrackerCode.PUT, rpc_key, (port, matchkey),
-                                       custom_addr])
+                        base.sendjson(
+                            tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]
+                        )
                         assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
                         unmatch_period_count = 0
                     continue
@@ -179,9 +183,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
                 if magic != base.RPC_TRACKER_MAGIC:
                     raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
                 # report status of current queue
-                cinfo = {"key" : "server:" + rpc_key}
-                base.sendjson(tracker_conn,
-                              [TrackerCode.UPDATE_INFO, cinfo])
+                cinfo = {"key": "server:" + rpc_key}
+                base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
                 assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
 
             # step 2: wait for in-coming connections
@@ -198,8 +201,9 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
         # step 3: serving
         work_path = util.tempdir()
         logger.info("connection from %s", addr)
-        server_proc = multiprocessing.Process(target=_serve_loop,
-                                              args=(conn, addr, load_library, work_path))
+        server_proc = multiprocessing.Process(
+            target=_serve_loop, args=(conn, addr, load_library, work_path)
+        )
         server_proc.deamon = True
         server_proc.start()
         # close from our side.
@@ -210,6 +214,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
             logger.info("Timeout in RPC session, kill..")
             # pylint: disable=import-outside-toplevel
             import psutil
+
             parent = psutil.Process(server_proc.pid)
             # terminate worker childs
             for child in parent.children(recursive=True):
@@ -243,8 +248,7 @@ def _connect_proxy_loop(addr, key, load_library):
             remote_key = py_str(base.recvall(sock, keylen))
             opts = _parse_server_opt(remote_key.split()[1:])
             logger.info("connected to %s", str(addr))
-            process = multiprocessing.Process(
-                target=_serve_loop, args=(sock, addr, load_library))
+            process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
             process.deamon = True
             process.start()
             sock.close()
@@ -260,11 +264,9 @@ def _connect_proxy_loop(addr, key, load_library):
                 raise RuntimeError("Maximum retry error: last error: %s" % str(err))
             time.sleep(retry_period)
 
+
 def _popen(cmd):
-    proc = subprocess.Popen(cmd,
-                            stdout=subprocess.PIPE,
-                            stderr=subprocess.STDOUT,
-                            env=os.environ)
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ)
     (out, _) = proc.communicate()
     if proc.returncode != 0:
         msg = "Server invoke error:\n"
@@ -316,20 +318,22 @@ class Server(object):
     silent: bool, optional
         Whether run this server in silent mode.
     """
-    def __init__(self,
-                 host,
-                 port=9091,
-                 port_end=9199,
-                 is_proxy=False,
-                 use_popen=False,
-                 tracker_addr=None,
-                 key="",
-                 load_library=None,
-                 custom_addr=None,
-                 silent=False,
-                 utvm_dev_id=None,
-                 utvm_dev_config_args=None,
-                 ):
+
+    def __init__(
+        self,
+        host,
+        port=9091,
+        port_end=9199,
+        is_proxy=False,
+        use_popen=False,
+        tracker_addr=None,
+        key="",
+        load_library=None,
+        custom_addr=None,
+        silent=False,
+        utvm_dev_id=None,
+        utvm_dev_config_args=None,
+    ):
         try:
             if _ffi_api.ServerLoop is None:
                 raise RuntimeError("Please compile with USE_RPC=1")
@@ -345,15 +349,17 @@ class Server(object):
             logger.setLevel(logging.ERROR)
 
         if use_popen:
-            cmd = [sys.executable,
-                   "-m", "tvm.exec.rpc_server",
-                   "--host=%s" % host,
-                   "--port=%s" % port,
-                   "--port-end=%s" % port_end]
+            cmd = [
+                sys.executable,
+                "-m",
+                "tvm.exec.rpc_server",
+                "--host=%s" % host,
+                "--port=%s" % port,
+                "--port-end=%s" % port_end,
+            ]
             if tracker_addr:
                 assert key
-                cmd += ["--tracker=%s:%d" % tracker_addr,
-                        "--key=%s" % key]
+                cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key]
             if load_library:
                 cmd += ["--load-library", load_library]
             if custom_addr:
@@ -397,14 +403,15 @@ class Server(object):
             sock.listen(1)
             self.sock = sock
             self.proc = multiprocessing.Process(
-                target=_listen_loop, args=(
-                    self.sock, self.port, key, tracker_addr, load_library,
-                    self.custom_addr))
+                target=_listen_loop,
+                args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
+            )
             self.proc.deamon = True
             self.proc.start()
         else:
             self.proc = multiprocessing.Process(
-                target=_connect_proxy_loop, args=((host, port), key, load_library))
+                target=_connect_proxy_loop, args=((host, port), key, load_library)
+            )
             self.proc.deamon = True
             self.proc.start()
 
index fd0d906..7801dec 100644 (file)
@@ -20,6 +20,7 @@ import socket
 import errno
 from tornado import ioloop
 
+
 class TCPHandler(object):
     """TCP socket handler backed tornado event loop.
 
@@ -28,17 +29,20 @@ class TCPHandler(object):
     sock : Socket
         The TCP socket, will set it to non-blocking mode.
     """
+
     def __init__(self, sock):
         self._sock = sock
         self._ioloop = ioloop.IOLoop.current()
         self._sock.setblocking(0)
         self._pending_write = []
         self._signal_close = False
+
         def _event_handler(_, events):
             self._event_handler(events)
+
         self._ioloop.add_handler(
-            self._sock.fileno(), _event_handler,
-            self._ioloop.READ | self._ioloop.ERROR)
+            self._sock.fileno(), _event_handler, self._ioloop.READ | self._ioloop.ERROR
+        )
 
     def signal_close(self):
         """Signal the handler to close.
@@ -96,13 +100,15 @@ class TCPHandler(object):
 
         if self._pending_write:
             self._ioloop.update_handler(
-                self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE)
+                self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE
+            )
         else:
             if self._signal_close:
                 self.close()
             else:
                 self._ioloop.update_handler(
-                    self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR)
+                    self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR
+                )
 
     def _update_read(self):
         """Update state when there is read event"""
index e3346b1..557c9ae 100644 (file)
@@ -55,7 +55,8 @@ try:
     from . import tornado_util
 except ImportError as error_msg:
     raise ImportError(
-        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
+        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg
+    )
 
 from .._ffi.base import py_str
 from . import base
@@ -63,8 +64,10 @@ from .base import RPC_TRACKER_MAGIC, TrackerCode
 
 logger = logging.getLogger("RPCTracker")
 
+
 class Scheduler(object):
     """Abstratc interface of scheduler."""
+
     def put(self, value):
         """Push a resource into the scheduler.
 
@@ -103,7 +106,6 @@ class Scheduler(object):
             The resource to remove
         """
 
-
     def summary(self):
         """Get summary information of the scheduler."""
         raise NotImplementedError()
@@ -111,6 +113,7 @@ class Scheduler(object):
 
 class PriorityScheduler(Scheduler):
     """Priority based scheduler, FIFO based on time"""
+
     def __init__(self, key):
         self._key = key
         self._values = []
@@ -141,8 +144,7 @@ class PriorityScheduler(Scheduler):
 
     def summary(self):
         """Get summary information of the scheduler."""
-        return {"free": len(self._values),
-                "pending": len(self._requests)}
+        return {"free": len(self._values), "pending": len(self._requests)}
 
 
 class TCPEventHandler(tornado_util.TCPHandler):
@@ -152,6 +154,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
     The message is in form [nbytes(int32)] [json-str].
     All the information is packed in json-str
     """
+
     def __init__(self, tracker, sock, addr):
         super(TCPEventHandler, self).__init__(sock)
         self._data = bytearray()
@@ -178,11 +181,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
         if len(message) != 4:
             logger.warning("Invalid connection from %s", self.name())
             self.close()
-        magic = struct.unpack('<i', message)[0]
+        magic = struct.unpack("<i", message)[0]
         if magic != RPC_TRACKER_MAGIC:
             logger.warning("Invalid magic from %s", self.name())
             self.close()
-        self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
+        self.write_message(struct.pack("<i", RPC_TRACKER_MAGIC), binary=True)
         self._init_req_nbytes = 0
 
     def on_message(self, message):
@@ -203,12 +206,12 @@ class TCPEventHandler(tornado_util.TCPHandler):
         while True:
             if self._msg_size == 0:
                 if len(self._data) >= 4:
-                    self._msg_size = struct.unpack('<i', self._data[:4])[0]
+                    self._msg_size = struct.unpack("<i", self._data[:4])[0]
                 else:
                     return
             if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
-                msg = py_str(bytes(self._data[4:4 + self._msg_size]))
-                del self._data[:4 + self._msg_size]
+                msg = py_str(bytes(self._data[4 : 4 + self._msg_size]))
+                del self._data[: 4 + self._msg_size]
                 self._msg_size = 0
                 # pylint: disable=broad-except
                 self.call_handler(json.loads(msg))
@@ -218,8 +221,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
     def ret_value(self, data):
         """return value to the output"""
         data = json.dumps(data)
-        self.write_message(
-            struct.pack('<i', len(data)), binary=True)
+        self.write_message(struct.pack("<i", len(data)), binary=True)
         self.write_message(data.encode("utf-8"), binary=True)
 
     def call_handler(self, args):
@@ -241,6 +243,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
             key = args[1]
             user = args[2]
             priority = args[3]
+
             def _cb(value):
                 # if the connection is already closed
                 if not self._sock:
@@ -250,6 +253,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
                 except (socket.error, IOError):
                     return False
                 return True
+
             self._tracker.request(key, user, priority, _cb)
         elif code == TrackerCode.PING:
             self.ret_value(TrackerCode.SUCCESS)
@@ -282,6 +286,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
 
 class TrackerServerHandler(object):
     """Tracker that tracks the resources."""
+
     def __init__(self, sock, stop_key):
         self._scheduler_map = {}
         self._sock = sock
@@ -289,10 +294,11 @@ class TrackerServerHandler(object):
         self._ioloop = ioloop.IOLoop.current()
         self._stop_key = stop_key
         self._connections = set()
+
         def _event_handler(_, events):
             self._on_event(events)
-        self._ioloop.add_handler(
-            self._sock.fileno(), _event_handler, self._ioloop.READ)
+
+        self._ioloop.add_handler(self._sock.fileno(), _event_handler, self._ioloop.READ)
 
     def _on_event(self, _):
         while True:
@@ -321,8 +327,8 @@ class TrackerServerHandler(object):
 
     def close(self, conn):
         self._connections.remove(conn)
-        if 'key' in conn._info:
-            key = conn._info['key'].split(':')[1]  # 'server:rasp3b' -> 'rasp3b'
+        if "key" in conn._info:
+            key = conn._info["key"].split(":")[1]  # 'server:rasp3b' -> 'rasp3b'
             for value in conn.put_values:
                 self._scheduler_map[key].remove(value)
 
@@ -350,6 +356,7 @@ class TrackerServerHandler(object):
         """Run the tracker server"""
         self._ioloop.start()
 
+
 def _tracker_server(listen_sock, stop_key):
     handler = TrackerServerHandler(listen_sock, stop_key)
     handler.run()
@@ -374,11 +381,8 @@ class Tracker(object):
     silent: bool, optional
         Whether run in silent mode
     """
-    def __init__(self,
-                 host,
-                 port=9190,
-                 port_end=9199,
-                 silent=False):
+
+    def __init__(self, host, port=9190, port_end=9199, silent=False):
         if silent:
             logger.setLevel(logging.WARN)
 
@@ -398,8 +402,7 @@ class Tracker(object):
             raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
         logger.info("bind to %s:%d", host, self.port)
         sock.listen(1)
-        self.proc = multiprocessing.Process(
-            target=_tracker_server, args=(sock, self.stop_key))
+        self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key))
         self.proc.start()
         self.host = host
         # close the socket on this process
index 64e0a93..11d317b 100644 (file)
@@ -36,13 +36,11 @@ def NodeGetAttr(obj, name):
 
 
 def SaveJSON(obj):
-    raise RuntimeError(
-        "Do not support object serialization in runtime only mode")
+    raise RuntimeError("Do not support object serialization in runtime only mode")
 
 
 def LoadJSON(json_str):
-    raise RuntimeError(
-        "Do not support object serialization in runtime only mode")
+    raise RuntimeError("Do not support object serialization in runtime only mode")
 
 
 # Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix.
index ae87534..63383e7 100644 (file)
@@ -54,8 +54,7 @@ def getitem_helper(obj, elem_getter, length, idx):
         return [elem_getter(obj, i) for i in range(start, stop, step)]
 
     if idx < -length or idx >= length:
-        raise IndexError("Index out of range. size: {}, got index {}"
-                         .format(length, idx))
+        raise IndexError("Index out of range. size: {}, got index {}".format(length, idx))
     if idx < 0:
         idx += length
     return elem_getter(obj, idx)
@@ -73,20 +72,20 @@ class ADT(Object):
     fields : list[Object] or tuple[Object]
         The source tuple.
     """
+
     def __init__(self, tag, fields):
         for f in fields:
-            assert isinstance(f, ObjectTypes), "Expect object or " \
-            "tvm NDArray type, but received : {0}".format(type(f))
-        self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
-                                            *fields)
+            assert isinstance(
+                f, ObjectTypes
+            ), "Expect object or " "tvm NDArray type, but received : {0}".format(type(f))
+        self.__init_handle_by_constructor__(_ffi_api.ADT, tag, *fields)
 
     @property
     def tag(self):
         return _ffi_api.GetADTTag(self)
 
     def __getitem__(self, idx):
-        return getitem_helper(
-            self, _ffi_api.GetADTFields, len(self), idx)
+        return getitem_helper(self, _ffi_api.GetADTFields, len(self), idx)
 
     def __len__(self):
         return _ffi_api.GetADTSize(self)
@@ -107,8 +106,9 @@ def tuple_object(fields=None):
     """
     fields = fields if fields else []
     for f in fields:
-        assert isinstance(f, ObjectTypes), "Expect object or tvm " \
-        "NDArray type, but received : {0}".format(type(f))
+        assert isinstance(
+            f, ObjectTypes
+        ), "Expect object or tvm " "NDArray type, but received : {0}".format(type(f))
     return _ffi_api.Tuple(*fields)
 
 
@@ -121,6 +121,7 @@ class String(str, PyNativeObject):
     content : str
         The content string used to construct the object.
     """
+
     __slots__ = ["__tvm_object__"]
 
     def __new__(cls, content):
index 754bb6f..d9166b5 100644 (file)
@@ -35,6 +35,7 @@ ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
 
 class Module(object):
     """Runtime Module."""
+
     __slots__ = ["handle", "_entry", "entry_name"]
 
     def __init__(self, handle):
@@ -79,13 +80,13 @@ class Module(object):
             The result function.
         """
         ret_handle = PackedFuncHandle()
-        check_call(_LIB.TVMModGetFunction(
-            self.handle, c_str(name),
-            ctypes.c_int(query_imports),
-            ctypes.byref(ret_handle)))
+        check_call(
+            _LIB.TVMModGetFunction(
+                self.handle, c_str(name), ctypes.c_int(query_imports), ctypes.byref(ret_handle)
+            )
+        )
         if not ret_handle.value:
-            raise AttributeError(
-                "Module has no function '%s'" %  name)
+            raise AttributeError("Module has no function '%s'" % name)
         return PackedFunc(ret_handle, False)
 
     def import_module(self, module):
@@ -163,7 +164,7 @@ class Module(object):
         """
         _ffi_api.ModuleSaveToFile(self, file_name, fmt)
 
-    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0, f_preproc=''):
+    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0, f_preproc=""):
         """Get an evaluator that measures time cost of running function.
 
         Parameters
@@ -208,8 +209,15 @@ class Module(object):
         """
         try:
             feval = _ffi_api.RPCTimeEvaluator(
-                self, func_name, ctx.device_type, ctx.device_id,
-                number, repeat, min_repeat_ms, f_preproc)
+                self,
+                func_name,
+                ctx.device_type,
+                ctx.device_id,
+                number,
+                repeat,
+                min_repeat_ms,
+                f_preproc,
+            )
 
             def evaluator(*args):
                 """Internal wrapped evaluator."""
@@ -243,11 +251,7 @@ class Module(object):
     def _dso_exportable(self):
         return self.type_key == "llvm" or self.type_key == "c"
 
-    def export_library(self,
-                       file_name,
-                       fcompile=None,
-                       addons=None,
-                       **kwargs):
+    def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
         """Export the module and its imported device code one library.
 
         This function only works on host llvm modules.
@@ -279,8 +283,10 @@ class Module(object):
 
         if self.type_key == "stackvm":
             if not file_name.endswith(".stackvm"):
-                raise ValueError("Module[%s]: can only be saved as stackvm format."
-                                 "did you build with LLVM enabled?" % self.type_key)
+                raise ValueError(
+                    "Module[%s]: can only be saved as stackvm format."
+                    "did you build with LLVM enabled?" % self.type_key
+                )
             self.save(file_name)
             return
 
@@ -303,10 +309,12 @@ class Module(object):
             path_obj = temp.relpath("lib" + str(index) + "." + object_format)
             module.save(path_obj)
             files.append(path_obj)
-            is_system_lib = (module.type_key == "llvm" and
-                             module.get_function("__tvm_is_system_module")())
-            llvm_target_triple = (module.type_key == "llvm" and
-                                  module.get_function("_get_target_triple")())
+            is_system_lib = (
+                module.type_key == "llvm" and module.get_function("__tvm_is_system_module")()
+            )
+            llvm_target_triple = (
+                module.type_key == "llvm" and module.get_function("_get_target_triple")()
+            )
         if not fcompile:
             if file_name.endswith(".tar"):
                 fcompile = _tar.tar
@@ -337,7 +345,7 @@ class Module(object):
                 opts = kwargs["options"]
                 options = opts if isinstance(opts, (list, tuple)) else [opts]
             opts = options + ["-I" + path for path in find_include_path()]
-            kwargs.update({'options': opts})
+            kwargs.update({"options": opts})
 
         fcompile(file_name, files, **kwargs)
 
@@ -390,12 +398,14 @@ def load_module(path, fmt=""):
     if path.endswith(".o"):
         # Extra dependencies during runtime.
         from tvm.contrib import cc as _cc
+
         _cc.create_shared(path + ".so", path)
         path += ".so"
     elif path.endswith(".tar"):
         # Extra dependencies during runtime.
         from tvm.contrib import cc as _cc, util as _util, tar as _tar
-        tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
+
+        tar_temp = _util.tempdir(custom_path=path.replace(".tar", ""))
         _tar.untar(path, tar_temp.temp_dir)
         files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
         _cc.create_shared(path + ".so", files)
index 060673d..b0a3c74 100644 (file)
@@ -92,17 +92,19 @@ class NDArray(NDArrayBase):
 
     def __setitem__(self, in_slice, value):
         """Set ndarray value"""
-        if (not isinstance(in_slice, slice) or
-                in_slice.start is not None
-                or in_slice.stop is not None):
-            raise ValueError('Array only support set from numpy array')
+        if (
+            not isinstance(in_slice, slice)
+            or in_slice.start is not None
+            or in_slice.stop is not None
+        ):
+            raise ValueError("Array only support set from numpy array")
         if isinstance(value, NDArrayBase):
             if value.handle is not self.handle:
                 value.copyto(self)
         elif isinstance(value, (np.ndarray, np.generic)):
             self.copyfrom(value)
         else:
-            raise TypeError('type %s not supported' % str(type(value)))
+            raise TypeError("type %s not supported" % str(type(value)))
 
     def copyfrom(self, source_array):
         """Peform an synchronize copy from the array.
@@ -125,8 +127,10 @@ class NDArray(NDArrayBase):
             try:
                 source_array = np.array(source_array, dtype=self.dtype)
             except:
-                raise TypeError('array must be an array_like data,' +
-                                'type %s is not supported' % str(type(source_array)))
+                raise TypeError(
+                    "array must be an array_like data,"
+                    + "type %s is not supported" % str(type(source_array))
+                )
 
         t = DataType(self.dtype)
         shape, dtype = self.shape, self.dtype
@@ -136,10 +140,13 @@ class NDArray(NDArrayBase):
             dtype = str(t)
 
         if source_array.shape != shape:
-            raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
-                source_array.shape, shape))
+            raise ValueError(
+                "array shape do not match the shape of NDArray {0} vs {1}".format(
+                    source_array.shape, shape
+                )
+            )
         source_array = np.ascontiguousarray(source_array, dtype=dtype)
-        assert source_array.flags['C_CONTIGUOUS']
+        assert source_array.flags["C_CONTIGUOUS"]
         data = source_array.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
@@ -168,7 +175,7 @@ class NDArray(NDArrayBase):
             t.lanes = 1
             dtype = str(t)
         np_arr = np.empty(shape, dtype=dtype)
-        assert np_arr.flags['C_CONTIGUOUS']
+        assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
@@ -218,8 +225,8 @@ def context(dev_type, dev_id=0):
       assert tvm.context("cuda", 0) == tvm.gpu(0)
     """
     if isinstance(dev_type, string_types):
-        if '-device=micro_dev' in dev_type:
-            dev_type = TVMContext.STR2MASK['micro_dev']
+        if "-device=micro_dev" in dev_type:
+            dev_type = TVMContext.STR2MASK["micro_dev"]
         else:
             dev_type = dev_type.split()[0]
             if dev_type not in TVMContext.STR2MASK:
@@ -229,10 +236,9 @@ def context(dev_type, dev_id=0):
 
 
 def numpyasarray(np_data):
-    """Return a TVMArray representation of a numpy array.
-    """
+    """Return a TVMArray representation of a numpy array."""
     data = np_data
-    assert data.flags['C_CONTIGUOUS']
+    assert data.flags["C_CONTIGUOUS"]
     arr = TVMArray()
     shape = c_array(tvm_shape_index_t, data.shape)
     arr.data = data.ctypes.data_as(ctypes.c_void_p)
@@ -268,14 +274,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
     ndim = ctypes.c_int(len(shape))
     handle = TVMArrayHandle()
     dtype = DataType(dtype)
-    check_call(_LIB.TVMArrayAlloc(
-        shape, ndim,
-        ctypes.c_int(dtype.type_code),
-        ctypes.c_int(dtype.bits),
-        ctypes.c_int(dtype.lanes),
-        ctx.device_type,
-        ctx.device_id,
-        ctypes.byref(handle)))
+    check_call(
+        _LIB.TVMArrayAlloc(
+            shape,
+            ndim,
+            ctypes.c_int(dtype.type_code),
+            ctypes.c_int(dtype.bits),
+            ctypes.c_int(dtype.lanes),
+            ctx.device_type,
+            ctx.device_id,
+            ctypes.byref(handle),
+        )
+    )
     return _make_array(handle, False, False)
 
 
@@ -329,6 +339,7 @@ def gpu(dev_id=0):
     """
     return TVMContext(2, dev_id)
 
+
 def rocm(dev_id=0):
     """Construct a ROCM device
 
@@ -502,5 +513,6 @@ def array(arr, ctx=cpu(0)):
         arr = np.array(arr)
     return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
 
+
 # Register back to FFI
 _set_class_ndarray(NDArray)
index 2a34b34..35f1f4e 100644 (file)
@@ -41,7 +41,9 @@ def _new_object(cls):
 
 class Object(ObjectBase):
     """Base class for all tvm's runtime objects."""
+
     __slots__ = []
+
     def __repr__(self):
         return _ffi_node_api.AsRepr(self)
 
@@ -55,8 +57,7 @@ class Object(ObjectBase):
         try:
             return _ffi_node_api.NodeGetAttr(self, name)
         except AttributeError:
-            raise AttributeError(
-                "%s has no attribute %s" % (str(type(self)), name))
+            raise AttributeError("%s has no attribute %s" % (str(type(self)), name))
 
     def __hash__(self):
         return _ffi_api.ObjectPtrHash(self)
@@ -69,21 +70,20 @@ class Object(ObjectBase):
 
     def __reduce__(self):
         cls = type(self)
-        return (_new_object, (cls, ), self.__getstate__())
+        return (_new_object, (cls,), self.__getstate__())
 
     def __getstate__(self):
         handle = self.handle
         if handle is not None:
-            return {'handle': _ffi_node_api.SaveJSON(self)}
-        return {'handle': None}
+            return {"handle": _ffi_node_api.SaveJSON(self)}
+        return {"handle": None}
 
     def __setstate__(self, state):
         # pylint: disable=assigning-non-slot, assignment-from-no-return
-        handle = state['handle']
+        handle = state["handle"]
         self.handle = None
         if handle is not None:
-            self.__init_handle_by_constructor__(
-                _ffi_node_api.LoadJSON, handle)
+            self.__init_handle_by_constructor__(_ffi_node_api.LoadJSON, handle)
 
     def _move(self):
         """Create an RValue reference to the object and mark the object as moved.
index 8f559ae..ae03ee9 100644 (file)
@@ -29,6 +29,7 @@ from .module import Module
 
 class ObjectGeneric(object):
     """Base class for all classes that can be converted to object."""
+
     def asobject(self):
         """Convert value to object"""
         raise NotImplementedError()
@@ -53,7 +54,7 @@ def convert_to_object(value):
     if isinstance(value, ObjectTypes):
         return value
     if isinstance(value, bool):
-        return const(value, 'uint1x1')
+        return const(value, "uint1x1")
     if isinstance(value, Number):
         return const(value)
     if isinstance(value, string_types):
@@ -64,8 +65,7 @@ def convert_to_object(value):
     if isinstance(value, dict):
         vlist = []
         for item in value.items():
-            if (not isinstance(item[0], ObjectTypes) and
-                    not isinstance(item[0], string_types)):
+            if not isinstance(item[0], ObjectTypes) and not isinstance(item[0], string_types):
                 raise ValueError("key of map must already been a container type")
             vlist.append(item[0])
             vlist.append(convert_to_object(item[1]))
@@ -100,21 +100,23 @@ def convert(value):
 
 
 def _scalar_type_inference(value):
-    if hasattr(value, 'dtype'):
+    if hasattr(value, "dtype"):
         dtype = str(value.dtype)
     elif isinstance(value, bool):
-        dtype = 'bool'
+        dtype = "bool"
     elif isinstance(value, float):
         # We intentionally convert the float to float32 since it's more common in DL.
-        dtype = 'float32'
+        dtype = "float32"
     elif isinstance(value, int):
         # We intentionally convert the python int to int32 since it's more common in DL.
-        dtype = 'int32'
+        dtype = "int32"
     else:
-        raise NotImplementedError('Cannot automatically inference the type.'
-                                  ' value={}'.format(value))
+        raise NotImplementedError(
+            "Cannot automatically inference the type." " value={}".format(value)
+        )
     return dtype
 
+
 def const(value, dtype=None):
     """construct a constant
 
@@ -134,8 +136,7 @@ def const(value, dtype=None):
     if dtype is None:
         dtype = _scalar_type_inference(value)
     if dtype == "uint64" and value >= (1 << 63):
-        return _ffi_node_api.LargeUIntImm(
-            dtype, value & ((1 << 32) - 1), value >> 32)
+        return _ffi_node_api.LargeUIntImm(dtype, value & ((1 << 32) - 1), value >> 32)
     return _ffi_node_api._const(value, dtype)
 
 
index af4265a..35a4783 100644 (file)
@@ -36,6 +36,7 @@ except (RuntimeError, ImportError):
 
 PackedFuncHandle = ctypes.c_void_p
 
+
 class PackedFunc(PackedFuncBase):
     """The PackedFunc object used in TVM.
 
@@ -58,4 +59,5 @@ class PackedFunc(PackedFuncBase):
     tvm.get_global_func: How to get global function.
     """
 
+
 _set_class_packed_func(PackedFunc)
index fbc7a7d..81a909b 100644 (file)
@@ -157,12 +157,16 @@ class Executable(object):
         if isinstance(bytecode, (bytes, str)):
             code = bytearray(bytecode)
         elif not isinstance(bytecode, (bytearray, TVMByteArray)):
-            raise TypeError("bytecode is expected to be the type of bytearray " +
-                            "or TVMByteArray, but received {}".format(type(code)))
+            raise TypeError(
+                "bytecode is expected to be the type of bytearray "
+                + "or TVMByteArray, but received {}".format(type(code))
+            )
 
         if lib is not None and not isinstance(lib, tvm.runtime.Module):
-            raise TypeError("lib is expected to be the type of tvm.runtime.Module" +
-                            ", but received {}".format(type(lib)))
+            raise TypeError(
+                "lib is expected to be the type of tvm.runtime.Module"
+                + ", but received {}".format(type(lib))
+            )
 
         return Executable(_ffi_api.Load_Executable(bytecode, lib))
 
@@ -296,8 +300,10 @@ class VirtualMachine(object):
 
     def __init__(self, exe, ctx, memory_cfg=None):
         if not isinstance(exe, Executable):
-            raise TypeError("exe is expected to be the type of Executable, " +
-                            "but received {}".format(type(exe)))
+            raise TypeError(
+                "exe is expected to be the type of Executable, "
+                + "but received {}".format(type(exe))
+            )
         self.module = _ffi_api._VirtualMachine(exe.module)
         self._exec = exe
         self._init = self.module["init"]
@@ -310,8 +316,10 @@ class VirtualMachine(object):
         ctxs = ctx
         if not isinstance(ctx, (list, tuple)):
             if not isinstance(ctx, tvm.runtime.TVMContext):
-                raise TypeError("ctx is expected to be TVMContext or \
-                                List[TVMContext]")
+                raise TypeError(
+                    "ctx is expected to be TVMContext or \
+                                List[TVMContext]"
+                )
             ctxs = [ctx]
 
         # CPU is required for executing shape functions
@@ -327,8 +335,10 @@ class VirtualMachine(object):
                 default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR
             memory_cfg = {}
         elif not isinstance(memory_cfg, dict):
-            raise TypeError("memory_cfg is expected be string or dictionary, " +
-                            "but received {}".format(type(memory_cfg)))
+            raise TypeError(
+                "memory_cfg is expected be string or dictionary, "
+                + "but received {}".format(type(memory_cfg))
+            )
         init_args = []
         for context in ctxs:
             init_args.append(context.device_type)
index c40296e..60fc659 100644 (file)
 
 
 ARM_ISA_MAP = {
-    'armv7e-m': ['SMLAD'],
+    "armv7e-m": ["SMLAD"],
 }
 
 
 class IsaAnalyzer(object):
-
     def __init__(self, target):
         self.target = target
         # TODO: actually parse -mcpu
-        arch = 'armv7e-m'
+        arch = "armv7e-m"
         self._isa_map = ARM_ISA_MAP[arch]
 
     def __contains__(self, instruction):
index 79ef46c..0ab4cb0 100644 (file)
@@ -1,4 +1,3 @@
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -74,5 +73,4 @@ def llvm_version_major(allow_none=False):
     except AttributeError:
         if allow_none:
             return None
-        raise RuntimeError(
-            "LLVM version is not available, please check if you build with LLVM")
+        raise RuntimeError("LLVM version is not available, please check if you build with LLVM")
index f93a943..5d3ca5f 100644 (file)
@@ -104,11 +104,11 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
 
     if op_name == "Cast":
         assert src_type_name is not None
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
-                          + type_name + "." + src_type_name
+        lower_func_name = (
+            "tvm.datatype.lower." + target + "." + op_name + "." + type_name + "." + src_type_name
+        )
     else:
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
-                          + type_name
+        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + type_name
     tvm._ffi.register_func(lower_func_name, lower_func)
 
 
index 1936ff1..932eaa4 100644 (file)
@@ -26,7 +26,7 @@ except ImportError:
         raise
 
 from tvm.runtime import Object
-from . target import Target
+from .target import Target
 from . import _ffi_api
 
 
@@ -41,6 +41,7 @@ class GenericFunc(Object):
     Do not construct an instance of this object, it should only ever be
     used as a return value from calling into C++.
     """
+
     def __call__(self, *args):
         return _ffi_api.GenericFuncCallFunc(self, *args)
 
@@ -167,27 +168,33 @@ def override_native_generic_func(func_name):
             -------
             The register function is necessary.
             """
+
             def _do_reg(myf):
                 generic_func_node.register(myf, key, override)
                 return myf
+
             if func:
                 return _do_reg(func)
             return _do_reg
 
         def dispatch_func(func, *args, **kwargs):
-            #pylint: disable=unused-argument
+            # pylint: disable=unused-argument
             """The wrapped dispath function"""
             if kwargs:
                 raise RuntimeError(
-                    "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
+                    "Keyword arguments cannot be used when invoking generic_func %s" % func_name
+                )
             return generic_func_node(*args)
+
         fresult = decorate(fdefault, dispatch_func)
         fresult.fdefault = fdefault
         fresult.register = register
         fresult.generic_func_node = generic_func_node
         return fresult
+
     return fdecorate
 
+
 def generic_func(fdefault):
     """Wrap a target generic function.
 
@@ -245,14 +252,15 @@ def generic_func(fdefault):
         -------
         The register function is necessary.
         """
+
         def _do_reg(myf):
             key_list = [key] if isinstance(key, str) else key
             for k in key_list:
                 if k in dispatch_dict and not override:
-                    raise ValueError(
-                        "Key is already registered for %s" % func_name)
+                    raise ValueError("Key is already registered for %s" % func_name)
                 dispatch_dict[k] = myf
             return myf
+
         if func:
             return _do_reg(func)
         return _do_reg
@@ -266,6 +274,7 @@ def generic_func(fdefault):
             if k in dispatch_dict:
                 return dispatch_dict[k](*args, **kwargs)
         return func(*args, **kwargs)
+
     fdecorate = decorate(fdefault, dispatch_func)
     fdecorate.register = register
     fdecorate.fdefault = fdefault
index c6a8f71..97cbf1e 100644 (file)
@@ -72,7 +72,10 @@ list_tags()
 
 # We purposely maintain all tags in the C++ side to support pure C++ use cases,
 # and the Python API is only used for fast prototyping.
-register_tag("nvidia/gtx1080ti", config={
-    "kind": "cuda",
-    "arch": "sm_61",
-})
+register_tag(
+    "nvidia/gtx1080ti",
+    config={
+        "kind": "cuda",
+        "arch": "sm_61",
+    },
+)
index 54b354f..1476f7b 100644 (file)
@@ -28,8 +28,7 @@ from . import _ffi_api
 
 @tvm._ffi.register_object
 class TargetKind(Object):
-    """Kind of a compilation target
-    """
+    """Kind of a compilation target"""
 
 
 @tvm._ffi.register_object
@@ -90,8 +89,7 @@ class Target(Object):
         """
         if not isinstance(tag_or_str_or_dict, (dict, str, Target)):
             raise ValueError("target has to be a string or dictionary.")
-        self.__init_handle_by_constructor__(
-            _ffi_api.Target, tag_or_str_or_dict)
+        self.__init_handle_by_constructor__(_ffi_api.Target, tag_or_str_or_dict)
 
     def __enter__(self):
         _ffi_api.TargetEnterScope(self)
@@ -152,6 +150,7 @@ class Target(Object):
 
 # TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead.
 
+
 def _merge_opts(opts, new_opts):
     """Helper function to merge options"""
     if isinstance(new_opts, str):
@@ -163,7 +162,7 @@ def _merge_opts(opts, new_opts):
     return opts
 
 
-def cuda(model='unknown', options=None):
+def cuda(model="unknown", options=None):
     """Returns a cuda target.
 
     Parameters
@@ -173,11 +172,11 @@ def cuda(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = _merge_opts(['-model=%s' % model], options)
+    opts = _merge_opts(["-model=%s" % model], options)
     return Target(" ".join(["cuda"] + opts))
 
 
-def rocm(model='unknown', options=None):
+def rocm(model="unknown", options=None):
     """Returns a ROCM target.
 
     Parameters
@@ -191,7 +190,7 @@ def rocm(model='unknown', options=None):
     return Target(" ".join(["rocm"] + opts))
 
 
-def mali(model='unknown', options=None):
+def mali(model="unknown", options=None):
     """Returns a ARM Mali GPU target.
 
     Parameters
@@ -201,12 +200,12 @@ def mali(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = ["-device=mali", '-model=%s' % model]
+    opts = ["-device=mali", "-model=%s" % model]
     opts = _merge_opts(opts, options)
     return Target(" ".join(["opencl"] + opts))
 
 
-def intel_graphics(model='unknown', options=None):
+def intel_graphics(model="unknown", options=None):
     """Returns an Intel Graphics target.
 
     Parameters
@@ -216,13 +215,12 @@ def intel_graphics(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = ["-device=intel_graphics", "-model=%s" %
-            model, "-thread_warp_size=16"]
+    opts = ["-device=intel_graphics", "-model=%s" % model, "-thread_warp_size=16"]
     opts = _merge_opts(opts, options)
     return Target(" ".join(["opencl"] + opts))
 
 
-def arm_cpu(model='unknown', options=None):
+def arm_cpu(model="unknown", options=None):
     """Returns a ARM CPU target.
     This function will also download pre-tuned op parameters when there is none.
 
@@ -234,19 +232,27 @@ def arm_cpu(model='unknown', options=None):
         Additional options
     """
     trans_table = {
-        "pixel2":    ["-model=snapdragon835", "-mtriple=arm64-linux-android", "-mattr=+neon"],
-        "mate10":    ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
+        "pixel2": ["-model=snapdragon835", "-mtriple=arm64-linux-android", "-mattr=+neon"],
+        "mate10": ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
         "mate10pro": ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
-        "p20":       ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
-        "p20pro":    ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
-        "rasp3b":    ["-model=bcm2837", "-mtriple=armv7l-linux-gnueabihf", "-mattr=+neon"],
-        "rasp4b":    ["-model=bcm2711", "-mtriple=armv8l-linux-gnueabihf", "-mattr=+neon",
-                      "-mcpu=cortex-a72"],
-        "rasp4b64":  ["-model=bcm2711", "-mtriple=aarch64-linux-gnu", "-mattr=+neon",
-                      "-mcpu=cortex-a72"],
-        "rk3399":    ["-model=rk3399", "-mtriple=aarch64-linux-gnu", "-mattr=+neon"],
-        "pynq":      ["-model=pynq", "-mtriple=armv7a-linux-eabi", "-mattr=+neon"],
-        "ultra96":   ["-model=ultra96", "-mtriple=aarch64-linux-gnu", "-mattr=+neon"],
+        "p20": ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
+        "p20pro": ["-model=kirin970", "-mtriple=arm64-linux-android", "-mattr=+neon"],
+        "rasp3b": ["-model=bcm2837", "-mtriple=armv7l-linux-gnueabihf", "-mattr=+neon"],
+        "rasp4b": [
+            "-model=bcm2711",
+            "-mtriple=armv8l-linux-gnueabihf",
+            "-mattr=+neon",
+            "-mcpu=cortex-a72",
+        ],
+        "rasp4b64": [
+            "-model=bcm2711",
+            "-mtriple=aarch64-linux-gnu",
+            "-mattr=+neon",
+            "-mcpu=cortex-a72",
+        ],
+        "rk3399": ["-model=rk3399", "-mtriple=aarch64-linux-gnu", "-mattr=+neon"],
+        "pynq": ["-model=pynq", "-mtriple=armv7a-linux-eabi", "-mattr=+neon"],
+        "ultra96": ["-model=ultra96", "-mtriple=aarch64-linux-gnu", "-mattr=+neon"],
     }
     pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
 
@@ -263,18 +269,19 @@ def rasp(options=None):
     options : str or list of str
         Additional options
     """
-    warnings.warn('tvm.target.rasp() is going to be deprecated. '
-                  'Please use tvm.target.arm_cpu("rasp3b")')
-    return arm_cpu('rasp3b', options)
+    warnings.warn(
+        "tvm.target.rasp() is going to be deprecated. " 'Please use tvm.target.arm_cpu("rasp3b")'
+    )
+    return arm_cpu("rasp3b", options)
 
 
-def vta(model='unknown', options=None):
-    opts = ["-device=vta", '-keys=vta,cpu', '-model=%s' % model]
+def vta(model="unknown", options=None):
+    opts = ["-device=vta", "-keys=vta,cpu", "-model=%s" % model]
     opts = _merge_opts(opts, options)
     return Target(" ".join(["ext_dev"] + opts))
 
 
-def bifrost(model='unknown', options=None):
+def bifrost(model="unknown", options=None):
     """Return an ARM Mali GPU target (Bifrost architecture).
 
     Parameters
@@ -282,12 +289,12 @@ def bifrost(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = ["-device=bifrost", '-model=%s' % model]
+    opts = ["-device=bifrost", "-model=%s" % model]
     opts = _merge_opts(opts, options)
     return Target(" ".join(["opencl"] + opts))
 
 
-def hexagon(cpu_ver='v66', sim_args=None, llvm_args=None, hvx=128):
+def hexagon(cpu_ver="v66", sim_args=None, llvm_args=None, hvx=128):
     """Returns a Hexagon target.
 
     Parameters
@@ -309,87 +316,95 @@ def hexagon(cpu_ver='v66', sim_args=None, llvm_args=None, hvx=128):
     # llvm -mtriple=hexagon -mcpu=hexagonv66 -mattr=+hvxv66,+hvx-length128b
 
     # Check for valid codegen cpu
-    valid_hex = ['v60', 'v62', 'v65', 'v66', 'v67', 'v67t']
+    valid_hex = ["v60", "v62", "v65", "v66", "v67", "v67t"]
     try:
-        cpu_ver = cpu_ver[cpu_ver.index('v'):].lower()
+        cpu_ver = cpu_ver[cpu_ver.index("v") :].lower()
         assert 3 <= len(cpu_ver) <= 4
     except:
-        msg = '{} is not a valid Hexagon version\nvalid versions include {}'
+        msg = "{} is not a valid Hexagon version\nvalid versions include {}"
         raise ValueError(msg.format(cpu_ver, valid_hex)) from None
 
     assert hvx in [0, 64, 128]
 
     # Target string
     def create_target(cpu_ver):
-        target = ' -mtriple=hexagon'
-        mcpu = ' -mcpu=hexagon' + cpu_ver
-        mattr = ''
+        target = " -mtriple=hexagon"
+        mcpu = " -mcpu=hexagon" + cpu_ver
+        mattr = ""
         # HVX enable
         if hvx:
-            mattr = ' -mattr=+hvx' + cpu_ver + ',+hvx-length' + str(hvx) + 'b'
+            mattr = " -mattr=+hvx" + cpu_ver + ",+hvx-length" + str(hvx) + "b"
         return target + mcpu + mattr
 
     # Simulator string
     def create_sim(cpu_ver, sim_args):
         def validate_hvx_length(codegen_hvx, sim_args):
-            if sim_args and '--hvx_length' in sim_args:
+            if sim_args and "--hvx_length" in sim_args:
                 # If --hvx_length was specified, check HVX length of sim
                 # vs codegen
-                i = sim_args.index('hvx_length') + len('hvx_length') + 1
-                sim_hvx = sim_args[i:i+3]
+                i = sim_args.index("hvx_length") + len("hvx_length") + 1
+                sim_hvx = sim_args[i : i + 3]
                 if sim_hvx != str(codegen_hvx):
-                    print('WARNING: sim hvx {} and codegen hvx {} mismatch!'
-                          .format(sim_hvx, codegen_hvx))
+                    print(
+                        "WARNING: sim hvx {} and codegen hvx {} mismatch!".format(
+                            sim_hvx, codegen_hvx
+                        )
+                    )
             elif codegen_hvx != 0:
                 # If --hvx_length was not given, add it if HVX is enabled
-                sim_args = sim_args + ' ' if isinstance(sim_args, str) else ''
-                sim_args += '--hvx_length ' + str(codegen_hvx)
-            return sim_args or ''
+                sim_args = sim_args + " " if isinstance(sim_args, str) else ""
+                sim_args += "--hvx_length " + str(codegen_hvx)
+            return sim_args or ""
 
         if not sim_args:
-            return cpu_ver + ' ' + validate_hvx_length(hvx, sim_args)
+            return cpu_ver + " " + validate_hvx_length(hvx, sim_args)
 
-        sim_cpu = cpu_ver + ' '
+        sim_cpu = cpu_ver + " "
 
         # Add user defined args
         if isinstance(sim_args, list):
-            sim_args = ' '.join(sim_args)
+            sim_args = " ".join(sim_args)
 
         # Check for supplied sim cpu version
-        if 'v6' in sim_args:
-            sim_cpu = ''
+        if "v6" in sim_args:
+            sim_cpu = ""
 
             # Regex match for allowed cpus
-            valid_cpu_str_regex = r'(?P<pre>--.*\s)?(--m)?' +                 \
-                r'(?P<base_version>v6[25678])(?P<sub_version>[a-z])?' +       \
-                r'(?P<l2_size>_[0-9]+)?(?P<rev>_rev[0-9])?\s?(?P<post>--.*)?'
+            valid_cpu_str_regex = (
+                r"(?P<pre>--.*\s)?(--m)?"
+                + r"(?P<base_version>v6[25678])(?P<sub_version>[a-z])?"
+                + r"(?P<l2_size>_[0-9]+)?(?P<rev>_rev[0-9])?\s?(?P<post>--.*)?"
+            )
             m = re.match(valid_cpu_str_regex, sim_args.lower())
             if not m:
-                raise ValueError(
-                    'Invalid simulator argument string "{}"'.format(sim_args))
+                raise ValueError('Invalid simulator argument string "{}"'.format(sim_args))
 
             # Parse options into correct order
-            cpu_attr = {x: str(m.groupdict()[x] or '') for x in m.groupdict()}
-            sim_args = cpu_attr['base_version'] +  \
-                cpu_attr['sub_version'] +  \
-                cpu_attr['l2_size'] +       \
-                cpu_attr['rev'] + ' ' +     \
-                cpu_attr['pre'] + cpu_attr['post']
-
-        return sim_cpu + ' ' + validate_hvx_length(hvx, sim_args)
+            cpu_attr = {x: str(m.groupdict()[x] or "") for x in m.groupdict()}
+            sim_args = (
+                cpu_attr["base_version"]
+                + cpu_attr["sub_version"]
+                + cpu_attr["l2_size"]
+                + cpu_attr["rev"]
+                + " "
+                + cpu_attr["pre"]
+                + cpu_attr["post"]
+            )
+
+        return sim_cpu + " " + validate_hvx_length(hvx, sim_args)
 
     # LLVM string
     def create_llvm(llvm_args):
         # TVM's option parser doesn't allow '=' in values, but '=' can
         # appear in LLVM flags. Replace it with '@', since it's unlikely
         # that '@' will be used in another context.
-        if llvm_args is None or len(llvm_args.replace(' ', '')) == 0:
-            return ''
-        args = [s.replace('=', '@') for s in llvm_args.split()]
-        return '--llvm-options=' + ','.join(args)
+        if llvm_args is None or len(llvm_args.replace(" ", "")) == 0:
+            return ""
+        args = [s.replace("=", "@") for s in llvm_args.split()]
+        return "--llvm-options=" + ",".join(args)
 
     # Sim args
-    os.environ['HEXAGON_SIM_ARGS'] = create_sim(cpu_ver, sim_args)
+    os.environ["HEXAGON_SIM_ARGS"] = create_sim(cpu_ver, sim_args)
 
     target_str = create_target(cpu_ver)
     llvm_str = create_llvm(llvm_args)
@@ -399,10 +414,8 @@ def hexagon(cpu_ver='v66', sim_args=None, llvm_args=None, hvx=128):
 
 
 def create(target):
-    """Deprecated. Use the constructor of :py:mod:`tvm.target.Target` directly.
-    """
-    warnings.warn(
-        'tvm.target.create() is being deprecated. Please use tvm.target.Target() instead')
+    """Deprecated. Use the constructor of :py:mod:`tvm.target.Target` directly."""
+    warnings.warn("tvm.target.create() is being deprecated. Please use tvm.target.Target() instead")
     return Target(target)
 
 
index afdedfb..3cd1b01 100644 (file)
@@ -52,6 +52,7 @@ def script(pyfunc):
     # pylint: disable=import-outside-toplevel, missing-docstring
     def wrapped_func(func, *args, **kwargs):
         from .util import _is_tvm_arg_types
+
         if _is_tvm_arg_types(args):
             src = _pruned_source(func)
             closure_vars = inspect.getclosurevars(func).nonlocals
@@ -59,6 +60,7 @@ def script(pyfunc):
             return source_to_op(src, args, func.__globals__, closure_vars)
 
         from .runtime import _enter_hybrid_runtime, _restore_runtime
+
         intersect = _enter_hybrid_runtime(func)
         value = func(*args, **kwargs)
         _restore_runtime(func, intersect)
index 2c2f2bf..2e7fc2b 100644 (file)
@@ -27,14 +27,14 @@ from tvm.tir.stmt import For
 
 from .util import _internal_assert
 
-# pylint: disable=redefined-builtin
+# pylint: disable=redefined-builtin,invalid-name
 
 LOOP_INTRIN = {
-    'range'       : For.Serial,
-    'unroll'      : For.Unrolled,
-    'parallel'    : For.Parallel,
-    'vectorize'   : For.Vectorized,
-    'const_range' : (For.Unrolled, ),
+    "range": For.Serial,
+    "unroll": For.Unrolled,
+    "parallel": For.Parallel,
+    "vectorize": For.Vectorized,
+    "const_range": (For.Unrolled,),
 }
 
 
@@ -42,26 +42,25 @@ def _range(annotation, args):
     """Handling TVM loop types"""
     n = args.__len__()
     if n == 1:
-        low, ext = const(0, dtype='int32'), args[0]
+        low, ext = const(0, dtype="int32"), args[0]
     else:
         _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
         low, ext = args[0], args[1]
-    if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype='int32')):
+    if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype="int32")):
         ext = ext - low
     for_type = LOOP_INTRIN[annotation]
     iter_var = None
     return iter_var, low, ext, for_type
 
 
-range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name
+range = unroll = vectorize = parallel = const_range = _range  # pylint: disable=invalid-name
 
 
 def bind(func_id, args):
     """Handling TVM thread binding"""
     _internal_assert(func_id == "bind", "This function cannot be directly invoked!")
     _internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
-    _internal_assert(isinstance(args[0], str), \
-                     "A loop bind's first argument should be a string!")
+    _internal_assert(isinstance(args[0], str), "A loop bind's first argument should be a string!")
     low, ext = const(0, "int32"), args[1]
     iter_var = tvm.te.thread_axis((low, ext), args[0])
     for_type = None
@@ -71,9 +70,13 @@ def bind(func_id, args):
 def _math_intrin(func_id, args):
     # pylint: disable=import-outside-toplevel
     from tvm.tir import op
+
     return getattr(op, func_id)(*args)
 
-sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name
+
+sqrt = (
+    log
+) = exp = tanh = sigmoid = power = popcount = round = _math_intrin  # pylint: disable=invalid-name
 
 
 def _min_max(func_id, args):
@@ -81,37 +84,38 @@ def _min_max(func_id, args):
     return getattr(_expr, func_id.title())(args[0], args[1])
 
 
-min = max = _min_max #pylint: disable=invalid-name
+min = max = _min_max  # pylint: disable=invalid-name
 
 
 def _allocate_tensor(func_id, args):
     """Handling TVM tensor allocation.
     You may refer hybrid.intrin.allocate for more details."""
     n = args.__len__()
-    _internal_assert(isinstance(convert(args[0]), Array), \
-                     "allocate's first argument should be a tuple of shape!")
+    _internal_assert(
+        isinstance(convert(args[0]), Array), "allocate's first argument should be a tuple of shape!"
+    )
     shape = args[0]
     for i in shape:
         _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression")
     if n > 1:
-        _internal_assert(isinstance(args[1], str),
-                         "The data type should be an str")
-        _internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
-                         "The data type should be either int or float!")
+        _internal_assert(isinstance(args[1], str), "The data type should be an str")
+        _internal_assert(
+            args[1].startswith("int") or args[1].startswith("float"),
+            "The data type should be either int or float!",
+        )
         dtype = args[1]
     else:
-        dtype = 'float32'
+        dtype = "float32"
     if n > 2:
-        _internal_assert(isinstance(args[2], str), \
-                         "The data scope should be an string")
-        _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
+        _internal_assert(isinstance(args[2], str), "The data scope should be an string")
+        _internal_assert(func_id != "output_tensor", "Output tensor cannot specify scope")
         scope = args[2]
     else:
-        scope = 'global' if func_id != 'output_tensor' else 'output'
+        scope = "global" if func_id != "output_tensor" else "output"
     return (shape, dtype, scope)
 
 
-output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
+output_tensor = allocate = _allocate_tensor  # pylint: disable=invalid-name
 
 
 def len(func_id, args):
@@ -120,19 +124,22 @@ def len(func_id, args):
     _internal_assert(func_id == "len", "This function cannot be directly invoked!")
     try:
         return convert(args[0].__len__())
-    except: #pylint: disable=bare-except
+    except:  # pylint: disable=bare-except
         _internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
         return convert(args[0].shape[0])
 
 
 def _cast(func_id, args):
-    _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
-                     "Only one expression can be cast")
+    _internal_assert(
+        args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr),
+        "Only one expression can be cast",
+    )
     return _expr.Cast(func_id, args[0])
 
-float16 = float32 = float64 = _cast #pylint: disable=invalid-name
-int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
-uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name
+
+float16 = float32 = float64 = _cast  # pylint: disable=invalid-name
+int8 = int16 = int32 = int64 = _cast  # pylint: disable=invalid-name
+uint8 = uint16 = uint32 = uint64 = _cast  # pylint: disable=invalid-name
 
 
 def ceil_div(func_id, args):
@@ -145,13 +152,13 @@ def ceil_div(func_id, args):
 
 
 def likely(func_id, args):
-    _internal_assert(args.__len__() == 1, \
-                     "Only one expression can be likely")
+    _internal_assert(args.__len__() == 1, "Only one expression can be likely")
     _internal_assert(func_id == "likely", "This function cannot be directly invoked!")
-    return call_intrin(args[0].dtype, 'tir.likely', *args)
+    return call_intrin(args[0].dtype, "tir.likely", *args)
 
 
 def max_num_threads(func_id, args):
+    """Set the maximum number of threads."""
     _internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
     _internal_assert(args.__len__() <= 1, "At most one argument accepted!")
     if args.__len__() == 0:
index 48b483e..672089c 100644 (file)
@@ -35,7 +35,6 @@ class HybridModule(object):
     lowered. This contradicts to the fact that Hybrid Module is originally a text
     format for Phase 0 HalideIR. Thus, a totally separated module is defined."""
 
-
     def __init__(self, src=None, name=None):
         """The constructor of this a hybrid module
 
@@ -51,31 +50,27 @@ class HybridModule(object):
         if src is not None:
             temp = util.tempdir()
             dst = temp.relpath("script.py")
-            with open(dst, 'w') as f:
+            with open(dst, "w") as f:
                 f.write("import tvm\n@tvm.te.hybrid.script\n%s" % src)
 
             if name is not None:
                 self.name = name
             self.load(dst)
 
-
     def __call__(self, *args):
         if _is_tvm_arg_types(args):
             return source_to_op(self.root_, args, globals(), {})
         return self.func_(*args)
 
-
     def get_source(self):
         return self.src_
 
-
     def save(self, path):
-        if not path.endswith('.py'):
-            path = path + '.py'
-        with open(path, 'w') as f:
+        if not path.endswith(".py"):
+            path = path + ".py"
+        with open(path, "w") as f:
             f.write(self.src_)
 
-
     def load(self, path):
         """Load the module from a python file
 
@@ -84,19 +79,19 @@ class HybridModule(object):
         path : str
             Path to the given python file
         """
-        with open(path, 'r') as f:
+        with open(path, "r") as f:
             self.src_ = f.read()
 
         src = self.src_
 
         class FindFunc(ast.NodeVisitor):
             """ Find the function in module to be loaded module. """
-            #pylint: disable=invalid-name
+
+            # pylint: disable=invalid-name
             def __init__(self):
                 self.name = None
                 self.root = None
 
-
             def visit_FunctionDef(self, node):
                 _internal_assert(self.name is None, "For now, only one function supported!")
                 self.name = node.name
@@ -106,14 +101,13 @@ class HybridModule(object):
         root = ast.parse(src)
         finder = FindFunc()
         finder.visit(root)
-        _internal_assert(finder.name is not None and finder.root is not None, \
-                         "No function found!")
+        _internal_assert(finder.name is not None and finder.root is not None, "No function found!")
         if self.name is None:
             self.name = finder.name
         self.root_ = finder.root
 
         _, local_ = {}, {}
-        exec(self.src_, _, local_) #pylint: disable=exec-used
-        local_.pop('tvm')
+        exec(self.src_, _, local_)  # pylint: disable=exec-used
+        local_.pop("tvm")
         assert len(local_) == 1
         self.func_ = list(local_.values())[0]
index b6f6e51..8704518 100644 (file)
@@ -64,6 +64,7 @@ def visit_list_to_block(visit, lst):
 
 class Symbol(Enum):
     """Enumerates types in the symbol table"""
+
     Callable = 0
     Input = 1
     OutputBuffer = 2
@@ -92,34 +93,27 @@ def _floormod(x, y):
 class HybridParser(ast.NodeVisitor):
     """Python AST visitor pass which finally lowers it to HalideIR"""
 
-
     _binop_maker = {
-        ast.Add     : operator.add,
-        ast.Sub     : operator.sub,
-        ast.Mult    : operator.mul,
-        ast.Div     : operator.div if sys.version_info[0] == 2 else operator.truediv,
+        ast.Add: operator.add,
+        ast.Sub: operator.sub,
+        ast.Mult: operator.mul,
+        ast.Div: operator.div if sys.version_info[0] == 2 else operator.truediv,
         ast.FloorDiv: _floordiv,
-        ast.Mod     : _floormod,
-        ast.BitOr   : operator.or_,
-        ast.BitAnd  : operator.and_,
-        ast.BitXor  : operator.xor,
-        ast.Gt      : operator.gt,
-        ast.GtE     : operator.ge,
-        ast.Lt      : operator.lt,
-        ast.LtE     : operator.le,
-        ast.Eq      : operator.eq,
-        ast.NotEq   : operator.ne,
-        ast.And     : _all,
-        ast.Or      : _any,
-    }
-
-
-    _unaryop_maker = {
-        ast.USub   : operator.neg,
-        ast.Invert : operator.invert,
-        ast.Not    : operator.not_
+        ast.Mod: _floormod,
+        ast.BitOr: operator.or_,
+        ast.BitAnd: operator.and_,
+        ast.BitXor: operator.xor,
+        ast.Gt: operator.gt,
+        ast.GtE: operator.ge,
+        ast.Lt: operator.lt,
+        ast.LtE: operator.le,
+        ast.Eq: operator.eq,
+        ast.NotEq: operator.ne,
+        ast.And: _all,
+        ast.Or: _any,
     }
 
+    _unaryop_maker = {ast.USub: operator.neg, ast.Invert: operator.invert, ast.Not: operator.not_}
 
     def __init__(self, args, usage, symbols, closure_vars, func_name=None):
         """
@@ -146,31 +140,31 @@ class HybridParser(ast.NodeVisitor):
         self.args = list(args)
         self.usage = usage.copy()
 
-        self.symbols = {} # Symbol table
+        self.symbols = {}  # Symbol table
         for k, v in symbols.items():
             if isinstance(v, types.FunctionType):
                 self.add_symbol(k, Symbol.Callable, v)
 
         self.closure_vars = closure_vars
 
-        self.binds = {} # Thread binds
-        self.device = 0 # Is it generating device
+        self.binds = {}  # Thread binds
+        self.device = 0  # Is it generating device
 
-        self.func_name = func_name # The name of the function to be lowered
-        self.outputs = [] # Output tensors' name
-        self.side_effect = set() # Tensors with side effects
-        self.parsed_body = None # The parsed HalideIR body
+        self.func_name = func_name  # The name of the function to be lowered
+        self.outputs = []  # Output tensors' name
+        self.side_effect = set()  # Tensors with side effects
+        self.parsed_body = None  # The parsed HalideIR body
         self.analyzer = tvm.arith.Analyzer()
-        self.returned = False # If this function has a valid return
-
+        self.returned = False  # If this function has a valid return
 
-    def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
+    def add_symbol(self, key, ty, val):  # pylint: disable=invalid-name
         """Add value to the symbol table context"""
         if key in self.symbols.keys():
             old = str(self.symbols[key])
             new = str((ty, val))
-            _internal_assert(False,
-                             "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))
+            _internal_assert(
+                False, "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new)
+            )
 
         self.symbols[key] = ty, val
 
@@ -179,11 +173,12 @@ class HybridParser(ast.NodeVisitor):
                 self.binds[val.var.name] = val
                 return
             val_ = self.binds[val.var.name]
-            _internal_assert(tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent),
-                             "Thread extents should be uniform!")
+            _internal_assert(
+                tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent),
+                "Thread extents should be uniform!",
+            )
             self.symbols[key] = ty, val_
 
-
     def wrap_up_realize(self, node, body):
         """Wrap up all the variables which will no longer be used"""
         to_pop = []
@@ -196,67 +191,65 @@ class HybridParser(ast.NodeVisitor):
                 continue
             _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
 
-            ty, entry = self.symbols[key] #pylint: disable=invalid-name
+            ty, entry = self.symbols[key]  # pylint: disable=invalid-name
             if ty in [Symbol.Input, Symbol.OutputBuffer]:
                 continue
-            if 'Buffer' in ty.name:
+            if "Buffer" in ty.name:
                 _buf = entry
-                _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
+                _scope = "global" if ty is Symbol.BufferVar else ty.name[:-6].lower()
                 to_pop.append(key)
             else:
                 continue
 
-            if _scope == 'global':
+            if _scope == "global":
                 body = self.wrap_up_binds(body)
 
             _domain = [Range.from_min_extent(0, i) for i in _buf.shape]
             _dtype = _buf.dtype
             _true = tvm.runtime.convert(True)
             body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
-            body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body)
+            body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body)
 
         for elem in to_pop:
             self.symbols.pop(elem)
 
         return body
 
-
     def wrap_up_binds(self, body):
         for _, iter_var in self.binds.items():
             ext = iter_var.dom.extent
-            body = tvm.tir.AttrStmt(iter_var, 'thread_extent', ext, body)
+            body = tvm.tir.AttrStmt(iter_var, "thread_extent", ext, body)
         self.binds = {}
         return body
 
-
-    #pylint: disable=invalid-name, missing-docstring
+    # pylint: disable=invalid-name, missing-docstring
     def visit_Module(self, node):
-        _internal_assert(len(node.body) == 1, \
-                         "Only one-function source code will be fed to this parser!")
+        _internal_assert(
+            len(node.body) == 1, "Only one-function source code will be fed to this parser!"
+        )
         return self.visit(node.body[0])
 
-
     def visit_FunctionDef(self, node):
-        _internal_assert(len(node.args.args) == len(self.args), \
-                         "The number of arguments passed to the \
-                         function should be the same as it is defined!")
+        _internal_assert(
+            len(node.args.args) == len(self.args),
+            "The number of arguments passed to the \
+                         function should be the same as it is defined!",
+        )
         if self.func_name is None:
             self.func_name = node.name
         for idx, arg in enumerate(node.args.args):
-            _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
+            _attr = "id" if sys.version_info[0] < 3 else "arg"  # To make py2 and 3 compatible
             self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx])
         res = visit_list_to_block(self.visit, node.body)
         res = self.wrap_up_realize(node, res)
         return self.wrap_up_binds(res)
 
-
     def visit_Expr(self, node):
         return self.visit(node.value)
 
-
     def visit_Name(self, node):
         name = node.id
-        if sys.version_info[0] == 2 and name in ['True', 'False']:
+        if sys.version_info[0] == 2 and name in ["True", "False"]:
             return tvm.runtime.convert(ast.literal_eval(name))
 
         if name in self.closure_vars:
@@ -272,28 +265,26 @@ class HybridParser(ast.NodeVisitor):
             return entry if isinstance(node.ctx, ast.Load) else None
         if ty is Symbol.BufferVar:
             if isinstance(node.ctx, ast.Load):
-                return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')])
-            return entry, [tvm.runtime.const(0, 'int32')]
+                return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, "int32")])
+            return entry, [tvm.runtime.const(0, "int32")]
         # Do I need any assertion here?
         return entry
 
-
     def visit_Num(self, node):
         if isinstance(node.n, numbers.Integral):
             dtype = "int32"
         elif isinstance(node.n, float):
             dtype = "float32"
         else:
-            _internal_assert(isinstance(node.n, bool),
-                             "The data type should be one of (int, float, bool)")
+            _internal_assert(
+                isinstance(node.n, bool), "The data type should be one of (int, float, bool)"
+            )
             dtype = "bool"
         return tvm.runtime.const(node.n, dtype)
 
-
     def visit_NameConstant(self, node):
         return tvm.runtime.convert(node.value)
 
-
     def visit_AugAssign(self, node):
         buf = self.visit(node.target)
         rhs = self.visit(node.value)
@@ -301,7 +292,7 @@ class HybridParser(ast.NodeVisitor):
             _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
             buf, args = buf
         else:
-            args = [tvm.runtime.const(0, 'int32')]
+            args = [tvm.runtime.const(0, "int32")]
         _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
 
         read = tvm.tir.ProducerLoad(buf, args)
@@ -309,16 +300,18 @@ class HybridParser(ast.NodeVisitor):
 
         return tvm.tir.ProducerStore(buf, value, args)
 
-
     def visit_Assign(self, node):
         rhs = self.visit(node.value)
         if isinstance(rhs, Operation):
             rmap = {}
-            _internal_assert(len(node.targets) == rhs.num_outputs, \
-                             "Unable to detuple the outs to targets")
+            _internal_assert(
+                len(node.targets) == rhs.num_outputs, "Unable to detuple the outs to targets"
+            )
             for i in range(rhs.num_outputs):
-                _internal_assert(isinstance(node.targets[i], ast.Name),
-                                 "You should bind a pure name to the tensors")
+                _internal_assert(
+                    isinstance(node.targets[i], ast.Name),
+                    "You should bind a pure name to the tensors",
+                )
                 self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
                 rmap[rhs.outputs[i].op] = rhs.output(i)
             return util.replace_io(rhs.body, rmap)
@@ -328,32 +321,35 @@ class HybridParser(ast.NodeVisitor):
         if isinstance(rhs, _expr.PrimExpr):
             rhs = self.analyzer.simplify(rhs)
         if isinstance(lhs, ast.Name):
-            #TODO: support defined intermediate buffer later
+            # TODO: support defined intermediate buffer later
             lhs_ = lhs
             lhs = lhs.id
             if lhs in self.symbols.keys():
                 ty, _ = self.symbols[lhs]
-                _internal_assert(ty != Symbol.LoopVar, \
-                                 "Loop variable cannot be overwritten!")
+                _internal_assert(ty != Symbol.LoopVar, "Loop variable cannot be overwritten!")
             decl, _, rw = self.usage[lhs]
             if decl == lhs_:
-                _internal_assert(lhs not in self.symbols.keys(),
-                                 "This value should not be defined before this point!")
+                _internal_assert(
+                    lhs not in self.symbols.keys(),
+                    "This value should not be defined before this point!",
+                )
                 if isinstance(rhs, tuple):
                     shape, dtype, scope = rhs
                     ph = tvm.te.placeholder(shape, dtype=dtype, name=lhs)
                     self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
-                    if scope == 'output':
+                    if scope == "output":
                         self.outputs.append(lhs)
                     return util.make_nop()
                 if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
                     self.add_symbol(lhs, Symbol.ConstVar, rhs)
                 else:
-                    _internal_assert(self.device == 0,
-                                     "Single variable not supported in devices' side!\n" + \
-                                     "If you are using GPU, please allocate a 'local' spad " + \
-                                     "outside the bind body")
-                    ph = tvm.te.placeholder((1, ), dtype=rhs.dtype, name=lhs)
+                    _internal_assert(
+                        self.device == 0,
+                        "Single variable not supported in devices' side!\n"
+                        + "If you are using GPU, please allocate a 'local' spad "
+                        + "outside the bind body",
+                    )
+                    ph = tvm.te.placeholder((1,), dtype=rhs.dtype, name=lhs)
                     self.add_symbol(lhs, Symbol.BufferVar, ph)
             lhs = self.visit(lhs_)
             if lhs is not None:
@@ -362,18 +358,17 @@ class HybridParser(ast.NodeVisitor):
             return util.make_nop()
 
         lhs, args = self.visit(lhs)
-        _internal_assert(isinstance(lhs, Tensor), \
-                         "An array access's LHS is expected to be a expr.Call!")
+        _internal_assert(
+            isinstance(lhs, Tensor), "An array access's LHS is expected to be a expr.Call!"
+        )
         res = tvm.tir.ProducerStore(lhs, rhs, args)
         return res
 
-
     def visit_Index(self, node):
         if isinstance(node.value, ast.Tuple):
             return self.visit(node.value)
         return [self.visit(node.value)]
 
-
     def visit_Attribute(self, node):
         buf = self.visit(node.value)
         return getattr(buf, node.attr)
@@ -386,8 +381,9 @@ class HybridParser(ast.NodeVisitor):
                 if isinstance(i, numbers.Integral):
                     arr = arr[i]
                 else:
-                    _internal_assert(isinstance(i, (_expr.IntImm,)), \
-                                     "All indices are supposed to be constants")
+                    _internal_assert(
+                        isinstance(i, (_expr.IntImm,)), "All indices are supposed to be constants"
+                    )
                     arr = arr[i.value]
             return arr
         if isinstance(node.ctx, ast.Load):
@@ -407,7 +403,6 @@ class HybridParser(ast.NodeVisitor):
         self.annotation[option.id] = context.func.id
         return visit_list_to_block(self.visit, node.body)
 
-
     def visit_If(self, node):
         cond = self.analyzer.simplify(self.visit(node.test))
 
@@ -427,17 +422,14 @@ class HybridParser(ast.NodeVisitor):
             else_body = None
         return tvm.tir.IfThenElse(cond, if_body, else_body)
 
-
     def visit_IfExp(self, node):
         cond = self.visit(node.test)
         if_body = self.visit(node.body)
         else_body = self.visit(node.orelse)
         return tvm.tir.Select(cond, if_body, else_body)
 
-
     def visit_Compare(self, node):
-        _internal_assert(len(node.ops) == len(node.comparators),
-                         "#compare ops != #comparators")
+        _internal_assert(len(node.ops) == len(node.comparators), "#compare ops != #comparators")
         ops = [self.visit(node.left)]
         ops += [self.visit(i) for i in node.comparators]
         res = []
@@ -447,34 +439,29 @@ class HybridParser(ast.NodeVisitor):
             res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs))
         return _all(*res)
 
-
     def visit_BoolOp(self, node):
         n = len(node.values)
         if n == 1:
-            _internal_assert(isinstance(node.op, ast.Not), \
-                             "Unary is supposed to be not!")
+            _internal_assert(isinstance(node.op, ast.Not), "Unary is supposed to be not!")
             return operator.not_(self.visit(node.values[0]))
-        _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
-                         "Binary is supposed to be and/or!")
+        _internal_assert(isinstance(node.op, (ast.And, ast.Or)), "Binary is supposed to be and/or!")
         values = [self.visit(i) for i in node.values]
         return HybridParser._binop_maker[type(node.op)](*values)
 
-
     def visit_UnaryOp(self, node):
         operand = self.visit(node.operand)
         return HybridParser._unaryop_maker[type(node.op)](operand)
 
-
     def visit_BinOp(self, node):
         lhs = self.visit(node.left)
         rhs = self.visit(node.right)
         return HybridParser._binop_maker[type(node.op)](lhs, rhs)
 
-
     def visit_Call(self, node):
         # Yet, no function pointer supported
-        _internal_assert(isinstance(node.func, ast.Name), \
-                         "Only id-function function call is supported so far!")
+        _internal_assert(
+            isinstance(node.func, ast.Name), "Only id-function function call is supported so far!"
+        )
 
         func_id = node.func.id
         args = [self.visit(i) for i in node.args]
@@ -482,35 +469,37 @@ class HybridParser(ast.NodeVisitor):
         if hasattr(calls, func_id):
             return getattr(calls, func_id)(func_id, args)
         # Contexts'
-        _internal_assert(func_id in self.symbols.keys(), \
-                         "The function called (%s) is not in the context either!" % func_id)
+        _internal_assert(
+            func_id in self.symbols.keys(),
+            "The function called (%s) is not in the context either!" % func_id,
+        )
         ty, entry = self.symbols[func_id]
-        _internal_assert(ty is Symbol.Callable, \
-                         "Are you sure what you call is a function?!")
+        _internal_assert(ty is Symbol.Callable, "Are you sure what you call is a function?!")
         outs = entry(*args)
         op = outs.op if isinstance(outs, Tensor) else outs[0].op
         return op
 
-
     def visit_For(self, node):
         iter_var, low, ext, for_type = self.visit(node.iter)
-        _internal_assert(isinstance(node.target, ast.Name), \
-                         "The loop iterator should be a variable!")
+        _internal_assert(
+            isinstance(node.target, ast.Name), "The loop iterator should be a variable!"
+        )
 
         _name = node.target.id
 
         if isinstance(for_type, tuple):
             low = self.analyzer.simplify(low)
             ext = self.analyzer.simplify(ext)
-            _internal_assert(isinstance(low, _expr.ConstExpr) and
-                             isinstance(ext, _expr.ConstExpr), \
-                             "Const range should start from a const " + \
-                             "and iterate const times")
+            _internal_assert(
+                isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr),
+                "Const range should start from a const " + "and iterate const times",
+            )
 
             low, ext = low.value, ext.value
             if ext > 114514:
-                logging.log(logging.CRITICAL, \
-                            '[Warning] Are you sure to unroll a large loop in Python?')
+                logging.log(
+                    logging.CRITICAL, "[Warning] Are you sure to unroll a large loop in Python?"
+                )
 
             bodies = []
             for i in range(low, low + ext):
@@ -524,7 +513,7 @@ class HybridParser(ast.NodeVisitor):
         if iter_var is None:
             _internal_assert(for_type is not None, "The loop iterating function parse error!")
             offset = iter_var = tvm.te.var(_name)
-            if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, 'int32')):
+            if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")):
                 offset = iter_var + low
             self.add_symbol(_name, Symbol.LoopVar, offset)
             _body = visit_list_to_block(self.visit, node.body)
@@ -540,42 +529,44 @@ class HybridParser(ast.NodeVisitor):
         if for_type is None:
             res = _body
         else:
-            _internal_assert(not isinstance(for_type, tuple), \
-                            "Micro expansion should be handled before!")
-            res = tvm.tir.For(iter_var, tvm.runtime.const(0, 'int32'), ext, for_type, 0, _body)
+            _internal_assert(
+                not isinstance(for_type, tuple), "Micro expansion should be handled before!"
+            )
+            res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, 0, _body)
 
         self.symbols.pop(_name)
         return res
 
-
     def visit_Return(self, node):
-        _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \
-                         "Return should not be in a loop body!")
+        _internal_assert(
+            all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()),
+            "Return should not be in a loop body!",
+        )
         ids = []
         if isinstance(node.value, ast.Name):
             ids = [node.value.id]
         else:
-            _internal_assert(isinstance(node.value, ast.Tuple), \
-                             "You should return either a single tensor or a tuple")
-            _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \
-                             "What do you return?")
+            _internal_assert(
+                isinstance(node.value, ast.Tuple),
+                "You should return either a single tensor or a tuple",
+            )
+            _internal_assert(
+                all(isinstance(i, ast.Name) for i in node.value.elts), "What do you return?"
+            )
             ids = [i.id for i in node.value.elts]
         _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
         if len(ids) < len(self.outputs):
-            logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
+            logging.log(logging.CRITICAL, "[Warning] Not all the output buffers returned!")
         self.outputs = [self.symbols[i][1] for i in ids]
         self.returned = True
         return util.make_nop()
 
-
     def visit_Tuple(self, node):
         return tuple(self.visit(i) for i in node.elts)
 
-
     def visit_Str(self, node):
         return node.s
 
-
     def visit_Assert(self, node):
         test = self.visit(node.test)
         mesg = tvm.runtime.convert(self.visit(node.msg))
@@ -612,7 +603,7 @@ def parse_python(src, args, symbols, closure_vars):
     var_usage = determine_variable_usage(root, args, symbols, closure_vars)
     parser = HybridParser(args, var_usage, symbols, closure_vars)
     parser.parsed_body = parser.visit(root)
-    _internal_assert(parser.returned, 'No valid return found in the function body!')
+    _internal_assert(parser.returned, "No valid return found in the function body!")
     return parser
 
 
@@ -644,6 +635,7 @@ def source_to_op(src, args, symbols, closure_vars):
     parser = parse_python(src, args, symbols, closure_vars)
 
     input_tensors = []
+
     def get_input_tensors(arg):
         if isinstance(arg, Tensor):
             input_tensors.append(arg)
@@ -653,7 +645,8 @@ def source_to_op(src, args, symbols, closure_vars):
 
     for i in args:
         get_input_tensors(i)
-    op = tvm.te._ffi_api.HybridOp(parser.func_name, "HybridOp", None, input_tensors,
-                                  parser.outputs, parser.parsed_body)
+    op = tvm.te._ffi_api.HybridOp(
+        parser.func_name, "HybridOp", None, input_tensors, parser.outputs, parser.parsed_body
+    )
     res = [op.output(i) for i in range(len(parser.outputs))]
     return res[0] if len(res) == 1 else res
index 035e8a4..b046231 100644 (file)
@@ -24,8 +24,9 @@ from .util import _internal_assert
 
 class PyVariableUsage(ast.NodeVisitor):
     """The vistor class to determine the declaration, r/w status, and last use of each variable"""
-    #pylint: disable=invalid-name
-    #pylint: disable=missing-docstring
+
+    # pylint: disable=invalid-name
+    # pylint: disable=missing-docstring
     def __init__(self, args, symbols, closure_vars):
         self.status = {}
         self.scope_level = []
@@ -37,46 +38,46 @@ class PyVariableUsage(ast.NodeVisitor):
 
     def visit_FunctionDef(self, node):
         self.scope_level.append(node)
-        _internal_assert(len(node.args.args) == len(self.args), \
-                '#arguments passed should be the same as #arguments defined')
+        _internal_assert(
+            len(node.args.args) == len(self.args),
+            "#arguments passed should be the same as #arguments defined",
+        )
         for idx, arg in enumerate(node.args.args):
-            _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
+            _attr = "id" if sys.version_info[0] < 3 else "arg"  # To make py2 and 3 compatible
             self._args[getattr(arg, _attr)] = self.args[idx]
         for i in node.body:
             self.visit(i)
 
-
     def visit_For(self, node):
-        _internal_assert(isinstance(node.target, ast.Name), \
-                "For's iterator should be an id")
+        _internal_assert(isinstance(node.target, ast.Name), "For's iterator should be an id")
         self.visit(node.iter)
         self.scope_level.append(node)
         for i in node.body:
             self.visit(i)
         self.scope_level.pop()
 
-
     def visit_Call(self, node):
-        #No function pointer supported so far
+        # No function pointer supported so far
         _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
         func_id = node.func.id
-        _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
-                         ['range', 'max', 'min', 'len'] + \
-                         list(self.symbols.keys()), \
-                         "Function call id " + func_id + " not in intrinsics' list")
+        _internal_assert(
+            func_id
+            in list(HYBRID_GLOBALS.keys())
+            + ["range", "max", "min", "len"]
+            + list(self.symbols.keys()),
+            "Function call id " + func_id + " not in intrinsics' list",
+        )
         for elem in node.args:
             self.visit(elem)
 
-
     def visit_AugAssign(self, node):
         self.aug_assign_ = True
         self.generic_visit(node)
         self.aug_assign_ = False
 
-
     def visit_Name(self, node):
         # If it is True or False, we do not worry about it!
-        if sys.version_info[0] == 2 and node.id in ['True', 'False']:
+        if sys.version_info[0] == 2 and node.id in ["True", "False"]:
             return
         # If it is from the argument list or loop variable, we do not worry about it!
         if node.id in self._args.keys():
@@ -85,8 +86,10 @@ class PyVariableUsage(ast.NodeVisitor):
         if node.id in fors:
             return
         # The loop variable cannot be overwritten when iteration
-        _internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \
-                         "Iter var cannot be overwritten")
+        _internal_assert(
+            not isinstance(node.ctx, ast.Store) or node.id not in fors,
+            "Iter var cannot be overwritten",
+        )
 
         if node.id not in self.status.keys():
             # It is a captured value in closure
@@ -97,16 +100,16 @@ class PyVariableUsage(ast.NodeVisitor):
                     raise ValueError("Only support capturing constant values in closure")
                 return
 
-            _internal_assert(isinstance(node.ctx, ast.Store), \
-                             'Undeclared variable %s' % node.id)
+            _internal_assert(isinstance(node.ctx, ast.Store), "Undeclared variable %s" % node.id)
             if self.aug_assign_:
                 raise ValueError('"First store" cannot be an AugAssign')
             self.status[node.id] = (node, self.scope_level[-1], set())
         else:
             decl, loop, usage = self.status[node.id]
             usage.add(type(node.ctx))
-            _internal_assert(loop in self.scope_level,
-                             "%s is used out of the scope it is defined!" % node.id)
+            _internal_assert(
+                loop in self.scope_level, "%s is used out of the scope it is defined!" % node.id
+            )
             self.status[node.id] = (decl, loop, usage)
 
 
index 0f5a2b4..7b90f87 100644 (file)
@@ -33,7 +33,7 @@ class bind(object):  # pylint: disable=invalid-name
             i += 1
 
 
-def allocate(shape, dtype='float32', scope='global'):  # pylint: disable=unused-argument
+def allocate(shape, dtype="float32", scope="global"):  # pylint: disable=unused-argument
     """Allocate a buffer with given shape
 
     Parameters
@@ -112,36 +112,36 @@ def max_num_threads(allow_none=True):
 
 
 HYBRID_GLOBALS = {
-    'unroll': range,
-    'vectorize': range,
-    'parallel': range,
-    'const_range': range,
-    'bind': bind,
-    'allocate': allocate,
-    'output_tensor': allocate,
-    'sqrt': numpy.sqrt,
-    'rsqrt': rsqrt,
-    'log': numpy.log,
-    'tanh': numpy.tanh,
-    'power': numpy.power,
-    'exp': numpy.exp,
-    'sigmoid': sigmoid,
-    'popcount': popcount,
-    'round': round,
-    'likely': lambda cond: cond,
-    'uint8': numpy.uint8,
-    'uint16': numpy.uint16,
-    'uint32': numpy.uint32,
-    'uint64': numpy.uint64,
-    'int8': numpy.int8,
-    'int16': numpy.int16,
-    'int32': numpy.int32,
-    'int64': numpy.int64,
-    'float16': numpy.float16,
-    'float32': numpy.float32,
-    'float64': numpy.float64,
-    'ceil_div': lambda a, b: (a + b - 1) // b,
-    'max_num_threads': max_num_threads
+    "unroll": range,
+    "vectorize": range,
+    "parallel": range,
+    "const_range": range,
+    "bind": bind,
+    "allocate": allocate,
+    "output_tensor": allocate,
+    "sqrt": numpy.sqrt,
+    "rsqrt": rsqrt,
+    "log": numpy.log,
+    "tanh": numpy.tanh,
+    "power": numpy.power,
+    "exp": numpy.exp,
+    "sigmoid": sigmoid,
+    "popcount": popcount,
+    "round": round,
+    "likely": lambda cond: cond,
+    "uint8": numpy.uint8,
+    "uint16": numpy.uint16,
+    "uint32": numpy.uint32,
+    "uint64": numpy.uint64,
+    "int8": numpy.int8,
+    "int16": numpy.int16,
+    "int32": numpy.int32,
+    "int64": numpy.int64,
+    "float16": numpy.float16,
+    "float32": numpy.float32,
+    "float64": numpy.float64,
+    "ceil_div": lambda a, b: (a + b - 1) // b,
+    "max_num_threads": max_num_threads,
 }
 
 
index 213a48e..4560518 100644 (file)
@@ -31,7 +31,7 @@ from tvm.tir import stmt as _stmt
 from tvm.te.tensor import Tensor
 
 
-#pylint: disable=invalid-name
+# pylint: disable=invalid-name
 np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
 tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
 halide_imm_types = (_expr.IntImm, _expr.FloatImm)
@@ -46,7 +46,7 @@ def _internal_assert(cond, err):
 # Useful constants. In avoid of runtime dependences, we use function calls to return them.
 def make_nop():
     """Returns a 'no operation' node in HalideIR."""
-    return _stmt.Evaluate(tvm.runtime.const(0, dtype='int32'))
+    return _stmt.Evaluate(tvm.runtime.const(0, dtype="int32"))
 
 
 def is_docstring(node):
@@ -57,15 +57,16 @@ def is_docstring(node):
 def _pruned_source(func):
     """Prune source code's extra leading spaces"""
     try:
-        lines = inspect.getsource(func).split('\n')
-        leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
+        lines = inspect.getsource(func).split("\n")
+        leading_space = len(lines[0]) - len(lines[0].lstrip(" "))
         lines = [line[leading_space:] for line in lines]
-        return '\n'.join(lines)
+        return "\n".join(lines)
     except IOError as err:
-        if sys.version_info[0] == 2 and str(err) == 'could not get source code':
-            logging.log(logging.CRITICAL, \
-                        'This module is not fully operated under Python2... ' \
-                        'Please move to Python3!')
+        if sys.version_info[0] == 2 and str(err) == "could not get source code":
+            logging.log(
+                logging.CRITICAL,
+                "This module is not fully operated under Python2... " "Please move to Python3!",
+            )
             raise err
 
 
@@ -78,12 +79,12 @@ def replace_io(body, rmap):
         if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys():
             buf = rmap[op.producer.op]
             return _stmt.ProducerStore(buf, op.value, op.indices)
-        if isinstance(op, _expr.ProducerLoad) and  op.producer.op in rmap.keys():
+        if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys():
             buf = rmap[op.producer.op]
             return _expr.ProducerLoad(buf, op.indices)
         return None
 
-    return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad'])
+    return stmt_functor.ir_transform(body, None, replace, ["tir.ProducerStore", "tir.ProducerLoad"])
 
 
 def _is_tvm_arg_types(args):
@@ -91,14 +92,17 @@ def _is_tvm_arg_types(args):
     If neither is true, raise a value error."""
     if isinstance(args[0], tvm_arg_types):
         for elem in args[1:]:
-            _internal_assert(isinstance(elem, tvm_arg_types),
-                             "Expecting a Var, Tensor or ConstExpr instance but %s get!" \
-                             % str(type(elem)))
+            _internal_assert(
+                isinstance(elem, tvm_arg_types),
+                "Expecting a Var, Tensor or ConstExpr instance but %s get!" % str(type(elem)),
+            )
         return True
 
-    _internal_assert(isinstance(args[0], np_arg_types), \
-                     "Expect a numpy type but %s get!" % str(type(args[0])))
+    _internal_assert(
+        isinstance(args[0], np_arg_types), "Expect a numpy type but %s get!" % str(type(args[0]))
+    )
     for elem in args[1:]:
-        _internal_assert(isinstance(elem, np_arg_types), \
-                         "Expect a numpy type but %s get!" % str(type(elem)))
+        _internal_assert(
+            isinstance(elem, np_arg_types), "Expect a numpy type but %s get!" % str(type(elem))
+        )
     return False
index 168265f..30d0df3 100644 (file)
@@ -51,8 +51,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
     """
     shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
     dtype = "float32" if dtype is None else dtype
-    return _ffi_api.Placeholder(
-        shape, dtype, name)
+    return _ffi_api.Placeholder(shape, dtype, name)
 
 
 def compute(shape, fcompute, name="compute", tag="", attrs=None):
@@ -96,7 +95,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
     if code.co_argcount == 0:
         arg_names = ["i%d" % i for i in range(ndim)]
     else:
-        arg_names = code.co_varnames[:code.co_argcount]
+        arg_names = code.co_varnames[: code.co_argcount]
         out_ndim = code.co_argcount
 
     if out_ndim != len(arg_names):
@@ -109,21 +108,22 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
         for i, s in enumerate(shape[out_ndim:]):
             var_name = "ax" + str(i)
             dim_var.append(tvm.tir.IterVar((0, s), var_name, 4))
-        op_node = _ffi_api.TensorComputeOp(name,
-                                           tag,
-                                           dim_var,
-                                           body.reduce_axis,
-                                           out_ndim,
-                                           body.intrin,
-                                           body.tensors,
-                                           body.regions,
-                                           body.scalar_inputs)
+        op_node = _ffi_api.TensorComputeOp(
+            name,
+            tag,
+            dim_var,
+            body.reduce_axis,
+            out_ndim,
+            body.intrin,
+            body.tensors,
+            body.regions,
+            body.scalar_inputs,
+        )
     else:
         if not isinstance(body, (list, tuple)):
             body = [body]
         body = convert(body)
-        op_node = _ffi_api.ComputeOp(
-            name, tag, attrs, dim_var, body)
+        op_node = _ffi_api.ComputeOp(name, tag, attrs, dim_var, body)
 
     num = op_node.num_outputs
     outputs = tuple(op_node.output(i) for i in range(num))
@@ -192,22 +192,22 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
     if len(init) != len(update) or len(init) != len(state_placeholder):
         raise ValueError("init, update, state_placeholder must have same length")
     axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
-    op = _ffi_api.ScanOp(name, tag, attrs,
-                         axis, init, update,
-                         state_placeholder, inputs)
+    op = _ffi_api.ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs)
     res = [op.output(i) for i in range(len(update))]
     return res[0] if len(res) == 1 else res
 
 
-def extern(shape,
-           inputs,
-           fcompute,
-           name="extern",
-           dtype=None,
-           in_buffers=None,
-           out_buffers=None,
-           tag="",
-           attrs=None):
+def extern(
+    shape,
+    inputs,
+    fcompute,
+    name="extern",
+    dtype=None,
+    in_buffers=None,
+    out_buffers=None,
+    tag="",
+    attrs=None,
+):
     """Compute several tensors via an extern function.
 
     Parameters
@@ -281,13 +281,17 @@ def extern(shape,
     if in_buffers is not None:
         in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
         if len(inputs) != len(in_buffers):
-            raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
-                               % (len(inputs), len(in_buffers)))
+            raise RuntimeError(
+                "Number of inputs and in_buffers mismatch: %d vs %d."
+                % (len(inputs), len(in_buffers))
+            )
     if out_buffers is not None:
         out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
         if len(shape) != len(out_buffers):
-            raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
-                               % (len(shape), len(out_buffers)))
+            raise RuntimeError(
+                "Number of outputs and out_buffers mismatch: %d vs %d."
+                % (len(shape), len(out_buffers))
+            )
     input_placeholders = in_buffers or []
     output_placeholders = out_buffers or []
     types = set()
@@ -295,8 +299,7 @@ def extern(shape,
         if not isinstance(t, _tensor.Tensor):
             raise ValueError("expect inputs to be tensor")
         if in_buffers is None:
-            input_placeholders.append(
-                tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name))
+            input_placeholders.append(tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name))
         types.add(t.dtype)
 
     if dtype is None:
@@ -316,9 +319,7 @@ def extern(shape,
     if not isinstance(body, tvm.tir.Stmt):
         raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__))
 
-    op = _ffi_api.ExternOp(name, tag, attrs,
-                           inputs, input_placeholders,
-                           output_placeholders, body)
+    op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
     res = [op.output(i) for i in range(len(output_placeholders))]
     return res[0] if len(res) == 1 else res
 
index b611954..7bd7dce 100644 (file)
@@ -63,6 +63,7 @@ def create_schedule(ops):
 @tvm._ffi.register_object
 class Schedule(Object):
     """Schedule for all the stages."""
+
     def __getitem__(self, k):
         if isinstance(k, _tensor.Tensor):
             k = k.op
@@ -112,8 +113,7 @@ class Schedule(Object):
             outputs = [outputs]
         if isinstance(inputs, _tensor.Tensor):
             inputs = [inputs]
-        return _ffi_api.ScheduleCreateGroup(
-            self, outputs, inputs, include_inputs)
+        return _ffi_api.ScheduleCreateGroup(self, outputs, inputs, include_inputs)
 
     def cache_read(self, tensor, scope, readers):
         """Create a cache read of original tensor for readers.
@@ -170,7 +170,7 @@ class Schedule(Object):
         return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
 
     def rfactor(self, tensor, axis, factor_axis=0):
-        """ Factor a reduction axis in tensor's schedule to be an explicit axis.
+        """Factor a reduction axis in tensor's schedule to be an explicit axis.
 
         This will create a new stage that generated the new tensor with axis
         as the first dimension. The tensor's body will be rewritten as a reduction
@@ -197,6 +197,7 @@ class Schedule(Object):
 @tvm._ffi.register_object
 class Stage(Object):
     """A Stage represents schedule for one operation."""
+
     def split(self, parent, factor=None, nparts=None):
         """Split the stage either by factor providing outer scope, or both
 
@@ -340,7 +341,7 @@ class Stage(Object):
         _ffi_api.StageReorder(self, args)
 
     def tile(self, x_parent, y_parent, x_factor, y_factor):
-        """ Perform tiling on two dimensions
+        """Perform tiling on two dimensions
 
         The final loop order from outmost to inner most are
         [x_outer, y_outer, x_inner, y_inner]
@@ -368,7 +369,8 @@ class Stage(Object):
             Inner axis of y dimension
         """
         x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
-            self, x_parent, y_parent, x_factor, y_factor)
+            self, x_parent, y_parent, x_factor, y_factor
+        )
         return x_outer, y_outer, x_inner, y_inner
 
     def vectorize(self, var):
@@ -513,6 +515,7 @@ class Stage(Object):
 @tvm._ffi.register_object
 class SpecializedCondition(Object):
     """Specialized condition to enable op specialization."""
+
     def __init__(self, conditions):
         """Create a specialized condition.
 
@@ -529,8 +532,7 @@ class SpecializedCondition(Object):
         """
         if not isinstance(conditions, (list, _container.Array)):
             conditions = [conditions]
-        self.__init_handle_by_constructor__(
-            _ffi_api.CreateSpecializedCondition, conditions)
+        self.__init_handle_by_constructor__(_ffi_api.CreateSpecializedCondition, conditions)
 
     @staticmethod
     def current():
index 487e696..42d2134 100644 (file)
 import warnings
 from tvm._ffi.base import decorate
 
+
 class TagScope(object):
     """Tag scope object to set tag for operators, working as context
     manager and decorator both. See also tag_scope.
     """
+
     _current = None
 
     @classmethod
@@ -52,6 +54,7 @@ class TagScope(object):
         def tagged_fdecl(func, *args, **kwargs):
             with self:
                 return func(*args, **kwargs)
+
         return decorate(fdecl, tagged_fdecl)
 
 
index 7d73bf4..6294eab 100644 (file)
@@ -23,6 +23,7 @@ from tvm.tir import expr as _expr, DataProducer
 
 from . import _ffi_api
 
+
 class TensorSlice(ObjectGeneric, _expr.ExprOp):
     """Auxiliary data structure for enable slicing syntax from tensor."""
 
@@ -46,6 +47,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
         """Data content of the tensor."""
         return self.tensor.dtype
 
+
 @tvm._ffi.register_object
 class TensorIntrinCall(Object):
     """Intermediate structure for calling a tensor intrinsic."""
@@ -71,7 +73,6 @@ class Tensor(DataProducer, _expr.ExprOp):
 
         return _expr.ProducerLoad(self, args)
 
-
     def __getitem__(self, indices):
         return TensorSlice(self, indices)
 
@@ -84,9 +85,11 @@ class Tensor(DataProducer, _expr.ExprOp):
                 return _expr.EqualOp(self, other)
             return False
         if self.ndim == 0 and other.ndim == 0:
-            raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
-                             "use Tensor.equal for content expression equvalence, "
-                             "use Tensor.same_as for exact reference comparison")
+            raise ValueError(
+                "Equal == comparison among rank-0 tensor is ambiguous, "
+                "use Tensor.equal for content expression equvalence, "
+                "use Tensor.same_as for exact reference comparison"
+            )
         return _ffi_api.TensorEqual(self, other)
 
     @property
@@ -159,6 +162,7 @@ class PlaceholderOp(Operation):
 @tvm._ffi.register_object
 class BaseComputeOp(Operation):
     """Compute operation."""
+
     @property
     def axis(self):
         """Represent the IterVar axis, defined when it is a ComputeOp"""
@@ -183,6 +187,7 @@ class TensorComputeOp(BaseComputeOp):
 @tvm._ffi.register_object
 class ScanOp(Operation):
     """Scan operation."""
+
     @property
     def scan_axis(self):
         """Represent the scan axis, only defined when it is a ScanOp"""
@@ -197,6 +202,7 @@ class ExternOp(Operation):
 @tvm._ffi.register_object
 class HybridOp(Operation):
     """Hybrid operation."""
+
     @property
     def axis(self):
         """Represent the IterVar axis, also defined when it is a HybridOp"""
index 7d396ee..79b2db5 100644 (file)
@@ -49,6 +49,7 @@ class TensorIntrin(Object):
     --------
     decl_tensor_intrin: Construct a TensorIntrin
     """
+
     def __call__(self, *args, **kwargs):
         tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
         scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
@@ -64,12 +65,9 @@ class TensorIntrin(Object):
         return _ffi_api.TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
 
 
-def decl_tensor_intrin(op,
-                       fcompute,
-                       name="tensor_intrin",
-                       binds=None,
-                       scalar_params=None,
-                       default_buffer_params=None):
+def decl_tensor_intrin(
+    op, fcompute, name="tensor_intrin", binds=None, scalar_params=None, default_buffer_params=None
+):
     """Declare a tensor intrinsic function.
 
     Parameters
@@ -128,20 +126,21 @@ def decl_tensor_intrin(op,
 
     default_buffer_params = {} if default_buffer_params is None else default_buffer_params
     for t in tensors:
-        buf = (binds[t] if t in binds else
-               tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
-                                   **default_buffer_params))
+        buf = (
+            binds[t]
+            if t in binds
+            else tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, **default_buffer_params)
+        )
         binds_list.append(buf)
 
     if scalar_params:
-        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
+        body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :], scalar_params)
     else:
-        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
+        body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :])
         scalar_params = []
     if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
         body = [body]
     body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
     if len(body) < 3:
         body += [None] * (3 - len(body))
-    return _ffi_api.TensorIntrin(
-        name, op, inputs, binds_list, scalar_params, *body)
+    return _ffi_api.TensorIntrin(name, op, inputs, binds_list, scalar_params, *body)
index 270f37f..20e968b 100644 (file)
@@ -67,7 +67,7 @@ from tvm.contrib import nvcc
 
 
 def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
-    """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set
+    """Version of np.testing.assert_allclose with `atol` and `rtol` fields set
     in reasonable defaults.
 
     Arguments `actual` and `desired` are not interchangable, since the function
@@ -77,8 +77,9 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
     np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
 
 
-def check_numerical_grads(function, input_values, grad_values, function_value=None,
-                          delta=1e-3, atol=1e-2, rtol=0.1):
+def check_numerical_grads(
+    function, input_values, grad_values, function_value=None, delta=1e-3, atol=1e-2, rtol=0.1
+):
     """A helper function that checks that numerical gradients of a function are
     equal to gradients computed in some different way (analytical gradients).
 
@@ -123,6 +124,7 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
 
         def _function(_input_len=input_len, _orig_function=function, **kwargs):
             return _orig_function(*(kwargs[str(i)] for i in range(input_len)))
+
         function = _function
 
         grad_values = {str(idx): val for idx, val in enumerate(grad_values)}
@@ -138,19 +140,22 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
 
     # numerically compute a partial derivative with respect to j-th element of the var `name`
     def derivative(x_name, j, a_delta):
-        modified_values = {n: modify(val, j, a_delta) if n == x_name else val
-                           for n, val in input_values.items()}
-        return (function(**modified_values) - function_value)/a_delta
+        modified_values = {
+            n: modify(val, j, a_delta) if n == x_name else val for n, val in input_values.items()
+        }
+        return (function(**modified_values) - function_value) / a_delta
 
     def compare_derivative(j, n_der, grad):
         der = grad.reshape(-1)[j]
-        return np.abs(n_der - der) < atol + rtol*np.abs(n_der)
+        return np.abs(n_der - der) < atol + rtol * np.abs(n_der)
 
     for x_name, grad in grad_values.items():
         if grad.shape != input_values[x_name].shape:
             raise AssertionError(
-                "Gradient wrt '{}' has unexpected shape {}, expected {} "
-                .format(x_name, grad.shape, input_values[x_name].shape))
+                "Gradient wrt '{}' has unexpected shape {}, expected {} ".format(
+                    x_name, grad.shape, input_values[x_name].shape
+                )
+            )
 
         ngrad = np.zeros_like(grad)
 
@@ -165,13 +170,15 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
             # precise and expensive methods
             if not compare_derivative(j, nder, grad):
                 # central difference approximation
-                nder = (derivative(x_name, j, -delta) + nder)/2
+                nder = (derivative(x_name, j, -delta) + nder) / 2
 
                 if not compare_derivative(j, nder, grad):
                     # central difference approximation using h = delta/2
-                    cnder2 = (derivative(x_name, j, delta/2) + derivative(x_name, j, -delta/2))/2
+                    cnder2 = (
+                        derivative(x_name, j, delta / 2) + derivative(x_name, j, -delta / 2)
+                    ) / 2
                     # five-point derivative
-                    nder = (4*cnder2 - nder)/3
+                    nder = (4 * cnder2 - nder) / 3
 
             # if the derivatives still don't match, add this position to the
             # list of wrong positions
@@ -180,35 +187,51 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
 
             ngrad.reshape(-1)[j] = nder
 
-        wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape))
+        wrong_percentage = int(100 * len(wrong_positions) / np.prod(grad.shape))
 
-        dist = np.sqrt(np.sum((ngrad - grad)**2))
-        grad_norm = np.sqrt(np.sum(ngrad**2))
+        dist = np.sqrt(np.sum((ngrad - grad) ** 2))
+        grad_norm = np.sqrt(np.sum(ngrad ** 2))
 
         if not (np.isfinite(dist) and np.isfinite(grad_norm)):
             raise ValueError(
                 "NaN or infinity detected during numerical gradient checking wrt '{}'\n"
-                "analytical grad = {}\n numerical grad = {}\n"
-                .format(x_name, grad, ngrad))
+                "analytical grad = {}\n numerical grad = {}\n".format(x_name, grad, ngrad)
+            )
 
         # we multiply atol by this number to make it more universal for different sizes
         sqrt_n = np.sqrt(float(np.prod(grad.shape)))
 
-        if dist > atol*sqrt_n + rtol*grad_norm:
+        if dist > atol * sqrt_n + rtol * grad_norm:
             raise AssertionError(
                 "Analytical and numerical grads wrt '{}' differ too much\n"
                 "analytical grad = {}\n numerical grad = {}\n"
                 "{}% of elements differ, first 10 of wrong positions: {}\n"
                 "distance > atol*sqrt(n) + rtol*grad_norm\n"
-                "distance {} > {}*{} + {}*{}"
-                .format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10],
-                        dist, atol, sqrt_n, rtol, grad_norm))
+                "distance {} > {}*{} + {}*{}".format(
+                    x_name,
+                    grad,
+                    ngrad,
+                    wrong_percentage,
+                    wrong_positions[:10],
+                    dist,
+                    atol,
+                    sqrt_n,
+                    rtol,
+                    grad_norm,
+                )
+            )
 
         max_diff = np.max(np.abs(ngrad - grad))
         avg_diff = np.mean(np.abs(ngrad - grad))
-        logging.info("Numerical grad test wrt '%s' of shape %s passes, "
-                     "dist = %f, max_diff = %f, avg_diff = %f",
-                     x_name, grad.shape, dist, max_diff, avg_diff)
+        logging.info(
+            "Numerical grad test wrt '%s' of shape %s passes, "
+            "dist = %f, max_diff = %f, avg_diff = %f",
+            x_name,
+            grad.shape,
+            dist,
+            max_diff,
+            avg_diff,
+        )
 
 
 def assert_prim_expr_equal(lhs, rhs):
@@ -230,7 +253,7 @@ def assert_prim_expr_equal(lhs, rhs):
 
 
 def check_bool_expr_is_true(bool_expr, vranges, cond=None):
-    """ Check that bool_expr holds given the condition cond
+    """Check that bool_expr holds given the condition cond
     for every value of free variables from vranges.
 
     for example, 2x > 4y solves to x > 2y given x in (0, 10) and y in (0, 10)
@@ -253,9 +276,10 @@ def check_bool_expr_is_true(bool_expr, vranges, cond=None):
         bool_expr = tvm.te.any(tvm.tir.Not(cond), bool_expr)
 
     def _run_expr(expr, vranges):
-        """ Evaluate expr for every value of free variables
+        """Evaluate expr for every value of free variables
         given by vranges and return the tensor of results.
         """
+
         def _compute_body(*us):
             vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
             return tvm.tir.stmt_functor.substitute(expr, vmap)
@@ -274,13 +298,14 @@ def check_bool_expr_is_true(bool_expr, vranges, cond=None):
         counterex = sorted(counterex, key=lambda x: x[0])
         counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
         ana = tvm.arith.Analyzer()
-        raise AssertionError("Expression {}\nis not true on {}\n"
-                             "Counterexample: {}"
-                             .format(ana.simplify(bool_expr), vranges, counterex))
+        raise AssertionError(
+            "Expression {}\nis not true on {}\n"
+            "Counterexample: {}".format(ana.simplify(bool_expr), vranges, counterex)
+        )
 
 
 def check_int_constraints_trans_consistency(constraints_trans, vranges=None):
-    """ Check IntConstraintsTransform is a bijective transformation.
+    """Check IntConstraintsTransform is a bijective transformation.
 
     Parameters
     ----------
@@ -298,7 +323,7 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges=None):
         all_vranges.update({v: r for v, r in constraints1.ranges.items()})
 
         # Check that the transformation is injective
-        cond_on_vars = tvm.tir.const(1, 'bool')
+        cond_on_vars = tvm.tir.const(1, "bool")
         for v in constraints1.variables:
             if v in varmap:
                 # variable mapping is consistent
@@ -306,7 +331,8 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges=None):
                 cond_on_vars = tvm.te.all(cond_on_vars, v == v_back)
         # Also we have to check that the new relations are true when old relations are true
         cond_subst = tvm.tir.stmt_functor.substitute(
-            tvm.te.all(tvm.tir.const(1, 'bool'), *constraints2.relations), backvarmap)
+            tvm.te.all(tvm.tir.const(1, "bool"), *constraints2.relations), backvarmap
+        )
         # We have to include relations from vranges too
         for v in constraints2.variables:
             if v in constraints2.ranges:
@@ -316,13 +342,23 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges=None):
                 cond_subst = tvm.te.all(cond_subst, range_cond)
         cond_subst = ana.simplify(cond_subst)
         check_bool_expr_is_true(
-            tvm.te.all(cond_subst, cond_on_vars), all_vranges,
-            cond=tvm.te.all(tvm.tir.const(1, 'bool'), *constraints1.relations))
+            tvm.te.all(cond_subst, cond_on_vars),
+            all_vranges,
+            cond=tvm.te.all(tvm.tir.const(1, "bool"), *constraints1.relations),
+        )
 
-    _check_forward(constraints_trans.src, constraints_trans.dst,
-                   constraints_trans.src_to_dst, constraints_trans.dst_to_src)
-    _check_forward(constraints_trans.dst, constraints_trans.src,
-                   constraints_trans.dst_to_src, constraints_trans.src_to_dst)
+    _check_forward(
+        constraints_trans.src,
+        constraints_trans.dst,
+        constraints_trans.src_to_dst,
+        constraints_trans.dst_to_src,
+    )
+    _check_forward(
+        constraints_trans.dst,
+        constraints_trans.src,
+        constraints_trans.dst_to_src,
+        constraints_trans.src_to_dst,
+    )
 
 
 def _get_targets():
@@ -413,8 +449,7 @@ def enabled_targets():
 
 
 def _compose(args, decs):
-    """Helper to apply multiple markers
-    """
+    """Helper to apply multiple markers"""
     if len(args) > 0:
         f = args[0]
         for d in reversed(decs):
@@ -456,8 +491,6 @@ def requires_gpu(*args):
     return _compose(args, _requires_gpu)
 
 
-
-
 def requires_cuda(*args):
     """Mark a test as requiring the CUDA runtime.
 
@@ -470,16 +503,12 @@ def requires_cuda(*args):
     """
     _requires_cuda = [
         pytest.mark.cuda,
-        pytest.mark.skipif(
-            not device_enabled("cuda"), reason="CUDA support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("cuda"), reason="CUDA support not enabled"),
         *requires_gpu(),
     ]
     return _compose(args, _requires_cuda)
 
 
-
-
 def requires_opencl(*args):
     """Mark a test as requiring the OpenCL runtime.
 
@@ -492,16 +521,12 @@ def requires_opencl(*args):
     """
     _requires_opencl = [
         pytest.mark.opencl,
-        pytest.mark.skipif(
-            not device_enabled("opencl"), reason="OpenCL support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("opencl"), reason="OpenCL support not enabled"),
         *requires_gpu(),
     ]
     return _compose(args, _requires_opencl)
 
 
-
-
 def requires_rocm(*args):
     """Mark a test as requiring the rocm runtime.
 
@@ -514,16 +539,12 @@ def requires_rocm(*args):
     """
     _requires_rocm = [
         pytest.mark.rocm,
-        pytest.mark.skipif(
-            not device_enabled("rocm"), reason="rocm support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("rocm"), reason="rocm support not enabled"),
         *requires_gpu(),
     ]
     return _compose(args, _requires_rocm)
 
 
-
-
 def requires_metal(*args):
     """Mark a test as requiring the metal runtime.
 
@@ -536,16 +557,12 @@ def requires_metal(*args):
     """
     _requires_metal = [
         pytest.mark.metal,
-        pytest.mark.skipif(
-            not device_enabled("metal"), reason="metal support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("metal"), reason="metal support not enabled"),
         *requires_gpu(),
     ]
     return _compose(args, _requires_metal)
 
 
-
-
 def requires_vulkan(*args):
     """Mark a test as requiring the vulkan runtime.
 
@@ -558,16 +575,12 @@ def requires_vulkan(*args):
     """
     _requires_vulkan = [
         pytest.mark.vulkan,
-        pytest.mark.skipif(
-            not device_enabled("vulkan"), reason="vulkan support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("vulkan"), reason="vulkan support not enabled"),
         *requires_gpu(),
     ]
     return _compose(args, _requires_vulkan)
 
 
-
-
 def requires_tensorcore(*args):
     """Mark a test as requiring a tensorcore to run.
 
@@ -589,8 +602,6 @@ def requires_tensorcore(*args):
     return _compose(args, _requires_tensorcore)
 
 
-
-
 def requires_llvm(*args):
     """Mark a test as requiring llvm to run.
 
@@ -601,9 +612,7 @@ def requires_llvm(*args):
     """
     _requires_llvm = [
         pytest.mark.llvm,
-        pytest.mark.skipif(
-            not device_enabled("llvm"), reason="LLVM support not enabled"
-        ),
+        pytest.mark.skipif(not device_enabled("llvm"), reason="LLVM support not enabled"),
     ]
     return _compose(args, _requires_llvm)
 
@@ -654,6 +663,7 @@ def parametrize_targets(*args):
     >>> def test_mytest(target, ctx):
     >>>     ...  # do something
     """
+
     def wrap(targets):
         def func(f):
             params = [
@@ -661,7 +671,9 @@ def parametrize_targets(*args):
                 for target in targets
             ]
             return pytest.mark.parametrize("target,ctx", params)(f)
+
         return func
+
     if len(args) == 1 and callable(args[0]):
         targets = [t for t, _ in enabled_targets()]
         return wrap(targets)(args[0])
index bd7672a..7d2d8ce 100644 (file)
@@ -38,6 +38,7 @@ class Buffer(Object):
     --------
     decl_buffer : Declare a buffer
     """
+
     READ = 1
     WRITE = 2
 
@@ -89,8 +90,7 @@ class Buffer(Object):
                     raise ValueError("Unknown access_mask %s" % access_mask)
             access_mask = mask
         offset = convert(offset)
-        return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type,
-                                        content_lanes, offset)
+        return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset)
 
     def vload(self, begin, dtype=None):
         """Generate an Expr that loads dtype from begin index.
@@ -133,16 +133,18 @@ class Buffer(Object):
         return _ffi_api.BufferVStore(self, begin, value)
 
 
-def decl_buffer(shape,
-                dtype=None,
-                name="buffer",
-                data=None,
-                strides=None,
-                elem_offset=None,
-                scope="",
-                data_alignment=-1,
-                offset_factor=0,
-                buffer_type=""):
+def decl_buffer(
+    shape,
+    dtype=None,
+    name="buffer",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    data_alignment=-1,
+    offset_factor=0,
+    buffer_type="",
+):
     """Declare a new symbolic buffer.
 
     Normally buffer is created automatically during lower and build.
@@ -239,12 +241,21 @@ def decl_buffer(shape,
     strides = () if strides is None else strides
     if offset_factor != 0 and elem_offset is None:
         shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
-        elem_offset = Var('%s_elem_offset' % name, shape_dtype)
+        elem_offset = Var("%s_elem_offset" % name, shape_dtype)
     if data is None:
         data = Var(name, PointerType(PrimType(dtype)))
     return _ffi_api.Buffer(
-        data, dtype, shape, strides, elem_offset, name, scope,
-        data_alignment, offset_factor, buffer_type)
+        data,
+        dtype,
+        shape,
+        strides,
+        elem_offset,
+        name,
+        scope,
+        data_alignment,
+        offset_factor,
+        buffer_type,
+    )
 
 
 @tvm._ffi.register_object("tir.DataProducer")
index 1616473..40805d9 100644 (file)
@@ -20,6 +20,7 @@ import tvm._ffi
 from tvm.runtime import Object
 from . import _ffi_api
 
+
 @tvm._ffi.register_object("tir.Layout")
 class Layout(Object):
     """Layout is composed of upper cases, lower cases and numbers,
@@ -33,6 +34,7 @@ class Layout(Object):
     --------
     layout : Declare a layout
     """
+
     def __len__(self):
         return _ffi_api.LayoutNdim(self)
 
@@ -97,6 +99,7 @@ class BijectiveLayout(Object):
     --------
     bijective_layout : Declare a layout
     """
+
     def forward_index(self, index):
         """Given the indices of the src-layout, infer the dst index.
 
index 6f3d550..60d92e9 100644 (file)
@@ -38,25 +38,27 @@ from . import _ffi_api
 
 def div_ambiguity_error():
     return RuntimeError(
-        "TVM supports multiple types of integer divisions, " +
-        "please call div, indexdiv/indexmod, floordiv/floormod " +
-        " or truncdiv/truncmod directly to avoid ambiguity in the code.")
+        "TVM supports multiple types of integer divisions, "
+        + "please call div, indexdiv/indexmod, floordiv/floormod "
+        + " or truncdiv/truncmod directly to avoid ambiguity in the code."
+    )
 
 
 def _dtype_is_int(value):
     if isinstance(value, int):
         return True
-    return (isinstance(value, ExprOp) and
-            DataType(value.dtype).type_code == DataTypeCode.INT)
+    return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT
+
 
 def _dtype_is_float(value):
     if isinstance(value, float):
         return True
-    return (isinstance(value, ExprOp) and
-            DataType(value.dtype).type_code == DataTypeCode.FLOAT)
+    return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT
+
 
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
+
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -165,8 +167,10 @@ class ExprOp(object):
         return _ffi_api._OpGE(self, other)
 
     def __nonzero__(self):
-        raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
-                         "use tvm.tir.all / tvm.tir.any instead")
+        raise ValueError(
+            "Cannot use and / or / not operator to Expr, hint: "
+            + "use tvm.tir.all / tvm.tir.any instead"
+        )
 
     def __bool__(self):
         return self.__nonzero__()
@@ -216,6 +220,7 @@ class EqualOp(ObjectGeneric, ExprOp):
     b : PrimExpr
         Right operand.
     """
+
     # This class is not manipulated by C++. So use python's identity check function is sufficient
     same_as = object.__eq__
 
@@ -248,6 +253,7 @@ class NotEqualOp(ObjectGeneric, ExprOp):
     b : PrimExpr
         Right operand.
     """
+
     # This class is not manipulated by C++. So use python's identity check function is sufficient
     same_as = object.__eq__
 
@@ -275,6 +281,7 @@ class IntImmEnum(ObjectGeneric):
     value : int
         The enum value
     """
+
     def __init__(self, value):
         self.value = value
 
@@ -285,6 +292,7 @@ class IntImmEnum(ObjectGeneric):
 
 class PrimExprWithOp(ExprOp, PrimExpr):
     """Helper base class to inherit from PrimExpr."""
+
     # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
     # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
     __hash__ = PrimExpr.__hash__
@@ -293,15 +301,19 @@ class PrimExprWithOp(ExprOp, PrimExpr):
 class ConstExpr(PrimExprWithOp):
     pass
 
+
 class BinaryOpExpr(PrimExprWithOp):
     pass
 
+
 class CmpExpr(PrimExprWithOp):
     pass
 
+
 class LogicalExpr(PrimExprWithOp):
     pass
 
+
 @tvm._ffi.register_object("tir.Var")
 class Var(PrimExprWithOp):
     """Symbolic variable.
@@ -314,9 +326,9 @@ class Var(PrimExprWithOp):
     dtype : Union[str, tvm.irType]
         The data type
     """
+
     def __init__(self, name, dtype):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Var, name, dtype)
+        self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype)
 
 
 @tvm._ffi.register_object("tir.SizeVar")
@@ -332,10 +344,10 @@ class SizeVar(Var):
     dtype : int
         The data type
     """
+
     # pylint: disable=super-init-not-called
     def __init__(self, name, dtype):
-        self.__init_handle_by_constructor__(
-            _ffi_api.SizeVar, name, dtype)
+        self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype)
 
 
 @tvm._ffi.register_object("tir.IterVar")
@@ -363,6 +375,7 @@ class IterVar(Object, ExprOp):
     te.thread_axis: Create thread axis IterVar.
     te.reduce_axis: Create reduce axis IterVar.
     """
+
     DataPar = 0
     ThreadIndex = 1
     CommReduce = 2
@@ -386,8 +399,7 @@ class IterVar(Object, ExprOp):
         name = var if var is not None else "iter"
         dtype = "int32" if dom is None else dom.extent.dtype
         var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
-        self.__init_handle_by_constructor__(
-            _ffi_api.IterVar, dom, var, iter_type, thread_tag)
+        self.__init_handle_by_constructor__(_ffi_api.IterVar, dom, var, iter_type, thread_tag)
 
 
 @tvm._ffi.register_object("tir.CommReducer")
@@ -408,9 +420,11 @@ class CommReducer(Object):
     identity_element : List[PrimExpr]
        The identity elements.
     """
+
     def __init__(self, lhs, rhs, result, identity_element):
         self.__init_handle_by_constructor__(
-            _ffi_api.CommReducer, lhs, rhs, result, identity_element)
+            _ffi_api.CommReducer, lhs, rhs, result, identity_element
+        )
 
 
 @tvm._ffi.register_object("tir.Reduce")
@@ -437,10 +451,11 @@ class Reduce(PrimExprWithOp):
     init : list of Expr
         The initial value for output. This can be an int, float or ProducerLoad
     """
+
     def __init__(self, combiner, src, rdom, condition, value_index, init=None):
         self.__init_handle_by_constructor__(
-            _ffi_api.Reduce, combiner, src, rdom,
-            condition, value_index, init)
+            _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init
+        )
 
 
 @tvm._ffi.register_object
@@ -455,9 +470,9 @@ class FloatImm(ConstExpr):
     value : float
         The constant value.
     """
+
     def __init__(self, dtype, value):
-        self.__init_handle_by_constructor__(
-            tvm.ir._ffi_api.FloatImm, dtype, value)
+        self.__init_handle_by_constructor__(tvm.ir._ffi_api.FloatImm, dtype, value)
 
 
 @tvm._ffi.register_object
@@ -472,9 +487,9 @@ class IntImm(ConstExpr):
     value : int
         The constant value.
     """
+
     def __init__(self, dtype, value):
-        self.__init_handle_by_constructor__(
-            tvm.ir._ffi_api.IntImm, dtype, value)
+        self.__init_handle_by_constructor__(tvm.ir._ffi_api.IntImm, dtype, value)
 
     def __hash__(self):
         return self.value
@@ -504,9 +519,9 @@ class StringImm(ConstExpr):
     value : str
         The value of the function.
     """
+
     def __init__(self, value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.StringImm, value)
+        self.__init_handle_by_constructor__(_ffi_api.StringImm, value)
 
     def __eq__(self, other):
         if isinstance(other, ConstExpr):
@@ -531,9 +546,9 @@ class Cast(PrimExprWithOp):
     value : PrimExpr
         The value of the function.
     """
+
     def __init__(self, dtype, value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Cast, dtype, value)
+        self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value)
 
 
 @tvm._ffi.register_object("tir.Add")
@@ -548,9 +563,9 @@ class Add(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Add, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Add, a, b)
 
 
 @tvm._ffi.register_object("tir.Sub")
@@ -565,9 +580,9 @@ class Sub(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Sub, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Sub, a, b)
 
 
 @tvm._ffi.register_object("tir.Mul")
@@ -582,9 +597,9 @@ class Mul(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Mul, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Mul, a, b)
 
 
 @tvm._ffi.register_object("tir.Div")
@@ -599,9 +614,9 @@ class Div(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Div, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Div, a, b)
 
 
 @tvm._ffi.register_object("tir.Mod")
@@ -616,9 +631,9 @@ class Mod(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Mod, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Mod, a, b)
 
 
 @tvm._ffi.register_object("tir.FloorDiv")
@@ -633,9 +648,9 @@ class FloorDiv(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.FloorDiv, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b)
 
 
 @tvm._ffi.register_object("tir.FloorMod")
@@ -650,9 +665,9 @@ class FloorMod(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.FloorMod, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b)
 
 
 @tvm._ffi.register_object("tir.Min")
@@ -667,9 +682,9 @@ class Min(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Min, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Min, a, b)
 
 
 @tvm._ffi.register_object("tir.Max")
@@ -684,9 +699,9 @@ class Max(BinaryOpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Max, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Max, a, b)
 
 
 @tvm._ffi.register_object("tir.EQ")
@@ -701,9 +716,9 @@ class EQ(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.EQ, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.EQ, a, b)
 
 
 @tvm._ffi.register_object("tir.NE")
@@ -718,9 +733,9 @@ class NE(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.NE, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.NE, a, b)
 
 
 @tvm._ffi.register_object("tir.LT")
@@ -735,9 +750,9 @@ class LT(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.LT, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.LT, a, b)
 
 
 @tvm._ffi.register_object("tir.LE")
@@ -752,9 +767,9 @@ class LE(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.LE, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.LE, a, b)
 
 
 @tvm._ffi.register_object("tir.GT")
@@ -769,9 +784,9 @@ class GT(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.GT, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.GT, a, b)
 
 
 @tvm._ffi.register_object("tir.GE")
@@ -786,9 +801,9 @@ class GE(CmpExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.GE, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.GE, a, b)
 
 
 @tvm._ffi.register_object("tir.And")
@@ -803,9 +818,9 @@ class And(LogicalExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.And, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.And, a, b)
 
 
 @tvm._ffi.register_object("tir.Or")
@@ -820,9 +835,9 @@ class Or(LogicalExpr):
     b : PrimExpr
         The right hand operand.
     """
+
     def __init__(self, a, b):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Or, a, b)
+        self.__init_handle_by_constructor__(_ffi_api.Or, a, b)
 
 
 @tvm._ffi.register_object("tir.Not")
@@ -834,9 +849,9 @@ class Not(LogicalExpr):
     a : PrimExpr
         The input value
     """
+
     def __init__(self, a):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Not, a)
+        self.__init_handle_by_constructor__(_ffi_api.Not, a)
 
 
 @tvm._ffi.register_object("tir.Select")
@@ -862,9 +877,9 @@ class Select(PrimExprWithOp):
         The value to take when condition is false.
 
     """
+
     def __init__(self, condition, true_value, false_value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Select, condition, true_value, false_value)
+        self.__init_handle_by_constructor__(_ffi_api.Select, condition, true_value, false_value)
 
 
 @tvm._ffi.register_object("tir.Load")
@@ -885,10 +900,10 @@ class Load(PrimExprWithOp):
     predicate : PrimExpr
         The load predicate.
     """
+
     def __init__(self, dtype, buffer_var, index, predicate=None):
         args = [] if predicate is None else [predicate]
-        self.__init_handle_by_constructor__(
-            _ffi_api.Load, dtype, buffer_var, index, *args)
+        self.__init_handle_by_constructor__(_ffi_api.Load, dtype, buffer_var, index, *args)
 
 
 @tvm._ffi.register_object("tir.BufferLoad")
@@ -903,9 +918,9 @@ class BufferLoad(PrimExprWithOp):
     indices : List[PrimExpr]
         The buffer indices.
     """
+
     def __init__(self, buffer, indices):
-        self.__init_handle_by_constructor__(
-            _ffi_api.BufferLoad, buffer, indices)
+        self.__init_handle_by_constructor__(_ffi_api.BufferLoad, buffer, indices)
 
 
 @tvm._ffi.register_object("tir.ProducerLoad")
@@ -920,9 +935,9 @@ class ProducerLoad(PrimExprWithOp):
     indices : List[PrimExpr]
         The buffer indices.
     """
+
     def __init__(self, producer, indices):
-        self.__init_handle_by_constructor__(
-            _ffi_api.ProducerLoad, producer, indices)
+        self.__init_handle_by_constructor__(_ffi_api.ProducerLoad, producer, indices)
 
 
 @tvm._ffi.register_object("tir.Ramp")
@@ -940,9 +955,9 @@ class Ramp(PrimExprWithOp):
     lanes : int
         The lanes of the expression.
     """
+
     def __init__(self, base, stride, lanes):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Ramp, base, stride, lanes)
+        self.__init_handle_by_constructor__(_ffi_api.Ramp, base, stride, lanes)
 
 
 @tvm._ffi.register_object("tir.Broadcast")
@@ -957,9 +972,9 @@ class Broadcast(PrimExprWithOp):
     lanes : int
         The lanes of the expression.
     """
+
     def __init__(self, value, lanes):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Broadcast, value, lanes)
+        self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes)
 
 
 @tvm._ffi.register_object("tir.Shuffle")
@@ -974,13 +989,14 @@ class Shuffle(PrimExprWithOp):
     indices : Array of indices
         The indices
     """
+
     def __init__(self, vectors, indices):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Shuffle, vectors, indices)
+        self.__init_handle_by_constructor__(_ffi_api.Shuffle, vectors, indices)
 
 
 class CallEffectKind:
     """Possible kinds of Call effects."""
+
     # only expose up to opaque
     ExprAnnotation = IntImmEnum(0)
     Pure = IntImmEnum(1)
@@ -1005,16 +1021,20 @@ class Call(PrimExprWithOp):
     args : list of Expr
         The input arguments to the call
     """
+
     def __init__(self, dtype, op, args):
         if isinstance(op, str):
             if not op.startswith("tir."):
                 raise ValueError(
-                    ("Cannot handle str op argument %s. This function only handles str " +
-                     "argument with the tir namespace. If you are " +
-                     "certain about the intrinsic name, pass in Op.get(name) instead") % op)
+                    (
+                        "Cannot handle str op argument %s. This function only handles str "
+                        + "argument with the tir namespace. If you are "
+                        + "certain about the intrinsic name, pass in Op.get(name) instead"
+                    )
+                    % op
+                )
             op = Op.get(op)
-        self.__init_handle_by_constructor__(
-            _ffi_api.Call, dtype, op, args)
+        self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args)
 
 
 @tvm._ffi.register_object("tir.Let")
@@ -1032,14 +1052,14 @@ class Let(PrimExprWithOp):
     body : PrimExpr
         The body expression.
     """
+
     def __init__(self, var, value, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Let, var, value, body)
+        self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body)
 
 
 @tvm._ffi.register_object("tir.Any")
 class Any(PrimExpr):
-    """Any node.
-    """
+    """Any node."""
+
     def __init__(self):
         self.__init_handle_by_constructor__(_ffi_api.Any)
index 47ad94f..b02ebba 100644 (file)
@@ -46,12 +46,8 @@ class PrimFunc(BaseFunc):
     attrs: Optional[tvm.Attrs]
         Attributes of the function, can be None
     """
-    def __init__(self,
-                 params,
-                 body,
-                 ret_type=None,
-                 buffer_map=None,
-                 attrs=None):
+
+    def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None):
         param_list = []
         buffer_map = {} if buffer_map is None else buffer_map
         for x in params:
@@ -66,7 +62,8 @@ class PrimFunc(BaseFunc):
                 raise TypeError("params can only contain Var or Buffer")
 
         self.__init_handle_by_constructor__(
-            _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
+            _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs
+        )
 
     def with_body(self, new_body):
         """Create a new PrimFunc with the same set signatures but a new body.
@@ -81,5 +78,4 @@ class PrimFunc(BaseFunc):
         new_func : PrimFunc
             The created new function.
         """
-        return PrimFunc(
-            self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
+        return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
index 88be5b1..220f434 100644 (file)
@@ -79,6 +79,7 @@ def multiply(lhs, rhs):
     """
     return _ffi_api._OpMul(lhs, rhs)
 
+
 def divide(lhs, rhs):
     """Generic divide operator.
 
@@ -96,6 +97,7 @@ def divide(lhs, rhs):
     """
     return _ffi_api._OpDiv(lhs, rhs)
 
+
 def floordiv(lhs, rhs):
     """Generic floordiv operator.
 
index b313e58..8b999bf 100644 (file)
@@ -25,6 +25,7 @@ from . import expr as _expr
 
 class WithScope(object):
     """Auxiliary scope  with"""
+
     def __init__(self, enter_value, exit_cb):
         self._enter_value = enter_value
         self._exit_cb = exit_cb
@@ -60,6 +61,7 @@ class BufferVar(ObjectGeneric):
     IRBuilder.buffer_ptr
     IRBuilder.allocate
     """
+
     def __init__(self, builder, buffer_var, content_type):
         self._builder = builder
         self._buffer_var = buffer_var
@@ -83,8 +85,8 @@ class BufferVar(ObjectGeneric):
         value = convert(value)
         if value.dtype != self._content_type:
             raise ValueError(
-                "data type does not match content type %s vs %s" % (
-                    value.dtype, self._content_type))
+                "data type does not match content type %s vs %s" % (value.dtype, self._content_type)
+            )
         t = DataType(self._content_type)
         if t.lanes > 1:
             base = index * t.lanes
@@ -108,6 +110,7 @@ class IRBuilder(object):
         # The result stmt.
         stmt = ib.get()
     """
+
     def __init__(self):
         self._seq_stack = [[]]
         self.nidx = 0
@@ -206,12 +209,13 @@ class IRBuilder(object):
             with ib.for_range(1, 10, name="i") as i:
                 x[i] = x[i - 1] + 1
         """
-        if name == 'i':
+        if name == "i":
             name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
             self.nidx += 1
         self._seq_stack.append([])
         loop_var = _expr.Var(name, dtype=dtype)
         extent = end if begin == 0 else (end - begin)
+
         def _exit_cb():
             if for_type == "serial":
                 for_type_id = 0
@@ -223,8 +227,8 @@ class IRBuilder(object):
                 for_type_id = 3
             else:
                 raise ValueError("Unknown for_type")
-            self.emit(_stmt.For(
-                loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
+            self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
+
         return WithScope(loop_var, _exit_cb)
 
     def if_scope(self, cond):
@@ -251,8 +255,10 @@ class IRBuilder(object):
                 x[i] = x[i - 1] + 1
         """
         self._seq_stack.append([])
+
         def _exit_cb():
             self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None))
+
         return WithScope(None, _exit_cb)
 
     def else_scope(self):
@@ -284,8 +290,10 @@ class IRBuilder(object):
             raise RuntimeError("else_scope can only follow an if_scope")
         self._seq_stack[-1].pop()
         self._seq_stack.append([])
+
         def _exit_cb():
             self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
+
         return WithScope(None, _exit_cb)
 
     def new_scope(self):
@@ -299,8 +307,10 @@ class IRBuilder(object):
            The result new scope.
         """
         self._seq_stack.append([])
+
         def _exit_cb():
             self.emit(self._pop_seq())
+
         return WithScope(None, _exit_cb)
 
     def allocate(self, dtype, shape, name="buf", scope=None):
@@ -330,8 +340,7 @@ class IRBuilder(object):
             shape = [shape]
         if scope:
             self.scope_attr(buffer_var, "storage_scope", scope)
-        self.emit(lambda x: _stmt.Allocate(
-            buffer_var, dtype, shape, const(1, dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
         return BufferVar(self, buffer_var, dtype)
 
     def pointer(self, content_type, name="ptr"):
index 9592e6e..0240485 100644 (file)
@@ -26,18 +26,19 @@ from . import _ffi_api
 
 
 def _pack_buffer(buf):
-    """Build intrinsics that packs the buffer.
-    """
+    """Build intrinsics that packs the buffer."""
     shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
     strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
-    pack_args = [buf.data,
-                 shape,
-                 strides,
-                 len(buf.shape),
-                 const(0, dtype=buf.dtype),
-                 buf.elem_offset]
-    return Call("handle", Op.get("tir.tvm_stack_make_array"),
-                pack_args)
+    pack_args = [
+        buf.data,
+        shape,
+        strides,
+        len(buf.shape),
+        const(0, dtype=buf.dtype),
+        buf.elem_offset,
+    ]
+    return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args)
+
 
 def call_packed(*args):
     """Build expression by call an external packed function.
@@ -64,8 +65,7 @@ def call_packed(*args):
     te.extern : Create tensor with extern function call.
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
-    return Call(
-        "int32", Op.get("tir.tvm_call_packed"), call_args)
+    return Call("int32", Op.get("tir.tvm_call_packed"), call_args)
 
 
 def call_intrin(dtype, func_name, *args):
@@ -90,8 +90,7 @@ def call_intrin(dtype, func_name, *args):
     call : PrimExpr
         The call expression.
     """
-    return Call(
-        dtype, func_name, convert(args))
+    return Call(dtype, func_name, convert(args))
 
 
 def call_pure_extern(dtype, func_name, *args):
@@ -113,8 +112,7 @@ def call_pure_extern(dtype, func_name, *args):
     call : PrimExpr
         The call expression.
     """
-    return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args))
+    return Call(dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args))
 
 
 def call_extern(dtype, func_name, *args):
@@ -136,8 +134,7 @@ def call_extern(dtype, func_name, *args):
     call : PrimExpr
         The call expression.
     """
-    return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args))
+    return Call(dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args))
 
 
 def call_llvm_intrin(dtype, name, *args):
@@ -161,11 +158,12 @@ def call_llvm_intrin(dtype, name, *args):
     """
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
+
     llvm_id = codegen.llvm_lookup_intrinsic_id(name)
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"),
-        tvm.tir.const(llvm_id, 'uint32'), *args)
+        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args
+    )
 
 
 def call_llvm_pure_intrin(dtype, name, *args):
@@ -189,11 +187,12 @@ def call_llvm_pure_intrin(dtype, name, *args):
     """
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
+
     llvm_id = codegen.llvm_lookup_intrinsic_id(name)
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_pure_intrin"),
-        tvm.tir.const(llvm_id, 'uint32'), *args)
+        dtype, Op.get("tir.call_llvm_pure_intrin"), tvm.tir.const(llvm_id, "uint32"), *args
+    )
 
 
 def any(*args):
@@ -247,6 +246,7 @@ def all(*args):
 def _tvm_default_trace_action(*args):
     print(list(args))
 
+
 def trace(args, trace_action="tvm.default_trace_action"):
     """Trace tensor data at the runtime.
 
@@ -276,9 +276,7 @@ def trace(args, trace_action="tvm.default_trace_action"):
         raise Exception("tvm.tir.trace consumes the args as list type")
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
     call_args.insert(0, trace_action)
-    return tvm.tir.Call(
-        args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
-
+    return tvm.tir.Call(args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
 
 
 def min_value(dtype):
@@ -964,6 +962,7 @@ def popcount(x):
     """
     return call_intrin(x.dtype, "tir.popcount", x)
 
+
 def q_multiply_shift(x, y, q, s):
     """Execute a multiplication between two Q-numbers x and y
     followed by a right shift s. The mathematical expression is:
@@ -990,7 +989,8 @@ def q_multiply_shift(x, y, q, s):
     y : PrimExpr
         The result.
     """
-    return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s)
+    return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
+
 
 def fmod(x, y):
     """Return the remainder of x divided by y with the same sign as x.
@@ -1229,14 +1229,15 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
         k = te.reduce_axis((0, m), name="k")
         B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
     """
+
     def _reduce_directly(*args):
         num = len(args)
         # process `where` is None
         if num == 3 and args[2] is None:
             num = 2
         res = args[0]
-        for i in range(num-1):
-            res = fcombine(res, args[i+1])
+        for i in range(num - 1):
+            res = fcombine(res, args[i + 1])
         return res
 
     def _make_reduce(expr, axis, where=None, init=None):
@@ -1263,8 +1264,9 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
                 assert len(init) == size
                 for init_i in range(size):
                     init_i = convert(init_i)
-                    assert isinstance(init_i,
-                                      (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
+                    assert isinstance(
+                        init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)
+                    )
             else:
                 init = convert([])
             lhs = convert(larr)
@@ -1292,11 +1294,13 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
         if where is None:
             where = convert(True)
         if init is None:
-            outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, convert([]))
-                            for i in range(size))
+            outputs = tuple(
+                tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size)
+            )
         else:
-            outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, init)
-                            for i in range(size))
+            outputs = tuple(
+                tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size)
+            )
         return outputs[0] if size == 1 else outputs
 
     # pylint: disable=keyword-arg-before-vararg
@@ -1344,7 +1348,8 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     reducer.__doc__ = doc_str.format(name)
     return reducer
 
+
 # pylint: disable=unnecessary-lambda
-sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
+sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y), max_value, name="min")
 max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y), min_value, name="max")
index 757b2ac..573bc0e 100644 (file)
@@ -51,9 +51,9 @@ class LetStmt(Stmt):
     body : Stmt
         The body statement.
     """
+
     def __init__(self, var, value, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.LetStmt, var, value, body)
+        self.__init_handle_by_constructor__(_ffi_api.LetStmt, var, value, body)
 
 
 @tvm._ffi.register_object("tir.AssertStmt")
@@ -71,9 +71,9 @@ class AssertStmt(Stmt):
     body : Stmt
         The body statement.
     """
+
     def __init__(self, condition, message, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.AssertStmt, condition, message, body)
+        self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body)
 
 
 @tvm._ffi.register_object("tir.For")
@@ -100,20 +100,16 @@ class For(Stmt):
     body : Stmt
         The body statement.
     """
+
     Serial = 0
     Parallel = 1
     Vectorized = 2
     Unrolled = 3
-    def __init__(self,
-                 loop_var,
-                 min_val,
-                 extent,
-                 for_type,
-                 device_api,
-                 body):
+
+    def __init__(self, loop_var, min_val, extent, for_type, device_api, body):
         self.__init_handle_by_constructor__(
-            _ffi_api.For, loop_var, min_val, extent,
-            for_type, device_api, body)
+            _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body
+        )
 
 
 @tvm._ffi.register_object("tir.Store")
@@ -134,10 +130,10 @@ class Store(Stmt):
     predicate : PrimExpr
         The store predicate.
     """
+
     def __init__(self, buffer_var, value, index, predicate=None):
         args = [] if predicate is None else [predicate]
-        self.__init_handle_by_constructor__(
-            _ffi_api.Store, buffer_var, value, index, *args)
+        self.__init_handle_by_constructor__(_ffi_api.Store, buffer_var, value, index, *args)
 
 
 @tvm._ffi.register_object("tir.BufferStore")
@@ -155,9 +151,9 @@ class BufferStore(Stmt):
     indices : List[PrimExpr]
         The indices location to be stored.
     """
+
     def __init__(self, buffer, value, indices):
-        self.__init_handle_by_constructor__(
-            _ffi_api.BufferStore, buffer, value, indices)
+        self.__init_handle_by_constructor__(_ffi_api.BufferStore, buffer, value, indices)
 
 
 @tvm._ffi.register_object("tir.BufferRealize")
@@ -178,9 +174,9 @@ class BufferRealize(Stmt):
     body : Stmt
         The body of the statement.
     """
+
     def __init__(self, buffer, bounds, condition, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.BufferRealize, buffer, bounds, condition, body)
+        self.__init_handle_by_constructor__(_ffi_api.BufferRealize, buffer, bounds, condition, body)
 
 
 @tvm._ffi.register_object("tir.ProducerStore")
@@ -198,9 +194,9 @@ class ProducerStore(Stmt):
     indices : list of Expr
         The index arguments of the store.
     """
+
     def __init__(self, producer, value, indices):
-        self.__init_handle_by_constructor__(
-            _ffi_api.ProducerStore, producer, value, indices)
+        self.__init_handle_by_constructor__(_ffi_api.ProducerStore, producer, value, indices)
 
 
 @tvm._ffi.register_object("tir.Allocate")
@@ -224,15 +220,11 @@ class Allocate(Stmt):
     body : Stmt
         The body statement.
     """
-    def __init__(self,
-                 buffer_var,
-                 dtype,
-                 extents,
-                 condition,
-                 body):
+
+    def __init__(self, buffer_var, dtype, extents, condition, body):
         self.__init_handle_by_constructor__(
-            _ffi_api.Allocate, buffer_var, dtype,
-            extents, condition, body)
+            _ffi_api.Allocate, buffer_var, dtype, extents, condition, body
+        )
 
 
 @tvm._ffi.register_object("tir.AttrStmt")
@@ -253,9 +245,9 @@ class AttrStmt(Stmt):
     body : Stmt
         The body statement.
     """
+
     def __init__(self, node, attr_key, value, body):
-        self.__init_handle_by_constructor__(
-            _ffi_api.AttrStmt, node, attr_key, value, body)
+        self.__init_handle_by_constructor__(_ffi_api.AttrStmt, node, attr_key, value, body)
 
 
 @tvm._ffi.register_object("tir.ProducerRealize")
@@ -276,13 +268,11 @@ class ProducerRealize(Stmt):
     body : Stmt
         The realize body
     """
-    def __init__(self,
-                 producer,
-                 bounds,
-                 condition,
-                 body):
+
+    def __init__(self, producer, bounds, condition, body):
         self.__init_handle_by_constructor__(
-            _ffi_api.ProducerRealize, producer, bounds, condition, body)
+            _ffi_api.ProducerRealize, producer, bounds, condition, body
+        )
 
 
 @tvm._ffi.register_object("tir.SeqStmt")
@@ -294,9 +284,9 @@ class SeqStmt(Stmt):
     seq : List[Stmt]
         The statements
     """
+
     def __init__(self, seq):
-        self.__init_handle_by_constructor__(
-            _ffi_api.SeqStmt, seq)
+        self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq)
 
     def __getitem__(self, i):
         return self.seq[i]
@@ -320,9 +310,9 @@ class IfThenElse(Stmt):
     else_case : Stmt
         The statement to execute if condition is false.
     """
+
     def __init__(self, condition, then_case, else_case):
-        self.__init_handle_by_constructor__(
-            _ffi_api.IfThenElse, condition, then_case, else_case)
+        self.__init_handle_by_constructor__(_ffi_api.IfThenElse, condition, then_case, else_case)
 
 
 @tvm._ffi.register_object("tir.Evaluate")
@@ -334,9 +324,9 @@ class Evaluate(Stmt):
     value : PrimExpr
         The expression to be evalued.
     """
+
     def __init__(self, value):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Evaluate, value)
+        self.__init_handle_by_constructor__(_ffi_api.Evaluate, value)
 
 
 @tvm._ffi.register_object("tir.Prefetch")
@@ -351,9 +341,9 @@ class Prefetch(Stmt):
     bounds : list of Range
         The bounds to be prefetched.
     """
+
     def __init__(self, buffer, bounds):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Prefetch, buffer, bounds)
+        self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds)
 
 
 def stmt_seq(*args):
index cea8d14..f1e64ba 100644 (file)
@@ -59,7 +59,7 @@ def post_order_visit(stmt, fvisit):
 
 
 def substitute(node, vmap):
-    """ Substitute the var specified by vmap.
+    """Substitute the var specified by vmap.
 
     Parameters
     ----------
index a19cc2f..59b3ecd 100644 (file)
@@ -34,8 +34,10 @@ class PrimFuncPass(Pass):
 
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass"""
+
     class PyFunctionPass(PrimFuncPass):
         """Internal wrapper class to create a class instance."""
+
         def __init__(self, *args, **kwargs):
             # initialize handle in cass pass_cls creation failed.fg
             self.handle = None
@@ -44,8 +46,8 @@ def _wrap_class_function_pass(pass_cls, pass_info):
             # avoid a cyclic dependency
             def _pass_func(func, mod, ctx):
                 return inst.transform_function(func, mod, ctx)
-            self.__init_handle_by_constructor__(
-                _ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
+
+            self.__init_handle_by_constructor__(_ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
             self._inst = inst
 
         def __getattr__(self, name):
@@ -132,8 +134,7 @@ def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
 
     required = required if required else []
     if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of " +
-                        "list/tuple.")
+        raise TypeError("Required is expected to be the type of " + "list/tuple.")
 
     def create_function_pass(pass_arg):
         """Internal function that creates a function pass"""
index 3f7fb41..40dd170 100644 (file)
@@ -38,6 +38,7 @@ def Apply(ftransform):
     # pylint: disable=unused-argument
     def _transform(func, mod, ctx):
         return ftransform(func)
+
     return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")
 
 
@@ -57,6 +58,7 @@ def Filter(fcond):
     # pylint: disable=unused-argument
     def _transform(func, mod, ctx):
         return func if fcond(func) else None
+
     return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
 
 
@@ -226,6 +228,7 @@ def RemoveNoOp():
     """
     return _ffi_api.RemoveNoOp()
 
+
 def BF16Legalize():
     """Legalize bf16 typed Ops.
     Runs BF16Promote, BF16CastElimination and BF16TypeLowering
@@ -237,6 +240,7 @@ def BF16Legalize():
     """
     return _ffi_api.BF16Legalize()
 
+
 def BF16Promote():
     """Promote bf16 to fp32. Add a cast to fp32
     before Ops, then add a cast back to bf16.
@@ -248,6 +252,7 @@ def BF16Promote():
     """
     return _ffi_api.BF16Promote()
 
+
 def BF16CastElimination():
     """Eliminate verbose casting between fp32 and bf16
     Checks if the AST has the pattern:
@@ -266,6 +271,7 @@ def BF16CastElimination():
     """
     return _ffi_api.BF16CastElimination()
 
+
 def BF16TypeLowering():
     """Replace all bf16 type with uint16. Also lower the casting
     between fp32 and bf16
@@ -277,6 +283,7 @@ def BF16TypeLowering():
     """
     return _ffi_api.BF16TypeLowering()
 
+
 def RewriteUnsafeSelect():
     """Detect and rewrite unsafe select that contains memory access.
 
@@ -374,7 +381,7 @@ def SkipAssert():
 
 
 def ThreadSync(storage_scope):
-    """ Insert sync between parallel read/write of shared buffers.
+    """Insert sync between parallel read/write of shared buffers.
 
     Parameters
     ----------
@@ -401,7 +408,7 @@ def LowerThreadAllreduce():
 
 
 def InferFragment():
-    """ Infer the TensorCore fragment infomation using tensor intrinsics.
+    """Infer the TensorCore fragment infomation using tensor intrinsics.
 
     Returns
     -------
@@ -500,7 +507,8 @@ def VerifyMemory():
     """
     return _ffi_api.VerifyMemory()
 
-#pylint: disable=no-else-return,inconsistent-return-statements
+
+# pylint: disable=no-else-return,inconsistent-return-statements
 def HoistIfThenElse(variant=None):
     """Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
 
index c17b6fd..2d44545 100644 (file)
@@ -54,8 +54,10 @@ from . import vision
 from . import image
 from . import sparse
 from . import hls
+
 # error reporting
 from .util import InvalidShapeError
+
 # not import testing by default
 # because testing can have extra deps that are not necessary
 # we can import them from test cases explicitly
index d4bac62..75c19af 100644 (file)
@@ -18,6 +18,7 @@
 """Argwhere operator"""
 from tvm.te import hybrid
 
+
 @hybrid.script
 def hybrid_argwhere_1d(output_shape, condition):
     """Find the indices of elements of a 1-D tensor that are non-zero.
@@ -41,6 +42,7 @@ def hybrid_argwhere_1d(output_shape, condition):
             valid_index += 1
     return a
 
+
 @hybrid.script
 def hybrid_argwhere_2d(output_shape, condition):
     """Find the indices of elements of a 2-D tensor that are non-zero.
@@ -67,6 +69,7 @@ def hybrid_argwhere_2d(output_shape, condition):
                 valid_index += 1
     return a
 
+
 @hybrid.script
 def hybrid_argwhere_3d(output_shape, condition):
     """Find the indices of elements of a 3-D tensor that are non-zero.
@@ -96,6 +99,7 @@ def hybrid_argwhere_3d(output_shape, condition):
                     valid_index += 1
     return a
 
+
 @hybrid.script
 def hybrid_argwhere_4d(output_shape, condition):
     """Find the indices of elements of a 4-D tensor that are non-zero.
@@ -128,6 +132,7 @@ def hybrid_argwhere_4d(output_shape, condition):
                         valid_index += 1
     return a
 
+
 @hybrid.script
 def hybrid_argwhere_5d(output_shape, condition):
     """Find the indices of elements of a 5-D tensor that are non-zero.
@@ -163,6 +168,7 @@ def hybrid_argwhere_5d(output_shape, condition):
                             valid_index += 1
     return a
 
+
 def argwhere(output_shape, condition):
     """Find the indices of elements of a tensor that are non-zero.
 
index e76b374..fb22930 100644 (file)
@@ -28,23 +28,38 @@ from ..nn.bitserial_util import bitpack, binary_op_multiplier
 from ..nn.util import get_pad_tuple
 from ..util import get_const_int, get_const_tuple
 
+
 def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True):
     if use_bitpack:
-        kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8')
+        kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type="uint8")
     else:
         kernel_q = kernel
     KH, KW, KB, CI, CO = kernel_q.shape
-    kvshape = (CO//VC, KH, KW, KB, VC, CI)
-    return te.compute(kvshape, lambda co, dh, dw, b, vc, ci: \
-                      kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec')
+    kvshape = (CO // VC, KH, KW, KB, VC, CI)
+    return te.compute(
+        kvshape,
+        lambda co, dh, dw, b, vc, ci: kernel_q[dh][dw][b][ci][co * VC + vc],
+        name="kernel_vec",
+    )
+
 
 @autotvm.register_topi_compute("bitserial_conv2d_nhwc.arm_cpu")
-def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits,
-                          pack_dtype, out_dtype, unipolar):
+def bitserial_conv2d_nhwc(
+    cfg,
+    data,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    pack_dtype,
+    out_dtype,
+    unipolar,
+):
     """ Compute convolution with pack on spatial axes. """
     assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
-    assert pack_dtype == 'uint8', "only support packing into uint8 bits"
-    assert out_dtype == 'int16', "only support output type of int16"
+    assert pack_dtype == "uint8", "only support packing into uint8 bits"
+    assert out_dtype == "int16", "only support output type of int16"
 
     N, H, W, CI = get_const_tuple(data.shape)
     if len(kernel.shape) == 4:
@@ -62,7 +77,7 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, w
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    HCAT, WCAT = KH-1, KW-1
+    HCAT, WCAT = KH - 1, KW - 1
 
     PAD_H = H + (TPAD + DPAD)
     PAD_W = W + (LPAD + RPAD)
@@ -85,19 +100,21 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, w
     ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
     ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)
 
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
-                              filter=lambda x: x.size[-1] == 8)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
-                              filter=lambda x: x.size[-1] >= 2)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
-                              filter=lambda x: x.size[-1] >= 2)
-    ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
-                                  filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
-    re_axes = cfg.define_reorder("reorder_0",
-                                 [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
-                                 policy='candidate', candidate=[
-                                     [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
-                                     [n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
+    co, vc = cfg.define_split("tile_co", co, num_outputs=2, filter=lambda x: x.size[-1] == 8)
+    oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2, filter=lambda x: x.size[-1] >= 2)
+    ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda x: x.size[-1] >= 2)
+    ci_o, ci_i = cfg.define_split(
+        "tile_ci", ci, num_outputs=2, filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16
+    )
+    re_axes = cfg.define_reorder(
+        "reorder_0",
+        [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
+        policy="candidate",
+        candidate=[
+            [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
+            [n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],
+        ],
+    )
     # binary ops
     cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
     # ====================
@@ -106,7 +123,7 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, w
     VH = cfg["tile_oh"].size[-1]
     VW = cfg["tile_ow"].size[-1]
 
-    data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
+    data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type="uint8")
 
     kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
     idxm = tvm.tir.indexmod
@@ -116,105 +133,151 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, w
     N, H, W, IB, CI = data_q.shape
     OCO, KH, KW, KB, VC, CI = kernel_vec.shape
 
-    dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI)
+    dvshape = (
+        N,
+        PAD_H // (VH * HSTR),
+        PAD_W // (VW * WSTR),
+        VH * HSTR + HCAT,
+        VW * WSTR + WCAT,
+        IB,
+        CI,
+    )
     ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC)
 
-    if (TPAD != 0 and RPAD != 0):
+    if TPAD != 0 and RPAD != 0:
         data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, CI_PAD), name="data_pad")
     elif CI_PAD != 0:
         data_pad = pad(data_q, (0, 0, 0, 0, 0), (0, 0, 0, 0, CI_PAD), name="data_pad")
     else:
         data_pad = data_q
 
-    data_vec = te.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \
-                          data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec')
-    ci = te.reduce_axis((0, CI), name='ci')
-    dh = te.reduce_axis((0, KH), name='dh')
-    dw = te.reduce_axis((0, KW), name='dw')
-    ib = te.reduce_axis((0, IB), name='ib')
-    kb = te.reduce_axis((0, KB), name='kb')
+    data_vec = te.compute(
+        dvshape,
+        lambda n, h, w, vh, vw, b, ci: data_pad[n][h * VH * HSTR + vh][w * VW * WSTR + vw][b][ci],
+        name="data_vec",
+    )
+    ci = te.reduce_axis((0, CI), name="ci")
+    dh = te.reduce_axis((0, KH), name="dh")
+    dw = te.reduce_axis((0, KW), name="dw")
+    ib = te.reduce_axis((0, IB), name="ib")
+    kb = te.reduce_axis((0, KB), name="kb")
 
     def _bipolar_conv(n, h, w, co, vh, vw, vc):
-        return te.sum((tvm.tir.popcount(
-            kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') &
-            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16'))
-                       << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
+        return te.sum(
+            (
+                tvm.tir.popcount(
+                    kernel_vec[co, dh, dw, kb, vc, ci].astype("uint16")
+                    & data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ib, ci].astype("uint16")
+                )
+                << (kb + ib).astype("uint16")
+            ),
+            axis=[dh, dw, kb, ib, ci],
+        )
+
     def _unipolar_conv(n, h, w, co, vh, vw, vc):
         return te.sum(
-            ((tvm.tir.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
-                               data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) -
-              tvm.tir.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
-                               data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16'))
-             << (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci])
+            (
+                (
+                    tvm.tir.popcount(
+                        kernel_vec[co, dh, dw, kb, vc, ci].astype("int16")
+                        & data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ib, ci].astype("int16")
+                    )
+                    - tvm.tir.popcount(
+                        ~kernel_vec[co, dh, dw, kb, vc, ci].astype("int16")
+                        & data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ib, ci]
+                    ).astype("int16")
+                )
+                << (kb + ib).astype("int16")
+            ),
+            axis=[dh, dw, kb, ib, ci],
+        )
+
     if unipolar:
-        conv_vec = te.compute(ovshape, _unipolar_conv, name='conv_vec', tag='unipolar')
+        conv_vec = te.compute(ovshape, _unipolar_conv, name="conv_vec", tag="unipolar")
     else:
-        conv_vec = te.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
-
+        conv_vec = te.compute(ovshape, _bipolar_conv, name="conv_vec", tag="bipolar")
 
-    conv = te.compute(oshape,
-                      lambda n, h, w, co:
-                      conv_vec[n,
-                               idxd(h, VH), idxd(w, VW), idxd(co, VC),
-                               idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
-                      name='conv', tag='spatial_bitserial_conv_nhwc')
+    conv = te.compute(
+        oshape,
+        lambda n, h, w, co: conv_vec[
+            n, idxd(h, VH), idxd(w, VW), idxd(co, VC), idxm(h, VH), idxm(w, VW), idxm(co, VC)
+        ].astype(out_dtype),
+        name="conv",
+        tag="spatial_bitserial_conv_nhwc",
+    )
 
     return conv
 
+
 def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
-    pack_dtype = 'uint8'
-    w = te.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w')
-    x = te.placeholder((x_b, k_i,), dtype=pack_dtype, name='x')
-    k = te.reduce_axis((0, k_i), name='k')
-    bw = te.reduce_axis((0, w_b), name='bw')
-    bx = te.reduce_axis((0, x_b), name='bx')
+    pack_dtype = "uint8"
+    w = te.placeholder((w_b, m, k_i), dtype=pack_dtype, name="w")
+    x = te.placeholder(
+        (
+            x_b,
+            k_i,
+        ),
+        dtype=pack_dtype,
+        name="x",
+    )
+    k = te.reduce_axis((0, k_i), name="k")
+    bw = te.reduce_axis((0, w_b), name="bw")
+    bx = te.reduce_axis((0, x_b), name="bx")
     if unipolar:
-        dtype = 'int16'
+        dtype = "int16"
         z = te.compute(
-            (m,), lambda i:
-            te.sum((tvm.tir.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)) -
-                    tvm.tir.popcount(~w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)))
-                   << (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
+            (m,),
+            lambda i: te.sum(
+                (
+                    tvm.tir.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
+                    - tvm.tir.popcount(~w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
+                )
+                << (bw + bx).astype(dtype),
+                axis=[bw, bx, k],
+            ),
+            name="z",
+        )
     else:
-        dtype = 'uint16'
-        z = te.compute((m,), lambda i:
-                       te.sum(tvm.tir.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
-                              << (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
-    Wb = tvm.tir.decl_buffer(w.shape, w.dtype,
-                             name="W",
-                             offset_factor=k_i,
-                             strides=[te.var('ldw'), te.var('ldw'), 1]) # stride can be inferred
-    Xb = tvm.tir.decl_buffer(x.shape, x.dtype,
-                             name="X",
-                             offset_factor=k_i,
-                             strides=[te.var('ldw'), 1])
-    Zb = tvm.tir.decl_buffer(z.shape, z.dtype,
-                             name="Z",
-                             offset_factor=1,
-                             strides=[1])
+        dtype = "uint16"
+        z = te.compute(
+            (m,),
+            lambda i: te.sum(
+                tvm.tir.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
+                << (bw + bx).astype(dtype),
+                axis=[bw, bx, k],
+            ),
+            name="z",
+        )
+    Wb = tvm.tir.decl_buffer(
+        w.shape, w.dtype, name="W", offset_factor=k_i, strides=[te.var("ldw"), te.var("ldw"), 1]
+    )  # stride can be inferred
+    Xb = tvm.tir.decl_buffer(
+        x.shape, x.dtype, name="X", offset_factor=k_i, strides=[te.var("ldw"), 1]
+    )
+    Zb = tvm.tir.decl_buffer(z.shape, z.dtype, name="Z", offset_factor=1, strides=[1])
 
     def _intrin_func(ins, outs):
         ww, xx = ins
         zz = outs[0]
 
-        args_2 = tvm.tir.const(2, 'uint32')
+        args_2 = tvm.tir.const(2, "uint32")
 
         if unipolar:
             vpadd = "llvm.arm.neon.vpadd.v8i8"
             vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16"
-            full_dtype = 'int8x16'
-            half_dtype = 'int8x8'
-            return_dtype = 'int16x8'
+            full_dtype = "int8x16"
+            half_dtype = "int8x8"
+            return_dtype = "int16x8"
         else:
             vpadd = "llvm.arm.neon.vpadd.v8u8"
             vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
-            full_dtype = 'uint8x16'
-            half_dtype = 'uint8x8'
-            return_dtype = 'uint16x8'
+            full_dtype = "uint8x16"
+            half_dtype = "uint8x8"
+            return_dtype = "uint16x8"
 
         def _instr(index):
             irb = tvm.tir.ir_builder.create()
-            if index == 1: # reduce reset
+            if index == 1:  # reduce reset
                 irb.emit(zz.vstore(0, tvm.tir.const(0, return_dtype)))
                 return irb.get()
             # body and reduce update
@@ -225,60 +288,69 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
                 for bx in range(x_b):
                     if k_i == 16:
                         for i in range(m):
-                            w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
-                            x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
+                            w_ = ww.vload([bw, i, 0], "uint8x16").astype(full_dtype)
+                            x_ = xx.vload([bx, 0], "uint8x16").astype(full_dtype)
                             if unipolar:
                                 cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_)
                             else:
                                 cnts = tvm.tir.popcount(w_ & x_)
-                            upper_half = tvm.tir.call_intrin(
-                                half_dtype, 'tir.vectorhigh', cnts)
-                            lower_half = tvm.tir.call_intrin(
-                                half_dtype, 'tir.vectorlow', cnts)
+                            upper_half = tvm.tir.call_intrin(half_dtype, "tir.vectorhigh", cnts)
+                            lower_half = tvm.tir.call_intrin(half_dtype, "tir.vectorlow", cnts)
                             cnts8[i] = upper_half + lower_half
-                        for i in range(m//2):
+                        for i in range(m // 2):
                             cnts4[i] = tvm.tir.call_llvm_pure_intrin(
-                                half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
-                        for i in range(m//4):
+                                half_dtype, vpadd, args_2, cnts8[i * 2], cnts8[i * 2 + 1]
+                            )
+                        for i in range(m // 4):
                             cnts2[i] = tvm.tir.call_llvm_pure_intrin(
-                                half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+                                half_dtype, vpadd, args_2, cnts4[i * 2], cnts4[i * 2 + 1]
+                            )
                         cnts = tvm.tir.call_intrin(
-                            full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
-                        shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
+                            full_dtype, "tir.vectorcombine", cnts2[0], cnts2[1]
+                        )
+                        shifted_cnts = cnts << tvm.tir.const(bw + bx, pack_dtype)
                         out = tvm.tir.call_llvm_pure_intrin(
-                            return_dtype, vpadalu,
-                            args_2, zz.vload(0, return_dtype), shifted_cnts)
-                    else: # ki == 8
+                            return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts
+                        )
+                    else:  # ki == 8
                         for i in range(m):
-                            w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
-                            x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
+                            w_ = ww.vload([bw, i, 0], "uint8x8").astype(half_dtype)
+                            x_ = xx.vload([bx, 0], "uint8x8").astype(half_dtype)
                             if unipolar:
                                 cnts8[i] = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_)
                             else:
                                 cnts8[i] = tvm.tir.popcount(w_ & x_)
-                        for i in range(m//2):
+                        for i in range(m // 2):
                             cnts4[i] = tvm.tir.call_llvm_pure_intrin(
-                                half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
-                        for i in range(m//4):
+                                half_dtype, vpadd, args_2, cnts8[i * 2], cnts8[i * 2 + 1]
+                            )
+                        for i in range(m // 4):
                             cnts2[i] = tvm.tir.call_llvm_pure_intrin(
-                                half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+                                half_dtype, vpadd, args_2, cnts4[i * 2], cnts4[i * 2 + 1]
+                            )
                         cnts = tvm.tir.call_intrin(
-                            full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
-                        shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
+                            full_dtype, "tir.vectorcombine", cnts2[0], cnts2[1]
+                        )
+                        shifted_cnts = cnts << tvm.tir.const(bw + bx, pack_dtype)
                         out = tvm.tir.call_llvm_pure_intrin(
-                            return_dtype, vpadalu,
-                            args_2, zz.vload(0, return_dtype), shifted_cnts)
+                            return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts
+                        )
                     irb.emit(zz.vstore(0, out))
             return irb.get()
+
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
+
     buffer_params = {"offset_factor": 1}
     return te.decl_tensor_intrin(
-        z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb}, default_buffer_params=buffer_params)
+        z.op, _intrin_func, binds={w: Wb, x: Xb, z: Zb}, default_buffer_params=buffer_params
+    )
+
 
 # ARM specific schedule that using custom microkernel
-def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
-                                  conv_out, output, last, unipolar):
+def _schedule_spatial_conv2d_nhwc(
+    cfg, s, data_pad, data_vec, kernel_vec, conv_out, output, last, unipolar
+):
     _, _, _, _, _, IB, CI = data_vec.shape
     _, KH, KW, KB, _, _ = kernel_vec.shape
     KB = get_const_int(KB)
@@ -307,20 +379,21 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
     n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
     kh, kw, kb, ib, ci = s[conv_out].op.reduce_axis
 
-    ci_o, ci_i = cfg['tile_ci'].apply(s, conv_out, ci)
-    re_axes = cfg["reorder_0"].apply(s, conv_out,
-                                     [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i])
+    ci_o, ci_i = cfg["tile_ci"].apply(s, conv_out, ci)
+    re_axes = cfg["reorder_0"].apply(
+        s, conv_out, [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i]
+    )
 
     # Use microkernel
-    kfactor = cfg['tile_ci'].size[1]
+    kfactor = cfg["tile_ci"].size[1]
     if kfactor % 8 == 0:
         pc = _intrin_popcount(VC, kfactor, KB, IB, unipolar)
         s[conv_out].tensorize(kb, pc)
 
     n, h, w, co = s[last].op.axis
-    co, vc = cfg['tile_co'].apply(s, last, co)
-    oh, vh = cfg['tile_oh'].apply(s, last, h)
-    ow, vw = cfg['tile_ow'].apply(s, last, w)
+    co, vc = cfg["tile_co"].apply(s, last, co)
+    oh, vh = cfg["tile_oh"].apply(s, last, h)
+    ow, vw = cfg["tile_ow"].apply(s, last, w)
     s[last].reorder(n, oh, ow, co, vh, vw, vc)
     s[last].vectorize(vc)
     if last != output:
@@ -330,6 +403,7 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
     s[last].parallel(oh)
     return s
 
+
 @autotvm.register_topi_schedule("bitserial_conv2d_nhwc.arm_cpu")
 def schedule_bitserial_conv2d_nhwc(cfg, outs):
     """Arm cpu schedule for bitserial conv2d"""
@@ -346,7 +420,7 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
 
-        if 'spatial_bitserial_conv_nhwc' in op.tag:
+        if "spatial_bitserial_conv_nhwc" in op.tag:
             output = op.output(0)
             conv_out = op.input_tensors[0]
             kernel_vec = conv_out.op.input_tensors[0]
@@ -360,13 +434,15 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
                 data_q = data
                 data = data.op.input_tensors[0]
             unipolar = "unipolar" in conv_out.op.tag
-            _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
-                                          conv_out, output, outs[0], unipolar)
+            _schedule_spatial_conv2d_nhwc(
+                cfg, s, data_pad, data_vec, kernel_vec, conv_out, output, outs[0], unipolar
+            )
         scheduled_ops.append(op)
 
     traverse(outs[0].op)
     return s
 
+
 @bitserial_conv2d_legalize.register("arm_cpu")
 def _bitserial_conv2d_legalize(attrs, inputs, arg_types):
     """Legalizes Bitserial Conv2D op.
@@ -387,18 +463,18 @@ def _bitserial_conv2d_legalize(attrs, inputs, arg_types):
     """
 
     # Fix different kernel layouts where possible.
-    if attrs['data_layout'] == 'NHWC':
+    if attrs["data_layout"] == "NHWC":
         data, kernel = inputs
         if len(kernel.data.shape) == 4:
             # HWIO layout is expected for NHWC input.
-            if attrs['kernel_layout'] == 'HWOI':
+            if attrs["kernel_layout"] == "HWOI":
                 # Handle HWOI layout. This is common in TF depthwise conv2d graph.
                 kernel = relay.transpose(kernel, axes=(0, 1, 3, 2))
-            elif attrs['kernel_layout'] == 'OIHW':
+            elif attrs["kernel_layout"] == "OIHW":
                 kernel = relay.transpose(kernel, axes=(2, 3, 1, 0))
             ## Set new attrs for the tranposed conv.
             new_attrs = {k: attrs[k] for k in attrs.keys()}
-            new_attrs['kernel_layout'] = 'HWIO'
+            new_attrs["kernel_layout"] = "HWIO"
 
             conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs)
             return conv
index c7aa567..61778b7 100644 (file)
@@ -26,9 +26,9 @@ from .bitserial_conv2d import _intrin_popcount
 from ..nn.pad import pad
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 
-@autotvm.register_topi_compute('bitserial_dense.arm_cpu')
-def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
-                    unipolar):
+
+@autotvm.register_topi_compute("bitserial_dense.arm_cpu")
+def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype, unipolar):
     """The default implementation of bitserial dense in topi.
 
     Parameters
@@ -57,7 +57,7 @@ def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_d
     # out_dim and in_dim need to be multiples of 8
     if out_dim % 8 != 0:
         out_dim_pad = out_dim % 8
-        data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name='PaddedInput')
+        data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name="PaddedInput")
         out_dim += out_dim_pad
 
     ######## Search space
@@ -65,43 +65,71 @@ def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_d
     x, y = cfg.axis(batch), cfg.axis(out_dim)
     db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)
 
-    ko, ki = cfg.define_split('tile_k', k, num_outputs=2,
-                              filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
-    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
-    yo, yi = cfg.define_split('tile_y', y, num_outputs=2,
-                              filter=lambda xx: xx.size[-1] == 8)
-
-    cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
-                       policy='candidate', candidate=[
-                           [yo, xo, ko, xi, wb, db, yi, ki],
-                           [yo, xo, xi, ko, wb, db, yi, ki],
-                           [yo, xo, ko, xi, wb, db, yi, ki]])
+    ko, ki = cfg.define_split(
+        "tile_k", k, num_outputs=2, filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16
+    )
+    xo, xi = cfg.define_split("tile_x", x, num_outputs=2)
+    yo, yi = cfg.define_split("tile_y", y, num_outputs=2, filter=lambda xx: xx.size[-1] == 8)
+
+    cfg.define_reorder(
+        "reorder_0",
+        [yo, xo, ko, xi, wb, db, yi, ki],
+        policy="candidate",
+        candidate=[
+            [yo, xo, ko, xi, wb, db, yi, ki],
+            [yo, xo, xi, ko, wb, db, yi, ki],
+            [yo, xo, ko, xi, wb, db, yi, ki],
+        ],
+    )
 
     ###### Compute rule
-    VY = cfg['tile_y'].size[-1]
-    VK = cfg['tile_k'].size[-1]
+    VY = cfg["tile_y"].size[-1]
+    VK = cfg["tile_k"].size[-1]
 
-    wvshape = (out_dim//VY, in_dim//VK, WB, VY, VK)
+    wvshape = (out_dim // VY, in_dim // VK, WB, VY, VK)
     oshape = (batch, out_dim)
 
-    k = te.reduce_axis((0, in_dim), name='k')
-    db = te.reduce_axis((0, DB), name='db')
-    wb = te.reduce_axis((0, WB), name='wb')
+    k = te.reduce_axis((0, in_dim), name="k")
+    db = te.reduce_axis((0, DB), name="db")
+    wb = te.reduce_axis((0, WB), name="wb")
 
     # Tile data and weights
-    weight_vec = te.compute(wvshape, lambda yo, ko, wb, vy, vk:
-                            weight_packed[yo*VY+vy][wb][ko*VK+vk], name='weight_vec')
-    matmul_unipolar = te.compute(oshape, lambda x, y: te.sum(
-        (tvm.tir.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
-                          data_packed[x, db, k].astype(out_dtype)) -
-         tvm.tir.popcount(~weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
-                          data_packed[x, db, k].astype(out_dtype)))
-        << (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
-
-    matmul = te.compute(oshape, lambda x, y: te.sum(
-        tvm.tir.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
-                         data_packed[x, db, k].astype(out_dtype))
-        << (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
+    weight_vec = te.compute(
+        wvshape,
+        lambda yo, ko, wb, vy, vk: weight_packed[yo * VY + vy][wb][ko * VK + vk],
+        name="weight_vec",
+    )
+    matmul_unipolar = te.compute(
+        oshape,
+        lambda x, y: te.sum(
+            (
+                tvm.tir.popcount(
+                    weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(out_dtype)
+                    & data_packed[x, db, k].astype(out_dtype)
+                )
+                - tvm.tir.popcount(
+                    ~weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(out_dtype)
+                    & data_packed[x, db, k].astype(out_dtype)
+                )
+            )
+            << (wb + db).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense_unipolar",
+    )
+
+    matmul = te.compute(
+        oshape,
+        lambda x, y: te.sum(
+            tvm.tir.popcount(
+                weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(out_dtype)
+                & data_packed[x, db, k].astype(out_dtype)
+            )
+            << (wb + db).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense",
+    )
 
     cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype))
 
@@ -110,7 +138,7 @@ def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_d
     return matmul
 
 
-@autotvm.register_topi_schedule('bitserial_dense.arm_cpu')
+@autotvm.register_topi_schedule("bitserial_dense.arm_cpu")
 def schedule_bitserial_dense(cfg, outs):
     """Schedule for binary_dense.
 
@@ -148,8 +176,8 @@ def schedule_bitserial_dense(cfg, outs):
         fused = s[output].fuse(xo, yo)
         s[output].parallel(fused)
 
-        nfactor = cfg['tile_y'].size[-1]
-        kfactor = cfg['tile_k'].size[-1]
+        nfactor = cfg["tile_y"].size[-1]
+        kfactor = cfg["tile_k"].size[-1]
         if nfactor % 8 == 0:
             pc = _intrin_popcount(nfactor, kfactor, WB, DB, unipolar)
             s[output].tensorize(wb, pc)
@@ -159,14 +187,14 @@ def schedule_bitserial_dense(cfg, outs):
     def traverse(op):
         """Internal traverse function"""
         # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
+        if tag.is_broadcast(op.tag) or "elemwise" in op.tag:
             if op not in s.outputs:
                 s[op].compute_inline()
             for tensor in op.input_tensors:
                 if isinstance(tensor.op, tvm.te.ComputeOp):
                     traverse(tensor.op)
 
-        elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
+        elif op.tag == "bitserial_dense" or "bitserial_dense_unipolar":
             output = op.output(0)
             weight_vec = op.input_tensors[0]
 
@@ -174,7 +202,7 @@ def schedule_bitserial_dense(cfg, outs):
             data = data_vec.op.input_tensors[0]
             if "QuantizeInput" in data.op.name:
                 data = data.op.input_tensors[0]
-            unipolar = (output.op.tag == 'bitserial_dense_unipolar')
+            unipolar = output.op.tag == "bitserial_dense_unipolar"
             _schedule(cfg, s, data_vec, weight_vec, output, unipolar)
         else:
             raise RuntimeError("Unsupported operator: %s" % op.tag)
index 4faee42..1eb8c8a 100644 (file)
@@ -27,18 +27,21 @@ from ..util import traverse_inline, get_const_tuple
 from .. import nn
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
-from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
-    conv2d_spatial_pack_nhwc, \
-    schedule_conv2d_spatial_pack_nchw, \
-    schedule_conv2d_spatial_pack_nhwc
+from .conv2d_spatial_pack import (
+    conv2d_spatial_pack_nchw,
+    conv2d_spatial_pack_nhwc,
+    schedule_conv2d_spatial_pack_nchw,
+    schedule_conv2d_spatial_pack_nhwc,
+)
 from .cortex_m7.conv2d import direct_simd
 
 
 @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
 def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with NCHW layout"""
-    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                    dilation, out_dtype, num_tile=2)
+    return conv2d_spatial_pack_nchw(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.arm_cpu")
@@ -48,7 +51,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
 
     def _callback(op):
         # schedule conv2d
-        if 'spatial_conv2d_output' in op.tag:
+        if "spatial_conv2d_output" in op.tag:
             output = op.output(0)
             conv = op.input_tensors[0]
 
@@ -57,15 +60,14 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
             s[data_pad].compute_inline()
 
             kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
+            if kernel_vec.op.name == "kernel_vec":
                 kernel = kernel_vec.op.input_tensors[0]
             else:
                 kernel = kernel_vec
             if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
-            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
-                                              conv, output, outs[0])
+            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
 
     traverse_inline(s, outs[0].op, _callback)
     return s
@@ -74,8 +76,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
 @autotvm.register_topi_compute("conv2d_nhwc_spatial_pack.arm_cpu")
 def conv2d_nhwc_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with NHWC layout"""
-    return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
-                                    dilation, out_dtype)
+    return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype)
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_spatial_pack.arm_cpu")
@@ -84,7 +85,7 @@ def schedule_conv2d_nhwc_spatial_pack(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'spatial_conv_output_NHWC' in op.tag:
+        if "spatial_conv_output_NHWC" in op.tag:
             schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])
 
     traverse_inline(s, outs[0].op, _callback)
@@ -95,8 +96,7 @@ def schedule_conv2d_nhwc_spatial_pack(cfg, outs):
 def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d_nchw layout using Winograd with weight transform"""
     tile_size = 4
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation,
-                          out_dtype, tile_size)
+    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size)
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd.arm_cpu")
@@ -105,7 +105,7 @@ def schedule_conv2d_nchw_winograd(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
+        if "winograd_conv2d_output" in op.tag:
             output = op.output(0)
             _schedule_winograd(cfg, s, output, outs[0])
 
@@ -151,24 +151,28 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
 
     H = (IH + pt + pb - 3) // HSTR + 1
     W = (IW + pl + pr - 3) // WSTR + 1
-    nH, nW = (H + m-1) // m, (W + m-1) // m
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
 
-    cfg.define_split('tile_p', cfg.axis(P), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
-    cfg.define_split('tile_k', cfg.axis(K), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
-    VP = cfg['tile_p'].size[-1]
-    VK = cfg['tile_k'].size[-1]
+    cfg.define_split("tile_p", cfg.axis(P), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
+    cfg.define_split("tile_k", cfg.axis(K), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
+    VP = cfg["tile_p"].size[-1]
+    VK = cfg["tile_k"].size[-1]
 
     # pack input tile
-    input_tile = te.compute((C, idxd(P, VP), alpha, alpha, VP),
-                            lambda c, b, eps, nu, bb:
-                            data_pad[idxd(b*VP + bb, nH*nW), c,
-                                     idxm(idxd(b*VP + bb, nW), nH) * m + eps,
-                                     idxm(b*VP + bb, nW) * m + nu],
-                            name='d')
+    input_tile = te.compute(
+        (C, idxd(P, VP), alpha, alpha, VP),
+        lambda c, b, eps, nu, bb: data_pad[
+            idxd(b * VP + bb, nH * nW),
+            c,
+            idxm(idxd(b * VP + bb, nW), nH) * m + eps,
+            idxm(b * VP + bb, nW) * m + nu,
+        ],
+        name="d",
+    )
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
-        VC = cfg['tile_k'].size[-1]
+        VC = cfg["tile_k"].size[-1]
         kvshape = (KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), CI, VC)
         U = tvm.te.placeholder(kvshape, kernel.dtype, name="U")
     else:
@@ -176,37 +180,60 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
         if pre_computed:
             U = kernel
         else:
-            r_kh = te.reduce_axis((0, KH), 'r_kh')
-            r_kw = te.reduce_axis((0, KW), 'r_kw')
-            U = te.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk:
-                           te.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
-                                  G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
+            r_kh = te.reduce_axis((0, KH), "r_kh")
+            r_kw = te.reduce_axis((0, KW), "r_kw")
+            U = te.compute(
+                (alpha, alpha, idxd(K, VK), C, VK),
+                lambda eps, nu, k, c, kk: te.sum(
+                    kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype)
+                    * G[eps][r_kh]
+                    * G[nu][r_kw],
+                    axis=[r_kh, r_kw],
+                ),
+                name="U",
+            )
 
     # transform image
-    r_eps = te.reduce_axis((0, alpha), 'r_eps')
-    r_nu = te.reduce_axis((0, alpha), 'r_nu')
-    V = te.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb:
-                   te.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
-                          B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')
+    r_eps = te.reduce_axis((0, alpha), "r_eps")
+    r_nu = te.reduce_axis((0, alpha), "r_nu")
+    V = te.compute(
+        (alpha, alpha, idxd(P, VP), C, VP),
+        lambda eps, nu, b, c, bb: te.sum(
+            input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) * B[r_eps][eps] * B[r_nu][nu],
+            axis=[r_eps, r_nu],
+        ),
+        name="V",
+    )
 
     # batch gemm
-    c = te.reduce_axis((0, C), name='c')
-    M = te.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
-                   te.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] *
-                          V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M')
+    c = te.reduce_axis((0, C), name="c")
+    M = te.compute(
+        (alpha, alpha, K, P),
+        lambda eps, nu, k, b: te.sum(
+            U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] * V[eps][nu][idxd(b, VP)][c][idxm(b, VP)],
+            axis=c,
+        ),
+        name="M",
+    )
 
     # inverse transform
-    r_eps = te.reduce_axis((0, alpha), 'r_eps')
-    r_nu = te.reduce_axis((0, alpha), 'r_nu')
-    Y = te.compute((K, P, m, m), lambda k, b, vh, vw:
-                   te.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
-                          axis=[r_eps, r_nu]), name='Y')
+    r_eps = te.reduce_axis((0, alpha), "r_eps")
+    r_nu = te.reduce_axis((0, alpha), "r_nu")
+    Y = te.compute(
+        (K, P, m, m),
+        lambda k, b, vh, vw: te.sum(
+            M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]
+        ),
+        name="Y",
+    )
 
     # unpack output
-    output = te.compute((N, K, H, W), lambda n, k, h, w:
-                        Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m),
-                             idxm(h, m), idxm(w, m)],
-                        name='output', tag='winograd_conv2d_output')
+    output = te.compute(
+        (N, K, H, W),
+        lambda n, k, h, w: Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m), idxm(h, m), idxm(w, m)],
+        name="output",
+        tag="winograd_conv2d_output",
+    )
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * K * H * W * KH * KW * C)
@@ -230,11 +257,17 @@ def _schedule_winograd(cfg, s, output, last):
     if isinstance(U.op, tvm.te.ComputeOp):
         kernel, G = U.op.input_tensors
         s[G].compute_inline()
-        eps, nu, k, c, kk, = s[U].op.axis
+        (
+            eps,
+            nu,
+            k,
+            c,
+            kk,
+        ) = s[U].op.axis
         if autotvm.GLOBAL_SCOPE.in_tuning:
             # kernel transformation will be pre-computed during compilation, so we skip
             # this part to make tuning records correct
-            s[U].pragma(eps, 'debug_skip_region')
+            s[U].pragma(eps, "debug_skip_region")
         else:
             r_kh, r_kw = s[U].op.reduce_axis
             s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
@@ -247,7 +280,7 @@ def _schedule_winograd(cfg, s, output, last):
             s[kernel].compute_inline()
 
     # transform image
-    DD = s.cache_read(d, 'global', [V])
+    DD = s.cache_read(d, "global", [V])
     s[B].compute_inline()
     eps, nu, b, c, bb = s[V].op.axis
     r_eps, r_nu = s[V].op.reduce_axis
@@ -261,17 +294,14 @@ def _schedule_winograd(cfg, s, output, last):
     # batch gemm
     eps, nu, k, b = s[M].op.axis
     c = s[M].op.reduce_axis[0]
-    cfg.define_split('tile_c', c, num_outputs=2, filter=lambda x: x.size[-1] <= 16)
-    co, ci = cfg['tile_c'].apply(s, M, c)
-    xo, xi = cfg['tile_p'].apply(s, M, b)
+    cfg.define_split("tile_c", c, num_outputs=2, filter=lambda x: x.size[-1] <= 16)
+    co, ci = cfg["tile_c"].apply(s, M, c)
+    xo, xi = cfg["tile_p"].apply(s, M, b)
     s[M].reorder(eps, nu, xo, co, k, ci, xi)
-    cfg.define_annotate('ann_reduce', [ci], policy='try_unroll')
-    cfg.define_annotate('ann_spatial', [k, xi], policy='try_unroll_vec')
-    cfg['ann_reduce'].apply(s, M, [ci],
-                            axis_lens=[cfg['tile_c'].size[-1]],
-                            max_unroll=16,
-                            cfg=cfg)
-    cfg['ann_spatial'].apply(s, M, [k, xi])
+    cfg.define_annotate("ann_reduce", [ci], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [k, xi], policy="try_unroll_vec")
+    cfg["ann_reduce"].apply(s, M, [ci], axis_lens=[cfg["tile_c"].size[-1]], max_unroll=16, cfg=cfg)
+    cfg["ann_spatial"].apply(s, M, [k, xi])
 
     # inverse transform
     s[A].compute_inline()
@@ -282,12 +312,12 @@ def _schedule_winograd(cfg, s, output, last):
 
     # output
     n, co, h, w = s[last].op.axis
-    co, coi = cfg['tile_k'].apply(s, last, co)
+    co, coi = cfg["tile_k"].apply(s, last, co)
     p = s[last].fuse(n, co)
     s[M].compute_at(s[last], p)
     s[last].parallel(p)
 
-    MM = s.cache_read(M, 'global', [Y])
+    MM = s.cache_read(M, "global", [Y])
     m = get_const_int(V.shape[0]) + 1 - 3
     ho, wo, hi, wi = s[last].tile(h, w, m, m)
     s[Y].compute_at(s[last], wo)
@@ -303,15 +333,28 @@ def conv2d_nchw_winograd_nnpack(cfg, data, kernel, strides, padding, dilation, o
     dtype = data.dtype
     if dtype == "float32":
         return _conv2d_arm_cpu_winograd_nnpack(
-            cfg, data, kernel, strides, padding, dilation, out_dtype,
-            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8)
+            cfg,
+            data,
+            kernel,
+            strides,
+            padding,
+            dilation,
+            out_dtype,
+            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8,
+        )
     elif dtype == "float16":
         return _conv2d_arm_cpu_winograd_nnpack(
-            cfg, data, kernel, strides, padding, dilation, out_dtype,
-            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16)
+            cfg,
+            data,
+            kernel,
+            strides,
+            padding,
+            dilation,
+            out_dtype,
+            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
+        )
     else:
-        raise ValueError("Unsupported data type {} for conv2d winograd nnpack".
-                         format(dtype))
+        raise ValueError("Unsupported data type {} for conv2d winograd nnpack".format(dtype))
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd_nnpack.arm_cpu")
@@ -320,7 +363,7 @@ def schedule_conv2d_nchw_winograd_nnpack(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'winograd_nnpack_conv2d_output' in op.tag:
+        if "winograd_nnpack_conv2d_output" in op.tag:
             output = op.output(0)
             _schedule_winograd_nnpack(cfg, s, output, outs[0])
 
@@ -329,7 +372,8 @@ def schedule_conv2d_nchw_winograd_nnpack(cfg, outs):
 
 
 def _conv2d_arm_cpu_winograd_nnpack(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, convolution_algorithm):
+    cfg, data, kernel, strides, padding, dilation, out_dtype, convolution_algorithm
+):
     """ TOPI compute callback. Use winograd NNPACK template """
     N, CI, IH, IW = get_const_tuple(data.shape)
 
@@ -343,27 +387,38 @@ def _conv2d_arm_cpu_winograd_nnpack(
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
 
-    assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
+    assert (
+        KH == 3
+        and KW == 3
+        and pt == 1
+        and pb == 1
+        and pl == 1
+        and pr == 1
+        and HSTR == 1
         and WSTR == 1
+    )
     H = (IH + pt + pb - 3) // HSTR + 1
     W = (IW + pl + pr - 3) // WSTR + 1
 
-    cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm])
+    cfg.define_knob("winograd_nnpack_algorithm", [convolution_algorithm])
 
     assert N == 1
     with tvm.te.tag_scope("winograd_nnpack_conv2d_weight_transform"):
         transformed_kernel = tvm.contrib.nnpack.convolution_inference_weight_transform(
-            kernel, algorithm=cfg['winograd_nnpack_algorithm'].val)
+            kernel, algorithm=cfg["winograd_nnpack_algorithm"].val
+        )
         if autotvm.GLOBAL_SCOPE.in_tuning:
             transformed_kernel = te.compute(transformed_kernel.shape, lambda *args: 0.0)
 
     with tvm.te.tag_scope("winograd_nnpack_conv2d_output"):
         output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
-            data, transformed_kernel,
+            data,
+            transformed_kernel,
             bias=None,
             padding=[pt, pb, pl, pr],
             stride=[HSTR, WSTR],
-            algorithm=cfg['winograd_nnpack_algorithm'].val)
+            algorithm=cfg["winograd_nnpack_algorithm"].val,
+        )
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * CI * H * W * KH * KW * CO)
@@ -380,12 +435,13 @@ def _schedule_winograd_nnpack(cfg, s, output, last):
     if autotvm.GLOBAL_SCOPE.in_tuning and isinstance(TK.op, te.tensor.ComputeOp):
         # kernel transformation will be pre-computed during compilation, so we skip
         # this part to make tuning records correct
-        s[TK].pragma(s[TK].op.axis[0], 'debug_skip_region')
+        s[TK].pragma(s[TK].op.axis[0], "debug_skip_region")
 
 
 @autotvm.register_topi_compute("conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu")
 def conv2d_nchw_winograd_nnpack_without_weight_transform(
-        cfg, data, transformed_kernel, bias, strides, padding, dilation, out_dtype):
+    cfg, data, transformed_kernel, bias, strides, padding, dilation, out_dtype
+):
     """Compute conv2d_nchw using NNPack winograd without weight transform"""
     N, CI, IH, IW = get_const_tuple(data.shape)
     if isinstance(dilation, int):
@@ -399,8 +455,16 @@ def conv2d_nchw_winograd_nnpack_without_weight_transform(
     KH, KW = 3, 3
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
 
-    assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
+    assert (
+        KH == 3
+        and KW == 3
+        and pt == 1
+        and pb == 1
+        and pl == 1
+        and pr == 1
+        and HSTR == 1
         and WSTR == 1
+    )
     H = (IH + pt + pb - 3) // HSTR + 1
     W = (IW + pl + pr - 3) // WSTR + 1
 
@@ -412,7 +476,8 @@ def conv2d_nchw_winograd_nnpack_without_weight_transform(
             bias=bias,
             padding=[pt, pb, pl, pr],
             stride=[HSTR, WSTR],
-            algorithm=cfg['winograd_nnpack_algorithm'].val)
+            algorithm=cfg["winograd_nnpack_algorithm"].val,
+        )
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * CI * H * W * KH * KW * CO)
@@ -425,18 +490,20 @@ def schedule_conv2d_nchw_winograd_nnpack_without_weight_transform(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'winograd_nnpack_conv2d_output' in op.tag:
+        if "winograd_nnpack_conv2d_output" in op.tag:
             output = op.output(0)
             _schedule_winograd_nnpack(cfg, s, output, outs[0])
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 @autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu")
 def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with SIMD (v7e-m)."""
     return direct_simd.conv2d_direct_simd_compute(
-        cfg, data, kernel, strides, padding, dilation, out_dtype)
+        cfg, data, kernel, strides, padding, dilation, out_dtype
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu")
index f37ae57..7bf7e42 100644 (file)
@@ -28,7 +28,7 @@ from ..nn import conv2d_alter_layout
 from ..util import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
 
 
 @conv2d_alter_layout.register(["arm_cpu"])
@@ -37,7 +37,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     dispatch_ctx = autotvm.task.DispatchContext.current
 
     _, outs = relay.backend.compile_engine.select_implementation(
-        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+    )
     workload = autotvm.task.get_workload(outs)
     if workload is None:
         # The best implementation is not an AutoTVM template,
@@ -67,45 +68,51 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     idxd = tvm.tir.indexdiv
 
     # We don't perform layout alteration for NHWC layout with real data types
-    if data_layout == "NHWC" and data_dtype not in ['uint8', 'int8']:
+    if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
         return None
 
     if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
-        VC = cfg['tile_co'].size[-1]
+        VC = cfg["tile_co"].size[-1]
 
-        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+        new_attrs["kernel_layout"] = "OIHW%do" % VC
 
         new_data = data
         new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            "conv2d_nchw_spatial_pack.arm_cpu")
+            "conv2d_nchw_spatial_pack.arm_cpu",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.conv2d(*inputs, **new_attrs)
 
     if topi_tmpl == "conv2d_nhwc_spatial_pack.arm_cpu":
-        assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
-                data.dtype == 'uint8' and kernel.dtype == 'uint8')
+        assert (
+            data.dtype == "int8"
+            and kernel.dtype == "int8"
+            or data.dtype == "uint8"
+            and kernel.dtype == "uint8"
+        )
 
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
 
         data_expr, kernel_expr = inputs
 
-        data_int16 = relay.cast(data_expr, dtype='int16')
-        kernel_int16 = relay.cast(kernel_expr, dtype='int16')
+        data_int16 = relay.cast(data_expr, dtype="int16")
+        kernel_int16 = relay.cast(kernel_expr, dtype="int16")
 
-        new_attrs = {k : attrs[k] for k in attrs.keys()}
+        new_attrs = {k: attrs[k] for k in attrs.keys()}
 
-        new_data = te.placeholder(data.shape, 'int16')
-        new_kernel = te.placeholder(kernel.shape, 'int16')
+        new_data = te.placeholder(data.shape, "int16")
+        new_kernel = te.placeholder(kernel.shape, "int16")
 
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            'conv2d_nhwc_spatial_pack.arm_cpu')
+            "conv2d_nhwc_spatial_pack.arm_cpu",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.conv2d(data_int16, kernel_int16, **new_attrs)
@@ -114,74 +121,79 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
-        VC = cfg['tile_k'].size[-1]
+        VC = cfg["tile_k"].size[-1]
         tile_size = 4
 
         weight_expr = inputs[1]
         weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
-            weight_expr, tile_size=tile_size)
-        weight_expr = relay.reshape(weight_expr,
-                                    newshape=(KH + tile_size - 1,
-                                              KW + tile_size - 1,
-                                              CO // VC, VC, CI))
+            weight_expr, tile_size=tile_size
+        )
+        weight_expr = relay.reshape(
+            weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)
+        )
         weight_expr = relay.transpose(weight_expr, axes=[0, 1, 2, 4, 3])
 
-        new_attrs['tile_size'] = tile_size
-        new_attrs['channels'] = CO
+        new_attrs["tile_size"] = tile_size
+        new_attrs["channels"] = CO
 
         new_data = data
-        new_kernel = te.placeholder((KH + tile_size - 1,
-                                     KW + tile_size -1,
-                                     idxd(CO, VC), CI, VC),
-                                    kernel.dtype)
+        new_kernel = te.placeholder(
+            (KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), CI, VC), kernel.dtype
+        )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            'conv2d_nchw_winograd.arm_cpu')
+            "conv2d_nchw_winograd.arm_cpu",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], weight_expr, **new_attrs)
+            inputs[0], weight_expr, **new_attrs
+        )
 
     if topi_tmpl == "conv2d_nchw_winograd_nnpack.arm_cpu":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
-        new_attrs['channels'] = CO
+        new_attrs["channels"] = CO
 
         # pre-compute winograd_nnpack transform
         # for winograd_nnpack_fp16, the the precompute prune pass must run on device,
         # where float16 is supported
-        weight_dtype = 'float32'
+        weight_dtype = "float32"
         weight_expr = inputs[1]
         transformed_weight = relay.nn.contrib_conv2d_winograd_nnpack_weight_transform(
             weight_expr,
-            convolution_algorithm=cfg['winograd_nnpack_algorithm'].val,
-            out_dtype=weight_dtype)
+            convolution_algorithm=cfg["winograd_nnpack_algorithm"].val,
+            out_dtype=weight_dtype,
+        )
 
         new_data = data
         new_kernel = te.placeholder((CO, CI, 8, 8), "float32")
 
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, None, strides, padding, dilation, out_dtype],
-            "conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu")
+            "conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], transformed_weight, **new_attrs)
+            inputs[0], transformed_weight, **new_attrs
+        )
 
     if topi_tmpl == "depthwise_conv2d_nchw_spatial_pack.arm_cpu":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, M, KH, KW = get_const_tuple(kernel.shape)
-        VC = cfg['tile_co'].size[-1]
+        VC = cfg["tile_co"].size[-1]
 
-        new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
+        new_attrs["kernel_layout"] = "OIHW%do" % (cfg["tile_co"].size[-1])
 
         # Store the same config for the altered operator (workload)
         new_data = data
         new_kernel = te.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            "depthwise_conv2d_nchw_spatial_pack.arm_cpu")
+            "depthwise_conv2d_nchw_spatial_pack.arm_cpu",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.conv2d(*inputs, **new_attrs)
@@ -190,27 +202,41 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # Converting NCHW to NCHWc.
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         if cfg.is_fallback:
-            _get_x86_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                    out_dtype, False, data_layout)
+            _get_x86_default_config(
+                cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
+            )
         batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
         out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
         ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
 
         # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        new_attrs["channels"] = out_channel
+        new_attrs["data_layout"] = "NCHW%dc" % ic_bn
         # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-        new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+        new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
+        new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
         # Store altered operator's config
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
-                                     kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
+        new_data = te.placeholder(
+            (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+        )
+        new_kernel = te.placeholder(
+            (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn, oc_bn),
+            dtype=kernel_tensor.dtype,
+        )
         new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
-             new_attrs["out_layout"], out_dtype], topi_tmpl)
+            [
+                new_data,
+                new_kernel,
+                strides,
+                padding,
+                dilation,
+                new_attrs["data_layout"],
+                new_attrs["out_layout"],
+                out_dtype,
+            ],
+            topi_tmpl,
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
 
@@ -218,8 +244,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # Converting NCHW to NCHWc.
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         if cfg.is_fallback:
-            _get_x86_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                    out_dtype, True, data_layout)
+            _get_x86_default_config(
+                cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, True, data_layout
+            )
 
         batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
         out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
@@ -227,23 +254,38 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert channel_multiplier == 1
 
         # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
-        new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+        new_attrs["channels"] = out_channel
+        new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+        new_attrs["kernel_layout"] = "OIHW1i%do" % oc_bn
+        new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
         # Store altered operator's config.
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
+        new_data = te.placeholder(
+            (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+        )
+        new_kernel = te.placeholder((out_channel // oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
         new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
-             new_attrs['out_layout'], out_dtype], topi_tmpl)
+            [
+                new_data,
+                new_kernel,
+                strides,
+                padding,
+                dilation,
+                new_attrs["data_layout"],
+                new_attrs["out_layout"],
+                out_dtype,
+            ],
+            topi_tmpl,
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
     if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
-        assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
-                data.dtype == 'uint8' and kernel.dtype == 'uint8')
+        assert (
+            data.dtype == "int8"
+            and kernel.dtype == "int8"
+            or data.dtype == "uint8"
+            and kernel.dtype == "uint8"
+        )
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
         KH, KW, IC, OC = get_const_tuple(kernel.shape)
         K = KH * KW * IC
@@ -262,20 +304,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         N_padded = N + pad_N
         K_padded = K + pad_K
         kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
-        new_kernel = te.placeholder((N_padded // tile_rows,
-                                     K_padded // tile_cols,
-                                     tile_rows,
-                                     tile_cols), kernel.dtype)
+        new_kernel = te.placeholder(
+            (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), kernel.dtype
+        )
 
         new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
-        new_workload = autotvm.task.args_to_workload([data, new_kernel,
-                                                      strides, padding, dilation,
-                                                      out_dtype, (KH, KW), OC],
-                                                     new_workload_name)
+        new_workload = autotvm.task.args_to_workload(
+            [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
+            new_workload_name,
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
-        return relay.nn.contrib_conv2d_gemm_without_weight_transform(inputs[0],
-                                                                     kernel_expr,
-                                                                     **new_attrs)
+        return relay.nn.contrib_conv2d_gemm_without_weight_transform(
+            inputs[0], kernel_expr, **new_attrs
+        )
 
     return None
index 62f013a..7f73cc8 100644 (file)
@@ -25,16 +25,17 @@ from ..util import get_const_tuple, get_const_int
 from ..nn.util import get_pad_tuple
 from .tensor_intrin import gemm_quantized, gemm_quantized_impl
 
+
 def is_aarch64_arm():
     """ Checks whether we are compiling for an AArch64 target. """
     target = tvm.target.Target.current(allow_none=False)
-    return 'aarch64' in target.attrs.get("mtriple", "")
+    return "aarch64" in target.attrs.get("mtriple", "")
 
 
 # Compute function
-def compute_conv2d_gemm_without_weight_transform(cfg,
-                                                 data, B_interleaved_t, strides, padding, dilation,
-                                                 out_dtype, kernel_size, output_channels):
+def compute_conv2d_gemm_without_weight_transform(
+    cfg, data, B_interleaved_t, strides, padding, dilation, out_dtype, kernel_size, output_channels
+):
     """Compute conv2d by transforming the input,
     executing GEMM and transforming the output back"""
     batches, IH, IW, IC = get_const_tuple(data.shape)
@@ -52,15 +53,17 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
     dilated_kernel_h = (KH - 1) * dilation_h + 1
     dilated_kernel_w = (KW - 1) * dilation_w + 1
 
-    pad_top, pad_left, pad_down, pad_right = \
-        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
 
     OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
     OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
     if pad_top or pad_left:
-        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
-                          name="data_pad")
+        data_pad = nn.pad(
+            data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad"
+        )
     else:
         data_pad = data
 
@@ -71,14 +74,22 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
 
     A_shape = (batches, M, K)
     if K_AREA == 1:
-        A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y],
-                       name='data_flatten')
+        A = te.compute(
+            A_shape,
+            lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y],
+            name="data_flatten",
+        )
     else:
-        A = te.compute(A_shape, lambda n, x, y:
-                       data_pad[n,
-                                HSTR * (x // OW) + dilation_h * ((y // IC) // KW),
-                                WSTR * (x % OW) + dilation_w * ((y // IC) % KW), y % IC],
-                       name='data_im2col')
+        A = te.compute(
+            A_shape,
+            lambda n, x, y: data_pad[
+                n,
+                HSTR * (x // OW) + dilation_h * ((y // IC) // KW),
+                WSTR * (x % OW) + dilation_w * ((y // IC) % KW),
+                y % IC,
+            ],
+            name="data_im2col",
+        )
     N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
@@ -105,52 +116,53 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
     # --- GEMM: A*B'
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16),
-                               lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-                               name='A_interleaved')
-
-    C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4),
-                               lambda b, x, y, w, z:
-                               te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)*
-                                      B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype),
-                                      axis=k),
-                               name='C_interleaved')
+    A_interleaved = te.compute(
+        (batches, M_padded // 4, K_padded // 16, 4, 16),
+        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
+        name="A_interleaved",
+    )
+
+    C_interleaved = te.compute(
+        (batches, M_padded // 4, N_transformed, 4, 4),
+        lambda b, x, y, w, z: te.sum(
+            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
+            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
+            axis=k,
+        ),
+        name="C_interleaved",
+    )
 
     # --- Unpack C
-    C = te.compute((batches, M, N),
-                   lambda b, x, y:
-                   C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-                   name="C")
+    C = te.compute(
+        (batches, M, N),
+        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
+        name="C",
+    )
 
     # --- Produce the conv output
     out_shape = (batches, OH, OW, OC)
-    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
-                     name='conv2d_gemm_output')
-
+    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name="conv2d_gemm_output")
 
     # Configuration space
     x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
-    cfg.define_reorder('reorder_gemm',
-                       [x, y],
-                       policy='candidate',
-                       candidate=[[x, y],
-                                  [y, x]])
+    cfg.define_reorder("reorder_gemm", [x, y], policy="candidate", candidate=[[x, y], [y, x]])
 
     outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
-    cfg.define_annotate("A_interleaved_unroll_vec",
-                        [outer_loop, inner_loop],
-                        policy="try_unroll_vec")
-    cfg.define_knob('gemm_quantized_unroll', [True, False])
-    cfg.define_knob('gemm_quantized_interleave', [True, False])
+    cfg.define_annotate(
+        "A_interleaved_unroll_vec", [outer_loop, inner_loop], policy="try_unroll_vec"
+    )
+    cfg.define_knob("gemm_quantized_unroll", [True, False])
+    cfg.define_knob("gemm_quantized_interleave", [True, False])
 
     # Fallback configuration
     if cfg.is_fallback:
-        cfg['reorder_gemm'] = ReorderEntity([0, 1])
-        cfg['A_interleaved_unroll_vec'] = AnnotateEntity(["unroll", "vec"])
-        cfg['gemm_quantized_unroll'] = OtherOptionEntity(False)
-        cfg['gemm_quantized_interleave'] = OtherOptionEntity(True)
+        cfg["reorder_gemm"] = ReorderEntity([0, 1])
+        cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])
+        cfg["gemm_quantized_unroll"] = OtherOptionEntity(False)
+        cfg["gemm_quantized_interleave"] = OtherOptionEntity(True)
     return out
 
+
 # Schedules
 def schedule_conv2d_gemm(cfg, s, out, final_out):
     """Create schedule for tensors"""
@@ -180,31 +192,30 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
 
     # Computation(through tensorize)
     b, xo, yo, xi, yi = C_interleaved.op.axis
-    outer_gemm, inner_gemm = cfg['reorder_gemm'].apply(s, C_interleaved, [xo, yo])
+    outer_gemm, inner_gemm = cfg["reorder_gemm"].apply(s, C_interleaved, [xo, yo])
     s[C_interleaved].reorder(yi, xi)
     b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
     s[C_interleaved].parallel(b_outer_gemm_fused)
     s[A_interleaved].compute_at(s[C_interleaved], b_outer_gemm_fused)
     _, _, _, outer_A_interleaved, inner_A_interleaved = A_interleaved.op.axis
-    cfg['A_interleaved_unroll_vec'].apply(s,
-                                          A_interleaved,
-                                          [outer_A_interleaved, inner_A_interleaved])
+    cfg["A_interleaved_unroll_vec"].apply(
+        s, A_interleaved, [outer_A_interleaved, inner_A_interleaved]
+    )
 
     in_type = A_interleaved.dtype
     out_type = C.dtype
-    if is_aarch64_arm() and out_type == 'int32':
+    if is_aarch64_arm() and out_type == "int32":
         K = A_interleaved_input.shape[2]
         _, M, N = C.shape
-        assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"
-        unroll = cfg['gemm_quantized_unroll'].val
-        interleave = cfg['gemm_quantized_interleave'].val
+        assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"
+        unroll = cfg["gemm_quantized_unroll"].val
+        interleave = cfg["gemm_quantized_interleave"].val
         gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
-        s[C_interleaved].pragma(b_outer_gemm_fused, "import_llvm", gemm_quantized_impl(M,
-                                                                                       N,
-                                                                                       K,
-                                                                                       unroll,
-                                                                                       interleave,
-                                                                                       in_type))
+        s[C_interleaved].pragma(
+            b_outer_gemm_fused,
+            "import_llvm",
+            gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
+        )
         s[C_interleaved].tensorize(yi, gemm)
 
     # Output transform
index 9a6e8cc..307f9e1 100644 (file)
@@ -34,16 +34,15 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
     wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
     is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
     if is_kernel_1x1:
-        conv2d_generic.fallback_schedule_cpu_1x1_int8(
-            cfg, wkl, int32_lanes=2, num_int8_elements=4)
+        conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=2, num_int8_elements=4)
     else:
         conv2d_generic.fallback_schedule_cpu_common_int8(
-            cfg, wkl, int32_lanes=2, num_int8_elements=4)
+            cfg, wkl, int32_lanes=2, num_int8_elements=4
+        )
 
 
 @autotvm.register_topi_compute("conv2d_NCHWc_int8.arm_cpu")
-def conv2d_NCHWc_int8(cfg, data, kernel, strides,
-                      padding, dilation, layout, out_layout, out_dtype):
+def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
     """Compute conv2d int8 with NCHWc layout"""
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
@@ -55,17 +54,17 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides,
 
     # If no config was set, we can fallback to NCHW config.
     if cfg.is_fallback:
-        _get_default_config(cfg, te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
-                            te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
-                            strides, padding, out_dtype)
-    return nn.conv2d_NCHWc_int8_compute(data,
-                                        kernel,
-                                        strides,
-                                        padding,
-                                        dilation,
-                                        layout,
-                                        out_layout,
-                                        out_dtype)
+        _get_default_config(
+            cfg,
+            te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+            te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
+            strides,
+            padding,
+            out_dtype,
+        )
+    return nn.conv2d_NCHWc_int8_compute(
+        data, kernel, strides, padding, dilation, layout, out_layout, out_dtype
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_NCHWc_int8.arm_cpu")
@@ -84,13 +83,15 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
 
-        if 'conv2d_NCHWc_int8' in op.tag:
+        if "conv2d_NCHWc_int8" in op.tag:
             conv_out = op.output(0)
             kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
-            data = data_vec.op.input_tensors[0] \
-                if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
+            data = (
+                data_vec.op.input_tensors[0]
+                if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag
                 else data_vec
+            )
             if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag:
                 data_pad = data
                 data = data_pad.op.input_tensors[0]
@@ -101,10 +102,12 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
             dtype = "uint" if data.dtype == "uint8" else "int"
             if kh == 1 and kw == 1:
                 conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
-                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
+                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)
+                )
             else:
                 conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
-                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
+                    *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)
+                )
 
         scheduled_ops.append(op)
 
@@ -119,18 +122,19 @@ def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation,
     tile_rows = 4
     tile_cols = 16
     kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols)
-    return  compute_conv2d_gemm_without_weight_transform(cfg,
-                                                         data, kernel, strides, padding,
-                                                         dilation, out_dtype, (KH, KW), OC)
+    return compute_conv2d_gemm_without_weight_transform(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), OC
+    )
 
 
 @autotvm.register_topi_compute("conv2d_NHWC_quantized_without_transform.arm_cpu")
-def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, padding,
-                                                    dilation, out_dtype, kernel_size=None,
-                                                    output_channels=None):
-    return  compute_conv2d_gemm_without_weight_transform(cfg, data, B, strides, padding,
-                                                         dilation, out_dtype, kernel_size,
-                                                         output_channels)
+def compute_conv2d_NHWC_quantized_without_transform(
+    cfg, data, B, strides, padding, dilation, out_dtype, kernel_size=None, output_channels=None
+):
+    """Compute for conv2d_NHWC_quantized without weight transform."""
+    return compute_conv2d_gemm_without_weight_transform(
+        cfg, data, B, strides, padding, dilation, out_dtype, kernel_size, output_channels
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_NHWC_quantized.arm_cpu")
@@ -156,6 +160,5 @@ def schedule_conv2d_NHWC_quantized(cfg, outs):
                 C = conv_out.op.input_tensors[0]
                 s[C].compute_at(s[out], inner)
 
-
     traverse_inline(s, outs[0].op, _callback)
     return s
index b475837..e3649e5 100644 (file)
@@ -24,8 +24,8 @@ from .. import nn
 from ..util import get_const_tuple
 from ..nn.util import get_const_int, get_pad_tuple
 
-def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
-                             out_dtype, num_tile):
+
+def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
     """compute define for Conv2d Spatial Pack with NCHW layout"""
     out_dtype = out_dtype or data.dtype
     N, CI, IH, IW = get_const_tuple(data.shape)
@@ -46,7 +46,8 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
     dilated_kernel_h = (KH - 1) * dilation_h + 1
     dilated_kernel_w = (KW - 1) * dilation_w + 1
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
     OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
@@ -56,35 +57,41 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
     n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
     ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
 
-    if num_tile == 2:     # for arm cpu
-        co, vc = cfg.define_split('tile_co', co, num_outputs=2)
-        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
-        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
-    elif num_tile == 3:   # for mali gpu
-        co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
-        oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
-        ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
+    if num_tile == 2:  # for arm cpu
+        co, vc = cfg.define_split("tile_co", co, num_outputs=2)
+        oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
+        ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
+    elif num_tile == 3:  # for mali gpu
+        co, _, vc = cfg.define_split("tile_co", co, num_outputs=3)
+        oh, _, vh = cfg.define_split("tile_oh", oh, num_outputs=3)
+        ow, _, vw = cfg.define_split("tile_ow", ow, num_outputs=3)
     else:
         raise RuntimeError("Invalid num_tile")
 
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                           [n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
+    cfg.define_reorder(
+        "reorder_0",
+        [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+        policy="candidate",
+        candidate=[
+            [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+            [n, co, oh, ow, ci, kh, kw, vc, vh, vw],
+        ],
+    )
 
-    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
-    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
+    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec")
 
     # fallback support
     if cfg.is_fallback:
-        if num_tile == 2:     # arm cpu
+        if num_tile == 2:  # arm cpu
             ref_log = autotvm.tophub.load_reference_log(
-                'arm_cpu', 'rk3399', 'conv2d_nchw_spatial_pack.arm_cpu')
+                "arm_cpu", "rk3399", "conv2d_nchw_spatial_pack.arm_cpu"
+            )
             cfg.fallback_with_reference_log(ref_log)
         elif num_tile == 3:  # mali gpu
             ref_log = autotvm.tophub.load_reference_log(
-                'mali', 'rk3399', 'conv2d_nchw_spatial_pack.mali')
+                "mali", "rk3399", "conv2d_nchw_spatial_pack.mali"
+            )
             cfg.fallback_with_reference_log(ref_log)
     # ====================================================================
 
@@ -99,15 +106,20 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
     if dilation_h != 1 or dilation_w != 1:
         # undilate input data
         dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
-        data_vec = te.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
-                              data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
-                              [(w*VW+vw)*WSTR+kw*dilation_w],
-                              name='data_vec_undilated')
+        data_vec = te.compute(
+            dvshape,
+            lambda n, h, w, ci, kh, kw, vh, vw: data_pad[n][ci][
+                (h * VH + vh) * HSTR + kh * dilation_h
+            ][(w * VW + vw) * WSTR + kw * dilation_w],
+            name="data_vec_undilated",
+        )
     else:
-        dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
-        data_vec = te.compute(dvshape, lambda n, h, w, ci, vh, vw:
-                              data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
-                              name='data_vec')
+        dvshape = (N, OH // VH, OW // VW, CI, VH * HSTR + KH - 1, VW * WSTR + KW - 1)
+        data_vec = te.compute(
+            dvshape,
+            lambda n, h, w, ci, vh, vw: data_pad[n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw],
+            name="data_vec",
+        )
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
         # use "kernel_autotvm" instead of "kernel" to avoid naming conflict with OpenCL keyword
@@ -116,90 +128,119 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
         if pre_packed:
             kernel_vec = kernel
         else:
-            kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc:
-                                    kernel[co*VC+vc][ci][kh][kw],
-                                    name='kernel_vec')
+            kernel_vec = te.compute(
+                kvshape,
+                lambda co, ci, kh, kw, vc: kernel[co * VC + vc][ci][kh][kw],
+                name="kernel_vec",
+            )
 
-    ci = te.reduce_axis((0, CI), name='ci')
-    kh = te.reduce_axis((0, KH), name='kh')
-    kw = te.reduce_axis((0, KW), name='kw')
+    ci = te.reduce_axis((0, CI), name="ci")
+    kh = te.reduce_axis((0, KH), name="kh")
+    kw = te.reduce_axis((0, KW), name="kw")
 
     if dilation_h != 1 or dilation_w != 1:
-        conv = te.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                          te.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
-                                 kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
-                                 axis=[ci, kh, kw]), name='conv')
+        conv = te.compute(
+            ovshape,
+            lambda n, co, h, w, vh, vw, vc: te.sum(
+                data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype)
+                * kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
+                axis=[ci, kh, kw],
+            ),
+            name="conv",
+        )
     else:
-        conv = te.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                          te.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
-                                 kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
-                                 axis=[ci, kh, kw]), name='conv')
+        conv = te.compute(
+            ovshape,
+            lambda n, co, h, w, vh, vw, vc: te.sum(
+                data_vec[n, h, w, ci, vh * HSTR + kh, vw * WSTR + kw].astype(out_dtype)
+                * kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
+                axis=[ci, kh, kw],
+            ),
+            name="conv",
+        )
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    output = te.compute(oshape, lambda n, co, h, w:
-                        conv[n,
-                             idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
-                             idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
-                        name='output_unpack', tag='spatial_conv2d_output')
+    output = te.compute(
+        oshape,
+        lambda n, co, h, w: conv[
+            n,
+            idxdiv(co, VC),
+            idxdiv(h, VH),
+            idxdiv(w, VW),
+            idxmod(h, VH),
+            idxmod(w, VW),
+            idxmod(co, VC),
+        ],
+        name="output_unpack",
+        tag="spatial_conv2d_output",
+    )
     return output
 
-def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
-                                      conv, output, last):
+
+def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output, last):
     """schedule implementation"""
     n, co, oh, ow, vh, vw, vc = s[conv].op.axis
     ci, kh, kw = s[conv].op.reduce_axis
 
     # schedule conv
     cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
-    cfg["ann_reduce"].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kh.dom.extent),
-                                       get_const_int(kw.dom.extent)],
-                            max_unroll=None,
-                            cfg=cfg)
-    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
-                             axis_lens=[cfg['tile_oh'].size[-1],
-                                        cfg['tile_ow'].size[-1],
-                                        cfg['tile_co'].size[-1]],
-                             max_unroll=None,
-                             cfg=cfg)
+    cfg["ann_reduce"].apply(
+        s,
+        conv,
+        [kh, kw],
+        axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)],
+        max_unroll=None,
+        cfg=cfg,
+    )
+    cfg["ann_spatial"].apply(
+        s,
+        conv,
+        [vh, vw, vc],
+        axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
+        max_unroll=None,
+        cfg=cfg,
+    )
 
     # schedule fusion
     n, co, h, w = s[last].op.axis
-    co, vc = cfg['tile_co'].apply(s, last, co)
-    oh, vh = cfg['tile_oh'].apply(s, last, h)
-    ow, vw = cfg['tile_ow'].apply(s, last, w)
+    co, vc = cfg["tile_co"].apply(s, last, co)
+    oh, vh = cfg["tile_oh"].apply(s, last, h)
+    ow, vw = cfg["tile_ow"].apply(s, last, w)
     s[last].reorder(n, co, oh, ow, vh, vw, vc)
     if last != output:
         s[output].compute_inline()
-        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
-                                 axis_lens=[cfg['tile_oh'].size[-1],
-                                            cfg['tile_ow'].size[-1],
-                                            cfg['tile_co'].size[-1]],
-                                 max_unroll=16,
-                                 cfg=cfg)
+        cfg["ann_spatial"].apply(
+            s,
+            last,
+            [vh, vw, vc],
+            axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
+            max_unroll=16,
+            cfg=cfg,
+        )
     s[conv].compute_at(s[last], ow)
 
     # mark parallel
     s[last].parallel(co)
 
-    if data_vec.op.name == 'data_vec_undilated':
+    if data_vec.op.name == "data_vec_undilated":
         _, h, _, _, _, _, _, _ = s[data_vec].op.axis
     else:
         _, h, _, _, _, _ = s[data_vec].op.axis
     s[data_vec].parallel(h)
 
-    if kernel_vec.op.name == 'kernel_vec':
+    if kernel_vec.op.name == "kernel_vec":
         if not autotvm.GLOBAL_SCOPE.in_tuning:
             co, _, _, _, _ = s[kernel_vec].op.axis
             s[kernel_vec].parallel(co)
-    elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose':  # for conv2d transpose
+    elif kernel_vec.op.name == "kernel_vec_conv2d_transpose":  # for conv2d transpose
         co, _, _, _, _ = s[kernel_vec].op.axis
         s[kernel_vec].parallel(co)
 
     return s
 
+
 def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Spatial pack compute for Conv2d NHWC"""
     out_dtype = out_dtype or data.dtype
@@ -216,8 +257,9 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_
     dilated_kernel_h = (KH - 1) * dilation_h + 1
     dilated_kernel_w = (KW - 1) * dilation_w + 1
 
-    pad_top, pad_left, pad_down, pad_right = \
-        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
 
     OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
@@ -228,25 +270,29 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_
     n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
     ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
 
-    oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
-    oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
-    owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)
-
-    cfg.define_reorder('reorder_conv',
-                       [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
-                       policy='candidate', candidate=[
-                           [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
-                           [n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
-                           [n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
-                           [n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])
-
-    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
-    cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
+    oco, oci = cfg.define_split("tile_co", oc, num_outputs=2)
+    oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2)
+    owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2)
+
+    cfg.define_reorder(
+        "reorder_conv",
+        [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+        policy="candidate",
+        candidate=[
+            [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+            [n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
+            [n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
+            [n, oho, owo, ohi, oco, kh, kw, owi, ic, oci],
+        ],
+    )
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy="try_unroll_vec")
     # ====================================================================
 
-    OCI = cfg['tile_co'].size[-1]
-    OHI = cfg['tile_oh'].size[-1]
-    OWI = cfg['tile_ow'].size[-1]
+    OCI = cfg["tile_co"].size[-1]
+    OHI = cfg["tile_oh"].size[-1]
+    OWI = cfg["tile_ow"].size[-1]
     OCO = OC // OCI
     OHO = OH // OHI
     OWO = OW // OWI
@@ -258,47 +304,70 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_
     if dilation_h != 1 or dilation_w != 1:
         # undilate input data
         dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
-        data_vec = te.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
-                              data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
-                              [(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
-                              name='data_vec_undilated')
+        data_vec = te.compute(
+            dvshape,
+            lambda n, oho, owo, kh, kw, ic, ohi, owi: data_pad[n][
+                (oho * OHI + ohi) * HSTR + kh * dilation_h
+            ][(owo * OWI + owi) * WSTR + kw * dilation_w][ic],
+            name="data_vec_undilated",
+        )
     else:
-        dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
-        data_vec = te.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
-                              data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
-                              name='data_vec')
+        dvshape = (N, OHO, OWO, KH + (OHI - 1) * HSTR, KW + (OWI - 1) * WSTR, IC)
+        data_vec = te.compute(
+            dvshape,
+            lambda n, oho, owo, ohi, owi, ic: data_pad[n][oho * OHI * HSTR + ohi][
+                owo * OWI * WSTR + owi
+            ][ic],
+            name="data_vec",
+        )
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
         kernel_vec = tvm.te.placeholder(kvshape, kernel.dtype, name="kernel")
     else:
-        kernel_vec = te.compute(kvshape, lambda oco, kh, kw, ic, oci: \
-                                kernel[kh][kw][ic][oco*OCI+oci],
-                                name='kernel_vec')
+        kernel_vec = te.compute(
+            kvshape,
+            lambda oco, kh, kw, ic, oci: kernel[kh][kw][ic][oco * OCI + oci],
+            name="kernel_vec",
+        )
 
-    ic = te.reduce_axis((0, IC), name='ic')
-    kh = te.reduce_axis((0, KH), name='kh')
-    kw = te.reduce_axis((0, KW), name='kw')
+    ic = te.reduce_axis((0, IC), name="ic")
+    kh = te.reduce_axis((0, KH), name="kh")
+    kw = te.reduce_axis((0, KW), name="kw")
 
     if dilation_h != 1 or dilation_w != 1:
-        conv = te.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
-                          te.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
-                                 kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
-                                 axis=[ic, kh, kw]), name='conv')
+        conv = te.compute(
+            ovshape,
+            lambda n, oho, owo, oco, ohi, owi, oci: te.sum(
+                data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype)
+                * kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+                axis=[ic, kh, kw],
+            ),
+            name="conv",
+        )
     else:
         conv = te.compute(
-            ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
-            te.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
-                   kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
-                   axis=[ic, kh, kw]), name='conv')
+            ovshape,
+            lambda n, oho, owo, oco, ohi, owi, oci: te.sum(
+                data_vec[n, oho, owo, ohi * HSTR + kh, owi * WSTR + kw, ic].astype(out_dtype)
+                * kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+                axis=[ic, kh, kw],
+            ),
+            name="conv",
+        )
 
     idiv = tvm.tir.indexdiv
     imod = tvm.tir.indexmod
-    output = te.compute(oshape, lambda n, oho, owo, oc:
-                        conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
-                        [imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
-                        name='output_unpack', tag='spatial_conv_output_NHWC')
+    output = te.compute(
+        oshape,
+        lambda n, oho, owo, oc: conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)][
+            imod(oho, OHI)
+        ][imod(owo, OWI)][imod(oc, OCI)],
+        name="output_unpack",
+        tag="spatial_conv_output_NHWC",
+    )
     return output
 
+
 def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
     """Spatial Pack schedule for Conv2d NHWC"""
     unpack = op.output(0)
@@ -306,23 +375,24 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
     data_vec = conv.op.input_tensors[0]
     kernel_vec = conv.op.input_tensors[1]
     data_pad = data_vec.op.input_tensors[0]
-    OHI = cfg['tile_oh'].size[-1]
-    OWI = cfg['tile_ow'].size[-1]
-    OCI = cfg['tile_co'].size[-1]
+    OHI = cfg["tile_oh"].size[-1]
+    OWI = cfg["tile_ow"].size[-1]
+    OCI = cfg["tile_co"].size[-1]
 
     # schedule unpack/output
     if output != unpack:
         s[unpack].compute_inline()
     n, oh, ow, oc = s[output].op.axis
-    oco, oci = cfg['tile_co'].apply(s, output, oc)
-    oho, ohi = cfg['tile_oh'].apply(s, output, oh)
-    owo, owi = cfg['tile_ow'].apply(s, output, ow)
+    oco, oci = cfg["tile_co"].apply(s, output, oc)
+    oho, ohi = cfg["tile_oh"].apply(s, output, oh)
+    owo, owi = cfg["tile_ow"].apply(s, output, ow)
     s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
-    cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
-                             max_unroll=16, cfg=cfg)
-    cfg.define_knob('compat', [0, 1, 2])
-    if cfg['compat'].val < 2:
-        compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+    cfg["ann_spatial"].apply(
+        s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI], max_unroll=16, cfg=cfg
+    )
+    cfg.define_knob("compat", [0, 1, 2])
+    if cfg["compat"].val < 2:
+        compat_axis = [owo, oco][cfg["compat"].val]  # pylint: disable=R1706
         s[conv].compute_at(s[output], compat_axis)
     paxis = s[output].fuse(n, oho)
     s[output].parallel(paxis)
@@ -330,16 +400,20 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
     # schedule conv
     n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
     ic, kh, kw = s[conv].op.reduce_axis
-    cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
-    cfg['ann_reduce'].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kh.dom.extent),
-                                       get_const_int(kw.dom.extent)],
-                            max_unroll=16,
-                            cfg=cfg)
-    cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
-                             max_unroll=16, cfg=cfg)
-    if cfg['compat'].val < 2:
-        compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+    cfg["reorder_conv"].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
+    cfg["ann_reduce"].apply(
+        s,
+        conv,
+        [kh, kw],
+        axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)],
+        max_unroll=16,
+        cfg=cfg,
+    )
+    cfg["ann_spatial"].apply(
+        s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI], max_unroll=16, cfg=cfg
+    )
+    if cfg["compat"].val < 2:
+        compat_axis = [owo, oco][cfg["compat"].val]  # pylint: disable=R1706
         s[kernel_vec].compute_at(s[conv], compat_axis)
         s[data_vec].compute_at(s[conv], compat_axis)
 
@@ -348,11 +422,11 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
         oco, kh, kw, ic, oci = kernel_vec.op.axis
         s[kernel_vec].vectorize(oci)
         s[kernel_vec].unroll(ic)
-        if cfg['compat'].val == 2:
+        if cfg["compat"].val == 2:
             s[kernel_vec].parallel(oco)
 
     # schedule data pack
-    if data_vec.op.name == 'data_vec_undilated':
+    if data_vec.op.name == "data_vec_undilated":
         n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
         s[data_vec].vectorize(owi)
         s[data_vec].unroll(ohi)
@@ -360,7 +434,7 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
         n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
         s[data_vec].vectorize(ic)
         s[data_vec].unroll(owi)
-    if cfg['compat'].val == 2:
+    if cfg["compat"].val == 2:
         paxis = s[data_vec].fuse(n, oho)
         s[data_vec].parallel(paxis)
 
index 8152ae2..d482228 100644 (file)
@@ -27,10 +27,8 @@ from ..util import get_const_tuple, traverse_inline
 from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw
 
 
-
 @autotvm.register_topi_compute("conv2d_transpose_nchw.arm_cpu")
-def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype,
-                          output_padding):
+def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype, output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -58,11 +56,14 @@ def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype,
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2,
-                              output_padding)
+    return _decl_spatial_pack(
+        cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2, output_padding
+    )
+
 
-def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile,
-                       output_padding):
+def _decl_spatial_pack(
+    cfg, data, kernel, strides, padding, layout, out_dtype, num_tile, output_padding
+):
     assert layout == "NCHW", "Only support NCHW"
     out_dtype = out_dtype or data.dtype
 
@@ -86,61 +87,83 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
     n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
     ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
 
-    if num_tile == 2:     # for arm cpu
-        co, vc = cfg.define_split('tile_co', co, num_outputs=2)
-        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
-        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
-    elif num_tile == 3:   # for mali gpu
-        co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
-        oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
-        ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
+    if num_tile == 2:  # for arm cpu
+        co, vc = cfg.define_split("tile_co", co, num_outputs=2)
+        oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
+        ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
+    elif num_tile == 3:  # for mali gpu
+        co, _, vc = cfg.define_split("tile_co", co, num_outputs=3)
+        oh, _, vh = cfg.define_split("tile_oh", oh, num_outputs=3)
+        ow, _, vw = cfg.define_split("tile_ow", ow, num_outputs=3)
     else:
         raise RuntimeError("Invalid num_tile")
 
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                           [n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
-
-    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
-    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
+    cfg.define_reorder(
+        "reorder_0",
+        [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+        policy="candidate",
+        candidate=[
+            [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+            [n, co, oh, ow, ci, kh, kw, vc, vh, vw],
+        ],
+    )
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec")
     # ====================================================================
 
     VC = cfg["tile_co"].size[-1]
     VH = cfg["tile_oh"].size[-1]
     VW = cfg["tile_ow"].size[-1]
 
-    dvshape = (N, OH // VH, OW // VW, CI, VH + KH-1, VW + KW-1)
+    dvshape = (N, OH // VH, OW // VW, CI, VH + KH - 1, VW + KW - 1)
     kvshape = (CO // VC, CI, KH, KW, VC)
     ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
     oshape = (N, CO, OH, OW)
 
-    data_vec = te.compute(dvshape, lambda n, h, w, ci, vh, vw:
-                          data_pad[n][ci][h*VH + vh][w*VW + vw],
-                          name='data_vec')
-
-    kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc:
-                            kernel[ci][co*VC+vc][kh][kw],
-                            name='kernel_vec_conv2d_transpose')
-
-    ci = te.reduce_axis((0, CI), name='ci')
-    kh = te.reduce_axis((0, KH), name='kh')
-    kw = te.reduce_axis((0, KW), name='kw')
-
-    conv = te.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                      te.sum(data_vec[n, h, w, ci, vh + kh, vw + kw].astype(out_dtype) *
-                             kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype),
-                             axis=[ci, kh, kw]), name='conv')
+    data_vec = te.compute(
+        dvshape,
+        lambda n, h, w, ci, vh, vw: data_pad[n][ci][h * VH + vh][w * VW + vw],
+        name="data_vec",
+    )
+
+    kernel_vec = te.compute(
+        kvshape,
+        lambda co, ci, kh, kw, vc: kernel[ci][co * VC + vc][kh][kw],
+        name="kernel_vec_conv2d_transpose",
+    )
+
+    ci = te.reduce_axis((0, CI), name="ci")
+    kh = te.reduce_axis((0, KH), name="kh")
+    kw = te.reduce_axis((0, KW), name="kw")
+
+    conv = te.compute(
+        ovshape,
+        lambda n, co, h, w, vh, vw, vc: te.sum(
+            data_vec[n, h, w, ci, vh + kh, vw + kw].astype(out_dtype)
+            * kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype),
+            axis=[ci, kh, kw],
+        ),
+        name="conv",
+    )
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    output = te.compute(oshape, lambda n, co, h, w:
-                        conv[n,
-                             idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
-                             idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
-                        name='output_unpack', tag='spatial_conv2d_transpose_output')
+    output = te.compute(
+        oshape,
+        lambda n, co, h, w: conv[
+            n,
+            idxdiv(co, VC),
+            idxdiv(h, VH),
+            idxdiv(w, VW),
+            idxmod(h, VH),
+            idxmod(w, VW),
+            idxmod(co, VC),
+        ],
+        name="output_unpack",
+        tag="spatial_conv2d_transpose_output",
+    )
     return output
 
 
@@ -151,7 +174,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'spatial_conv2d_transpose_output' in op.tag:
+        if "spatial_conv2d_transpose_output" in op.tag:
             output = op.output(0)
             conv = op.input_tensors[0]
 
@@ -162,15 +185,14 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             s[dilated_input].compute_inline()
 
             kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
+            if kernel_vec.op.name == "kernel_vec":
                 kernel = kernel_vec.op.input_tensors[0]
             else:
                 kernel = kernel_vec
             if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
-            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
-                                              conv, output, outs[0])
+            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index 3f1a5ff..b084066 100644 (file)
@@ -23,6 +23,7 @@ from tvm.autotvm.task import deserialize_args
 from tvm.topi.nn.conv2d import conv2d_nchw, conv2d_nhwc
 from tvm.topi.util import get_const_tuple, get_const_int, traverse_inline
 
+
 def conv2d_direct(*args, **kwargs):
     """Schedule function for directly-scheduled conv2d."""
     assert not kwargs, "Do not support kwargs in template function call"
@@ -32,32 +33,33 @@ def conv2d_direct(*args, **kwargs):
     cfg = autotvm.get_config()
     args = [cfg] + args
     conv = conv2d_direct_compute(*args)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         sched = conv2d_direct_nhwc_schedule(cfg, [data, kernel, conv])
-    elif layout == 'NCHW':
+    elif layout == "NCHW":
         sched = conv2d_direct_nchw_schedule(cfg, [data, kernel, conv])
     else:
         raise RuntimeError(f'unsupported data layout "{layout}"')
     return sched, [data, kernel, conv]
 
 
-conv2d_direct.template_key = 'direct'
-conv2d_direct.default_data_layout = 'NHWC'
-conv2d_direct.default_kernel_layout = 'HWIO'
+conv2d_direct.template_key = "direct"
+conv2d_direct.default_data_layout = "NHWC"
+conv2d_direct.default_kernel_layout = "HWIO"
+
 
-@autotvm.register_topi_compute('conv2d_direct.micro_dev')
+@autotvm.register_topi_compute("conv2d_direct.micro_dev")
 def conv2d_direct_compute(*args):
     layout = args[-2]
-    if layout == 'NHWC':
+    if layout == "NHWC":
         return _conv2d_direct_nhwc_compute(*args)
-    if layout == 'NCHW':
+    if layout == "NCHW":
         return _conv2d_direct_nchw_compute(*args)
 
     raise RuntimeError(f'unsupported data layout "{layout}"')
 
 
 def _conv2d_direct_nhwc_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    assert layout == 'NHWC'
+    assert layout == "NHWC"
     conv = conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
 
     # Config Space Definition
@@ -67,47 +69,51 @@ def _conv2d_direct_nhwc_compute(cfg, data, kernel, strides, padding, dilation, l
     kh, kw, ci = cfg.reduce_axis(KH), cfg.reduce_axis(KW), cfg.reduce_axis(CI)
 
     # TODO should we add a max_factor attr to these splits?
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
-
-    cfg.define_reorder('reorder_0',
-                       [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                           [n, co, oh, ow, ci, kh, kw, vc, vh, vw],
-                           [n, co, oh, ow, ci, vh, vw, vc, kh, kw],
-                           [n, co, oh, ow, ci, vc, vh, vw, kh, kw]])
-
-    cfg.define_annotate('ann_reduce', [kh, kw], policy='try_unroll')
-    cfg.define_annotate('ann_spatial', [vh, vw, vc], policy='try_unroll')
-
-    cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32])
-    cfg.define_knob('unroll_explicit', [0, 1])
+    co, vc = cfg.define_split("tile_co", co, num_outputs=2)
+    oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
+    ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
+
+    cfg.define_reorder(
+        "reorder_0",
+        [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+        policy="candidate",
+        candidate=[
+            [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+            [n, co, oh, ow, ci, kh, kw, vc, vh, vw],
+            [n, co, oh, ow, ci, vh, vw, vc, kh, kw],
+            [n, co, oh, ow, ci, vc, vh, vw, kh, kw],
+        ],
+    )
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll")
+
+    cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32])
+    cfg.define_knob("unroll_explicit", [0, 1])
 
     return conv
 
 
 def _conv2d_direct_nchw_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    assert layout == 'NCHW'
+    assert layout == "NCHW"
     conv = conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
     ###########################
     # Config Space Definition #
     ###########################
-    cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32])
-    cfg.define_knob('unroll_explicit', [0, 1])
+    cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32])
+    cfg.define_knob("unroll_explicit", [0, 1])
 
     return conv
 
 
-@autotvm.register_topi_schedule('conv2d_direct_nhwc.micro_dev')
+@autotvm.register_topi_schedule("conv2d_direct_nhwc.micro_dev")
 def conv2d_direct_nhwc_schedule(cfg, outs):
     """Schedule function for directly-scheduled conv2d on NHWC layout."""
     sched = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc' not in op.tag:
+        if "conv2d_nhwc" not in op.tag:
             return
 
         ### extract tensors ###
@@ -125,33 +131,38 @@ def conv2d_direct_nhwc_schedule(cfg, outs):
         data_pad = data_vec.op
         sched[data_pad].compute_inline()
 
-        co, vc = cfg['tile_co'].apply(sched, conv, co)
-        oh, vh = cfg['tile_oh'].apply(sched, conv, oh)
-        ow, vw = cfg['tile_ow'].apply(sched, conv, ow)
-        cfg['reorder_0'].apply(sched, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
-        cfg['ann_reduce'].apply(sched, conv, [kh, kw],
-                                axis_lens=[get_const_int(kh.dom.extent),
-                                           get_const_int(kw.dom.extent)],
-                                max_unroll=8,
-                                cfg=cfg)
-        cfg['ann_spatial'].apply(sched, conv, [vh, vw, vc],
-                                 axis_lens=[cfg['tile_oh'].size[-1],
-                                            cfg['tile_ow'].size[-1],
-                                            cfg['tile_co'].size[-1]],
-                                 max_unroll=8,
-                                 cfg=cfg)
+        co, vc = cfg["tile_co"].apply(sched, conv, co)
+        oh, vh = cfg["tile_oh"].apply(sched, conv, oh)
+        ow, vw = cfg["tile_ow"].apply(sched, conv, ow)
+        cfg["reorder_0"].apply(sched, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
+        cfg["ann_reduce"].apply(
+            sched,
+            conv,
+            [kh, kw],
+            axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)],
+            max_unroll=8,
+            cfg=cfg,
+        )
+        cfg["ann_spatial"].apply(
+            sched,
+            conv,
+            [vh, vw, vc],
+            axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
+            max_unroll=8,
+            cfg=cfg,
+        )
 
         kernel_scope = n  # this is the scope to attach global config inside this kernel
 
         # tune unroll
-        sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-        sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+        sched[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        sched[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(sched, outs[-1].op, _callback)
     return sched
 
 
-@autotvm.register_topi_schedule('conv2d_direct_nchw.micro_dev')
+@autotvm.register_topi_schedule("conv2d_direct_nchw.micro_dev")
 def conv2d_direct_nchw_schedule(cfg, outs):
     """Schedule function for Cortex-M7 direct implementation of conv2d."""
     # use default schedule
@@ -169,7 +180,7 @@ def conv2d_direct_nchw_schedule(cfg, outs):
     kernel_scope = n  # this is the scope to attach global config inside this kernel
 
     # tune unroll
-    sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    sched[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    sched[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     return sched
index 5be9b4c..61dca8a 100644 (file)
@@ -25,9 +25,11 @@ from tvm.topi.nn.pad import pad
 from tvm.topi.nn.util import get_pad_tuple
 
 from ..micro_kernel.gemm import (
-        intrin_gemm_MxKxN, gemm_MxKxN_impl,
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
 )
 
+
 def conv2d_direct_simd(*args, **kwargs):
     """Defines the Cortex-M7 SIMD implementation of conv2d."""
     assert not kwargs, "Do not support kwargs in template function call"
@@ -36,15 +38,16 @@ def conv2d_direct_simd(*args, **kwargs):
     layout = args[-2]
     cfg = autotvm.get_config()
     args = [cfg] + args
-    assert layout == 'NHWC'
+    assert layout == "NHWC"
     conv = conv2d_direct_simd_compute(*args)
     sched = conv2d_direct_simd_nhwc_schedule(cfg, [data, kernel, conv])
     return sched, [data, kernel, conv]
 
 
-conv2d_direct_simd.template_key = 'direct_simd'
-conv2d_direct_simd.default_data_layout = 'NHWC'
-conv2d_direct_simd.default_kernel_layout = 'HWOI'
+conv2d_direct_simd.template_key = "direct_simd"
+conv2d_direct_simd.default_data_layout = "NHWC"
+conv2d_direct_simd.default_kernel_layout = "HWOI"
+
 
 def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute function for Cortex-M7 SIMD implementation of conv2d."""
@@ -68,53 +71,68 @@ def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, ou
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
 
     pad_before = [0, pad_top, pad_left, 0]
     pad_after = [0, pad_down, pad_right, 0]
-    padded_data = pad(data, pad_before, pad_after, name='padded_data')
+    padded_data = pad(data, pad_before, pad_after, name="padded_data")
 
-    rc = te.reduce_axis((0, in_channels), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channels), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     conv = te.compute(
         (batch_size, out_height, out_width, out_channels),
         lambda nn, yy, xx, ff: te.sum(
-            padded_data[nn, yy * stride_h + ry * dilation_h,
-                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            kernel[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]),
-        name='conv2d', tag='conv2d_nhwc')
+            padded_data[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc
+            ].astype(out_dtype)
+            * kernel[ry, rx, ff, rc].astype(out_dtype),
+            axis=[ry, rx, rc],
+        ),
+        name="conv2d",
+        tag="conv2d_nhwc",
+    )
 
     ###########################
     # Config Space Definition #
     ###########################
-    n, oh, ow, co = (cfg.axis(batch_size.value),
-                     cfg.axis(out_height.value),
-                     cfg.axis(out_width.value),
-                     cfg.axis(out_channels.value))
-    kh, kw, ci = (cfg.reduce_axis(kernel_h.value),
-                  cfg.reduce_axis(kernel_w.value),
-                  cfg.reduce_axis(in_channels.value))
+    n, oh, ow, co = (
+        cfg.axis(batch_size.value),
+        cfg.axis(out_height.value),
+        cfg.axis(out_width.value),
+        cfg.axis(out_channels.value),
+    )
+    kh, kw, ci = (
+        cfg.reduce_axis(kernel_h.value),
+        cfg.reduce_axis(kernel_w.value),
+        cfg.reduce_axis(in_channels.value),
+    )
 
     assert in_channels.value % 4 == 0
-    owo, owi = cfg.define_split('tile_ow', ow, policy='factors', num_outputs=2)
-    cio, cii = cfg.define_split('tile_ci', ci, policy='factors', num_outputs=2,
-                                filter=lambda x: x.size[-1] % 4 == 0)
-    coo, coi = cfg.define_split('tile_co', co, policy='factors', num_outputs=2)
-
-    cfg.define_reorder('reorder_0_simd',
-                       [n, oh, owo, owi, coo, coi, kh, kw, cio, cii],
-                       policy='candidate', candidate=[
-                           [n, oh, kh, kw, owo, coo, cio, owi, coi, cii],
-                           [n, oh, kh, kw, coo, owo, cio, owi, coi, cii],
-                           [n, kh, kw, oh, owo, coo, cio, owi, coi, cii],
-                           [n, kh, kw, oh, coo, owo, cio, owi, coi, cii]])
-
-    cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32])
-    cfg.define_knob('unroll_explicit', [0, 1])
+    owo, owi = cfg.define_split("tile_ow", ow, policy="factors", num_outputs=2)
+    cio, cii = cfg.define_split(
+        "tile_ci", ci, policy="factors", num_outputs=2, filter=lambda x: x.size[-1] % 4 == 0
+    )
+    coo, coi = cfg.define_split("tile_co", co, policy="factors", num_outputs=2)
+
+    cfg.define_reorder(
+        "reorder_0_simd",
+        [n, oh, owo, owi, coo, coi, kh, kw, cio, cii],
+        policy="candidate",
+        candidate=[
+            [n, oh, kh, kw, owo, coo, cio, owi, coi, cii],
+            [n, oh, kh, kw, coo, owo, cio, owi, coi, cii],
+            [n, kh, kw, oh, owo, coo, cio, owi, coi, cii],
+            [n, kh, kw, oh, coo, owo, cio, owi, coi, cii],
+        ],
+    )
+
+    cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32])
+    cfg.define_knob("unroll_explicit", [0, 1])
 
     return conv
 
@@ -124,7 +142,7 @@ def conv2d_direct_simd_nhwc_schedule(cfg, outs):
     sched = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc' not in op.tag:
+        if "conv2d_nhwc" not in op.tag:
             return
 
         # extract tensors
@@ -138,26 +156,26 @@ def conv2d_direct_simd_nhwc_schedule(cfg, outs):
         n, oh, ow, co = sched[conv].op.axis
         kh, kw, ci = sched[conv].op.reduce_axis
 
-        M = cfg['tile_ow'].size[-1]
-        K = cfg['tile_ci'].size[-1]
-        N = cfg['tile_co'].size[-1]
+        M = cfg["tile_ow"].size[-1]
+        K = cfg["tile_ci"].size[-1]
+        N = cfg["tile_co"].size[-1]
 
-        owo, owi = cfg['tile_ow'].apply(sched, conv, ow)
-        cio, cii = cfg['tile_ci'].apply(sched, conv, ci)
-        coo, coi = cfg['tile_co'].apply(sched, conv, co)
+        owo, owi = cfg["tile_ow"].apply(sched, conv, ow)
+        cio, cii = cfg["tile_ci"].apply(sched, conv, ci)
+        coo, coi = cfg["tile_co"].apply(sched, conv, co)
 
-        cfg['reorder_0_simd'].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii])
+        cfg["reorder_0_simd"].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii])
 
         gemm, uniq_id = intrin_gemm_MxKxN(M, K, N, data_vec.dtype, output.dtype)
         sched[output].tensorize(owi, gemm)
-        sched[output].pragma(n, 'import_c', gemm_MxKxN_impl(M, K, N, uniq_id))
+        sched[output].pragma(n, "import_c", gemm_MxKxN_impl(M, K, N, uniq_id))
 
         # this is the scope to attach global config inside this kernel
         kernel_scope = n
 
         # tune unroll
-        sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-        sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+        sched[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        sched[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(sched, outs[-1].op, _callback)
     return sched
index 7bd9bdb..943aee0 100644 (file)
@@ -38,7 +38,7 @@ def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype):
     # instantiation and include it only once, eliminating the need for unique
     # IDs
     UNIQ_ID_LEN = 8
-    uniq_id = ''.join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN))
+    uniq_id = "".join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN))
 
     if isinstance(M, tvm.tir.IntImm):
         M = M.value
@@ -48,63 +48,74 @@ def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype):
         N = N.value
     assert K % 4 == 0
     # TODO(weberlo, areusch): support more dtypes?
-    assert in_dtype == 'int8'
-    assert out_dtype == 'int32'
-    A = te.placeholder((M, K), name='a', dtype=in_dtype)
-    B = te.placeholder((N, K), name='b', dtype=in_dtype)
-    k = te.reduce_axis((0, K), name='k')
+    assert in_dtype == "int8"
+    assert out_dtype == "int32"
+    A = te.placeholder((M, K), name="a", dtype=in_dtype)
+    B = te.placeholder((N, K), name="b", dtype=in_dtype)
+    k = te.reduce_axis((0, K), name="k")
     C = te.compute(
         (M, N),
         lambda i, j: te.sum(A[i, k].astype(out_dtype) * B[j, k].astype(out_dtype), axis=k),
-        name='c')
+        name="c",
+    )
     A_buf = tvm.tir.decl_buffer(
-        A.shape, A.dtype,
-        name="A",
-        offset_factor=1,
-        strides=[te.var("A_s"), 1])
+        A.shape, A.dtype, name="A", offset_factor=1, strides=[te.var("A_s"), 1]
+    )
     B_buf = tvm.tir.decl_buffer(
-        B.shape, B.dtype,
-        name="B",
-        offset_factor=1,
-        strides=[te.var("B_s"), 1])
+        B.shape, B.dtype, name="B", offset_factor=1, strides=[te.var("B_s"), 1]
+    )
     C_buf = tvm.tir.decl_buffer(
-        C.shape, C.dtype,
-        name="C",
-        offset_factor=1,
-        strides=[te.var("C_s"), 1])
+        C.shape, C.dtype, name="C", offset_factor=1, strides=[te.var("C_s"), 1]
+    )
+
     def intrin_func(ins, outs):
         aa, bb = ins
         cc = outs[0]
+
         def _reduce_update():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_update_{uniq_id}",
-                                        aa.access_ptr("r"),
-                                        bb.access_ptr("r"),
-                                        cc.access_ptr("w"),
-                                        aa.strides[0],
-                                        bb.strides[0],
-                                        cc.strides[0]))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    f"gemm_{M}x{K}x{N}_update_{uniq_id}",
+                    aa.access_ptr("r"),
+                    bb.access_ptr("r"),
+                    cc.access_ptr("w"),
+                    aa.strides[0],
+                    bb.strides[0],
+                    cc.strides[0],
+                )
+            )
             return ib.get()
+
         def _reduce_reset():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}",
-                                        cc.access_ptr("w"),
-                                        cc.strides[0]))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
+                )
+            )
             return ib.get()
+
         def _body():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_body_{uniq_id}",
-                                        aa.access_ptr("r"),
-                                        bb.access_ptr("r"),
-                                        cc.access_ptr("w"),
-                                        aa.strides[0],
-                                        bb.strides[0],
-                                        cc.strides[0]))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    f"gemm_{M}x{K}x{N}_body_{uniq_id}",
+                    aa.access_ptr("r"),
+                    bb.access_ptr("r"),
+                    cc.access_ptr("w"),
+                    aa.strides[0],
+                    bb.strides[0],
+                    cc.strides[0],
+                )
+            )
             return ib.get()
+
         return _body(), _reduce_reset(), _reduce_update()
 
-    intrin_decl = te.decl_tensor_intrin(
-        C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf})
+    intrin_decl = te.decl_tensor_intrin(C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf})
     return intrin_decl, uniq_id
 
 
index 07749ee..b71c0c9 100644 (file)
@@ -32,6 +32,7 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype
     """Compute depthwise_conv2d with NCHW layout"""
     return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
+
 @autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
 def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule depthwise conv2d
@@ -58,54 +59,61 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
 
         ##### space definition begin #####
         n, c, h, w = s[output].op.axis
-        _, vc = cfg.define_split('tile_c', c, num_outputs=2)
-        _, vh = cfg.define_split('tile_h', h, num_outputs=2)
-        _, vw = cfg.define_split('tile_w', w, num_outputs=2)
-        cfg.define_annotate('ann', [vh, vw, vc], policy='try_unroll_vec')
+        _, vc = cfg.define_split("tile_c", c, num_outputs=2)
+        _, vh = cfg.define_split("tile_h", h, num_outputs=2)
+        _, vw = cfg.define_split("tile_w", w, num_outputs=2)
+        cfg.define_annotate("ann", [vh, vw, vc], policy="try_unroll_vec")
 
         # fallback support
         if cfg.is_fallback:
             ref_log = autotvm.tophub.load_reference_log(
-                'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw.arm_cpu')
+                "arm_cpu", "rk3399", "depthwise_conv2d_nchw.arm_cpu"
+            )
             cfg.fallback_with_reference_log(ref_log)
         ##### space definition end #####
 
         # park data to vector form  [n, c, h, w] -> [n, C, h, w, VC]
         A0 = s.cache_read(data_pad, "global", C)
         n, c, h, w = s[A0].op.axis
-        c, vc = cfg['tile_c'].apply(s, A0, c)
+        c, vc = cfg["tile_c"].apply(s, A0, c)
         s[A0].reorder(n, c, h, w, vc)
-        A1 = s.cache_write(A0, 'global')
+        A1 = s.cache_write(A0, "global")
         s[A0].compute_inline()
 
         # park kernel to vector form  [co, ci, kh, kw] -> [CO, ci, kh, kw, VC]
         B0 = s.cache_read(B, "global", C)
         c, m, h, w = s[B0].op.axis
-        c, vc, = cfg['tile_c'].apply(s, B0, c)
+        c, vc, = cfg[
+            "tile_c"
+        ].apply(s, B0, c)
         s[B0].reorder(c, m, h, w, vc)
-        B1 = s.cache_write(B0, 'global')
+        B1 = s.cache_write(B0, "global")
         s[B0].compute_inline()
 
         n, c, h, w = s[C].op.axis
-        c, vc, = cfg['tile_c'].apply(s, C, c)
+        c, vc, = cfg[
+            "tile_c"
+        ].apply(s, C, c)
         s[C].reorder(n, c, h, w, vc)
 
         # depthwise conv
-        C0 = s.cache_write(C, 'global')
+        C0 = s.cache_write(C, "global")
         _, c, h, w, vc = s[C0].op.axis
         dh, dw = s[C0].op.reduce_axis
-        oh, ih = cfg['tile_h'].apply(s, C0, h)
-        ow, iw = cfg['tile_w'].apply(s, C0, w)
+        oh, ih = cfg["tile_h"].apply(s, C0, h)
+        ow, iw = cfg["tile_w"].apply(s, C0, w)
         s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
         s[A1].compute_at(s[C0], oh)
 
         # try unroll and vectorization
-        cfg['ann'].apply(s, C0, [ih, iw, vc],
-                         axis_lens=[cfg['tile_h'].size[-1],
-                                    cfg['tile_w'].size[-1],
-                                    cfg['tile_c'].size[-1]],
-                         max_unroll=16,
-                         cfg=cfg)
+        cfg["ann"].apply(
+            s,
+            C0,
+            [ih, iw, vc],
+            axis_lens=[cfg["tile_h"].size[-1], cfg["tile_w"].size[-1], cfg["tile_c"].size[-1]],
+            max_unroll=16,
+            cfg=cfg,
+        )
 
         # fusion
         if C.op not in s.outputs:
@@ -125,7 +133,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
         return s
 
     def _callback(op):
-        if op.tag == 'depthwise_conv2d_nchw':
+        if op.tag == "depthwise_conv2d_nchw":
             output = op.output(0)
             kernel = op.input_tensors[1]
             data = op.input_tensors[0]
@@ -181,6 +189,7 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+
 @autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
 def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
     """TOPI compute callback for depthwise_conv2d nhwc
@@ -229,39 +238,47 @@ def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, o
     dilated_kernel_w = (KW - 1) * dilation_w + 1
 
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
 
     OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
     OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
 
     if pad_top or pad_left or pad_down or pad_right:
-        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
-                          name="data_pad")
+        data_pad = nn.pad(
+            data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad"
+        )
     else:
         data_pad = data
 
-    output_shape = (N, OH, OW, IC*channel_multiplier)
+    output_shape = (N, OH, OW, IC * channel_multiplier)
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
-    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
-
-    out = te.compute(output_shape, lambda n, h, w, c:
-                     te.sum(data_pad[n,
-                                     HSTR*h+dilation_h*reduce_h,
-                                     w*WSTR+reduce_w*dilation_w,
-                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
-                            kernel[reduce_h,
-                                   reduce_w,
-                                   idxdiv(c, channel_multiplier),
-                                   idxmod(c, channel_multiplier)].astype(out_dtype),
-                            axis=[reduce_h, reduce_w]),
-                     name='depthwise_conv2d_nhwc_output')
+    reduce_h = te.reduce_axis((0, KH), name="reduce_h")
+    reduce_w = te.reduce_axis((0, KW), name="reduce_w")
+
+    out = te.compute(
+        output_shape,
+        lambda n, h, w, c: te.sum(
+            data_pad[
+                n,
+                HSTR * h + dilation_h * reduce_h,
+                w * WSTR + reduce_w * dilation_w,
+                idxdiv(c, channel_multiplier),
+            ].astype(out_dtype)
+            * kernel[
+                reduce_h, reduce_w, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)
+            ].astype(out_dtype),
+            axis=[reduce_h, reduce_w],
+        ),
+        name="depthwise_conv2d_nhwc_output",
+    )
     return out
 
+
 @autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
 def schedule_depthwise_conv2d_nhwc(cfg, outs):
     """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
@@ -271,17 +288,17 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):
 
     ##### space definition begin #####
     n, h, w, c = s[out].op.axis
-    cfg.define_split('tile_c', c, num_outputs=2)
-    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
-    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
-    cfg.define_knob('locate_output', [0, 1])
+    cfg.define_split("tile_c", c, num_outputs=2)
+    _, hi = cfg.define_split("tile_h", h, num_outputs=2)
+    _, wi = cfg.define_split("tile_w", w, num_outputs=2)
+    cfg.define_knob("locate_output", [0, 1])
 
     # fallback support
     if cfg.is_fallback:
-        cfg['tile_c'] = SplitEntity([-1, 8])
-        cfg['tile_h'] = SplitEntity([-1, 2])
-        cfg['tile_w'] = SplitEntity([-1, 2])
-        cfg['locate_output'] = OtherOptionEntity(1)
+        cfg["tile_c"] = SplitEntity([-1, 8])
+        cfg["tile_h"] = SplitEntity([-1, 2])
+        cfg["tile_w"] = SplitEntity([-1, 2])
+        cfg["locate_output"] = OtherOptionEntity(1)
     ##### space definition end #####
 
     def schedule_conv(conv):
@@ -289,23 +306,23 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):
 
         n, w, h, c = conv.op.axis
         r_h, r_w = conv.op.reduce_axis
-        ho, hi = cfg['tile_h'].apply(s, conv, h)
-        wo, wi = cfg['tile_w'].apply(s, conv, w)
-        co, ci = cfg['tile_c'].apply(s, conv, c)
+        ho, hi = cfg["tile_h"].apply(s, conv, h)
+        wo, wi = cfg["tile_w"].apply(s, conv, w)
+        co, ci = cfg["tile_c"].apply(s, conv, c)
 
         if conv_data.name == "data_pad":
             assert isinstance(conv_data.op, tvm.te.ComputeOp)
             # Define a policy for padding computation
-            cfg.define_knob('data_pad_inline', [1, 2, 3])
+            cfg.define_knob("data_pad_inline", [1, 2, 3])
             if cfg.is_fallback:
-                cfg['data_pad_inline'] = OtherOptionEntity(3)
-            if cfg['data_pad_inline'].val == 1:
+                cfg["data_pad_inline"] = OtherOptionEntity(3)
+            if cfg["data_pad_inline"].val == 1:
                 s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                 s[conv_data].compute_at(s[conv], ho)
-            if cfg['data_pad_inline'].val == 2:
+            if cfg["data_pad_inline"].val == 2:
                 s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                 s[conv_data].compute_at(s[conv], wo)
-            if cfg['data_pad_inline'].val == 3:
+            if cfg["data_pad_inline"].val == 3:
                 s[conv_data].compute_inline()
 
         s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
@@ -315,12 +332,12 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):
 
     def schedule_conv_out(out):
         n, h, w, c = out.op.axis
-        co, ci = cfg['tile_c'].apply(s, out, c)
-        wo, wi = cfg['tile_w'].apply(s, out, w)
-        ho, hi = cfg['tile_h'].apply(s, out, h)
+        co, ci = cfg["tile_c"].apply(s, out, c)
+        wo, wi = cfg["tile_w"].apply(s, out, w)
+        ho, hi = cfg["tile_h"].apply(s, out, h)
         s[out].reorder(n, ho, wo, co, hi, wi)
 
-        if out.dtype in ['int8', 'uint8']:
+        if out.dtype in ["int8", "uint8"]:
             # In case of quantized convolution further split the channel in batches of 4 elements
             # so that we can use arm intrinsics to run fixed_point_multiplication
             ci_outer, ci_inner = s[out].split(ci, 4)
@@ -330,14 +347,14 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):
         return hi, wi, fused_n_ho
 
     def _callback(op):
-        if op.name == 'depthwise_conv2d_nhwc_output':
+        if op.name == "depthwise_conv2d_nhwc_output":
             conv = op.output(0)
             if conv != out:
                 hi, wi, p_axis = schedule_conv_out(out)
                 schedule_conv(conv)
-                if cfg['locate_output'].val == 0:
+                if cfg["locate_output"].val == 0:
                     s[conv].compute_at(s[out], hi)
-                if cfg['locate_output'].val == 1:
+                if cfg["locate_output"].val == 1:
                     s[conv].compute_at(s[out], wi)
             else:
                 p_axis = schedule_conv(out)
@@ -347,6 +364,7 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 @autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
 def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
     """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
@@ -354,12 +372,12 @@ def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'spatial_depthwise_conv2d_nchw_output':
+        if op.tag == "spatial_depthwise_conv2d_nchw_output":
             output = op.output(0)
             conv = op.input_tensors[0]
             data_vec = conv.op.input_tensors[0]
             kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
+            if kernel_vec.op.name == "kernel_vec":
                 kernel = kernel_vec.op.input_tensors[0]
             else:
                 kernel = kernel_vec
@@ -393,17 +411,19 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
     dilated_kernel_w = (KW - 1) * dilation_w + 1
 
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
     OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
     # pack data
     HPAD = pad_top + pad_down
     WPAD = pad_left + pad_right
-    DOPAD = (HPAD != 0 or WPAD != 0)
+    DOPAD = HPAD != 0 or WPAD != 0
     if DOPAD:
-        data_pad = nn.pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
-                          name="data_pad")
+        data_pad = nn.pad(
+            data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), name="data_pad"
+        )
     else:
         data_pad = data
 
@@ -411,7 +431,8 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
     # Currently, Mali schedule doesn't use it like conv2d.
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw_spatial_pack.arm_cpu')
+            "arm_cpu", "rk3399", "depthwise_conv2d_nchw_spatial_pack.arm_cpu"
+        )
         cfg.fallback_with_reference_log(ref_log)
 
     # ==================== define configuration space ====================
@@ -420,28 +441,33 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
 
     # Currently, Mali schedule doesn't use it like conv2d.
     # Leave num_tile for possible future use of Mali schedule
-    if num_tile == 2:     # for arm cpu
-        co, vc = cfg.define_split('tile_co', c, num_outputs=2)
-        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
-        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
+    if num_tile == 2:  # for arm cpu
+        co, vc = cfg.define_split("tile_co", c, num_outputs=2)
+        oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
+        ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
     else:
         raise RuntimeError("Invalid num_tile")
 
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, kh, kw, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, kh, kw, vh, vw, vc],
-                           [n, co, oh, ow, kh, kw, vc, vh, vw]])
-
-    cfg.define_reorder("reorder_1",
-                       [n, co, oh, ow, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, vh, vw, vc],
-                           [n, co, oh, ow, vc, vh, vw],
-                           [n, co, oh, ow, vh, vc, vw]])
-
-    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
-    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
+    cfg.define_reorder(
+        "reorder_0",
+        [n, co, oh, ow, kh, kw, vh, vw, vc],
+        policy="candidate",
+        candidate=[[n, co, oh, ow, kh, kw, vh, vw, vc], [n, co, oh, ow, kh, kw, vc, vh, vw]],
+    )
+
+    cfg.define_reorder(
+        "reorder_1",
+        [n, co, oh, ow, vh, vw, vc],
+        policy="candidate",
+        candidate=[
+            [n, co, oh, ow, vh, vw, vc],
+            [n, co, oh, ow, vc, vh, vw],
+            [n, co, oh, ow, vh, vc, vw],
+        ],
+    )
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec")
     # ====================================================================
 
     VC = cfg["tile_co"].size[-1]
@@ -455,59 +481,80 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
     if dilation_h != 1 or dilation_w != 1:
         # undilate input data
         dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW)
-        data_vec = te.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw:
-                              data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h]
-                              [(w*VW+vw)*WSTR+kw*dilation_w],
-                              name='data_vec_undilated')
+        data_vec = te.compute(
+            dvshape,
+            lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][
+                (h * VH + vh) * HSTR + kh * dilation_h
+            ][(w * VW + vw) * WSTR + kw * dilation_w],
+            name="data_vec_undilated",
+        )
     else:
-        dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1)
-        data_vec = te.compute(dvshape, lambda n, h, w, c, vh, vw:
-                              data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
-                              name='data_vec')
+        dvshape = (N, OH // VH, OW // VW, C, VH * HSTR + KH - 1, VW * WSTR + KW - 1)
+        data_vec = te.compute(
+            dvshape,
+            lambda n, h, w, c, vh, vw: data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
+            name="data_vec",
+        )
 
     if pre_packed:
         kernel_vec = kernel
     else:
-        kernel_vec = te.compute(kvshape, lambda co, m, kh, kw, vc:
-                                kernel[co*VC+vc][m][kh][kw],
-                                name='kernel_vec')
+        kernel_vec = te.compute(
+            kvshape, lambda co, m, kh, kw, vc: kernel[co * VC + vc][m][kh][kw], name="kernel_vec"
+        )
 
-    kh = te.reduce_axis((0, KH), name='kh')
-    kw = te.reduce_axis((0, KW), name='kw')
+    kh = te.reduce_axis((0, KH), name="kh")
+    kw = te.reduce_axis((0, KW), name="kw")
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
     if dilation_h != 1 or dilation_w != 1:
         conv = te.compute(
-            ovshape, lambda n, co, h, w, vh, vw, vc: \
-            te.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]
-                   .astype(out_dtype) *
-                   kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
-                   axis=[kh, kw]), name='depthwise_conv')
+            ovshape,
+            lambda n, co, h, w, vh, vw, vc: te.sum(
+                data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw].astype(out_dtype)
+                * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
+                axis=[kh, kw],
+            ),
+            name="depthwise_conv",
+        )
     else:
-        conv = te.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                          te.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh,
-                                          vw * WSTR + kw].astype(out_dtype) *
-                                 kernel_vec[idxdiv(co, M),
-                                            idxmod(co, M),
-                                            kh, kw, vc].astype(out_dtype),
-                                 axis=[kh, kw]), name='depthwise_conv')
-
-    output = te.compute(oshape, lambda n, co, h, w:
-                        conv[n,
-                             idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
-                             idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
-                        name='output_unpack', tag='spatial_depthwise_conv2d_nchw_output')
+        conv = te.compute(
+            ovshape,
+            lambda n, co, h, w, vh, vw, vc: te.sum(
+                data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh, vw * WSTR + kw].astype(
+                    out_dtype
+                )
+                * kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
+                axis=[kh, kw],
+            ),
+            name="depthwise_conv",
+        )
+
+    output = te.compute(
+        oshape,
+        lambda n, co, h, w: conv[
+            n,
+            idxdiv(co, VC),
+            idxdiv(h, VH),
+            idxdiv(w, VW),
+            idxmod(h, VH),
+            idxmod(w, VW),
+            idxmod(co, VC),
+        ],
+        name="output_unpack",
+        tag="spatial_depthwise_conv2d_nchw_output",
+    )
     return output
 
-def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
-                           conv, output, last):
+
+def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last):
     """schedule implementation"""
     n, co, oh, ow, vh, vw, vc = s[conv].op.axis
     kh, kw = s[conv].op.reduce_axis
 
-    if data_vec.op.name == 'data_vec_undilated':
+    if data_vec.op.name == "data_vec_undilated":
         _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis
     else:
         _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis
@@ -520,80 +567,87 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
         assert isinstance(data_pad.op, tvm.te.PlaceholderOp)
         has_padding = False
 
-    cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4])
+    cfg.define_knob("data_pad_inline", [0, 1, 2, 3, 4])
 
-    if cfg['data_pad_inline'].val == 1 and has_padding:
+    if cfg["data_pad_inline"].val == 1 and has_padding:
         s[data_pad].compute_inline()
-    if cfg['data_pad_inline'].val == 2 and has_padding:
+    if cfg["data_pad_inline"].val == 2 and has_padding:
         s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
-    if cfg['data_pad_inline'].val == 3 and has_padding:
+    if cfg["data_pad_inline"].val == 3 and has_padding:
         s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
         s[data_pad].compute_at(s[data_vec], dv_oh)
-    if cfg['data_pad_inline'].val == 4 and has_padding:
+    if cfg["data_pad_inline"].val == 4 and has_padding:
         s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
         s[data_pad].compute_at(s[data_vec], dv_ow)
 
-    cfg.define_knob('data_vec_inline', [0, 1, 2, 3])
-    if cfg['data_vec_inline'].val == 1:
+    cfg.define_knob("data_vec_inline", [0, 1, 2, 3])
+    if cfg["data_vec_inline"].val == 1:
         s[data_vec].compute_at(s[conv], oh)
-    if cfg['data_vec_inline'].val == 2:
+    if cfg["data_vec_inline"].val == 2:
         s[data_vec].compute_at(s[conv], ow)
-    if cfg['data_vec_inline'].val == 3:
+    if cfg["data_vec_inline"].val == 3:
         s[data_vec].compute_at(s[conv], co)
 
     # schedule conv
     cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc])
-    cfg["ann_reduce"].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kh.dom.extent),
-                                       get_const_int(kw.dom.extent)],
-                            max_unroll=16,
-                            cfg=cfg)
-    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
-                             axis_lens=[cfg['tile_oh'].size[-1],
-                                        cfg['tile_ow'].size[-1],
-                                        cfg['tile_co'].size[-1]],
-                             max_unroll=16,
-                             cfg=cfg)
+    cfg["ann_reduce"].apply(
+        s,
+        conv,
+        [kh, kw],
+        axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)],
+        max_unroll=16,
+        cfg=cfg,
+    )
+    cfg["ann_spatial"].apply(
+        s,
+        conv,
+        [vh, vw, vc],
+        axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
+        max_unroll=16,
+        cfg=cfg,
+    )
 
     # schedule fusion
     n, co, h, w = s[last].op.axis
-    co, vc = cfg['tile_co'].apply(s, last, co)
-    oh, vh = cfg['tile_oh'].apply(s, last, h)
-    ow, vw = cfg['tile_ow'].apply(s, last, w)
+    co, vc = cfg["tile_co"].apply(s, last, co)
+    oh, vh = cfg["tile_oh"].apply(s, last, h)
+    ow, vw = cfg["tile_ow"].apply(s, last, w)
     cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc])
     if last != output:
         s[output].compute_inline()
-        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
-                                 axis_lens=[cfg['tile_oh'].size[-1],
-                                            cfg['tile_ow'].size[-1],
-                                            cfg['tile_co'].size[-1]],
-                                 max_unroll=16,
-                                 cfg=cfg)
+        cfg["ann_spatial"].apply(
+            s,
+            last,
+            [vh, vw, vc],
+            axis_lens=[cfg["tile_oh"].size[-1], cfg["tile_ow"].size[-1], cfg["tile_co"].size[-1]],
+            max_unroll=16,
+            cfg=cfg,
+        )
     else:
         s[last].vectorize(vw)
-    cfg.define_knob('conv_inline', [0, 1, 2, 3])
-    if cfg['conv_inline'].val == 1:
+    cfg.define_knob("conv_inline", [0, 1, 2, 3])
+    if cfg["conv_inline"].val == 1:
         s[conv].compute_at(s[last], ow)
-    if cfg['conv_inline'].val == 2:
+    if cfg["conv_inline"].val == 2:
         s[conv].compute_at(s[last], oh)
-    if cfg['conv_inline'].val == 3:
+    if cfg["conv_inline"].val == 3:
         s[conv].compute_at(s[last], co)
 
     # mark parallel
     s[last].parallel(co)
 
-    if data_vec.op.name == 'data_vec_undilated':
+    if data_vec.op.name == "data_vec_undilated":
         _, h, _, _, _, _, _, _ = s[data_vec].op.axis
     else:
         _, h, _, _, _, _ = s[data_vec].op.axis
     s[data_vec].parallel(h)
 
-    if kernel_vec.op.name == 'kernel_vec':
+    if kernel_vec.op.name == "kernel_vec":
         co, _, _, _, _ = s[kernel_vec].op.axis
         if autotvm.GLOBAL_SCOPE.in_tuning:
             # kernel packing will be pre-computed during compliation, so we skip
             # this part to make tuning records correct
-            s[kernel_vec].pragma(co, 'debug_skip_region')
+            s[kernel_vec].pragma(co, "debug_skip_region")
         else:
             s[kernel_vec].parallel(co)
 
index 3e3c73d..aec86bc 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from ..util import is_empty_shape
 
+
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -45,6 +46,7 @@ def schedule_injective_from_existing(sch, out):
         sch[out].parallel(sch[out].op.axis[0])
     return sch
 
+
 def schedule_injective(outs):
     """ARM CPU schedule for injective op.
 
@@ -73,6 +75,7 @@ def schedule_injective(outs):
         schedule_injective_from_existing(s, x)
     return s
 
+
 def schedule_concatenate(outs):
     """Schedule for concatenate op.
 
index 52e67ad..e87bdc4 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from tvm import te
 from tvm.contrib import util, clang
 
+
 def gemm_quantized_4_4_batched():
     return """
            // First half
@@ -114,6 +115,7 @@ def gemm_quantized_4_4_batched():
            "uadalp v31.4s, v15.8h\\n"
     """
 
+
 def gemm_quantized_4_4_interleaved():
     return """
              // First half
@@ -200,23 +202,23 @@ def gemm_quantized_4_4_interleaved():
     """
 
 
-def gemm_quantized_impl(M, N, K, unroll, interleave, data_type='uint8'):
-    """ Assembly implementation of a blocked gemv. Given
+def gemm_quantized_impl(M, N, K, unroll, interleave, data_type="uint8"):
+    """Assembly implementation of a blocked gemv. Given
     a block a of shape (4, k) and a block b' of shape (4, k)
-    produces the output block c = a*b of shape (4,4) """
+    produces the output block c = a*b of shape (4,4)"""
 
     stepA = min(4, M)
     stepB = min(4, N)
-    assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation'
+    assert data_type in ["uint8", "int8"], "Only uint8/int8 supported for this implementation"
 
-    signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(data_type,
-                                                                                 stepA,
-                                                                                 stepB)
+    signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(
+        data_type, stepA, stepB
+    )
     if unroll:
-        signature += ("_" + str(K))
+        signature += "_" + str(K)
 
     if interleave:
-        signature += ("_interleaved")
+        signature += "_interleaved"
 
     signature += """(int *c_buffer,
                       unsigned char *a_buffer,
@@ -291,10 +293,12 @@ def gemm_quantized_impl(M, N, K, unroll, interleave, data_type='uint8'):
     blockB = min(64, N * 16)
     main_loop += """// Increment pointers
                     "add %[a_ptr], %[a_ptr], #{0}\\n"
-                    "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(blockA, blockB)
+                    "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(
+        blockA, blockB
+    )
 
     if unroll:
-        k = int(K//16)
+        k = int(K // 16)
         for l in range(0, k):
             cc_code += main_loop
     else:
@@ -363,17 +367,17 @@ def gemm_quantized_impl(M, N, K, unroll, interleave, data_type='uint8'):
         }
     """
 
-    if data_type == 'int8':
-        cc_code = cc_code.replace('unsigned char', 'char')
-        cc_code = cc_code.replace('umull', 'smull')
-        cc_code = cc_code.replace('uadalp', 'sadalp')
+    if data_type == "int8":
+        cc_code = cc_code.replace("unsigned char", "char")
+        cc_code = cc_code.replace("umull", "smull")
+        cc_code = cc_code.replace("uadalp", "sadalp")
 
     temp = util.tempdir()
     ll_path = temp.relpath("temp.ll")
     # Create LLVM ir from c source code
-    ll_code = clang.create_llvm(cc_code,
-                                options=["--target=aarch64-linux-gnu -mattr=+neon"],
-                                output=ll_path)
+    ll_code = clang.create_llvm(
+        cc_code, options=["--target=aarch64-linux-gnu -mattr=+neon"], output=ll_path
+    )
     return ll_code
 
 
@@ -407,29 +411,44 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     intrin : TensorIntrin
         The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
     """
-    A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name='A')
-    B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name='B')
+    A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A")
+    B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B")
 
     idxm = tvm.tir.indexmod
 
     k = te.reduce_axis((0, K), "k")
 
-    C = te.compute((te.var("m"), te.var("n")),
-                   lambda x, y: te.sum(A[k // 16, x, idxm(k, 16)].astype(out_type) *
-                                       B[k // 16, y, idxm(k, 16)].astype(out_type),
-                                       axis=k), name="C")
-
-    a_buffer = tvm.tir.decl_buffer(A.shape, dtype=in_type, name="a_buffer",
-                                   offset_factor=1, strides=[te.var('sa_1'), te.var('sa_2'), 1])
-
-    b_buffer = tvm.tir.decl_buffer(B.shape, dtype=in_type, name="b_buffer",
-                                   offset_factor=1, strides=[te.var('sb_1'), te.var('sb_2'), 1])
-
-    c_buffer = tvm.tir.decl_buffer(C.shape, dtype=out_type, name="c_buffer",
-                                   offset_factor=1, strides=[te.var('sc'), 1])
+    C = te.compute(
+        (te.var("m"), te.var("n")),
+        lambda x, y: te.sum(
+            A[k // 16, x, idxm(k, 16)].astype(out_type)
+            * B[k // 16, y, idxm(k, 16)].astype(out_type),
+            axis=k,
+        ),
+        name="C",
+    )
+
+    a_buffer = tvm.tir.decl_buffer(
+        A.shape,
+        dtype=in_type,
+        name="a_buffer",
+        offset_factor=1,
+        strides=[te.var("sa_1"), te.var("sa_2"), 1],
+    )
+
+    b_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype=in_type,
+        name="b_buffer",
+        offset_factor=1,
+        strides=[te.var("sb_1"), te.var("sb_2"), 1],
+    )
+
+    c_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype=out_type, name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
 
     def _intrin_func(ins, outs):
-
         def _instr():
             ib = tvm.tir.ir_builder.create()
             aa, bb = ins
@@ -438,27 +457,34 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
             stepB = min(4, N)
             intrin_name = "gemm_quantized_{0}_{0}_int32_{1}_{2}".format(in_type, stepA, stepB)
             if unroll:
-                intrin_name += ("_" + str(K))
+                intrin_name += "_" + str(K)
             if interleave:
                 intrin_name += "_interleaved"
-            ib.emit(tvm.tir.call_extern("int32",
-                                        intrin_name,
-                                        outs[0].access_ptr("w"),
-                                        a_buffer.access_ptr("r"),
-                                        b_buffer.access_ptr("r"),
-                                        K))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    intrin_name,
+                    outs[0].access_ptr("w"),
+                    a_buffer.access_ptr("r"),
+                    b_buffer.access_ptr("r"),
+                    K,
+                )
+            )
             return ib.get()
 
         # body, reset, update
         return _instr()
 
     buffer_params = {"offset_factor": 1}
-    return te.decl_tensor_intrin(C.op, _intrin_func,
-                                 binds={A:a_buffer, B:b_buffer, C:c_buffer},
-                                 default_buffer_params=buffer_params)
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: a_buffer, B: b_buffer, C: c_buffer},
+        default_buffer_params=buffer_params,
+    )
 
 
-def dot_int8_int8_int32(int32_lanes, dtype='uint'):
+def dot_int8_int8_int32(int32_lanes, dtype="uint"):
     """
     Int8 dot product by every 4 elements using ARM v8.2 udot.
     This function takes two arrays of int8 datatype -- data[4] and
@@ -496,50 +522,58 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'):
     """
     num_int8_elements = 4  # 4 int8 elements in int32
 
-    data = te.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data')
-    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel')
-
-    k = te.reduce_axis((0, num_int8_elements), name='k')
-    C = te.compute((int32_lanes,),
-                   lambda i: te.sum(data[k].astype('%s32' % dtype) *
-                                    kernel[i, k].astype('%s32' % dtype),
-                                    axis=k), name="C")
-
-    a_buffer = tvm.tir.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer",
-                                   offset_factor=1,
-                                   strides=[1])
-    b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer",
-                                   offset_factor=1,
-                                   strides=[te.var('s'), 1])
+    data = te.placeholder((num_int8_elements,), dtype="%s8" % dtype, name="data")
+    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="%s8" % dtype, name="kernel")
+
+    k = te.reduce_axis((0, num_int8_elements), name="k")
+    C = te.compute(
+        (int32_lanes,),
+        lambda i: te.sum(
+            data[k].astype("%s32" % dtype) * kernel[i, k].astype("%s32" % dtype), axis=k
+        ),
+        name="C",
+    )
+
+    a_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype="%s8" % dtype, name="a_buffer", offset_factor=1, strides=[1]
+    )
+    b_buffer = tvm.tir.decl_buffer(
+        kernel.shape,
+        dtype="%s8" % dtype,
+        name="b_buffer",
+        offset_factor=1,
+        strides=[te.var("s"), 1],
+    )
 
     def _intrin_func(ins, outs):
         def _instr(index):
             ib = tvm.tir.ir_builder.create()
             if index == 1:
-                ib.emit(outs[0].vstore(0, tvm.tir.const(0, '%s32x%d' % (dtype, int32_lanes))))
+                ib.emit(outs[0].vstore(0, tvm.tir.const(0, "%s32x%d" % (dtype, int32_lanes))))
                 return ib.get()
 
-            dtype_a = '%s8x%d' % (dtype, num_int8_elements)
-            dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
-            dtype_c = '%s32x%d' % (dtype, int32_lanes)
+            dtype_a = "%s8x%d" % (dtype, num_int8_elements)
+            dtype_b = "%s8x%d" % (dtype, int32_lanes * num_int8_elements)
+            dtype_c = "%s32x%d" % (dtype, int32_lanes)
 
             a_int8 = ins[0].vload([0], dtype_a)
-            re_int32 = tvm.tir.call_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
+            re_int32 = tvm.tir.call_intrin("%s32" % dtype, "tir.reinterpret", a_int8)
             # broadcast a
             vec_ai32 = re_int32.astype(dtype_c)
 
-            vec_a = tvm.tir.call_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
+            vec_a = tvm.tir.call_intrin(dtype_b, "tir.reinterpret", vec_ai32)
             vec_b = ins[1].vload([0, 0], dtype_b)
             vec_c = outs[0].vload([0], dtype_c)
 
-            inst = 'udot' if dtype == 'uint' else 'sdot'
-            inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
-                inst, int32_lanes, int32_lanes * num_int8_elements)
-            vdot = tvm.tir.call_llvm_pure_intrin(
-                dtype_c,
+            inst = "udot" if dtype == "uint" else "sdot"
+            inst = "llvm.aarch64.neon.%s.v%di32.v%di8" % (
                 inst,
-                tvm.tir.const(2, 'uint32'),
-                vec_c, vec_a, vec_b)
+                int32_lanes,
+                int32_lanes * num_int8_elements,
+            )
+            vdot = tvm.tir.call_llvm_pure_intrin(
+                dtype_c, inst, tvm.tir.const(2, "uint32"), vec_c, vec_a, vec_b
+            )
             ib.emit(outs[0].vstore(0, vdot))
             return ib.get()
 
@@ -548,8 +582,12 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'):
 
     buffer_params = {"offset_factor": 1}
     return te.decl_tensor_intrin(
-        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
-        default_buffer_params=buffer_params)
+        C.op,
+        _intrin_func,
+        binds={data: a_buffer, kernel: b_buffer},
+        default_buffer_params=buffer_params,
+    )
+
 
 def _q_multiply_shift_arm(op):
     """
@@ -574,31 +612,26 @@ def _q_multiply_shift_arm(op):
         return op
 
     # Case 1, shift is negative
-    sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype,
-                                        'llvm.aarch64.neon.sqrdmulh',
-                                        tvm.tir.const(2, 'uint32'),
-                                        x,
-                                        y)
+    sqrdmulh = tvm.tir.call_llvm_intrin(
+        op.dtype, "llvm.aarch64.neon.sqrdmulh", tvm.tir.const(2, "uint32"), x, y
+    )
 
     fixup = (sqrdmulh & (-s)) >> 31
-    fixed_up_x = (sqrdmulh + fixup)
-    out_1 = tvm.tir.call_llvm_intrin(op.dtype,
-                                     'llvm.aarch64.neon.srshl',
-                                     tvm.tir.const(2, 'uint32'),
-                                     sqrdmulh,
-                                     s)
+    fixed_up_x = sqrdmulh + fixup
+    out_1 = tvm.tir.call_llvm_intrin(
+        op.dtype, "llvm.aarch64.neon.srshl", tvm.tir.const(2, "uint32"), sqrdmulh, s
+    )
 
     # Case 2, shift is positive
     x = x * (1 << (s))
-    out_2 = tvm.tir.call_llvm_intrin(op.dtype,
-                                     'llvm.aarch64.neon.sqrdmulh',
-                                     tvm.tir.const(2, 'uint32'),
-                                     x,
-                                     y)
+    out_2 = tvm.tir.call_llvm_intrin(
+        op.dtype, "llvm.aarch64.neon.sqrdmulh", tvm.tir.const(2, "uint32"), x, y
+    )
 
     # Select depending on the shift
     return tvm.tir.Select(s < 0, out_1, out_2)
 
-tvm.target.intrin.register_intrin_rule("llvm.aarch64",
-                                       "q_multiply_shift",
-                                       _q_multiply_shift_arm, override=True)
+
+tvm.target.intrin.register_intrin_rule(
+    "llvm.aarch64", "q_multiply_shift", _q_multiply_shift_arm, override=True
+)
index ecc67c7..a3be906 100644 (file)
@@ -67,8 +67,9 @@ def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_
     output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                    dilation, out_dtype, num_tile=3)
+    return conv2d_spatial_pack_nchw(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=3
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.bifrost")
@@ -92,7 +93,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
 
     def _callback(op):
         # schedule conv2d
-        if 'spatial_conv2d_output' in op.tag:
+        if "spatial_conv2d_output" in op.tag:
             output = op.output(0)
             conv = op.input_tensors[0]
 
@@ -101,7 +102,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
             s[data_pad].compute_inline()
 
             kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
+            if kernel_vec.op.name == "kernel_vec":
                 kernel = kernel_vec.op.input_tensors[0]
             else:
                 kernel = kernel_vec
@@ -131,7 +132,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
         s[data_pad].compute_inline()
 
     # schedule data packing
-    if isinstance(data_vec.op, te.tensor.ComputeOp) and data_vec.op.name == 'data_vec_undilated':
+    if isinstance(data_vec.op, te.tensor.ComputeOp) and data_vec.op.name == "data_vec_undilated":
         _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
     else:
         _, h, w, ci, vh, vw = s[data_vec].op.axis
@@ -141,7 +142,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
     if vw.dom.extent.value < max_unroll:
         s[data_vec].unroll(vw)
 
-    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec':
+    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec":
         if not autotvm.GLOBAL_SCOPE.in_tuning:
             max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
             co, ci, kh, kw, vc = s[kernel_vec].op.axis
@@ -160,16 +161,23 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
     cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc])
     tile_and_bind3d(s, conv, c, h, w, TC, TH, TW)
 
-    cfg["ann_reduce"].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kernel_vec.shape[2]),
-                                       get_const_int(kernel_vec.shape[3])],
-                            max_unroll=max_unroll)
+    cfg["ann_reduce"].apply(
+        s,
+        conv,
+        [kh, kw],
+        axis_lens=[get_const_int(kernel_vec.shape[2]), get_const_int(kernel_vec.shape[3])],
+        max_unroll=max_unroll,
+    )
 
-    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
-                             axis_lens=[VH, VW, VC],
-                             max_unroll=max_unroll,
-                             vec_size=vec_size,
-                             cfg=cfg)
+    cfg["ann_spatial"].apply(
+        s,
+        conv,
+        [vh, vw, vc],
+        axis_lens=[VH, VW, VC],
+        max_unroll=max_unroll,
+        vec_size=vec_size,
+        cfg=cfg,
+    )
 
     # schedule output
     if output.op not in s.outputs:  # has bias
@@ -193,7 +201,7 @@ def schedule_conv2d_nchw_winograd(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
+        if "winograd_conv2d_output" in op.tag:
             _schedule_winograd(cfg, s, op)
 
     traverse_inline(s, outs[0].op, _callback)
@@ -221,7 +229,7 @@ def _decl_winograd_kernel_transform(kernel, tile_size, G):
     """
     CO, CI, KH, KW = [get_const_int(x) for x in kernel.shape]
     # Only support 32 bit floats
-    out_dtype = 'float32'
+    out_dtype = "float32"
 
     alpha = G.shape[0]
     K = CO
@@ -235,22 +243,25 @@ def _decl_winograd_kernel_transform(kernel, tile_size, G):
 
     # Padded Kernel [K_round, C, KH, KW]
     # Pad the number of kernels to multiple of ALIGN
-    padded_kernel = te.compute((K_round, C, KH, KW),
-                               lambda k, c, h, w:
-                               tvm.tir.if_then_else(k < K,
-                                                    kernel[k][c][h][w],
-                                                    tvm.tir.const(0, out_dtype)),
-                               name='padded_kernel')
+    padded_kernel = te.compute(
+        (K_round, C, KH, KW),
+        lambda k, c, h, w: tvm.tir.if_then_else(
+            k < K, kernel[k][c][h][w], tvm.tir.const(0, out_dtype)
+        ),
+        name="padded_kernel",
+    )
 
     # U [alpha, alpha, K_round, C]
     # Perform the kernel transform
-    r_kh = te.reduce_axis((0, KH), 'r_kh')
-    r_kw = te.reduce_axis((0, KW), 'r_kw')
-    U = te.compute((alpha, alpha, K_round, C),
-                   lambda eps, nu, k, c:
-                   te.sum(padded_kernel[k][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
-                          axis=[r_kh, r_kw]),
-                   name='U')
+    r_kh = te.reduce_axis((0, KH), "r_kh")
+    r_kw = te.reduce_axis((0, KW), "r_kw")
+    U = te.compute(
+        (alpha, alpha, K_round, C),
+        lambda eps, nu, k, c: te.sum(
+            padded_kernel[k][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+        ),
+        name="U",
+    )
 
     return U
 
@@ -288,7 +299,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
     C = CI
     H = (IH + pt + pb - 3) // HSTR + 1
     W = (IW + pl + pr - 3) // WSTR + 1
-    nH, nW = (H + m-1) // m, (W + m-1) // m
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
 
     def upround(x, align):
@@ -304,13 +315,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
     cfg.define_knob("data_transform_wgy", [1, 2, 4, 8, 16, 32, 64])
 
     # Pack input tile
-    input_tile = te.compute((N, C, H + 2, W + 2),
-                            lambda n, c, h, w:
-                            data_pad[n][c][h][w],
-                            name='d')
+    input_tile = te.compute((N, C, H + 2, W + 2), lambda n, c, h, w: data_pad[n][c][h][w], name="d")
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
-        VC = cfg['tile_k'].size[-1]
+        VC = cfg["tile_k"].size[-1]
         kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC)
         U = tvm.te.placeholder(kvshape, kernel.dtype, name="U")
     else:
@@ -321,33 +329,44 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
 
     # V [alpha * alpha, C, P_round)
     # Perform the image transform
-    r_eps = te.reduce_axis((0, alpha), 'r_eps')
-    r_nu = te.reduce_axis((0, alpha), 'r_nu')
-    V = te.compute((alpha * alpha, C, P_round),
-                   lambda epsnu, c, b:
-                   te.sum(input_tile[b // (nH*nW)][c][b // nW % nH * m + r_eps][b % nW * m +r_nu]\
-                          * B[r_eps][epsnu // alpha] * B[r_nu][epsnu % alpha],
-                          axis=[r_eps, r_nu]),
-                   name='V')
+    r_eps = te.reduce_axis((0, alpha), "r_eps")
+    r_nu = te.reduce_axis((0, alpha), "r_nu")
+    V = te.compute(
+        (alpha * alpha, C, P_round),
+        lambda epsnu, c, b: te.sum(
+            input_tile[b // (nH * nW)][c][b // nW % nH * m + r_eps][b % nW * m + r_nu]
+            * B[r_eps][epsnu // alpha]
+            * B[r_nu][epsnu % alpha],
+            axis=[r_eps, r_nu],
+        ),
+        name="V",
+    )
 
     # Winograd GEMM is a wrapper around batched GEMM to convert U to a 3D Tensor
     _, M = decl_winograd_gemm(cfg, U, V)
 
     # Y [K, P, m, m]
     # Winograd output transform
-    r_eps = te.reduce_axis((0, alpha), 'r_eps')
-    r_nu = te.reduce_axis((0, alpha), 'r_nu')
-    Y = te.compute((K, P, m, m), lambda k, b, vh, vw:
-                   te.sum(M[r_eps * alpha + r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
-                          axis=[r_eps, r_nu]), name='Y')
+    r_eps = te.reduce_axis((0, alpha), "r_eps")
+    r_nu = te.reduce_axis((0, alpha), "r_nu")
+    Y = te.compute(
+        (K, P, m, m),
+        lambda k, b, vh, vw: te.sum(
+            M[r_eps * alpha + r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]
+        ),
+        name="Y",
+    )
 
     # Output [N, K, H, W]
     # Unpack back to NCHW format
     # The last term ensures alignment is not lost to bound inference
-    output = te.compute((N, K, H, W), lambda n, k, h, w:
-                        Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
-                        + tvm.tir.const(0, out_dtype) * M[(alpha*alpha)-1][K_round-1][P_round-1],
-                        name='output', tag='winograd_conv2d_output')
+    output = te.compute(
+        (N, K, H, W),
+        lambda n, k, h, w: Y[k][n * nH * nW + (h // m) * nW + w // m][h % m][w % m]
+        + tvm.tir.const(0, out_dtype) * M[(alpha * alpha) - 1][K_round - 1][P_round - 1],
+        name="output",
+        tag="winograd_conv2d_output",
+    )
 
     return output
 
@@ -401,7 +420,7 @@ def _schedule_winograd(cfg, s, op):
     tile_and_bind3d(s, d, b, h, w, 1, 4, 2)
 
     # Transform data
-    bIL_d = s.cache_read(d, 'local', [V])
+    bIL_d = s.cache_read(d, "local", [V])
 
     s[B].compute_inline()
     epsnu, c, b = s[V].op.axis
@@ -425,8 +444,8 @@ def _schedule_winograd(cfg, s, op):
     )
 
     # Inverse transform
-    CR_M = s.cache_read(M, 'local', [Y])
-    CW_Y = s.cache_write(Y, 'local')
+    CR_M = s.cache_read(M, "local", [Y])
+    CW_Y = s.cache_write(Y, "local")
 
     s[A].compute_inline()
     k, b, vh, vw = s[Y].op.axis
@@ -452,7 +471,6 @@ def _schedule_winograd(cfg, s, op):
     tile_and_bind3d(s, output, k, h, w, 1, 2, 2)
 
 
-
 ##### REGISTER ALTER OP LAYOUT #####
 @nn.conv2d_alter_layout.register("bifrost")
 def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
@@ -460,7 +478,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     dispatch_ctx = autotvm.task.DispatchContext.current
 
     _, outs = relay.backend.compile_engine.select_implementation(
-        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+    )
     workload = autotvm.task.get_workload(outs)
     if workload is None:
         # The best implementation is not an AutoTVM template,
@@ -488,15 +507,16 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
-        VC = cfg['tile_co'].size[-1]
+        VC = cfg["tile_co"].size[-1]
 
-        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+        new_attrs["kernel_layout"] = "OIHW%do" % VC
 
         new_data = data
         new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            "conv2d_nchw_spatial_pack.bifrost")
+            "conv2d_nchw_spatial_pack.bifrost",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.conv2d(*inputs, **new_attrs)
@@ -509,21 +529,24 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
 
         weight_expr = inputs[1]
         weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
-            weight_expr, tile_size=tile_size)
+            weight_expr, tile_size=tile_size
+        )
         weight_expr = relay.reshape(
-            weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI))
+            weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI)
+        )
 
-        new_attrs['tile_size'] = tile_size
+        new_attrs["tile_size"] = tile_size
 
         new_data = data
-        new_kernel = te.placeholder(
-            (KH + tile_size - 1, KW + tile_size -1, CO, CI), kernel.dtype)
+        new_kernel = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CO, CI), kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            'conv2d_nchw_winograd.bifrost')
+            "conv2d_nchw_winograd.bifrost",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], weight_expr, **new_attrs)
+            inputs[0], weight_expr, **new_attrs
+        )
 
     return None
index 7104842..4a0158e 100644 (file)
@@ -22,12 +22,14 @@ from tvm import autotvm
 from .. import nn
 from ..util import traverse_inline
 
-@autotvm.register_topi_compute('dense.biforst')
+
+@autotvm.register_topi_compute("dense.biforst")
 def dense(_, data, weight, bias=None, out_dtype=None):
     """Dense operator on Biforst"""
     return nn.dense(data, weight, bias, out_dtype)
 
-@autotvm.register_topi_schedule('dense.bifrost')
+
+@autotvm.register_topi_schedule("dense.bifrost")
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -48,7 +50,7 @@ def schedule_dense(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense':
+        if op.tag == "dense":
             vec_size = [1, 2, 4, 8, 16]
             max_unroll = 32
 
@@ -59,47 +61,47 @@ def schedule_dense(cfg, outs):
             c = s[dense_out].op.reduce_axis[0]
 
             ##### space definition begin #####
-            cfg.define_split('tile_y', y, num_outputs=3)
-            cfg.define_split('tile_x', x, num_outputs=3)
-            cfg.define_split('c_unroll', c, num_outputs=2, max_factor=64)
+            cfg.define_split("tile_y", y, num_outputs=3)
+            cfg.define_split("tile_x", x, num_outputs=3)
+            cfg.define_split("c_unroll", c, num_outputs=2, max_factor=64)
 
             # fallback support
             if cfg.is_fallback:
-                ref_log = autotvm.tophub.load_reference_log(
-                    'mali', 'rk3399', 'dense.bifrost')
+                ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "dense.bifrost")
                 cfg.fallback_with_reference_log(ref_log)
             ##### space definition end #####
 
             if dense_out.op in s.outputs:
-                dense_out = s.cache_write(output, 'local')
+                dense_out = s.cache_write(output, "local")
 
-            by, ty, yi = cfg['tile_y'].apply(s, output, y)
-            bx, tx, xi = cfg['tile_x'].apply(s, output, x)
+            by, ty, yi = cfg["tile_y"].apply(s, output, y)
+            bx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
-            s[output].bind(by, te.thread_axis('blockIdx.y'))
-            s[output].bind(bx, te.thread_axis('blockIdx.x'))
-            s[output].bind(ty, te.thread_axis('threadIdx.y'))
-            s[output].bind(tx, te.thread_axis('threadIdx.x'))
+            s[output].bind(by, te.thread_axis("blockIdx.y"))
+            s[output].bind(bx, te.thread_axis("blockIdx.x"))
+            s[output].bind(ty, te.thread_axis("threadIdx.y"))
+            s[output].bind(tx, te.thread_axis("threadIdx.x"))
 
-            if cfg['tile_y'].size[-1] < max_unroll:
+            if cfg["tile_y"].size[-1] < max_unroll:
                 s[output].unroll(yi)
-            if cfg['tile_x'].size[-1] in vec_size:
+            if cfg["tile_x"].size[-1] in vec_size:
                 s[output].vectorize(xi)
             s[dense_out].compute_at(s[output], tx)
 
             k = s[dense_out].op.reduce_axis[0]
             y, x = s[dense_out].op.axis
-            k, k_unroll = cfg['c_unroll'].apply(s, dense_out, k)
+            k, k_unroll = cfg["c_unroll"].apply(s, dense_out, k)
             s[dense_out].reorder(k, k_unroll, y, x)
             s[dense_out].unroll(k_unroll)
-            if cfg['tile_y'].size[-1] < max_unroll:
+            if cfg["tile_y"].size[-1] < max_unroll:
                 s[dense_out].unroll(y)
-            if cfg['tile_x'].size[-1] in vec_size:
+            if cfg["tile_x"].size[-1] in vec_size:
                 s[dense_out].vectorize(x)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def fuse_and_bind(s, tensor, axis=None, num_thread=None):
     """ fuse all the axis and bind to GPU threads """
     axis = axis or s[tensor].op.axis
index 7a96705..35da5a5 100644 (file)
@@ -25,6 +25,7 @@ from tvm import te
 from .. import util
 from .. import tag
 
+
 def schedule_depthwise_conv2d_nchw(outs):
     """Schedule for depthwise_conv2d nchw forward.
 
@@ -41,12 +42,13 @@ def schedule_depthwise_conv2d_nchw(outs):
     """
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
+
     def _schedule(pad_data, kernel, conv):
         raw_data = s[pad_data].op.input_tensors[0]
 
         if conv.op not in s.outputs:  # has bias or relu
             output = outs[0]
-        else:                         # no bias or relu
+        else:  # no bias or relu
             output = conv
 
         def tile_and_bind3d(tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
@@ -72,7 +74,7 @@ def schedule_depthwise_conv2d_nchw(outs):
             VW = VW * 2
         while util.get_const_int(conv.shape[2]) % (VH * 2) == 0 and VH * 2 <= 2:
             VH = VH * 2
-        if raw_data.dtype == 'float16':
+        if raw_data.dtype == "float16":
             if util.get_const_int(conv.shape[3]) % (VW * 2) == 0:
                 VW *= 2
                 num_thread *= 2
@@ -113,10 +115,10 @@ def schedule_depthwise_conv2d_nchw(outs):
                     traverse(tensor.op)
 
         # schedule depthwise_conv2d
-        if op.tag == 'depthwise_conv2d_nchw':
+        if op.tag == "depthwise_conv2d_nchw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
-            if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
             conv = op.output(0)
             _schedule(pad_data, kernel, conv)
index 3dc0108..6f147d9 100644 (file)
 # under the License.
 # pylint: disable=invalid-name,unused-variable,unused-argument
 """GEMM schedules for Mali Bifrost"""
-from .transforms import tile_and_bind, tile_and_bind3d, interleave_transpose, \
-    transpose_interleave
+from .transforms import tile_and_bind, tile_and_bind3d, interleave_transpose, transpose_interleave
 from .. import util
 
+
 def decl_gemm(cfg, A, B):
     """Declare a single GEMM computation for Mali Bifrost GPUs
 
@@ -47,7 +47,6 @@ def decl_gemm(cfg, A, B):
     cfg.define_knob("B_interleave", [1, 4, 8, 16, 32])
     cfg.define_knob("split_k_factor", [1, 4, 16])
 
-
     # Mutual k axis must be of equal extent
     assert util.get_const_int(A.shape[1]) == util.get_const_int(B.shape[0])
     n = A.shape[0]
@@ -61,9 +60,10 @@ def decl_gemm(cfg, A, B):
         B_unrolled = te.compute((k_size, m), lambda i, j: B[i, j], name="B_unrolled")
 
         # Declare standard GEMM
-        k = te.reduce_axis((0, A.shape[1]), name='k')
-        C = te.compute((n, m), lambda i, j:
-                       te.sum(A_unrolled[i, k] * B_unrolled[k, j], axis=k), name='C')
+        k = te.reduce_axis((0, A.shape[1]), name="k")
+        C = te.compute(
+            (n, m), lambda i, j: te.sum(A_unrolled[i, k] * B_unrolled[k, j], axis=k), name="C"
+        )
 
         R = te.compute((n, m), lambda i, j: C[i, j], name="R")
 
@@ -71,24 +71,33 @@ def decl_gemm(cfg, A, B):
         unrolled_k_size = k_size // unroll_gemm
 
         # Unroll the two input matrices along the shared k axis
-        A_unrolled = te.compute((unroll_gemm, n, unrolled_k_size), lambda b, i, j:
-                                A[i][unrolled_k_size * b + j], name='A_unrolled')
-
-        B_unrolled = te.compute((unroll_gemm, unrolled_k_size, m), lambda b, i, j:
-                                B[unrolled_k_size * b + i][j], name='B_unrolled')
+        A_unrolled = te.compute(
+            (unroll_gemm, n, unrolled_k_size),
+            lambda b, i, j: A[i][unrolled_k_size * b + j],
+            name="A_unrolled",
+        )
+
+        B_unrolled = te.compute(
+            (unroll_gemm, unrolled_k_size, m),
+            lambda b, i, j: B[unrolled_k_size * b + i][j],
+            name="B_unrolled",
+        )
 
         # Declare a batched GEMM
-        k = te.reduce_axis((0, unrolled_k_size), name='k')
-        C = te.compute((unroll_gemm, n, m), lambda b, i, j:
-                       te.sum(A_unrolled[b][i][k] * B_unrolled[b][k][j], axis=k), name='C')
+        k = te.reduce_axis((0, unrolled_k_size), name="k")
+        C = te.compute(
+            (unroll_gemm, n, m),
+            lambda b, i, j: te.sum(A_unrolled[b][i][k] * B_unrolled[b][k][j], axis=k),
+            name="C",
+        )
 
         # Then declare a reduction to reduce the sub matrices
-        k = te.reduce_axis((0, unroll_gemm), name='k')
-        R = te.compute((n, m), lambda i, j:
-                       te.sum(C[k][i][j], axis=k), name='R')
+        k = te.reduce_axis((0, unroll_gemm), name="k")
+        R = te.compute((n, m), lambda i, j: te.sum(C[k][i][j], axis=k), name="R")
 
     return R
 
+
 def decl_batched_gemm(cfg, A, B):
     """Declare a batched GEMM computation for Mali Bifrost GPUs
     Parameters
@@ -124,12 +133,14 @@ def decl_batched_gemm(cfg, A, B):
     b_size = util.get_const_int(A.shape[0])
 
     # Declare a batched GEMM
-    k = te.reduce_axis((0, k_size), name='k')
-    C = te.compute((b_size, n, m), lambda b, i, j:
-                   te.sum(A[b][i][k] * B[b][k][j], axis=k), name='C')
+    k = te.reduce_axis((0, k_size), name="k")
+    C = te.compute(
+        (b_size, n, m), lambda b, i, j: te.sum(A[b][i][k] * B[b][k][j], axis=k), name="C"
+    )
 
     return C
 
+
 def decl_winograd_gemm(cfg, A, B):
     """Declare a winograd GEMM for Mali Bifrost GPUs
     Winograd uses batched GEMM, however the input tensors are 4D
@@ -154,12 +165,14 @@ def decl_winograd_gemm(cfg, A, B):
     n = util.get_const_int(A.shape[2])
     k = util.get_const_int(A.shape[3])
 
-    A_3D = te.compute((alpha * alpha, n, k), lambda b, i, j:
-                      A[b // alpha][b % alpha][i][j], name='A_3D')
+    A_3D = te.compute(
+        (alpha * alpha, n, k), lambda b, i, j: A[b // alpha][b % alpha][i][j], name="A_3D"
+    )
 
     C = decl_batched_gemm(cfg, A_3D, B)
     return A_3D, C
 
+
 def schedule_gemm(cfg, s, A, B, C, batched=False, schedule_transforms=True):
     """Schedule GEMM, single and batched
 
@@ -223,9 +236,9 @@ def schedule_gemm(cfg, s, A, B, C, batched=False, schedule_transforms=True):
         tile_and_bind(s, inter_trans, y, xo, 4, 4)
 
     # Schedule C
-    CR_A = s.cache_read(A_transposed_interleaved, 'local', [C])
-    CR_B = s.cache_read(B_interleaved_transposed, 'local', [C])
-    CW_C = s.cache_write(C, 'local')
+    CR_A = s.cache_read(A_transposed_interleaved, "local", [C])
+    CR_B = s.cache_read(B_interleaved_transposed, "local", [C])
+    CW_C = s.cache_write(C, "local")
 
     if not batched:
         y, x = s[C].op.axis
@@ -275,6 +288,7 @@ def schedule_gemm(cfg, s, A, B, C, batched=False, schedule_transforms=True):
 
     return trans_inter, inter_trans
 
+
 def schedule_unrollable_gemm(cfg, s, A, B, C, R):
     """Schedule a GEMM that can be unrolled by a constant factor
     along its inner dimension
@@ -314,7 +328,7 @@ def schedule_unrollable_gemm(cfg, s, A, B, C, R):
         s[B].compute_inline()
         schedule_gemm(cfg, s, A, B, C, batched=True)
 
-        CR_C = s.cache_read(C, 'local', [R])
+        CR_C = s.cache_read(C, "local", [R])
 
         y, x = s[R].op.axis
         xo, xi = s[R].split(x, 4)
@@ -330,6 +344,7 @@ def schedule_unrollable_gemm(cfg, s, A, B, C, R):
         s[CR_C].unroll(y)
         s[CR_C].vectorize(x)
 
+
 def get_unrollable_gemm_ops(R):
     """Get all GEMM operators from the final reduction
     This is a helper function to more easily get all the GEMM operations
index 3feb4e6..6a39f19 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import te
 
+
 def fuse_and_bind(s, tensor, axis=None, num_thread=None):
     """Fuse all the axis and bind to GPU threads"""
     axis = axis or s[tensor].op.axis
@@ -31,6 +32,7 @@ def fuse_and_bind(s, tensor, axis=None, num_thread=None):
     s[tensor].bind(tx, te.thread_axis("threadIdx.x"))
     return bx, tx
 
+
 def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
     """Tile and bind to GPU threads"""
     x_factor = x_factor or y_factor
@@ -41,6 +43,7 @@ def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
     s[tensor].bind(yi, te.thread_axis("threadIdx.y"))
     return yo, xo, yi, xi
 
+
 def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
     """Tile and bind 3d"""
     y_factor = y_factor or z_factor
@@ -56,25 +59,27 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
     s[tensor].bind(xi, te.thread_axis("threadIdx.x"))
     return zo, yo, xo, zi, yi, xi
 
+
 def pack_tensor(s, tensor, factor, readers):
     """Do transform X[n, m] -> X[n / factor, m, factor]"""
-    tmp = s.cache_read(tensor, 'global', readers)
+    tmp = s.cache_read(tensor, "global", readers)
     y, x = s[tmp].op.axis
     yo, yi = s[tmp].split(y, factor)
     s[tmp].reorder(yo, x, yi)
     s[tmp].compute_inline()
-    return s.cache_write(tmp, 'global'), tmp
+    return s.cache_write(tmp, "global"), tmp
+
 
 def transpose(s, tensor, y_index, x_index, readers):
     """Do transform X[n, m] -> X[m, n]"""
-    tmp = s.cache_read(tensor, 'global', readers)
+    tmp = s.cache_read(tensor, "global", readers)
     y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
     s[tmp].reorder(x, y)
     s[tmp].compute_inline()
     A_transpose = s.cache_write(tmp, "global")
 
-    CR_A = s.cache_read(tensor, 'local', [A_transpose])
-    CW_A_transpose = s.cache_write(A_transpose, 'local')
+    CR_A = s.cache_read(tensor, "local", [A_transpose])
+    CW_A_transpose = s.cache_write(A_transpose, "local")
 
     y, x = s[A_transpose].op.axis[y_index], s[A_transpose].op.axis[x_index]
     yo, xo, yi, xi = s[A_transpose].tile(y, x, 4, 4)
@@ -94,9 +99,10 @@ def transpose(s, tensor, y_index, x_index, readers):
 
     return tmp
 
+
 def interleave_transpose(s, tensor, width, y_index, x_index, readers, batched=False):
     """Interleave the tensor, then transpose it"""
-    tmp = s.cache_read(tensor, 'global', readers)
+    tmp = s.cache_read(tensor, "global", readers)
     y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
     xo, xi = s[tmp].split(x, width)
     s[tmp].reorder(xo, y, xi)
@@ -105,11 +111,12 @@ def interleave_transpose(s, tensor, width, y_index, x_index, readers, batched=Fa
         z = s[tmp].op.axis[0]
         s[tmp].fuse(z, xo)
     s[tmp].compute_inline()
-    return s.cache_write(tmp, 'global'), tmp
+    return s.cache_write(tmp, "global"), tmp
+
 
 def transpose_interleave(s, tensor, width, y_index, x_index, readers, batched=False):
     """Transpose the tensor, then interleave it"""
-    tmp = s.cache_read(tensor, 'global', readers)
+    tmp = s.cache_read(tensor, "global", readers)
     y, x = s[tmp].op.axis[y_index], s[tmp].op.axis[x_index]
     yo, yi = s[tmp].split(y, width)
     s[tmp].reorder(yo, x, yi)
@@ -118,4 +125,4 @@ def transpose_interleave(s, tensor, width, y_index, x_index, readers, batched=Fa
         z = s[tmp].op.axis[0]
         s[tmp].fuse(z, yo)
     s[tmp].compute_inline()
-    return s.cache_write(tmp, 'global'), tmp
+    return s.cache_write(tmp, "global"), tmp
index cc36637..2b350ff 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 """Broadcast operators"""
 from __future__ import absolute_import as _abs
-from .import cpp as _cpp
+from . import cpp as _cpp
 
 
 def broadcast_to(data, shape):
index 788b888..62e274c 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 
 """FFI for C++ TOPI ops and schedules"""
-from .impl import * #pylint: disable=wildcard-import
+from .impl import *  # pylint: disable=wildcard-import
 from . import cuda
 from . import nn
 from . import vision
index 373b1ec..26647dd 100644 (file)
@@ -24,6 +24,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
+
 @autotvm.register_topi_compute("batch_matmul.cuda")
 def batch_matmul(cfg, x, y):
     """Compute conv2d with NCHW layout"""
@@ -62,14 +63,14 @@ def schedule_batch_matmul(cfg, outs):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        k, = s[CC].op.reduce_axis
+        (k,) = s[CC].op.reduce_axis
 
         cfg.define_split("tile_y", y, num_outputs=3)
         cfg.define_split("tile_x", x, num_outputs=3)
         cfg.define_split("tile_k", k, num_outputs=2)
         cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
         target = tvm.target.Target.current()
-        if target.kind.name in ['nvptx', 'rocm']:
+        if target.kind.name in ["nvptx", "rocm"]:
             # llvm-based backends cannot do non-explicit unrolling
             cfg.define_knob("unroll_explicit", [1])
         else:
@@ -80,10 +81,10 @@ def schedule_batch_matmul(cfg, outs):
             x_bn = get_max_power2_factor(N, 64)
             y_nthreads = min(y_bn, 8)
             x_nthreads = min(x_bn, 8)
-            cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
-            cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
-            cfg['tile_k'] = SplitEntity([-1, 8])
-            cfg['auto_unroll_max_step'] = OtherOptionEntity(16)
+            cfg["tile_x"] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
+            cfg["tile_y"] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
+            cfg["tile_k"] = SplitEntity([-1, 8])
+            cfg["auto_unroll_max_step"] = OtherOptionEntity(16)
 
         by, ty, yi = cfg["tile_y"].apply(s, C, y)
         bx, tx, xi = cfg["tile_x"].apply(s, C, x)
@@ -97,15 +98,15 @@ def schedule_batch_matmul(cfg, outs):
         s[C].bind(bx, te.thread_axis("blockIdx.x"))
         s[C].bind(ty, thread_y)
         s[C].bind(tx, thread_x)
-        s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
-        s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)
+        s[C].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        s[C].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val)
 
         s[CC].compute_at(s[C], tx)
         _, yi, xi = s[CC].op.axis
         ko, ki = cfg["tile_k"].apply(s, CC, k)
         s[CC].reorder(ko, ki, yi, xi)
-        s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
-        s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val)
+        s[CC].pragma(ki, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        s[CC].pragma(ki, "unroll_explicit", cfg["unroll_explicit"].val)
 
         s[AA].compute_at(s[CC], ko)
         s[AL].compute_at(s[CC], ki)
@@ -117,8 +118,8 @@ def schedule_batch_matmul(cfg, outs):
         s[AA].reorder(ty, tx, yi, ki)
         s[AA].bind(ty, thread_y)
         s[AA].bind(tx, thread_x)
-        s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
-        s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)
+        s[AA].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        s[AA].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val)
 
         _, x, k = s[BB].op.axis
         ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
@@ -126,8 +127,8 @@ def schedule_batch_matmul(cfg, outs):
         s[BB].bind(ty, thread_y)
         s[BB].bind(tx, thread_x)
         s[BB].reorder(ty, tx, xi, ki)
-        s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
-        s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val)
+        s[BB].pragma(xi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        s[BB].pragma(xi, "unroll_explicit", cfg["unroll_explicit"].val)
 
     def _callback(op):
         if "batch_matmul" in op.tag:
@@ -136,6 +137,7 @@ def schedule_batch_matmul(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def batch_matmul_cublas(x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
index c099d25..416e480 100644 (file)
@@ -25,13 +25,7 @@ from ..util import traverse_inline, get_const_tuple
 
 
 @autotvm.register_topi_compute("conv1d_ncw.cuda")
-def conv1d_ncw(cfg,
-               data,
-               kernel,
-               strides,
-               padding,
-               dilation,
-               out_dtype='float32'):
+def conv1d_ncw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     return nn.conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
 
 
@@ -57,7 +51,7 @@ def schedule_conv1d_ncw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv1d_ncw':
+        if op.tag == "conv1d_ncw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -72,29 +66,28 @@ def schedule_conv1d_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
 
             ##### space definition end #####
 
-            if isinstance(kernel.op,
-                          tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            s[pad_data].set_scope('shared')
+            s[pad_data].set_scope("shared")
             AA = pad_data
-            WW = s.cache_read(kernel, 'shared', [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
 
             # tile and bind spatial axes
             n, f, x = s[output].op.axis
@@ -120,7 +113,7 @@ def schedule_conv1d_ncw(cfg, outs):
             # tile reduction axes
             n, f, x = s[OL].op.axis
             rc, rx = s[OL].op.reduce_axis
-            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
+            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
             s[OL].reorder(rco, rcm, rx, rci, n, f, x)
 
             s[AA].compute_at(s[OL], rx)
@@ -135,10 +128,8 @@ def schedule_conv1d_ncw(cfg, outs):
                 s[load].bind(tz, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                             cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit',
-                             cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
             N, CO, OW = get_const_tuple(output.shape)
             _, CI, KW = get_const_tuple(kernel.shape)
@@ -150,13 +141,7 @@ def schedule_conv1d_ncw(cfg, outs):
 
 
 @autotvm.register_topi_compute("conv1d_nwc.cuda")
-def conv1d_nwc(cfg,
-               data,
-               kernel,
-               strides,
-               padding,
-               dilation,
-               out_dtype='float32'):
+def conv1d_nwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     return nn.conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
 
 
@@ -182,7 +167,7 @@ def schedule_conv1d_nwc(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv1d_nwc':
+        if op.tag == "conv1d_nwc":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -197,29 +182,28 @@ def schedule_conv1d_nwc(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
 
             ##### space definition end #####
 
-            if isinstance(kernel.op,
-                          tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            s[pad_data].set_scope('shared')
+            s[pad_data].set_scope("shared")
             AA = pad_data
-            WW = s.cache_read(kernel, 'shared', [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
 
             # tile and bind spatial axes
             n, f, x = s[output].op.axis
@@ -245,7 +229,7 @@ def schedule_conv1d_nwc(cfg, outs):
             # tile reduction axes
             n, x, f = s[OL].op.axis
             rc, rx = s[OL].op.reduce_axis
-            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
+            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
             s[OL].reorder(rco, rcm, rx, rci, n, x, f)
 
             s[AA].compute_at(s[OL], rx)
@@ -260,10 +244,8 @@ def schedule_conv1d_nwc(cfg, outs):
                 s[load].bind(tz, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                             cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit',
-                             cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
             N, OW, CO = get_const_tuple(output.shape)
             KW, CI, _ = get_const_tuple(kernel.shape)
index dbfe1f5..c827007 100644 (file)
@@ -23,9 +23,9 @@ from tvm import autotvm
 from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
+
 @autotvm.task.register_topi_compute("conv1d_transpose_nchw.cuda")
-def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype,
-                         output_padding):
+def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_padding):
     """Transposed 1D convolution ncw forward operator.
 
     Parameters
@@ -69,24 +69,32 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype,
     data = te.compute(
         (batch, inp_channels, pad_left + dilated_width + pad_right),
         lambda n, c, x: tvm.tir.if_then_else(
-            tvm.tir.all(x >= pad_left,
-                        x < pad_left + dilated_width,
-                        tvm.tir.indexmod(x - pad_left, stride).equal(0)),
+            tvm.tir.all(
+                x >= pad_left,
+                x < pad_left + dilated_width,
+                tvm.tir.indexmod(x - pad_left, stride).equal(0),
+            ),
             data[n, c, tvm.tir.indexdiv(x - pad_left, stride)],
-            tvm.tir.const(0., "float32")),
-        name='data_pad')
+            tvm.tir.const(0.0, "float32"),
+        ),
+        name="data_pad",
+    )
 
-    dc = te.reduce_axis((0, inp_channels), name='dc')
-    dw = te.reduce_axis((0, kernel_size), name='dw')
+    dc = te.reduce_axis((0, inp_channels), name="dc")
+    dw = te.reduce_axis((0, kernel_size), name="dw")
     data_out = te.compute(
         (batch, out_channels, out_width),
         lambda b, c, w: te.sum(
-            data[b, dc, w + dw].astype(out_dtype) *
-            kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype),
-            axis=[dc, dw]), tag="conv1d_transpose_ncw")
+            data[b, dc, w + dw].astype(out_dtype)
+            * kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype),
+            axis=[dc, dw],
+        ),
+        tag="conv1d_transpose_ncw",
+    )
 
     return data_out
 
+
 @autotvm.task.register_topi_schedule("conv1d_transpose_nchw.cuda")
 def schedule_conv1d_transpose_ncw(cfg, outs):
     """TOPI Schedule callback for conv1d_transpose operator.
@@ -109,7 +117,7 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv1d_transpose_ncw':
+        if op.tag == "conv1d_transpose_ncw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -124,28 +132,28 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
 
             ##### space definition end #####
 
-            if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            s[pad_data].set_scope('shared')
+            s[pad_data].set_scope("shared")
             AA = pad_data
-            WW = s.cache_read(kernel, 'shared', [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
 
             # tile and bind spatial axes
             n, f, x = s[output].op.axis
@@ -171,7 +179,7 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
             # tile reduction axes
             n, f, x = s[OL].op.axis
             rc, rx = s[OL].op.reduce_axis
-            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
+            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
             s[OL].reorder(rco, rcm, rx, rci, n, f, x)
 
             s[AA].compute_at(s[OL], rx)
@@ -186,8 +194,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
                 s[load].bind(tz, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(s, outs[0].op, _callback)
 
index 973c216..cf335ac 100644 (file)
@@ -29,7 +29,7 @@ from .conv2d_nhwc import schedule_conv2d_nhwc_direct
 
 
 @autotvm.register_topi_compute("conv2d_nchw.cuda")
-def conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+def conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     """Compute conv2d with NCHW layout"""
     return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
@@ -41,7 +41,7 @@ def schedule_conv2d_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv2d_nchw':
+        if op.tag == "conv2d_nchw":
             schedule_direct_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
@@ -49,7 +49,7 @@ def schedule_conv2d_nchw(cfg, outs):
 
 
 @autotvm.register_topi_compute("conv2d_nhwc.cuda")
-def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     """Compute conv2d with NHWC layout"""
     return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
 
@@ -59,22 +59,25 @@ def schedule_conv2d_nhwc(cfg, outs):
     """Create the schedule for conv2d_nhwc"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
+
     def _callback(op):
-        if op.tag == 'conv2d_nhwc':
+        if op.tag == "conv2d_nhwc":
             schedule_conv2d_nhwc_direct(cfg, s, op.output(0))
+
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 @autotvm.register_topi_compute("conv2d_cudnn.cuda")
-def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
-                 layout='NCHW', out_dtype='float32'):
+def conv2d_cudnn(
+    cfg, data, kernel, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32"
+):
     """Compute conv2d using CuDNN library"""
-    if layout == 'NCHW':
-        tensor_format = 0 # CUDNN_TENSOR_NCHW
+    if layout == "NCHW":
+        tensor_format = 0  # CUDNN_TENSOR_NCHW
         N, _, H, W = get_const_tuple(data.shape)
-    elif layout == 'NHWC':
-        tensor_format = 1 # CUDNN_TENSOR_NHWC
+    elif layout == "NHWC":
+        tensor_format = 1  # CUDNN_TENSOR_NHWC
         N, H, W, _ = get_const_tuple(data.shape)
     else:
         raise ValueError("Unsupported layout %s in cudnn" % layout)
@@ -84,36 +87,50 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
     stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
     dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
 
-    if isinstance(padding, (list, tuple)) and len(padding) == 4 and \
-            (padding[0] != padding[2] or padding[1] != padding[3]):
+    if (
+        isinstance(padding, (list, tuple))
+        and len(padding) == 4
+        and (padding[0] != padding[2] or padding[1] != padding[3])
+    ):
         raise ValueError("Cudnn doesn't support asymmetric padding.")
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
     OH = (H + pt + pb - KH) // stride_h + 1
     OW = (W + pl + pr - KW) // stride_w + 1
-    cfg.add_flop(groups * 2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \
-                 ((KW - 1) * dilation_w + 1))
+    cfg.add_flop(
+        groups
+        * 2
+        * N
+        * OH
+        * OW
+        * CO
+        * CI
+        * ((KH - 1) * dilation_h + 1)
+        * ((KW - 1) * dilation_w + 1)
+    )
 
     if data.dtype == "int8" or kernel.dtype == "int8":
-        if layout == 'NCHW':
+        if layout == "NCHW":
             raise ValueError("NCHW layout do not support int8 in cudnn")
         dtype = "int32"
     else:
         dtype = data.dtype
 
-    cfg.define_knob('algo', range(8))
-    if cfg.is_fallback: # Let CUDNN choose the best algo
-        cfg['algo'] = OtherOptionEntity(-1)
-
-    return cudnn.conv_forward(data,
-                              kernel,
-                              [pt, pl], # cudnn padding pt, pl on both sides of input
-                              [stride_h, stride_w],
-                              [dilation_h, dilation_w],
-                              conv_mode=1,
-                              tensor_format=tensor_format,
-                              algo=cfg['algo'].val,
-                              conv_dtype=dtype,
-                              groups=groups)
+    cfg.define_knob("algo", range(8))
+    if cfg.is_fallback:  # Let CUDNN choose the best algo
+        cfg["algo"] = OtherOptionEntity(-1)
+
+    return cudnn.conv_forward(
+        data,
+        kernel,
+        [pt, pl],  # cudnn padding pt, pl on both sides of input
+        [stride_h, stride_w],
+        [dilation_h, dilation_w],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        algo=cfg["algo"].val,
+        conv_dtype=dtype,
+        groups=groups,
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_cudnn.cuda")
index 89a8569..9bac87c 100644 (file)
@@ -28,7 +28,8 @@ from ..util import get_const_tuple
 from .conv2d_winograd import _infer_tile_size
 from ..nn import conv2d_legalize
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
+
 
 @nn.conv2d_alter_layout.register(["cuda", "gpu"])
 def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
@@ -36,7 +37,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     dispatch_ctx = autotvm.task.DispatchContext.current
 
     _, outs = relay.backend.compile_engine.select_implementation(
-        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+    )
     workload = autotvm.task.get_workload(outs)
     if workload is None:
         # The best implementation is not an AutoTVM template,
@@ -53,7 +55,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     strides = attrs.get_int_tuple("strides")
     padding = attrs.get_int_tuple("padding")
     dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int('groups')
+    groups = attrs.get_int("groups")
     data_layout = attrs["data_layout"]
     kernel_layout = attrs["kernel_layout"]
     data, kernel = tinfos
@@ -64,21 +66,32 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
 
-        new_layout = 'NCHW4c'
+        new_layout = "NCHW4c"
         new_attrs["channels"] = CO
         new_attrs["data_layout"] = new_layout
-        new_attrs['out_layout'] = new_layout
-        new_attrs['kernel_layout'] = 'OIHW4o4i'
+        new_attrs["out_layout"] = new_layout
+        new_attrs["kernel_layout"] = "OIHW4o4i"
         ic_block_factor = oc_block_factor = 4
 
         # Store the same config for the altered operator (workload)
-        new_data = te.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
-                                  dtype=data.dtype)
-        new_kernel = te.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW, \
-                                     oc_block_factor, ic_block_factor), dtype=kernel.dtype)
+        new_data = te.placeholder(
+            (N, CI // ic_block_factor, H, W, ic_block_factor), dtype=data.dtype
+        )
+        new_kernel = te.placeholder(
+            (
+                CO // oc_block_factor,
+                CI // ic_block_factor,
+                KH,
+                KW,
+                oc_block_factor,
+                ic_block_factor,
+            ),
+            dtype=kernel.dtype,
+        )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype],
-            "conv2d_NCHWc_int8.cuda")
+            "conv2d_NCHWc_int8.cuda",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.conv2d(*inputs, **new_attrs)
 
@@ -94,24 +107,26 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # pre-compute weight transformation in winograd
         tile_size = _infer_tile_size(tinfos[0], tinfos[1])
 
-        weight = relay.nn.contrib_conv2d_winograd_weight_transform(inputs[1],
-                                                                   tile_size=tile_size)
+        weight = relay.nn.contrib_conv2d_winograd_weight_transform(inputs[1], tile_size=tile_size)
         weight = relay.transpose(weight, axes=[0, 1, 3, 2])
-        new_attrs['tile_size'] = tile_size
-        new_attrs['channels'] = CO
+        new_attrs["tile_size"] = tile_size
+        new_attrs["channels"] = CO
 
         # Store the same config for the altered operator (workload)
         new_data = data
-        new_weight = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
-                                    dtype=kernel.dtype)
+        new_weight = te.placeholder(
+            (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel.dtype
+        )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_weight, strides, padding, dilation, out_dtype],
-            "conv2d_nchw_winograd_without_weight_transform.cuda")
+            "conv2d_nchw_winograd_without_weight_transform.cuda",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], weight, **new_attrs)
+            inputs[0], weight, **new_attrs
+        )
 
-    if topi_tmpl in ('conv2d_nhwc_winograd_direct.cuda', 'conv2d_nhwc_winograd_tensorcore.cuda'):
+    if topi_tmpl in ("conv2d_nhwc_winograd_direct.cuda", "conv2d_nhwc_winograd_tensorcore.cuda"):
         if dilation != (1, 1):
             logger.warning("Does not support weight pre-transform for dilated convolution.")
             return None
@@ -126,48 +141,63 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         else:
             tile_size = 2
         kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1])
-        weight = relay.nn.contrib_conv2d_winograd_weight_transform(kernel_transform,
-                                                                   tile_size=tile_size)
+        weight = relay.nn.contrib_conv2d_winograd_weight_transform(
+            kernel_transform, tile_size=tile_size
+        )
         weight = relay.transpose(weight, axes=[0, 1, 3, 2])
-        new_attrs['tile_size'] = tile_size
-        new_attrs['channels'] = CO
+        new_attrs["tile_size"] = tile_size
+        new_attrs["channels"] = CO
         # Store the same config for the altered operator (workload)
         new_data = data
-        new_weight = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
-                                    dtype=kernel.dtype)
+        new_weight = te.placeholder(
+            (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel.dtype
+        )
         if topi_tmpl == "conv2d_nhwc_winograd_direct.cuda":
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_weight, strides, padding, dilation, out_dtype],
-                "conv2d_nhwc_winograd_direct_without_weight_transform.cuda")
+                "conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
+            )
         elif topi_tmpl == "conv2d_nhwc_winograd_tensorcore.cuda":
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_weight, strides, padding, dilation, out_dtype],
-                "conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda")
+                "conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda",
+            )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], weight, **new_attrs)
+            inputs[0], weight, **new_attrs
+        )
 
     if topi_tmpl == "group_conv2d_NCHWc_int8.cuda":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
 
-        new_layout = 'NCHW4c'
+        new_layout = "NCHW4c"
         new_attrs["channels"] = CO
         new_attrs["data_layout"] = new_layout
-        new_attrs['out_layout'] = new_layout
-        new_attrs['kernel_layout'] = 'OIHW4o4i'
+        new_attrs["out_layout"] = new_layout
+        new_attrs["kernel_layout"] = "OIHW4o4i"
         ic_block_factor = oc_block_factor = 4
 
         # Store the same config for the altered operator (workload)
-        new_data = te.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
-                                  dtype=data.dtype)
-        new_kernel = te.placeholder((CO // oc_block_factor, CI // ic_block_factor // groups,
-                                     KH, KW, oc_block_factor, ic_block_factor),
-                                    dtype=kernel.dtype)
+        new_data = te.placeholder(
+            (N, CI // ic_block_factor, H, W, ic_block_factor), dtype=data.dtype
+        )
+        new_kernel = te.placeholder(
+            (
+                CO // oc_block_factor,
+                CI // ic_block_factor // groups,
+                KH,
+                KW,
+                oc_block_factor,
+                ic_block_factor,
+            ),
+            dtype=kernel.dtype,
+        )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, groups, out_dtype],
-            "group_conv2d_NCHWc_int8.cuda")
+            "group_conv2d_NCHWc_int8.cuda",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.conv2d(*inputs, **new_attrs)
 
@@ -177,32 +207,47 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         H, W, N, CI = get_const_tuple(data.shape)
         KH, KW, CO, _ = get_const_tuple(kernel.shape)
 
-        if kernel.dtype in ['int4', 'uint4'] and (CI % 32 != 0 or CO % 8 != 0) or \
-            kernel.dtype in ['int8', 'uint8'] and (CI % 16 != 0 or CO % 32 != 0):
+        if (
+            kernel.dtype in ["int4", "uint4"]
+            and (CI % 32 != 0 or CO % 8 != 0)
+            or kernel.dtype in ["int8", "uint8"]
+            and (CI % 16 != 0 or CO % 32 != 0)
+        ):
             return relay.nn.conv2d(*inputs, **new_attrs)
 
         new_attrs["channels"] = CO
-        if kernel.dtype in ['int4', 'uint4']:
-            new_attrs['kernel_layout'] = 'HWOI8o32i'
+        if kernel.dtype in ["int4", "uint4"]:
+            new_attrs["kernel_layout"] = "HWOI8o32i"
             ic_block_factor = 32
             oc_block_factor = 8
         else:
-            new_attrs['kernel_layout'] = 'HWOI32o16i'
+            new_attrs["kernel_layout"] = "HWOI32o16i"
             ic_block_factor = 16
             oc_block_factor = 32
 
-        new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI // ic_block_factor,
-                                     oc_block_factor, ic_block_factor), dtype=kernel.dtype)
+        new_kernel = te.placeholder(
+            (
+                KH,
+                KW,
+                CO // oc_block_factor,
+                CI // ic_block_factor,
+                oc_block_factor,
+                ic_block_factor,
+            ),
+            dtype=kernel.dtype,
+        )
 
         new_workload = autotvm.task.args_to_workload(
             [data, new_kernel, strides, padding, dilation, out_dtype],
-            "conv2d_HWNCnc_tensorcore.cuda")
+            "conv2d_HWNCnc_tensorcore.cuda",
+        )
 
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.conv2d(*inputs, **new_attrs)
 
     return None
 
+
 @conv2d_legalize.register("cuda")
 def _conv2d_legalize(attrs, inputs, arg_types):
     """Legalizes Conv2D op.
@@ -246,12 +291,12 @@ def _conv2d_legalize(attrs, inputs, arg_types):
     new_attrs = {k: attrs[k] for k in attrs.keys()}
 
     # Get data layout. Return None if not NCHW
-    data_layout = attrs['data_layout']
-    kernel_layout = attrs['kernel_layout']
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
 
     # Pad input and output channels to use int8 schedule.
-    if data_dtype in ['int8', 'uint8']:
-        if data_layout == 'NCHW' and kernel_layout == "OIHW":
+    if data_dtype in ["int8", "uint8"]:
+        if data_layout == "NCHW" and kernel_layout == "OIHW":
             oc_modified = False
             in_channel = data_tensor.shape[1].value
             out_channel = kernel_tensor.shape[0].value
@@ -273,11 +318,10 @@ def _conv2d_legalize(attrs, inputs, arg_types):
                 oc_modified = True
 
             if oc_modified:
-                new_attrs['channels'] = new_out_channel
+                new_attrs["channels"] = new_out_channel
                 out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
                 original_out_shape = [x.value for x in output_tensor.shape]
-                out = relay.strided_slice(out, begin=[0, 0, 0, 0],
-                                          end=original_out_shape)
+                out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
             else:
                 out = relay.nn.conv2d(data, kernel, **new_attrs)
             return out
index 8a26a82..2065ab9 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 from tvm import autotvm
 from ..util import get_const_tuple
 
+
 def schedule_direct_cuda(cfg, s, conv):
     """schedule optimized for batch size = 1"""
 
@@ -36,7 +37,7 @@ def schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -44,27 +45,28 @@ def schedule_direct_cuda(cfg, s, conv):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, 'conv2d_nchw.cuda')
+            target.kind.name, target.model, "conv2d_nchw.cuda"
+        )
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
 
     s[pad_data].compute_inline()
-    if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
         s[kernel].compute_inline()
 
     if conv.op in s.outputs:
         output = conv
-        OL = s.cache_write(conv, 'local')
+        OL = s.cache_write(conv, "local")
     else:
         output = s.outputs[0].output(0)
-        s[conv].set_scope('local')
+        s[conv].set_scope("local")
         OL = conv
 
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
+    AA = s.cache_read(pad_data, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
 
     # tile and bind spatial axes
     n, f, y, x = s[output].op.axis
@@ -90,9 +92,9 @@ def schedule_direct_cuda(cfg, s, conv):
     # tile reduction axes
     n, f, y, x = s[OL].op.axis
     rc, ry, rx = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
+    ryo, ryi = cfg["tile_ry"].apply(s, OL, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, OL, rx)
     s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
 
     s[AA].compute_at(s[OL], rxo)
@@ -110,8 +112,8 @@ def schedule_direct_cuda(cfg, s, conv):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     N, CO, OH, OW = get_const_tuple(output.shape)
     _, KH, KW, CI = get_const_tuple(kernel.shape)
index e45083f..46a618e 100644 (file)
@@ -24,8 +24,9 @@ from tvm.autotvm.task.space import SplitEntity
 
 from .. import nn, tag
 
+
 @autotvm.register_topi_compute("conv2d_hwcn.cuda")
-def conv2d_hwcn(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+def conv2d_hwcn(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     """Compute conv2d with HWCN layout on CUDA"""
     return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
 
@@ -47,6 +48,7 @@ def schedule_conv2d_hwcn(cfg, outs):
     """
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     sch = te.create_schedule([x.op for x in outs])
+
     def schedule(Apad, W, B):
         """Schedule conv2d_hwcn"""
         sch[Apad].compute_inline()
@@ -70,37 +72,37 @@ def schedule_conv2d_hwcn(cfg, outs):
         vthread_cand = [1, 2, 4, 8]
 
         cfg.define_split(
-            'tile_fi',
+            "tile_fi",
             fi,
             num_outputs=4,
-            filter=lambda x:
-            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+            filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand),
+        )
         cfg.define_split(
-            'tile_ni',
+            "tile_ni",
             ni,
             num_outputs=4,
-            filter=lambda x:
-            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
+            filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand),
+        )
 
         if cfg.is_fallback:
-            cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
-            cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])
+            cfg["tile_fi"] = SplitEntity([-1, 2, 8, 4])
+            cfg["tile_ni"] = SplitEntity([-1, 2, 8, 4])
 
         # Scheduling
         step = 8
 
         bz = sch[Out].fuse(hi, wi)
-        by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
-        bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
+        by, tyz, ty, fi = cfg["tile_fi"].apply(sch, Out, fi)
+        bx, txz, tx, ni = cfg["tile_ni"].apply(sch, Out, ni)
         sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
 
-        sch[Out].bind(bz, te.thread_axis('blockIdx.z'))
-        sch[Out].bind(by, te.thread_axis('blockIdx.y'))
-        sch[Out].bind(bx, te.thread_axis('blockIdx.x'))
-        sch[Out].bind(tyz, te.thread_axis('vthread'))
-        sch[Out].bind(txz, te.thread_axis('vthread'))
-        sch[Out].bind(ty, te.thread_axis('threadIdx.y'))
-        sch[Out].bind(tx, te.thread_axis('threadIdx.x'))
+        sch[Out].bind(bz, te.thread_axis("blockIdx.z"))
+        sch[Out].bind(by, te.thread_axis("blockIdx.y"))
+        sch[Out].bind(bx, te.thread_axis("blockIdx.x"))
+        sch[Out].bind(tyz, te.thread_axis("vthread"))
+        sch[Out].bind(txz, te.thread_axis("vthread"))
+        sch[Out].bind(ty, te.thread_axis("threadIdx.y"))
+        sch[Out].bind(tx, te.thread_axis("threadIdx.x"))
 
         # Schedule BL local write
         sch[BL].compute_at(sch[Out], tx)
@@ -118,21 +120,21 @@ def schedule_conv2d_hwcn(cfg, outs):
         sch[WL].compute_at(sch[BL], rci)
         # Schedule for A's shared memory load
         yi, xi, ci, ni = sch[AA].op.axis
-        ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
-        tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
+        ty, ci = sch[AA].split(ci, nparts=cfg["tile_fi"].size[2])
+        tx, ni = sch[AA].split(ni, nparts=cfg["tile_ni"].size[2])
         _, ni = sch[AA].split(ni, factor=4)
         sch[AA].reorder(ty, tx, yi, xi, ci, ni)
-        sch[AA].bind(ty, te.thread_axis('threadIdx.y'))
-        sch[AA].bind(tx, te.thread_axis('threadIdx.x'))
+        sch[AA].bind(ty, te.thread_axis("threadIdx.y"))
+        sch[AA].bind(tx, te.thread_axis("threadIdx.x"))
         sch[AA].vectorize(ni)
         # Schedule for W's shared memory load
         yi, xi, ci, fi = sch[WW].op.axis
-        ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
-        tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
+        ty, ci = sch[WW].split(ci, nparts=cfg["tile_fi"].size[2])
+        tx, fi = sch[WW].split(fi, nparts=cfg["tile_ni"].size[2])
         _, fi = sch[WW].split(fi, factor=4)
         sch[WW].reorder(ty, tx, yi, xi, ci, fi)
-        sch[WW].bind(ty, te.thread_axis('threadIdx.y'))
-        sch[WW].bind(tx, te.thread_axis('threadIdx.x'))
+        sch[WW].bind(ty, te.thread_axis("threadIdx.y"))
+        sch[WW].bind(tx, te.thread_axis("threadIdx.x"))
         sch[WW].vectorize(fi)
 
     scheduled_ops = []
@@ -145,10 +147,10 @@ def schedule_conv2d_hwcn(cfg, outs):
             for tensor in operator.input_tensors:
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
-        elif operator.tag == 'conv2d_hwcn':
+        elif operator.tag == "conv2d_hwcn":
             Apad = operator.input_tensors[0]
             W = operator.input_tensors[1]
-            if isinstance(W.op, tvm.te.ComputeOp) and 'dilate' in W.op.tag:
+            if isinstance(W.op, tvm.te.ComputeOp) and "dilate" in W.op.tag:
                 sch[W].compute_inline()
             B = operator.output(0)
             schedule(Apad, W, B)
index b7dad79..db5a6c9 100644 (file)
@@ -53,28 +53,27 @@ def unpack_HWNCnc_to_hwnc(packed_out, out_dtype):
     idxdiv = tvm.tir.indexdiv
 
     oshape = (H, W, N * wmma_m, O * wmma_n)
-    unpacked_out = \
-        te.compute(oshape,
-                   lambda h, w, n, o:
-                   packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n),
-                              idxmod(n, wmma_m), idxmod(o, wmma_n)]
-                   .astype(out_dtype),
-                   name='output_unpack',
-                   tag=tag.INJECTIVE + ",unpack_hwncc")
+    unpacked_out = te.compute(
+        oshape,
+        lambda h, w, n, o: packed_out[
+            h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n), idxmod(n, wmma_m), idxmod(o, wmma_n)
+        ].astype(out_dtype),
+        name="output_unpack",
+        tag=tag.INJECTIVE + ",unpack_hwncc",
+    )
     return unpacked_out
 
 
-def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype='int32'):
+def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype="int32"):
     """"Compute conv2d with tensorcore for HWNC layout with int8/int4"""
-    assert data.dtype in ('int4', 'uint4', 'int8', 'uint8')
-    assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8')
-    packed_out = hwnc_tensorcore_cuda(
-        data, kernel, strides, padding, dilation, out_dtype)
+    assert data.dtype in ("int4", "uint4", "int8", "uint8")
+    assert kernel.dtype in ("int4", "uint4", "int8", "uint8")
+    packed_out = hwnc_tensorcore_cuda(data, kernel, strides, padding, dilation, out_dtype)
     return unpack_HWNCnc_to_hwnc(packed_out, out_dtype)
 
 
 @autotvm.register_topi_compute("conv2d_HWNCnc_tensorcore.cuda")
-def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype='int32'):
+def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype="int32"):
     """Compute declaration for tensorcore"""
     assert isinstance(stride, int) or len(stride) == 2
     assert isinstance(dilation, int) or len(dilation) == 2
@@ -91,7 +90,7 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
 
     in_dtype = Input.dtype
 
-    if in_dtype in ['int4', 'uint4']:
+    if in_dtype in ["int4", "uint4"]:
         wmma_n = wmma_m = 8
         wmma_k = 32
     else:
@@ -102,85 +101,84 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
     pre_computed = len(Filter.shape) == 6
     in_height, in_width, batch, in_channels = get_const_tuple(Input.shape)
     if pre_computed:
-        kernel_h, kernel_w, oc_chunk, _, oc_block_factor, _\
-            = get_const_tuple(Filter.shape)
+        kernel_h, kernel_w, oc_chunk, _, oc_block_factor, _ = get_const_tuple(Filter.shape)
         num_filter = oc_block_factor * oc_chunk
     else:
         kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape)
 
-    if in_dtype in ['int4', 'uint4']:
-        assert (batch % 8 == 0 and in_channels %
-                32 == 0 and num_filter % 8 == 0)
+    if in_dtype in ["int4", "uint4"]:
+        assert batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 == 0
     else:
-        assert (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0), \
-            "The shape of (batch, in_channels, num_filter) "\
-            "must be multiple of (8, 16, 32) for int8, "\
+        assert batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0, (
+            "The shape of (batch, in_channels, num_filter) "
+            "must be multiple of (8, 16, 32) for int8, "
             "and (8, 32, 8) for int4"
+        )
 
     # compute the output shape
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
 
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
 
     out_channels = num_filter
-    out_height = simplify(
-        (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
-    out_width = simplify((in_width - dilated_kernel_w +
-                          pad_left + pad_right) // stride_w + 1)
+    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
 
-    cfg.add_flop(2 * batch * out_height * out_width *
-                 out_channels * in_channels * kernel_h * kernel_w)
+    cfg.add_flop(
+        2 * batch * out_height * out_width * out_channels * in_channels * kernel_h * kernel_w
+    )
 
     # Input feature map: (H, W, N, IC, n, ic)
-    data_shape = (in_height,
-                  in_width,
-                  batch // wmma_m,
-                  in_channels // wmma_k,
-                  wmma_m,
-                  wmma_k)
+    data_shape = (in_height, in_width, batch // wmma_m, in_channels // wmma_k, wmma_m, wmma_k)
 
     # Kernel: (H, W, OC, IC, oc, ic)
-    kernel_shape = (kernel_h,
-                    kernel_w,
-                    out_channels // wmma_n,
-                    in_channels // wmma_k,
-                    wmma_n,
-                    wmma_k)
+    kernel_shape = (
+        kernel_h,
+        kernel_w,
+        out_channels // wmma_n,
+        in_channels // wmma_k,
+        wmma_n,
+        wmma_k,
+    )
 
     # Reduction axes
-    kh = te.reduce_axis((0, kernel_h), name='kh')
-    kw = te.reduce_axis((0, kernel_w), name='kw')
-    ic = te.reduce_axis((0, in_channels // wmma_k), name='ic')
-    ii = te.reduce_axis((0, wmma_k), name='ii')
+    kh = te.reduce_axis((0, kernel_h), name="kh")
+    kw = te.reduce_axis((0, kernel_w), name="kw")
+    ic = te.reduce_axis((0, in_channels // wmma_k), name="ic")
+    ii = te.reduce_axis((0, wmma_k), name="ii")
 
     if pre_computed:
         packed_kernel = Filter
     else:
-        packed_kernel = te.compute(kernel_shape, lambda kh, kw, o, i, oo, ii:
-                                   Filter[kh, kw, o * wmma_n +
-                                          oo, i * wmma_k + ii],
-                                   name="packed_kernel"
-                                   )
+        packed_kernel = te.compute(
+            kernel_shape,
+            lambda kh, kw, o, i, oo, ii: Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii],
+            name="packed_kernel",
+        )
 
-    packed_data = te.compute(data_shape,
-                             lambda h, w, n, i, nn, ii: Input[h,
-                                                              w, n * wmma_m + nn, i * wmma_k + ii]
-                             )
+    packed_data = te.compute(
+        data_shape, lambda h, w, n, i, nn, ii: Input[h, w, n * wmma_m + nn, i * wmma_k + ii]
+    )
 
     pad_before = [pad_top, pad_left, 0, 0, 0, 0]
     pad_after = [pad_down, pad_right, 0, 0, 0, 0]
     pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
 
-    Conv = te.compute((out_height, out_width, batch // wmma_m,
-                       out_channels // wmma_n, wmma_m, wmma_n),
-                      lambda h, w, n, o, nn, oo: te.sum(
-                          (pad_data[h * stride_h + kh, w * stride_w + kw,
-                                    n, ic, nn, ii].astype('int32') *
-                           packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')),
-                          axis=[ic, kh, kw, ii]),
-                      name="Conv", tag="conv2d_HWNCnc_tensorcore")
+    Conv = te.compute(
+        (out_height, out_width, batch // wmma_m, out_channels // wmma_n, wmma_m, wmma_n),
+        lambda h, w, n, o, nn, oo: te.sum(
+            (
+                pad_data[h * stride_h + kh, w * stride_w + kw, n, ic, nn, ii].astype("int32")
+                * packed_kernel[kh, kw, o, ic, oo, ii].astype("int32")
+            ),
+            axis=[ic, kh, kw, ii],
+        ),
+        name="Conv",
+        tag="conv2d_HWNCnc_tensorcore",
+    )
     return Conv
 
 
@@ -190,37 +188,36 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     ic, kh, kw, ii = s[Conv].op.reduce_axis
     pad_data = s[packed_data].op.input_tensors[0]
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
     # Designate the memory hierarchy
-    AS = s.cache_read(packed_data, 'shared', [Conv])
-    WS = s.cache_read(packed_kernel, 'shared', [Conv])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
-    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
-    ConvF = s.cache_write(Conv, 'wmma.accumulator')
+    AS = s.cache_read(packed_data, "shared", [Conv])
+    WS = s.cache_read(packed_kernel, "shared", [Conv])
+    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
+    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
+    ConvF = s.cache_write(Conv, "wmma.accumulator")
 
     if Conv.op in s.outputs:
         output = Conv
-        ConvS = s.cache_read(ConvF, 'shared', [Conv])
+        ConvS = s.cache_read(ConvF, "shared", [Conv])
         OL = ConvS
     else:
         output = s.outputs[0].output(0)
-        s[Conv].set_scope('shared')
+        s[Conv].set_scope("shared")
         OL = Conv
 
     out_dtype = Conv.dtype
 
     if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel":
         if autotvm.GLOBAL_SCOPE.in_tuning:
-            s[packed_kernel].pragma(
-                s[packed_kernel].op.axis[0], "debug_skip_region")
+            s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
         else:
-            with Target('cuda'):
+            with Target("cuda"):
                 schedule_injective_from_existing(s, packed_kernel)
 
     if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag:
@@ -262,10 +259,10 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     if not fuse_pack:
         s[packed_data].compute_inline()
     else:
-        with Target('cuda'):
+        with Target("cuda"):
             schedule_injective_from_existing(s, packed_data)
 
-    if data_dtype in ['int4', 'uint4']:
+    if data_dtype in ["int4", "uint4"]:
         wmma_m = wmma_n = 8
         wmma_k = 32
     else:
@@ -277,7 +274,12 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
 
     # Schedule for output
     if len(s[output].op.axis) == 4:
-        hc, wc, nc, oc, = output.op.axis
+        (
+            hc,
+            wc,
+            nc,
+            oc,
+        ) = output.op.axis
         nc, nnc = s[output].split(nc, factor=wmma_m)
         oc, ooc = s[output].split(oc, factor=wmma_n)
     else:
@@ -286,14 +288,12 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     kernel_scope, hc = s[output].split(hc, nparts=1)
 
     block_k = s[output].fuse(hc, wc)
-    block_k, split_block_k = s[output].split(
-        block_k, factor=split_block_k_nums)
+    block_k, split_block_k = s[output].split(block_k, factor=split_block_k_nums)
     nc, nci = s[output].split(nc, factor=warp_row_tiles)
     block_i, nc = s[output].split(nc, factor=block_row_warps)
     oc, oci = s[output].split(oc, factor=warp_col_tiles)
     block_j, oc = s[output].split(oc, factor=block_col_warps)
-    s[output].reorder(block_k, split_block_k, block_i,
-                      block_j, nc, oc, nci, oci, nnc, ooc)
+    s[output].reorder(block_k, split_block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
     t = s[output].fuse(nnc, ooc)
     _, tx = s[output].split(t, factor=warp_size)
     s[output].bind(block_k, block_z)
@@ -375,18 +375,17 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     s[WS].vectorize(ti)
 
     # double buffer
-    cfg.define_knob('AS_double_buffer', [0, 1])
-    cfg.define_knob('WS_double_buffer', [0, 1])
-    if cfg['AS_double_buffer'].val:
+    cfg.define_knob("AS_double_buffer", [0, 1])
+    cfg.define_knob("WS_double_buffer", [0, 1])
+    if cfg["AS_double_buffer"].val:
         s[AS].double_buffer()
-    if cfg['WS_double_buffer'].val:
+    if cfg["WS_double_buffer"].val:
         s[WS].double_buffer()
 
     # unroll
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                     cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', False)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", False)
 
     shape = (wmma_m, wmma_n, wmma_k)
 
@@ -397,13 +396,16 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     CL_shape = (wmma_m, wmma_n)
     CS_shape = (wmma_m, wmma_n)
 
-    AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype)
-    WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype)
+    AL_gemm = te.placeholder(AL_shape, name="A", dtype=data_dtype)
+    WL_gemm = te.placeholder(WL_shape, name="B", dtype=kernel_dtype)
     k_gemm = te.reduce_axis((0, wmma_k), name="k")
-    CL_compute = te.compute(CL_shape, lambda ii, jj:
-                            te.sum((AL_gemm[ii, k_gemm].astype(
-                                'int32') * WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm),
-                            name='C')
+    CL_compute = te.compute(
+        CL_shape,
+        lambda ii, jj: te.sum(
+            (AL_gemm[ii, k_gemm].astype("int32") * WL_gemm[jj, k_gemm].astype("int32")), axis=k_gemm
+        ),
+        name="C",
+    )
 
     AL_strides = [wmma_k, 1]
     AS_strides = [wmma_k, 1]
@@ -412,19 +414,28 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     CL_strides = [wmma_n, 1]
     CS_strides = [wmma_n, 1]
 
-    s[AF].tensorize(AF.op.axis[-2],
-                    intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
-                                              "row_major", AS_shape, AL_shape, data_dtype))
-
-    s[WF].tensorize(WF.op.axis[-2],
-                    intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
-                                              "col_major", WS_shape, WL_shape, kernel_dtype))
-
-    s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
-                                                  shape, out_dtype, CL_shape, CS_shape))
-
-    s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
-                                             WL_strides, CL_strides, shape))
+    s[AF].tensorize(
+        AF.op.axis[-2],
+        intrin_wmma_load_matrix_A(
+            AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, data_dtype
+        ),
+    )
+
+    s[WF].tensorize(
+        WF.op.axis[-2],
+        intrin_wmma_load_matrix_W(
+            WL_strides, WS_strides, shape, "col_major", WS_shape, WL_shape, kernel_dtype
+        ),
+    )
+
+    s[OL].tensorize(
+        nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)
+    )
+
+    s[ConvF].tensorize(
+        nnf,
+        intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape),
+    )
 
     return s
 
@@ -435,7 +446,7 @@ def schedule_conv2d_hwnc_tensorcore(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_HWNCnc_tensorcore' in op.tag:
+        if "conv2d_HWNCnc_tensorcore" in op.tag:
             schedule_hwnc_tensorcore_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
index bc8aa35..deeec50 100644 (file)
@@ -29,18 +29,20 @@ from ..nn.util import get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
 
 
-def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
+def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"):
     """Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
-    assert data.dtype in ('int8', 'uint8')
-    assert kernel.dtype in ('int8', 'uint8')
+    assert data.dtype in ("int8", "uint8")
+    assert kernel.dtype in ("int8", "uint8")
     assert data.dtype == kernel.dtype
     packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
     return unpack_NCHWc_to_nchw(packed_out, out_dtype)
 
+
 def schedule_conv2d_nchw_int8(outs):
     """Create schedule for tensors"""
     return schedule_conv2d_NCHWc_int8(outs)
 
+
 @autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
 def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
     """Convolution operator in NCHW[x]c layout for int8.
@@ -86,35 +88,42 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
     pre_computed = len(kernel.shape) == 6
     if not pre_computed:
         batch, channels, height, width = get_const_tuple(data.shape)
-        assert channels % ic_block_factor == 0, \
-            "Number of input channels should be multiple of {}".format(
-                ic_block_factor)
-        packed_data = te.compute((batch, channels // ic_block_factor, height, width,
-                                  ic_block_factor),
-                                 lambda n, c, h, w, vc: data[n, c*ic_block_factor + vc, h, w],
-                                 name="packed_data")
-
-        out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(
-            kernel.shape)
-        assert out_channels % 4 == 0, \
-            "Number of output channels should be multiple of {}".format(
-                oc_block_factor)
+        assert (
+            channels % ic_block_factor == 0
+        ), "Number of input channels should be multiple of {}".format(ic_block_factor)
+        packed_data = te.compute(
+            (batch, channels // ic_block_factor, height, width, ic_block_factor),
+            lambda n, c, h, w, vc: data[n, c * ic_block_factor + vc, h, w],
+            name="packed_data",
+        )
+
+        out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape)
+        assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format(
+            oc_block_factor
+        )
         packed_kernel = te.compute(
-            (out_channels // oc_block_factor, in_channels // ic_block_factor, kernel_h, kernel_w,
-             oc_block_factor, ic_block_factor),
-            lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block:
-            kernel[oc_chunk * oc_block_factor + oc_block,
-                   ic_chunk * ic_block_factor + ic_block, kh, kw],
-            name="packed_kernel")
+            (
+                out_channels // oc_block_factor,
+                in_channels // ic_block_factor,
+                kernel_h,
+                kernel_w,
+                oc_block_factor,
+                ic_block_factor,
+            ),
+            lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block: kernel[
+                oc_chunk * oc_block_factor + oc_block, ic_chunk * ic_block_factor + ic_block, kh, kw
+            ],
+            name="packed_kernel",
+        )
 
     else:
         packed_data = data
         packed_kernel = kernel
 
-    batch, ic_chunk, in_height, in_width, ic_block = get_const_tuple(
-        packed_data.shape)
+    batch, ic_chunk, in_height, in_width, ic_block = get_const_tuple(packed_data.shape)
     oc_chunk, ic_chunk, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
-        packed_kernel.shape)
+        packed_kernel.shape
+    )
 
     if isinstance(stride, int):
         stride_h = stride_w = stride
@@ -126,8 +135,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
     else:
         dilation_h, dilation_w = dilation
 
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (kernel_h, kernel_w))
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
     # compute graph
     pad_before = [0, 0, pad_top, pad_left, 0]
     pad_after = [0, 0, pad_down, pad_right, 0]
@@ -139,33 +147,47 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
 
     oshape = (batch, oc_chunk, out_height, out_width, oc_block)
 
-    icc = te.reduce_axis((0, ic_chunk), name='ic_chunk')
-    icb = te.reduce_axis((0, ic_block), name='ic_block')
-    kh = te.reduce_axis((0, kernel_h), name='kh')
-    kw = te.reduce_axis((0, kernel_w), name='kw')
-
-    conv = te.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                      te.sum(pad_data[n, icc, oh*stride_h+kh*dilation_h, \
-                                      ow*stride_w+kw*dilation_w, icb]
-                             .astype('int32') *
-                             packed_kernel[oc_chunk, icc,
-                                           kh, kw, oc_block, icb]
-                             .astype('int32'),
-                             axis=[icc, kh, kw, icb]))
-
-    output = te.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                        conv[n, oc_chunk, oh, ow, oc_block].astype(out_dtype),
-                        tag="conv2d_NCHWc_int8")
+    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
+    icb = te.reduce_axis((0, ic_block), name="ic_block")
+    kh = te.reduce_axis((0, kernel_h), name="kh")
+    kw = te.reduce_axis((0, kernel_w), name="kw")
+
+    conv = te.compute(
+        oshape,
+        lambda n, oc_chunk, oh, ow, oc_block: te.sum(
+            pad_data[
+                n, icc, oh * stride_h + kh * dilation_h, ow * stride_w + kw * dilation_w, icb
+            ].astype("int32")
+            * packed_kernel[oc_chunk, icc, kh, kw, oc_block, icb].astype("int32"),
+            axis=[icc, kh, kw, icb],
+        ),
+    )
+
+    output = te.compute(
+        oshape,
+        lambda n, oc_chunk, oh, ow, oc_block: conv[n, oc_chunk, oh, ow, oc_block].astype(out_dtype),
+        tag="conv2d_NCHWc_int8",
+    )
 
     # num flop
-    num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
-        ic_chunk * ic_block * kernel_h * kernel_w * 2
+    num_flop = (
+        batch
+        * oc_chunk
+        * oc_block
+        * out_height
+        * out_width
+        * ic_chunk
+        * ic_block
+        * kernel_h
+        * kernel_w
+        * 2
+    )
     cfg.add_flop(num_flop)
 
     return output
 
 
-_dp4a = dp4a('shared', 'shared', 'local')
+_dp4a = dp4a("shared", "shared", "local")
 
 
 @autotvm.register_topi_schedule("conv2d_NCHWc_int8.cuda")
@@ -175,7 +197,7 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv2d_NCHWc_int8':
+        if op.tag == "conv2d_NCHWc_int8":
             _schedule_conv2d_NCHWc_int8(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
@@ -198,8 +220,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
         s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
         s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
     else:
-        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and\
-                packed_kernel.name == 'packed_kernel':
+        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
             # data and kernel are not pre-computed, schedule layout transform here
             schedule_injective_from_existing(s, packed_data)
             schedule_injective_from_existing(s, packed_kernel)
@@ -208,10 +229,10 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
         s[pad_data].compute_inline()
 
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [conv])
-    WW = s.cache_read(packed_kernel, 'shared', [conv])
+    AA = s.cache_read(pad_data, "shared", [conv])
+    WW = s.cache_read(packed_kernel, "shared", [conv])
 
-    s[conv].set_scope('local')
+    s[conv].set_scope("local")
 
     # handle bias
     if output.op not in s.outputs:
@@ -248,7 +269,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
     s[output].bind(vy, te.thread_axis("vthread"))
     s[output].bind(vx, te.thread_axis("vthread"))
 
-    cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
     if cfg["fuse_yx"].val:
         s[output].bind(tn, te.thread_axis("threadIdx.z"))
         s[output].bind(tf, te.thread_axis("threadIdx.y"))
@@ -278,9 +299,9 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
     cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
     cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
     cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
-    rco, rci = cfg['tile_rc'].apply(s, conv, rc)
-    ryo, ryi = cfg['tile_ry'].apply(s, conv, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, conv, rx)
+    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
 
     s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)
 
@@ -311,17 +332,16 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # double buffer
-    cfg.define_knob('AA_double_buffer', [0, 1])
-    cfg.define_knob('WW_double_buffer', [0, 1])
-    if cfg['AA_double_buffer'].val:
+    cfg.define_knob("AA_double_buffer", [0, 1])
+    cfg.define_knob("WW_double_buffer", [0, 1])
+    if cfg["AA_double_buffer"].val:
         s[AA].double_buffer()
-    if cfg['WW_double_buffer'].val:
+    if cfg["WW_double_buffer"].val:
         s[WW].double_buffer()
 
     # unroll
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                     cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', False)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", False)
 
     return s
index 23607b1..b256345 100644 (file)
@@ -27,18 +27,18 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
     pad_data, kernel = s[Conv].op.input_tensors
     s[pad_data].compute_inline()
 
-    if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
         s[kernel].compute_inline()
 
     if Conv.op in s.outputs:
         output = Conv
-        OL = s.cache_write(Conv, 'local')
+        OL = s.cache_write(Conv, "local")
     else:
         output = s.outputs[0].output(0)
-        s[Conv].set_scope('local')
+        s[Conv].set_scope("local")
         OL = Conv
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [OL])
+    AA = s.cache_read(pad_data, "shared", [OL])
     WW = s.cache_read(kernel, "shared", [OL])
     AL = s.cache_read(AA, "local", [OL])
     WL = s.cache_read(WW, "local", [OL])
@@ -56,7 +56,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, 'conv2d_nhwc.cuda')
+            target.kind.name, target.model, "conv2d_nhwc.cuda"
+        )
         cfg.fallback_with_reference_log(ref_log)
 
     tile_n = cfg["tile_n"].val
index a82508b..a33092d 100644 (file)
@@ -47,39 +47,47 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
 
     batch, in_height, in_width, in_channel = get_const_tuple(Input.shape)
     kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
-    assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \
-               (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \
-               (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \
-               "The shape of (batch, in_channel, num_filter) "\
-               "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+    assert (
+        (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0)
+        or (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0)
+        or (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0)
+    ), (
+        "The shape of (batch, in_channel, num_filter) "
+        "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+    )
 
     # compute the output shape
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
     pad_before = [0, pad_top, pad_left, 0]
     pad_after = [0, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     # convert data type of input feature maps and weights
     TransPaddedInput = te.compute(
-        PaddedInput.shape,
-        lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16'))
-    TransFilter = te.compute(
-        Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
+        PaddedInput.shape, lambda n, h, w, c: PaddedInput[n, h, w, c].astype("float16")
+    )
+    TransFilter = te.compute(Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype("float16"))
     Output = te.compute(
         (batch, out_height, out_width, out_channel),
         lambda nn, yy, xx, ff: te.sum(
-            TransPaddedInput[nn, yy * stride_h + ry * dilation_h,
-                             xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
-        name="Conv2dOutput", tag="conv2d_nhwc_tensorcore")
+            TransPaddedInput[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc
+            ].astype(out_dtype)
+            * TransFilter[ry, rx, rc, ff].astype(out_dtype),
+            axis=[ry, rx, rc],
+        ),
+        name="Conv2dOutput",
+        tag="conv2d_nhwc_tensorcore",
+    )
     return Output
 
 
@@ -99,19 +107,19 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     s[paddata[0]].compute_inline()
 
     # Designate the memory hierarchy
-    AS = s.cache_read(trans_paddata, 'shared', [Conv])
-    WS = s.cache_read(kernel, 'shared', [Conv])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
-    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
-    ConvF = s.cache_write(Conv, 'wmma.accumulator')
+    AS = s.cache_read(trans_paddata, "shared", [Conv])
+    WS = s.cache_read(kernel, "shared", [Conv])
+    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
+    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
+    ConvF = s.cache_write(Conv, "wmma.accumulator")
 
     if Conv.op in s.outputs:
         output = Conv
-        ConvS = s.cache_read(ConvF, 'shared', [Conv])
+        ConvS = s.cache_read(ConvF, "shared", [Conv])
         OL = ConvS
     else:
         output = s.outputs[0].output(0)
-        s[Conv].set_scope('shared')
+        s[Conv].set_scope("shared")
         OL = Conv
 
     # Schedule for autotvm
@@ -123,18 +131,19 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     cfg.define_knob("offset", [0, 8])
     cfg.define_knob("vector_width", [1, 2, 4, 8])
 
-    if (batch % 16 == 0 and out_channels % 16 == 0):
+    if batch % 16 == 0 and out_channels % 16 == 0:
         cfg.define_knob("wmma_m", [16, 8, 32])
-    elif (batch % 8 == 0 and out_channels % 32 == 0):
+    elif batch % 8 == 0 and out_channels % 32 == 0:
         cfg.define_knob("wmma_m", [8, 16, 32])
-    elif (batch % 32 == 0 and out_channels % 8 == 0):
+    elif batch % 32 == 0 and out_channels % 8 == 0:
         cfg.define_knob("wmma_m", [32, 16, 8])
 
     # fallback support
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, 'conv2d_nhwc_tensorcore.cuda')
+            target.kind.name, target.model, "conv2d_nhwc_tensorcore.cuda"
+        )
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
@@ -156,12 +165,12 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
 
     warp_size = 32
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
     # Define the intrin strides
     def get_strides(extents):
@@ -277,22 +286,37 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     CL_shape = (wmma_m, 1, 1, wmma_n)
     CS_shape = (wmma_m, 1, 1, wmma_n)
 
-    AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype)
-    WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype)
+    AL_gemm = te.placeholder(AL_shape, name="A", dtype=in_dtype)
+    WL_gemm = te.placeholder(WL_shape, name="B", dtype=in_dtype)
     k_gemm = te.reduce_axis((0, wmma_k), name="k")
-    CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj:
-                            te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \
-                                   WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm),
-                            name='C')
-
-    s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
-                                                  "row_major", AS_shape, AL_shape, in_dtype))
-    s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
-                                                  "row_major", WS_shape, WL_shape, in_dtype))
-    s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
-                                                  shape, out_dtype, CL_shape, CS_shape))
-    s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
-                                             WL_strides, CL_strides, shape))
+    CL_compute = te.compute(
+        CL_shape,
+        lambda ii, t0, t1, jj: te.sum(
+            AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * WL_gemm[k_gemm, jj].astype(out_dtype),
+            axis=k_gemm,
+        ),
+        name="C",
+    )
+
+    s[AF].tensorize(
+        nn,
+        intrin_wmma_load_matrix_A(
+            AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype
+        ),
+    )
+    s[WF].tensorize(
+        ii,
+        intrin_wmma_load_matrix_W(
+            WL_strides, WS_strides, shape, "row_major", WS_shape, WL_shape, in_dtype
+        ),
+    )
+    s[OL].tensorize(
+        nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)
+    )
+    s[ConvF].tensorize(
+        nnf,
+        intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape),
+    )
 
     N, OH, OW, CO = get_const_tuple(output.shape)
     KH, KW, CI, _ = get_const_tuple(kernel.shape)
@@ -311,7 +335,7 @@ def schedule_conv2d_nhwc_tensorcore(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_tensorcore' in op.tag:
+        if "conv2d_nhwc_tensorcore" in op.tag:
             schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
index 2f5b85e..cc0bbeb 100644 (file)
@@ -30,6 +30,7 @@ from .tensor_intrin import intrin_wmma_load_matrix_W
 from .tensor_intrin import intrin_wmma_store_matrix
 from .tensor_intrin import intrin_wmma_gemm
 
+
 def _infer_tile_size(data, kernel):
     """Compute the tile size"""
     N, H, W, CI = get_const_tuple(data.shape)
@@ -47,12 +48,12 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
     out_dtype = C.dtype
 
     # Explicit memory access
-    AS = s.cache_read(A, 'shared', [C])
-    BS = s.cache_read(B, 'shared', [C])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
-    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
-    CF = s.cache_write(C, 'wmma.accumulator')
-    CS = s.cache_read(CF, 'shared', [C])
+    AS = s.cache_read(A, "shared", [C])
+    BS = s.cache_read(B, "shared", [C])
+    AF = s.cache_read(AS, "wmma.matrix_a", [C])
+    BF = s.cache_read(BS, "wmma.matrix_b", [C])
+    CF = s.cache_write(C, "wmma.accumulator")
+    CS = s.cache_read(CF, "shared", [C])
 
     # Create tuning space
     cfg.define_knob("block_row_warps", [1, 2, 4])
@@ -65,11 +66,11 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
     cfg.define_knob("vec", [1, 2, 4, 8])
 
     # Ensure that the default parameters are applicable when autotvm is not in use
-    if (P % 16 == 0 and out_dim % 16 == 0):
+    if P % 16 == 0 and out_dim % 16 == 0:
         cfg.define_knob("wmma_m", [16, 8, 32])
-    elif (P % 32 == 0 and out_dim % 8 == 0):
+    elif P % 32 == 0 and out_dim % 8 == 0:
         cfg.define_knob("wmma_m", [32, 16, 8])
-    elif (P % 8 == 0 and out_dim % 32 == 0):
+    elif P % 8 == 0 and out_dim % 32 == 0:
         cfg.define_knob("wmma_m", [8, 16, 32])
 
     warp_size = 32
@@ -101,12 +102,12 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
     BF_stride = [wmma_n * warp_col_tiles, 1]
     CF_stride = [warp_col_tiles * wmma_n, 1]
     CS_stride = [CS_align, 1]
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
     # Schedule for computation
     block_factor_b = wmma_m * warp_row_tiles * block_row_warps
@@ -144,7 +145,7 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
     _, _, warp_i, warp_j = CF.op.axis
     warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
     warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
-    k, = CF.op.reduce_axis
+    (k,) = CF.op.reduce_axis
     k, _k = s[CF].split(k, factor=wmma_k)
     ko, ki = s[CF].split(k, factor=chunk)
     s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
@@ -182,25 +183,42 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
     shared_shedule(BS, BS_align)
 
     shape = (wmma_m, wmma_n, wmma_k)
-    in_dtype = 'float16'
-    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
-    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
-    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
-    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
-                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
-                                   BL_gemm[k_gemm, jj].astype(out_dtype),
-                                   axis=k_gemm), name='CL_compute')
+    in_dtype = "float16"
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name="BL_gemm", dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
+    CL_compute = te.compute(
+        (wmma_m, wmma_n),
+        lambda ii, jj: te.sum(
+            AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[k_gemm, jj].astype(out_dtype),
+            axis=k_gemm,
+        ),
+        name="CL_compute",
+    )
 
     # Lower the computation loops down to TensorCore hardware intrinsics
     # by mapping the tensorcore to tensor intrinsics
-    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major",
-                                                    (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
-    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major",
-                                                    (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16'))
-    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
-                                          BF_stride, CF_stride, shape))
-    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype,
-                                                  (wmma_m, wmma_n), (wmma_m, wmma_n)))
+    s[AF].tensorize(
+        b_ii,
+        intrin_wmma_load_matrix_A(
+            AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16"
+        ),
+    )
+    s[BF].tensorize(
+        i_ii,
+        intrin_wmma_load_matrix_W(
+            BF_stride, BS_stride, shape, "row_major", (wmma_k, wmma_n), (wmma_k, wmma_n), "float16"
+        ),
+    )
+    s[CF].tensorize(
+        _ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape)
+    )
+    s[CS].tensorize(
+        bbi,
+        intrin_wmma_store_matrix(
+            CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)
+        ),
+    )
 
 
 def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
@@ -210,8 +228,9 @@ def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
     alpha = get_const_int(b1.dom.extent)
 
     # Create tuning space
-    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
-                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split(
+        "tile_b", cfg.axis(alpha * alpha), num_outputs=4, filter=lambda x: x.size[-3:] == [1, 1, 1]
+    )
     cfg.define_split("tile_y", y, num_outputs=4)
     cfg.define_split("tile_x", x, num_outputs=4)
     cfg.define_split("tile_rc", rc, num_outputs=2)
@@ -224,9 +243,9 @@ def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
     A0, B0 = kernel_pack, data_pack
 
     # Designate the memory hierarchy
-    OL = s.cache_write(C, 'local')
-    AA = s.cache_read(A0, 'shared', [OL])
-    BB = s.cache_read(B0, 'shared', [OL])
+    OL = s.cache_write(C, "local")
+    AA = s.cache_read(A0, "shared", [OL])
+    BB = s.cache_read(B0, "shared", [OL])
 
     # Tile and bind spatial axes
     b = s[bgemm].fuse(b1, b2)
@@ -249,8 +268,8 @@ def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
     s[OL].compute_at(s[C], tx)
     b1, b2, y, x = s[OL].op.axis
     b = s[OL].fuse(b1, b2)
-    rc, = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    (rc,) = s[OL].op.reduce_axis
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
     s[OL].reorder(rco, b, y, x, rci)
 
     s[AA].compute_at(s[OL], rco)
@@ -276,8 +295,9 @@ def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
         s[load].vectorize(ti)
 
 
-def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                       use_tensorcore, pre_computed):
+def nhwc_winograd_cuda(
+    cfg, data, kernel, strides, padding, dilation, out_dtype, use_tensorcore, pre_computed
+):
     """Compute declaration for winograd"""
     tile_size = _infer_tile_size(data, kernel)
     N, H, W, CI = get_const_tuple(data.shape)
@@ -313,9 +333,11 @@ def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
     P = N * nH * nW
 
     # Determine whether the shape is available with tensorcore
-    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+    shape_judge = (
+        (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
+        or (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
+        or (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+    )
 
     if shape_judge and use_tensorcore:
         trans_type = "float16"
@@ -331,16 +353,19 @@ def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
         # Check if we are currently tuning, if so we want to avoid counting
         # prepacking in time costs. Just use a placeholder with the packed shape instead.
         if autotvm.GLOBAL_SCOPE.in_tuning:
-            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
-                                         dtype=kernel.dtype,
-                                         name='kernel_pack')
+            kernel_pack = te.placeholder(
+                (alpha, alpha, CI, CO), dtype=kernel.dtype, name="kernel_pack"
+            )
         else:
-            r_kh = te.reduce_axis((0, KH), name='r_kh')
-            r_kw = te.reduce_axis((0, KW), name='r_kw')
-            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
-                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
-                                            G[eps][r_kh] * G[nu][r_kw],
-                                            axis=[r_kh, r_kw]), name='kernel_pack')
+            r_kh = te.reduce_axis((0, KH), name="r_kh")
+            r_kw = te.reduce_axis((0, KW), name="r_kw")
+            kernel_pack = te.compute(
+                (alpha, alpha, CI, CO),
+                lambda eps, nu, ci, co: te.sum(
+                    (kernel[r_kh][r_kw][ci][co]) * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+                ),
+                name="kernel_pack",
+            )
     else:
         kernel_pack = kernel
 
@@ -348,46 +373,65 @@ def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
     idxmod = tvm.tir.indexmod
 
     # Pack input tile
-    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
-                            data_pad[idxdiv(p, (nH * nW)),
-                                     idxmod(idxdiv(p, nW), nH) * m + eps,
-                                     idxmod(p, nW) * m + nu,
-                                     c], name='d')
+    input_tile = te.compute(
+        (P, CI, alpha, alpha),
+        lambda p, c, eps, nu: data_pad[
+            idxdiv(p, (nH * nW)), idxmod(idxdiv(p, nW), nH) * m + eps, idxmod(p, nW) * m + nu, c
+        ],
+        name="d",
+    )
 
     # Transform data
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
-                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
-                                  axis=[r_a, r_b]), name='data_pack')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    data_pack = te.compute(
+        (alpha, alpha, P, CI),
+        lambda eps, nu, p, ci: te.sum(
+            input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
+        ),
+        name="data_pack",
+    )
 
     # Convert data type of input feature maps and weights for tensorcore
     Transdata = te.compute(
-        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type))
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type)
+    )
     TransFilter = te.compute(
-        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type))
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type)
+    )
 
     # Do batch gemm
-    ci = te.reduce_axis((0, CI), name='ci')
-    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
-                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
-                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),
-                              axis=[ci]), name='bgemm')
+    ci = te.reduce_axis((0, CI), name="ci")
+    bgemm = te.compute(
+        (alpha, alpha, P, CO),
+        lambda eps, nu, p, co: te.sum(
+            (Transdata[eps][nu][p][ci]).astype(out_dtype)
+            * (TransFilter[eps][nu][ci][co]).astype(out_dtype),
+            axis=[ci],
+        ),
+        name="bgemm",
+    )
 
     # Inverse transform
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_a')
-    inverse = te.compute((P, CO, m, m), lambda p, co, vh, vw:
-                         te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw],
-                                axis=[r_a, r_b]), name='inverse')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_a")
+    inverse = te.compute(
+        (P, CO, m, m),
+        lambda p, co, vh, vw: te.sum(
+            bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+        ),
+        name="inverse",
+    )
 
     # Output
-    output = te.compute((N, H, W, CO), lambda n, h, w, co:
-                        inverse[n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
-                                co,
-                                idxmod(h, m),
-                                idxmod(w, m)],
-                        name='output', tag='conv2d_nhwc_winograd')
+    output = te.compute(
+        (N, H, W, CO),
+        lambda n, h, w, co: inverse[
+            n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co, idxmod(h, m), idxmod(w, m)
+        ],
+        name="output",
+        tag="conv2d_nhwc_winograd",
+    )
     cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
     return output
 
@@ -395,8 +439,8 @@ def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
 def data_weight_transform(s, data_trans, input_tile, thread_num_trans, offset_trans, trans_tag):
     """Schedule for data or kernel transform"""
     kernel_align = thread_num_trans + offset_trans
-    indata_s = s.cache_read(input_tile, 'shared', [data_trans])
-    data_l = s.cache_write(data_trans, 'local')
+    indata_s = s.cache_read(input_tile, "shared", [data_trans])
+    data_l = s.cache_write(data_trans, "local")
     # Schedule for data or kernel transform
     eps, nu, p, c = s[data_trans].op.axis
 
@@ -421,8 +465,9 @@ def data_weight_transform(s, data_trans, input_tile, thread_num_trans, offset_tr
     s[indata_s].compute_at(s[data_l], block_x_l)
     if trans_tag == "data":
         p_is, c_is, eps_is, nu_is = s[indata_s].op.axis
-        data_align = get_const_int(eps_is.dom.extent) * \
-                         get_const_int(nu_is.dom.extent) + offset_trans
+        data_align = (
+            get_const_int(eps_is.dom.extent) * get_const_int(nu_is.dom.extent) + offset_trans
+        )
         s[indata_s].storage_align(c_is, data_align - 1, data_align)
         block_x_is, thread_x_is = s[indata_s].split(c_is, thread_num_trans)
         s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x"))
@@ -475,8 +520,9 @@ def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed):
     if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
         kernel, G = s[kernel_pack].op.input_tensors
         s[G].compute_inline()
-        data_weight_transform(s, kernel_pack, kernel, thread_num_kernel,
-                              offset_kernel, trans_tag="kernel")
+        data_weight_transform(
+            s, kernel_pack, kernel, thread_num_kernel, offset_kernel, trans_tag="kernel"
+        )
     else:
         kernel = kernel_pack
 
@@ -489,9 +535,11 @@ def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed):
     _, _, _, CO = get_const_tuple(TransFilter.shape)
 
     # Determine whether the shape is available with tensorcore
-    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+    shape_judge = (
+        (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
+        or (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
+        or (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+    )
 
     if shape_judge and use_tensorcore:
         schedule_bgemm_tensorcore(cfg, s, bgemm, Transdata, TransFilter)
@@ -503,11 +551,11 @@ def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed):
         OL = None
     else:
         OL = output
-        s[OL].set_scope('local')
+        s[OL].set_scope("local")
         output = s.outputs[0]
 
     s[A].compute_inline()
-    inverse_s = s.cache_read(bgemm, 'shared', [inverse])
+    inverse_s = s.cache_read(bgemm, "shared", [inverse])
 
     m = alpha - 3 + 1
     offset_inverse_in = offset_inverse
@@ -556,8 +604,17 @@ def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed):
 @autotvm.register_topi_compute("conv2d_nhwc_winograd_direct.cuda")
 def conv2d_nhwc_winograd_direct(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with winograd for NHWC layout"""
-    return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                              use_tensorcore=False, pre_computed=False)
+    return nhwc_winograd_cuda(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        use_tensorcore=False,
+        pre_computed=False,
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct.cuda")
@@ -566,9 +623,10 @@ def schedule_conv2d_nhwc_winograd_direct(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_winograd' in op.tag:
-            schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False,
-                                        pre_computed=False)
+        if "conv2d_nhwc_winograd" in op.tag:
+            schedule_nhwc_winograd_cuda(
+                cfg, s, op.output(0), use_tensorcore=False, pre_computed=False
+            )
 
     traverse_inline(s, outs[0].op, _callback)
     return s
@@ -577,8 +635,17 @@ def schedule_conv2d_nhwc_winograd_direct(cfg, outs):
 @autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore.cuda")
 def conv2d_nhwc_winograd_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with winograd for NHWC layout"""
-    return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                              use_tensorcore=True, pre_computed=False)
+    return nhwc_winograd_cuda(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        use_tensorcore=True,
+        pre_computed=False,
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore.cuda")
@@ -587,20 +654,31 @@ def schedule_conv2d_nhwc_winograd_tensorcore(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_winograd' in op.tag:
-            schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True,
-                                        pre_computed=False)
+        if "conv2d_nhwc_winograd" in op.tag:
+            schedule_nhwc_winograd_cuda(
+                cfg, s, op.output(0), use_tensorcore=True, pre_computed=False
+            )
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 @autotvm.register_topi_compute("conv2d_nhwc_winograd_direct_without_weight_transform.cuda")
-def conv2d_nhwc_winograd_direct_without_weight_transform(cfg, data, kernel, strides,
-                                                         padding, dilation, out_dtype):
+def conv2d_nhwc_winograd_direct_without_weight_transform(
+    cfg, data, kernel, strides, padding, dilation, out_dtype
+):
     """Compute conv2d with winograd for NHWC layout"""
-    return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                              use_tensorcore=False, pre_computed=True)
+    return nhwc_winograd_cuda(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        use_tensorcore=False,
+        pre_computed=True,
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct_without_weight_transform.cuda")
@@ -609,20 +687,31 @@ def schedule_conv2d_nhwc_winograd_direct_without_weight_transform(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_winograd' in op.tag:
-            schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False,
-                                        pre_computed=True)
+        if "conv2d_nhwc_winograd" in op.tag:
+            schedule_nhwc_winograd_cuda(
+                cfg, s, op.output(0), use_tensorcore=False, pre_computed=True
+            )
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 @autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda")
-def conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, data, kernel, strides,
-                                                             padding, dilation, out_dtype):
+def conv2d_nhwc_winograd_tensorcore_without_weight_transform(
+    cfg, data, kernel, strides, padding, dilation, out_dtype
+):
     """Compute conv2d with winograd for NHWC layout"""
-    return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                              use_tensorcore=True, pre_computed=True)
+    return nhwc_winograd_cuda(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        use_tensorcore=True,
+        pre_computed=True,
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda")
@@ -631,9 +720,10 @@ def schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, outs)
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_winograd' in op.tag:
-            schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True,
-                                        pre_computed=True)
+        if "conv2d_nhwc_winograd" in op.tag:
+            schedule_nhwc_winograd_cuda(
+                cfg, s, op.output(0), use_tensorcore=True, pre_computed=True
+            )
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index 7e41209..46ee685 100644 (file)
@@ -25,10 +25,8 @@ from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
 
-
 @autotvm.register_topi_compute("conv2d_transpose_nchw.cuda")
-def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype,
-                          output_padding):
+def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -60,54 +58,66 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype,
     assert outpad_height < stride_height and outpad_width < stride_width
     cfg.stride = stride
     pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
-        padding, (kernel_height, kernel_width))
+        padding, (kernel_height, kernel_width)
+    )
 
-    out_width = (inp_width - 1) * stride_width + \
-        kernel_width - pad_left - pad_right + outpad_width
+    out_width = (inp_width - 1) * stride_width + kernel_width - pad_left - pad_right + outpad_width
     pad_left = kernel_width - 1 - pad_left
     pad_right = kernel_width - 1 - pad_right + outpad_width
     dilated_width = stride_width * (inp_width - 1) + 1
 
-    out_height = (inp_height - 1) * stride_height + \
-        kernel_height - pad_top - pad_bottom + outpad_height
+    out_height = (
+        (inp_height - 1) * stride_height + kernel_height - pad_top - pad_bottom + outpad_height
+    )
     pad_top = kernel_height - 1 - pad_top
     pad_bottom = kernel_height - 1 - pad_bottom + outpad_height
     dilated_height = stride_height * (inp_height - 1) + 1
 
     # compute pad
     data = te.compute(
-        (batch, inp_channels,
-         pad_top + dilated_height + pad_bottom,
-         pad_left + dilated_width + pad_right),
+        (
+            batch,
+            inp_channels,
+            pad_top + dilated_height + pad_bottom,
+            pad_left + dilated_width + pad_right,
+        ),
         lambda n, c, y, x: tvm.tir.if_then_else(
-            tvm.tir.all(x >= pad_left,
-                        x < pad_left + dilated_width,
-                        tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
-                        y >= pad_top,
-                        y < pad_top + dilated_height,
-                        tvm.tir.indexmod(y - pad_top, stride_height).equal(0)),
-            data[n, c,
-                 tvm.tir.indexdiv(y - pad_top, stride_height),
-                 tvm.tir.indexdiv(x - pad_left, stride_width)],
-            tvm.tir.const(0., "float32")),
-        name='data_pad')
+            tvm.tir.all(
+                x >= pad_left,
+                x < pad_left + dilated_width,
+                tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
+                y >= pad_top,
+                y < pad_top + dilated_height,
+                tvm.tir.indexmod(y - pad_top, stride_height).equal(0),
+            ),
+            data[
+                n,
+                c,
+                tvm.tir.indexdiv(y - pad_top, stride_height),
+                tvm.tir.indexdiv(x - pad_left, stride_width),
+            ],
+            tvm.tir.const(0.0, "float32"),
+        ),
+        name="data_pad",
+    )
 
     # compute transposed conv
-    dc = te.reduce_axis((0, inp_channels), name='dc')
-    dh = te.reduce_axis((0, kernel_height), name='dh')
-    dw = te.reduce_axis((0, kernel_width), name='dw')
+    dc = te.reduce_axis((0, inp_channels), name="dc")
+    dh = te.reduce_axis((0, kernel_height), name="dh")
+    dw = te.reduce_axis((0, kernel_width), name="dw")
     data_out = te.compute(
         (batch, out_channels, out_height, out_width),
         lambda b, c, h, w: te.sum(
-            data[b, dc, h + dh, w + dw].astype(out_dtype) *
-            kernel[dc,
-                   c,
-                   kernel_height - 1 - dh,
-                   kernel_width - 1 - dw].astype(out_dtype),
-            axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
+            data[b, dc, h + dh, w + dw].astype(out_dtype)
+            * kernel[dc, c, kernel_height - 1 - dh, kernel_width - 1 - dw].astype(out_dtype),
+            axis=[dc, dh, dw],
+        ),
+        tag="conv2d_transpose_nchw",
+    )
 
     return data_out
 
+
 @autotvm.register_topi_schedule("conv2d_transpose_nchw.cuda")
 def schedule_conv2d_transpose_nchw(cfg, outs):
     """TOPI Schedule callback for conv2d transpose operator.
@@ -161,7 +171,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
         cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
 
     def _callback(op):
-        if op.tag == 'conv2d_transpose_nchw':
+        if op.tag == "conv2d_transpose_nchw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -177,7 +187,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -188,21 +198,21 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
 
             ##### space definition end #####
 
-            if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            s[pad_data].set_scope('shared')
+            s[pad_data].set_scope("shared")
             AA = pad_data
-            WW = s.cache_read(kernel, 'shared', [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
 
             # tile and bind spatial axes
             n, f, y, x = s[output].op.axis
@@ -221,7 +231,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             s[output].bind(vy, te.thread_axis("vthread"))
             s[output].bind(vx, te.thread_axis("vthread"))
 
-            cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
+            cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
 
             if cfg["fuse_yx"].val:
                 s[output].bind(tn, te.thread_axis("threadIdx.z"))
@@ -248,7 +258,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             # tile reduction axes
             n, f, y, x = s[OL].op.axis
             rc, ry, rx = s[OL].op.reduce_axis
-            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
+            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
             s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
 
             s[AA].compute_at(s[OL], rx)
@@ -265,8 +275,8 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
                 s[load].bind(ty, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(s, outs[0].op, _callback)
 
index f5259ba..69513d5 100644 (file)
@@ -27,7 +27,8 @@ from ..util import get_const_int, get_const_tuple, traverse_inline
 from ..nn.winograd_util import winograd_transform_matrices
 
 
-logger = logging.getLogger('conv2d_winograd')
+logger = logging.getLogger("conv2d_winograd")
+
 
 def _infer_tile_size(data, kernel):
     N, CI, H, W = get_const_tuple(data.shape)
@@ -36,8 +37,8 @@ def _infer_tile_size(data, kernel):
         return 4
     return 2
 
-def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                  pre_computed):
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
     """Compute declaration for winograd"""
     tile_size = _infer_tile_size(data, kernel)
 
@@ -49,7 +50,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
         dilation_h, dilation_w = dilation
     HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
 
-    if not pre_computed: # kernel tensor is raw tensor, do strict check
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
         if dilation_h != 1 or dilation_w != 1:
             kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
         CO, CI, KH, KW = get_const_tuple(kernel.shape)
@@ -71,55 +72,75 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
 
     H = (H + pt + pb - KH) // HSTR + 1
     W = (W + pl + pr - KW) // WSTR + 1
-    nH, nW = (H + m-1) // m, (W + m-1) // m
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
 
     # transform kernel
     if not pre_computed:
-        r_kh = te.reduce_axis((0, KH), name='r_kh')
-        r_kw = te.reduce_axis((0, KW), name='r_kw')
-        kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
-                                 te.sum(kernel[co][ci][r_kh][r_kw] *
-                                        G[eps][r_kh] * G[nu][r_kw],
-                                        axis=[r_kh, r_kw]), name='kernel_pack')
+        r_kh = te.reduce_axis((0, KH), name="r_kh")
+        r_kw = te.reduce_axis((0, KW), name="r_kw")
+        kernel_pack = te.compute(
+            (alpha, alpha, CI, CO),
+            lambda eps, nu, ci, co: te.sum(
+                kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+            ),
+            name="kernel_pack",
+        )
     else:
         kernel_pack = kernel
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
     # pack input tile
-    input_tile = te.compute((CI, P, alpha, alpha), lambda c, p, eps, nu:
-                            data_pad[idxdiv(p, (nH * nW))][c][idxmod(idxdiv(p, nW), nH) * m + eps]
-                            [idxmod(p, nW) * m + nu], name='d')
+    input_tile = te.compute(
+        (CI, P, alpha, alpha),
+        lambda c, p, eps, nu: data_pad[idxdiv(p, (nH * nW))][c][
+            idxmod(idxdiv(p, nW), nH) * m + eps
+        ][idxmod(p, nW) * m + nu],
+        name="d",
+    )
 
     # transform data
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_a')
-    data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p:
-                           te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
-                                  axis=[r_a, r_b]), name='data_pack')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_a")
+    data_pack = te.compute(
+        (alpha, alpha, CI, P),
+        lambda eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
+        ),
+        name="data_pack",
+    )
 
     # do batch gemm
-    ci = te.reduce_axis((0, CI), name='ci')
-    bgemm = te.compute((alpha, alpha, CO, P), lambda eps, nu, co, p:
-                       te.sum(kernel_pack[eps][nu][ci][co] *
-                              data_pack[eps][nu][ci][p],
-                              axis=[ci]), name='bgemm')
+    ci = te.reduce_axis((0, CI), name="ci")
+    bgemm = te.compute(
+        (alpha, alpha, CO, P),
+        lambda eps, nu, co, p: te.sum(
+            kernel_pack[eps][nu][ci][co] * data_pack[eps][nu][ci][p], axis=[ci]
+        ),
+        name="bgemm",
+    )
 
     # inverse transform
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_a')
-    inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw:
-                         te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw],
-                                axis=[r_a, r_b]), name='inverse')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_a")
+    inverse = te.compute(
+        (CO, P, m, m),
+        lambda co, p, vh, vw: te.sum(
+            bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+        ),
+        name="inverse",
+    )
 
     # output
-    output = te.compute((N, CO, H, W), lambda n, co, h, w:
-                        inverse[co,
-                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
-                                idxmod(h, m),
-                                idxmod(w, m)],
-                        name='output', tag='conv2d_nchw_winograd')
+    output = te.compute(
+        (N, CO, H, W),
+        lambda n, co, h, w: inverse[
+            co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
+        ],
+        name="output",
+        tag="conv2d_nchw_winograd",
+    )
     cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
 
     return output
@@ -137,7 +158,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     # data transform
     s[B].compute_inline()
 
-    data_l = s.cache_write(data_pack, 'local')
+    data_l = s.cache_write(data_pack, "local")
     eps, nu, c, p = s[data_l].op.axis
     r_a, r_b = s[data_l].op.reduce_axis
     for axis in [eps, nu, r_a, r_b]:
@@ -162,8 +183,8 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
         if autotvm.GLOBAL_SCOPE.in_tuning:
             # skip this part during tuning to make recrods accurate
             # this part will be pre-computed during pre-compute optimization pass
-            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
-            s[kernel_pack].pragma(eps, 'debug_skip_region')
+            s[G].pragma(s[G].op.axis[0], "debug_skip_region")
+            s[kernel_pack].pragma(eps, "debug_skip_region")
         else:
             s[G].compute_inline()
             r_a, r_b = s[kernel_pack].op.reduce_axis
@@ -186,14 +207,15 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     rc = s[bgemm].op.reduce_axis[0]
     alpha = get_const_int(b1.dom.extent)
 
-    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
-                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split(
+        "tile_b", cfg.axis(alpha * alpha), num_outputs=4, filter=lambda x: x.size[-3:] == [1, 1, 1]
+    )
     cfg.define_split("tile_y", y, num_outputs=4)
     cfg.define_split("tile_x", x, num_outputs=4)
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -203,9 +225,9 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     C = bgemm
     A0, B0 = kernel_pack, data_pack
 
-    OL = s.cache_write(C, 'local')
-    AA = s.cache_read(A0, 'shared', [OL])
-    BB = s.cache_read(B0, 'shared', [OL])
+    OL = s.cache_write(C, "local")
+    AA = s.cache_read(A0, "shared", [OL])
+    BB = s.cache_read(B0, "shared", [OL])
 
     b = s[bgemm].fuse(b1, b2)
 
@@ -229,8 +251,8 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     s[OL].compute_at(s[C], tx)
     b1, b2, y, x = s[OL].op.axis
     b = s[OL].fuse(b1, b2)
-    rc, = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    (rc,) = s[OL].op.reduce_axis
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
     s[OL].reorder(rco, rci, b, y, x)
 
     s[AA].compute_at(s[OL], rco)
@@ -246,15 +268,15 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
         s[load].bind(ty, te.thread_axis("threadIdx.y"))
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[C].pragma(bgemm_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[C].pragma(bgemm_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     # schedule inverse, output and fusion
     if output.op in s.outputs:
         OL = None
     else:
         OL = output
-        s[OL].set_scope('local')
+        s[OL].set_scope("local")
         output = s.outputs[0]
 
     m = alpha - 3 + 1
@@ -280,17 +302,20 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
 
     return s
 
+
 @autotvm.register_topi_compute("conv2d_nchw_winograd.cuda")
 def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
-    return winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                         pre_computed=False)
+    return winograd_cuda(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False
+    )
+
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd.cuda")
 def schedule_conv2d_nchw_winograd(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nchw_winograd' in op.tag:
+        if "conv2d_nchw_winograd" in op.tag:
             schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
 
     traverse_inline(s, outs[0].op, _callback)
@@ -298,10 +323,12 @@ def schedule_conv2d_nchw_winograd(cfg, outs):
 
 
 @autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.cuda")
-def conv2d_nchw_winograd_without_weight_transform(cfg, data, kernel, strides,
-                                                  padding, dilation, out_dtype):
-    return winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                         pre_computed=True)
+def conv2d_nchw_winograd_without_weight_transform(
+    cfg, data, kernel, strides, padding, dilation, out_dtype
+):
+    return winograd_cuda(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.cuda")
@@ -310,7 +337,7 @@ def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nchw_winograd' in op.tag:
+        if "conv2d_nchw_winograd" in op.tag:
             schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
 
     traverse_inline(s, outs[0].op, _callback)
index f244c65..98f351b 100644 (file)
@@ -26,7 +26,7 @@ from .conv3d_direct import schedule_direct_conv3d_cuda
 
 
 @autotvm.register_topi_compute("conv3d_ncdhw.cuda")
-def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     """Conv3D operator in NCDHW layout for cuda backend.
 
     Parameters
@@ -82,16 +82,15 @@ def schedule_conv3d_ncdhw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv3d_ncdhw':
-            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW",
-                                        "conv3d_ncdhw.cuda")
+        if op.tag == "conv3d_ncdhw":
+            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW", "conv3d_ncdhw.cuda")
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 @autotvm.register_topi_compute("conv3d_ndhwc.cuda")
-def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
     """Conv3d operator in NDHWC layout for cuda backend.
 
     Parameters
@@ -141,17 +140,17 @@ def schedule_conv3d_ndhwc(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv3d_ndhwc':
-            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NDHWC",
-                                        "conv3d_ndhwc.cuda")
+        if op.tag == "conv3d_ndhwc":
+            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NDHWC", "conv3d_ndhwc.cuda")
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 @autotvm.register_topi_compute("conv3d_cudnn.cuda")
-def conv3d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCDHW',
-                 out_dtype='float32'):
+def conv3d_cudnn(
+    cfg, data, kernel, strides, padding, dilation, layout="NCDHW", out_dtype="float32"
+):
     """Conv3D operator for cuda backend.
 
     Parameters
@@ -185,38 +184,52 @@ def conv3d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCDHW',
     output : tvm.te.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    if layout == 'NCDHW':
-        tensor_format = 0 # CUDNN_TENSOR_NCHW
+    if layout == "NCDHW":
+        tensor_format = 0  # CUDNN_TENSOR_NCHW
         N, _, D, H, W = get_const_tuple(data.shape)
-    elif layout == 'NDHWC':
-        tensor_format = 1 # CUDNN_TENSOR_NHWC
+    elif layout == "NDHWC":
+        tensor_format = 1  # CUDNN_TENSOR_NHWC
         N, D, H, W, _ = get_const_tuple(data.shape)
     else:
         raise ValueError("Unsupported layout %s in cudnn" % layout)
     CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
 
     # handle dilation
-    stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
-        else strides
+    stride_d, stride_h, stride_w = (
+        (strides, strides, strides) if isinstance(strides, int) else strides
+    )
     pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding
-    dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
-        isinstance(dilation, int) else dilation
+    dilation_d, dilation_h, dilation_w = (
+        (dilation, dilation, dilation) if isinstance(dilation, int) else dilation
+    )
 
     OD = (D + 2 * pad_d - KD) // stride_d + 1
     OH = (H + 2 * pad_h - KH) // stride_h + 1
     OW = (W + 2 * pad_w - KW) // stride_w + 1
-    cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) * \
-                 ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
-
-    return cudnn.conv_forward(data,
-                              kernel,
-                              [pad_d, pad_h, pad_w],
-                              [stride_d, stride_h, stride_w],
-                              [dilation_d, dilation_h, dilation_w],
-                              conv_mode=1,
-                              tensor_format=tensor_format,
-                              algo=-1,         # let CUDNN choose the best algo
-                              conv_dtype=dtype)
+    cfg.add_flop(
+        2
+        * N
+        * OD
+        * OH
+        * OW
+        * CO
+        * CI
+        * ((KD - 1) * dilation_d + 1)
+        * ((KH - 1) * dilation_h + 1)
+        * ((KW - 1) * dilation_w + 1)
+    )
+
+    return cudnn.conv_forward(
+        data,
+        kernel,
+        [pad_d, pad_h, pad_w],
+        [stride_d, stride_h, stride_w],
+        [dilation_d, dilation_h, dilation_w],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        algo=-1,  # let CUDNN choose the best algo
+        conv_dtype=dtype,
+    )
 
 
 @autotvm.register_topi_schedule("conv3d_cudnn.cuda")
index fbda456..2dfba50 100644 (file)
@@ -27,7 +27,8 @@ from .. import nn
 from ..util import get_const_tuple
 from .conv3d_winograd import _infer_tile_size
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
+
 
 @nn.conv3d_alter_layout.register(["cuda", "gpu"])
 def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
@@ -35,7 +36,8 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
     dispatch_ctx = autotvm.task.DispatchContext.current
 
     _, outs = relay.backend.compile_engine.select_implementation(
-        relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target)
+        relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target
+    )
     workload = autotvm.task.get_workload(outs)
     if workload is None:
         # The best implementation is not an AutoTVM template,
@@ -52,7 +54,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
     strides = attrs.get_int_tuple("strides")
     padding = attrs.get_int_tuple("padding")
     dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int('groups')
+    groups = attrs.get_int("groups")
     data_layout = attrs["data_layout"]
     kernel_layout = attrs["kernel_layout"]
     data, kernel = tinfos
@@ -71,8 +73,8 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
         tile_size = _infer_tile_size(tinfos[0], tinfos[1])
 
         weight = relay.nn.contrib_conv3d_winograd_weight_transform(inputs[1], tile_size=tile_size)
-        new_attrs['tile_size'] = tile_size
-        new_attrs['channels'] = CO
+        new_attrs["tile_size"] = tile_size
+        new_attrs["channels"] = CO
 
         # Store the same config for the altered operators (workload)
         new_data = data
@@ -80,16 +82,19 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
         if 2 < KD < 8 and KD == KH:
             new_weight = te.placeholder(
                 (KD + tile_size - 1, KH + tile_size - 1, KW + tile_size - 1, CO, CI),
-                dtype=kernel.dtype)
+                dtype=kernel.dtype,
+            )
         else:
             new_weight = te.placeholder(
-                (KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI),
-                dtype=kernel.dtype)
+                (KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI), dtype=kernel.dtype
+            )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_weight, strides, padding, dilation, out_dtype],
-            "conv3d_ncdhw_winograd_without_weight_transform.cuda")
+            "conv3d_ncdhw_winograd_without_weight_transform.cuda",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv3d_winograd_without_weight_transform(
-            inputs[0], weight, **new_attrs)
+            inputs[0], weight, **new_attrs
+        )
 
     return None
index e3dd6f9..aa13e6b 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 from tvm import autotvm
 from ..util import get_const_tuple
 
+
 def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     """schedule optimized for batch size = 1"""
 
@@ -43,35 +44,34 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
 
     # fallback support
     if cfg.is_fallback:
-        ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, workload_name)
+        ref_log = autotvm.tophub.load_reference_log(target.kind.name, target.model, workload_name)
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
 
     s[pad_data].compute_inline()
-    if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
         s[kernel].compute_inline()
 
     if conv.op in s.outputs:
         output = conv
-        OL = s.cache_write(conv, 'local')
+        OL = s.cache_write(conv, "local")
     else:
         output = s.outputs[0].output(0)
-        s[conv].set_scope('local')
+        s[conv].set_scope("local")
         OL = conv
 
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
+    AA = s.cache_read(pad_data, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
 
     # tile and bind spatial axes
     n, f, d, y, x = s[output].op.axis
@@ -100,10 +100,10 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     # tile reduction axes
     n, f, d, y, x = s[OL].op.axis
     rc, rd, ry, rx = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    rdo, rdi = cfg['tile_rd'].apply(s, OL, rd)
-    ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
+    rdo, rdi = cfg["tile_rd"].apply(s, OL, rd)
+    ryo, ryi = cfg["tile_ry"].apply(s, OL, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, OL, rx)
     s[OL].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x)
 
     s[AA].compute_at(s[OL], rxo)
@@ -122,8 +122,8 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     N, CO, OD, OH, OW = get_const_tuple(output.shape)
     _, KD, KH, KW, CI = get_const_tuple(kernel.shape)
index bc4f0e1..b253130 100644 (file)
@@ -47,18 +47,22 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dty
 
     batch, in_depth, in_height, in_width, in_channel = get_const_tuple(Input.shape)
     kernel_d, kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
-    assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \
-               (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \
-               (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \
-               "The shape of (batch, in_channel, num_filter) "\
-               "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+    assert (
+        (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0)
+        or (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0)
+        or (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0)
+    ), (
+        "The shape of (batch, in_channel, num_filter) "
+        "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+    )
 
     # compute the output shape
     dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
@@ -66,27 +70,33 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dty
     pad_before = [0, pad_front, pad_top, pad_left, 0]
     pad_after = [0, pad_back, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    rz = te.reduce_axis((0, kernel_d), name='rz')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    rz = te.reduce_axis((0, kernel_d), name="rz")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     # convert data type of input feature maps and weights
     TransPaddedInput = te.compute(
-        PaddedInput.shape,
-        lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype('float16'))
+        PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype("float16")
+    )
     TransFilter = te.compute(
-        Filter.shape, lambda d, h, w, i, o: Filter[d, h, w, i, o].astype('float16'))
+        Filter.shape, lambda d, h, w, i, o: Filter[d, h, w, i, o].astype("float16")
+    )
     Output = te.compute(
         (batch, out_depth, out_height, out_width, out_channel),
         lambda nn, zz, yy, xx, ff: te.sum(
-            TransPaddedInput[nn,
-                             zz * stride_d + rz * dilation_d,
-                             yy * stride_h + ry * dilation_h,
-                             xx * stride_w + rx * dilation_w,
-                             rc].astype(out_dtype) *
-            TransFilter[rz, ry, rx, rc, ff].astype(out_dtype),
-            axis=[rz, ry, rx, rc]),
-        name="Conv3dOutput", tag="conv3d_ndhwc_tensorcore")
+            TransPaddedInput[
+                nn,
+                zz * stride_d + rz * dilation_d,
+                yy * stride_h + ry * dilation_h,
+                xx * stride_w + rx * dilation_w,
+                rc,
+            ].astype(out_dtype)
+            * TransFilter[rz, ry, rx, rc, ff].astype(out_dtype),
+            axis=[rz, ry, rx, rc],
+        ),
+        name="Conv3dOutput",
+        tag="conv3d_ndhwc_tensorcore",
+    )
     return Output
 
 
@@ -106,19 +116,19 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     s[paddata[0]].compute_inline()
 
     # Designate the memory hierarchy
-    AS = s.cache_read(trans_paddata, 'shared', [Conv])
-    WS = s.cache_read(kernel, 'shared', [Conv])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
-    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
-    ConvF = s.cache_write(Conv, 'wmma.accumulator')
+    AS = s.cache_read(trans_paddata, "shared", [Conv])
+    WS = s.cache_read(kernel, "shared", [Conv])
+    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
+    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
+    ConvF = s.cache_write(Conv, "wmma.accumulator")
 
     if Conv.op in s.outputs:
         output = Conv
-        ConvS = s.cache_read(ConvF, 'shared', [Conv])
+        ConvS = s.cache_read(ConvF, "shared", [Conv])
         OL = ConvS
     else:
         output = s.outputs[0].output(0)
-        s[Conv].set_scope('shared')
+        s[Conv].set_scope("shared")
         OL = Conv
 
     # Schedule for autotvm
@@ -130,18 +140,19 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     cfg.define_knob("offset", [0, 8])
     cfg.define_knob("vector_width", [1, 2, 4, 8])
 
-    if (batch % 16 == 0 and out_channels % 16 == 0):
+    if batch % 16 == 0 and out_channels % 16 == 0:
         cfg.define_knob("wmma_m", [16, 8, 32])
-    elif (batch % 8 == 0 and out_channels % 32 == 0):
+    elif batch % 8 == 0 and out_channels % 32 == 0:
         cfg.define_knob("wmma_m", [8, 16, 32])
-    elif (batch % 32 == 0 and out_channels % 8 == 0):
+    elif batch % 32 == 0 and out_channels % 8 == 0:
         cfg.define_knob("wmma_m", [32, 16, 8])
 
     # fallback support
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
+            target.kind.name, target.model, "conv3d_ndhwc_tensorcore.cuda"
+        )
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
@@ -163,12 +174,12 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
 
     warp_size = 32
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
     # Define the intrin strides
     def get_strides(extents):
@@ -284,22 +295,38 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     CL_shape = (wmma_m, 1, 1, 1, wmma_n)
     CS_shape = (wmma_m, 1, 1, 1, wmma_n)
 
-    AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype)
-    WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype)
+    AL_gemm = te.placeholder(AL_shape, name="A", dtype=in_dtype)
+    WL_gemm = te.placeholder(WL_shape, name="B", dtype=in_dtype)
     k_gemm = te.reduce_axis((0, wmma_k), name="k")
-    CL_compute = te.compute(CL_shape, lambda ii, t0, t1, t2, jj:
-                            te.sum(AL_gemm[ii, t0, t1, t2, k_gemm].astype(out_dtype) * \
-                                   WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm),
-                            name='C')
-
-    s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
-                                                  "row_major", AS_shape, AL_shape, in_dtype))
-    s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
-                                                  "row_major", WS_shape, WL_shape, in_dtype))
-    s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
-                                                  shape, out_dtype, CL_shape, CS_shape))
-    s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
-                                             WL_strides, CL_strides, shape))
+    CL_compute = te.compute(
+        CL_shape,
+        lambda ii, t0, t1, t2, jj: te.sum(
+            AL_gemm[ii, t0, t1, t2, k_gemm].astype(out_dtype)
+            * WL_gemm[k_gemm, jj].astype(out_dtype),
+            axis=k_gemm,
+        ),
+        name="C",
+    )
+
+    s[AF].tensorize(
+        nn,
+        intrin_wmma_load_matrix_A(
+            AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype
+        ),
+    )
+    s[WF].tensorize(
+        ii,
+        intrin_wmma_load_matrix_W(
+            WL_strides, WS_strides, shape, "row_major", WS_shape, WL_shape, in_dtype
+        ),
+    )
+    s[OL].tensorize(
+        nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)
+    )
+    s[ConvF].tensorize(
+        nnf,
+        intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape),
+    )
 
     N, OD, OH, OW, CO = get_const_tuple(output.shape)
     KD, KH, KW, CI, _ = get_const_tuple(kernel.shape)
@@ -318,7 +345,7 @@ def schedule_conv3d_ndhwc_tensorcore(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv3d_ndhwc_tensorcore' in op.tag:
+        if "conv3d_ndhwc_tensorcore" in op.tag:
             schedule_ndhwc_tensorcore_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
index d6ca9bc..69c0e0f 100644 (file)
@@ -26,8 +26,7 @@ from .conv3d_direct import schedule_direct_conv3d_cuda
 
 
 @autotvm.register_topi_compute("conv3d_transpose_ncdhw.cuda")
-def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype,
-                           output_padding):
+def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype, output_padding):
     """Transposed 3D convolution ncdhw forward operator.
 
     Parameters
@@ -56,71 +55,86 @@ def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype,
     _, out_channels, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape)
     stride_depth, stride_height, stride_width = stride
     outpad_depth, outpad_height, outpad_width = output_padding
-    assert (outpad_height < stride_height and outpad_width < stride_width and
-            outpad_depth < stride_depth)
+    assert (
+        outpad_height < stride_height
+        and outpad_width < stride_width
+        and outpad_depth < stride_depth
+    )
     cfg.stride = stride
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = nn.get_pad_tuple3d(
-        padding, (kernel_depth, kernel_height, kernel_width))
+        padding, (kernel_depth, kernel_height, kernel_width)
+    )
 
-    out_depth = (inp_depth - 1) * stride_depth + \
-        kernel_depth - pad_front - pad_back + outpad_depth
+    out_depth = (inp_depth - 1) * stride_depth + kernel_depth - pad_front - pad_back + outpad_depth
     pad_front = kernel_depth - 1 - pad_front
     pad_back = kernel_depth - 1 - pad_back
     dilated_depth = stride_depth * (inp_depth - 1) + 1
 
-    out_width = (inp_width - 1) * stride_width + \
-        kernel_width - pad_left - pad_right + outpad_width
+    out_width = (inp_width - 1) * stride_width + kernel_width - pad_left - pad_right + outpad_width
     pad_left = kernel_width - 1 - pad_left
     pad_right = kernel_width - 1 - pad_right
     dilated_width = stride_width * (inp_width - 1) + 1
 
-    out_height = (inp_height - 1) * stride_height + \
-        kernel_height - pad_top - pad_bottom + outpad_height
+    out_height = (
+        (inp_height - 1) * stride_height + kernel_height - pad_top - pad_bottom + outpad_height
+    )
     pad_top = kernel_height - 1 - pad_top
     pad_bottom = kernel_height - 1 - pad_bottom
     dilated_height = stride_height * (inp_height - 1) + 1
 
     # compute pad
     data = te.compute(
-        (batch, inp_channels,
-         pad_front + dilated_depth + pad_back,
-         pad_top + dilated_height + pad_bottom,
-         pad_left + dilated_width + pad_right),
+        (
+            batch,
+            inp_channels,
+            pad_front + dilated_depth + pad_back,
+            pad_top + dilated_height + pad_bottom,
+            pad_left + dilated_width + pad_right,
+        ),
         lambda n, c, d, y, x: tvm.tir.if_then_else(
-            tvm.tir.all(x >= pad_left,
-                        x < pad_left + dilated_width,
-                        tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
-                        y >= pad_top,
-                        y < pad_top + dilated_height,
-                        tvm.tir.indexmod(y - pad_top, stride_height).equal(0),
-                        d >= pad_front,
-                        d < pad_front + dilated_depth,
-                        tvm.tir.indexmod(d - pad_front, stride_depth).equal(0)),
-            data[n, c,
-                 tvm.tir.indexdiv(d - pad_front, stride_depth),
-                 tvm.tir.indexdiv(y - pad_top, stride_height),
-                 tvm.tir.indexdiv(x - pad_left, stride_width)],
-            tvm.tir.const(0., "float32")),
-        name='data_pad')
+            tvm.tir.all(
+                x >= pad_left,
+                x < pad_left + dilated_width,
+                tvm.tir.indexmod(x - pad_left, stride_width).equal(0),
+                y >= pad_top,
+                y < pad_top + dilated_height,
+                tvm.tir.indexmod(y - pad_top, stride_height).equal(0),
+                d >= pad_front,
+                d < pad_front + dilated_depth,
+                tvm.tir.indexmod(d - pad_front, stride_depth).equal(0),
+            ),
+            data[
+                n,
+                c,
+                tvm.tir.indexdiv(d - pad_front, stride_depth),
+                tvm.tir.indexdiv(y - pad_top, stride_height),
+                tvm.tir.indexdiv(x - pad_left, stride_width),
+            ],
+            tvm.tir.const(0.0, "float32"),
+        ),
+        name="data_pad",
+    )
 
     # compute transposed conv
-    dc = te.reduce_axis((0, inp_channels), name='dc')
-    dd = te.reduce_axis((0, kernel_depth), name='dd')
-    dh = te.reduce_axis((0, kernel_height), name='dh')
-    dw = te.reduce_axis((0, kernel_width), name='dw')
+    dc = te.reduce_axis((0, inp_channels), name="dc")
+    dd = te.reduce_axis((0, kernel_depth), name="dd")
+    dh = te.reduce_axis((0, kernel_height), name="dh")
+    dw = te.reduce_axis((0, kernel_width), name="dw")
     data_out = te.compute(
         (batch, out_channels, out_depth, out_height, out_width),
         lambda b, c, d, h, w: te.sum(
-            data[b, dc, d + dd, h + dh, w + dw].astype(out_dtype) *
-            kernel[dc,
-                   c,
-                   kernel_depth - 1 - dd,
-                   kernel_height - 1 - dh,
-                   kernel_width - 1 - dw].astype(out_dtype),
-            axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw")
+            data[b, dc, d + dd, h + dh, w + dw].astype(out_dtype)
+            * kernel[
+                dc, c, kernel_depth - 1 - dd, kernel_height - 1 - dh, kernel_width - 1 - dw
+            ].astype(out_dtype),
+            axis=[dc, dd, dh, dw],
+        ),
+        tag="conv3d_transpose_ncdhw",
+    )
 
     return data_out
 
+
 @autotvm.register_topi_schedule("conv3d_transpose_ncdhw.cuda")
 def schedule_conv3d_transpose_ncdhw(cfg, outs):
     """TOPI Schedule callback for conv3d transpose operator.
@@ -143,9 +157,10 @@ def schedule_conv3d_transpose_ncdhw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'conv3d_transpose_ncdhw':
-            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW",
-                                        "conv3d_transpose_ncdhw.cuda")
+        if op.tag == "conv3d_transpose_ncdhw":
+            schedule_direct_conv3d_cuda(
+                cfg, s, op.output(0), "NCDHW", "conv3d_transpose_ncdhw.cuda"
+            )
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index 3e6b1c1..7f4f139 100644 (file)
@@ -26,7 +26,7 @@ from .. import nn
 from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
 from ..nn.winograd_util import winograd_transform_matrices
 
-logger = logging.getLogger('conv3d_winograd')
+logger = logging.getLogger("conv3d_winograd")
 
 
 def _infer_tile_size(data, kernel):
@@ -60,8 +60,14 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
         # dilation is not supported
         alpha, _, _, CO, CI = get_const_tuple(kernel.shape)
         KD = KH = KW = alpha + 1 - tile_size
-        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
-               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+        assert (
+            DSTR == 1
+            and HSTR == 1
+            and WSTR == 1
+            and dilation_d == 1
+            and dilation_h == 1
+            and dilation_w == 1
+        )
 
     pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
     data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
@@ -81,78 +87,91 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
         # Check if we are currently tuning, if so we want to avoid counting
         # prepacking in time costs. Just use a placeholder with the packed shape instead.
         if autotvm.GLOBAL_SCOPE.in_tuning:
-            kernel_pack = te.placeholder((alpha, alpha, alpha, CO, CI),
-                                         dtype=kernel.dtype,
-                                         name='kernel_pack')
+            kernel_pack = te.placeholder(
+                (alpha, alpha, alpha, CO, CI), dtype=kernel.dtype, name="kernel_pack"
+            )
         else:
-            r_kd = te.reduce_axis((0, KD), name='r_kd')
-            r_kh = te.reduce_axis((0, KH), name='r_kh')
-            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            r_kd = te.reduce_axis((0, KD), name="r_kd")
+            r_kh = te.reduce_axis((0, KH), name="r_kh")
+            r_kw = te.reduce_axis((0, KW), name="r_kw")
             kernel_pack = te.compute(
                 (alpha, alpha, alpha, CO, CI),
                 lambda omg, eps, nu, co, ci: te.sum(
                     kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
-                    axis=[r_kd, r_kh, r_kw]),
-                name='kernel_pack')
+                    axis=[r_kd, r_kh, r_kw],
+                ),
+                name="kernel_pack",
+            )
     else:
         kernel_pack = kernel
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
     # pack input tile
-    input_tile = te.compute((CI, P, alpha, alpha, alpha),
-                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
-                            [c]
-                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
-                            [idxmod(idxdiv(p, nW), nH) * m + eps]
-                            [idxmod(p, nW) * m + nu],
-                            name='d')
+    input_tile = te.compute(
+        (CI, P, alpha, alpha, alpha),
+        lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))][c][
+            idxmod(idxdiv(p, nH * nW), nD) * m + omg
+        ][idxmod(idxdiv(p, nW), nH) * m + eps][idxmod(p, nW) * m + nu],
+        name="d",
+    )
 
     # transform data
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    r_c = te.reduce_axis((0, alpha), 'r_c')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    r_c = te.reduce_axis((0, alpha), "r_c")
     data_pack = te.compute(
         (alpha, alpha, alpha, CI, P),
         lambda omg, eps, nu, ci, p: te.sum(
             input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
-            axis=[r_a, r_b, r_c]),
-        name='data_pack')
+            axis=[r_a, r_b, r_c],
+        ),
+        name="data_pack",
+    )
 
     # do batch gemm
-    ci = te.reduce_axis((0, CI), name='ci')
+    ci = te.reduce_axis((0, CI), name="ci")
     bgemm = te.compute(
         (alpha, alpha, alpha, CO, P),
         lambda omg, eps, nu, co, p: te.sum(
-            kernel_pack[omg][eps][nu][co][ci] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
-        name='bgemm')
+            kernel_pack[omg][eps][nu][co][ci] * data_pack[omg][eps][nu][ci][p], axis=[ci]
+        ),
+        name="bgemm",
+    )
 
     # inverse transform
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    r_c = te.reduce_axis((0, alpha), 'r_c')
-    inverse = te.compute((CO, P, m, m, m),
-                         lambda co, p, vd, vh, vw: te.sum(
-                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
-                             axis=[r_a, r_b, r_c]),
-                         name='inverse')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    r_c = te.reduce_axis((0, alpha), "r_c")
+    inverse = te.compute(
+        (CO, P, m, m, m),
+        lambda co, p, vd, vh, vw: te.sum(
+            bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw], axis=[r_a, r_b, r_c]
+        ),
+        name="inverse",
+    )
 
     # output
-    output = te.compute((N, CO, D, H, W),
-                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
-                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
-                                                       idxmod(d, m),
-                                                       idxmod(h, m),
-                                                       idxmod(w, m)],
-                        name='output',
-                        tag='conv3d_ncdhw_winograd')
+    output = te.compute(
+        (N, CO, D, H, W),
+        lambda n, co, d, h, w: inverse[
+            co,
+            n * nD * nH * nW + idxdiv(d, m) * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+            idxmod(d, m),
+            idxmod(h, m),
+            idxmod(w, m),
+        ],
+        name="output",
+        tag="conv3d_ncdhw_winograd",
+    )
     cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
 
     return output
 
 
-def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                                pre_computed):
+def winograd_without_depth_cuda(
+    cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed
+):
     """Compute declaration for winograd without transforming depth"""
     tile_size = _infer_tile_size(data, kernel)
 
@@ -188,7 +207,7 @@ def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, o
 
     H = (H + pt + pd - KH) // HSTR + 1
     W = (W + pl + pr - KW) // WSTR + 1
-    nH, nW = (H + m-1) // m, (W + m-1) // m
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
 
     # transform kernel
@@ -196,58 +215,76 @@ def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, o
         # During autotuning dont count kernel packing as a time cost
         # as it will later be removed via alter_op_layout.
         if autotvm.GLOBAL_SCOPE.in_tuning:
-            kernel_pack = te.placeholder((alpha, alpha, KD, CO, CI),
-                                         dtype=kernel.dtype,
-                                         name='kernel_pack')
+            kernel_pack = te.placeholder(
+                (alpha, alpha, KD, CO, CI), dtype=kernel.dtype, name="kernel_pack"
+            )
         else:
-            r_kh = te.reduce_axis((0, KH), name='r_kh')
-            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            r_kh = te.reduce_axis((0, KH), name="r_kh")
+            r_kw = te.reduce_axis((0, KW), name="r_kw")
             kernel_pack = te.compute(
                 (alpha, alpha, KD, CO, CI),
                 lambda eps, nu, d, co, ci: te.sum(
-                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
-                name='kernel_pack')
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+                ),
+                name="kernel_pack",
+            )
     else:
         kernel_pack = kernel
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
     # pack input tile
-    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
-                            data_pad[idxdiv(p, (nH * nW))][c][d]
-                            [idxmod(idxdiv(p, nW), nH) * m + eps]
-                            [idxmod(p, nW) * m + nu], name='d')
+    input_tile = te.compute(
+        (CI, D, P, alpha, alpha),
+        lambda c, d, p, eps, nu: data_pad[idxdiv(p, (nH * nW))][c][d][
+            idxmod(idxdiv(p, nW), nH) * m + eps
+        ][idxmod(p, nW) * m + nu],
+        name="d",
+    )
 
     # transform data
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
-                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
-                                  axis=[r_a, r_b]), name='data_pack')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    data_pack = te.compute(
+        (alpha, alpha, CI, D, P),
+        lambda eps, nu, ci, d, p: te.sum(
+            input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
+        ),
+        name="data_pack",
+    )
 
     # do batch gemm
-    ci = te.reduce_axis((0, CI), name='ci')
-    rz = te.reduce_axis((0, KD), name='rz')
-    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
-                       te.sum(kernel_pack[eps][nu][rz][co][ci] *
-                              data_pack[eps][nu][ci][d * DSTR + rz][p],
-                              axis=[ci, rz]), name='bgemm')
+    ci = te.reduce_axis((0, CI), name="ci")
+    rz = te.reduce_axis((0, KD), name="rz")
+    bgemm = te.compute(
+        (alpha, alpha, CO, out_depth, P),
+        lambda eps, nu, co, d, p: te.sum(
+            kernel_pack[eps][nu][rz][co][ci] * data_pack[eps][nu][ci][d * DSTR + rz][p],
+            axis=[ci, rz],
+        ),
+        name="bgemm",
+    )
 
     # inverse transform
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
-                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
-                                axis=[r_a, r_b]), name='inverse')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    inverse = te.compute(
+        (CO, out_depth, P, m, m),
+        lambda co, d, p, vh, vw: te.sum(
+            bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+        ),
+        name="inverse",
+    )
 
     # output
-    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
-                        inverse[co,
-                                d,
-                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
-                                idxmod(h, m),
-                                idxmod(w, m)],
-                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    output = te.compute(
+        (N, CO, out_depth, H, W),
+        lambda n, co, d, h, w: inverse[
+            co, d, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
+        ],
+        name="output",
+        tag="conv3d_ncdhw_winograd_without_depth",
+    )
     cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
 
     return output
@@ -265,7 +302,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     # data transform
     s[B].compute_inline()
 
-    data_l = s.cache_write(data_pack, 'local')
+    data_l = s.cache_write(data_pack, "local")
     omg, eps, nu, c, p = s[data_l].op.axis
     r_a, r_b, r_c = s[data_l].op.reduce_axis
     # TODO unrolling by omg, eps, nu may improve performance but
@@ -315,13 +352,14 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
         "tile_b",
         cfg.axis(alpha * alpha * alpha),
         num_outputs=4,
-        filter=lambda x: x.size[-3:] == [1, 1, 1])
+        filter=lambda x: x.size[-3:] == [1, 1, 1],
+    )
     cfg.define_split("tile_y", y, num_outputs=4)
     cfg.define_split("tile_x", x, num_outputs=4)
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -331,9 +369,9 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     C = bgemm
     A0, B0 = kernel_pack, data_pack
 
-    OL = s.cache_write(C, 'local')
-    AA = s.cache_read(A0, 'shared', [OL])
-    BB = s.cache_read(B0, 'shared', [OL])
+    OL = s.cache_write(C, "local")
+    AA = s.cache_read(A0, "shared", [OL])
+    BB = s.cache_read(B0, "shared", [OL])
 
     b = s[bgemm].fuse(b1, b2, b3)
 
@@ -357,8 +395,8 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     s[OL].compute_at(s[C], tx)
     b1, b2, b3, y, x = s[OL].op.axis
     b = s[OL].fuse(b1, b2, b3)
-    rc, = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    (rc,) = s[OL].op.reduce_axis
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
     s[OL].reorder(rco, rci, b, y, x)
 
     s[AA].compute_at(s[OL], rco)
@@ -374,15 +412,15 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
         s[load].bind(ty, te.thread_axis("threadIdx.y"))
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[C].pragma(bgemm_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[C].pragma(bgemm_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     # schedule inverse, output and fusion
     if output.op in s.outputs:
         OL = None
     else:
         OL = output
-        s[OL].set_scope('local')
+        s[OL].set_scope("local")
         output = s.outputs[0]
 
     m = alpha - 3 + 1
@@ -425,7 +463,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     # data transform
     s[B].compute_inline()
 
-    data_l = s.cache_write(data_pack, 'local')
+    data_l = s.cache_write(data_pack, "local")
     eps, nu, c, d, p = s[data_l].op.axis
     r_a, r_b = s[data_l].op.reduce_axis
     for axis in [eps, nu, r_a, r_b]:
@@ -470,15 +508,16 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     rz = s[bgemm].op.reduce_axis[1]
     alpha = get_const_int(b1.dom.extent)
 
-    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
-                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split(
+        "tile_b", cfg.axis(alpha * alpha), num_outputs=4, filter=lambda x: x.size[-3:] == [1, 1, 1]
+    )
     cfg.define_split("tile_y", y, num_outputs=4)
     cfg.define_split("tile_x", x, num_outputs=4)
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_split("tile_rz", rz, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -488,9 +527,9 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     C = bgemm
     A0, B0 = kernel_pack, data_pack
 
-    OL = s.cache_write(C, 'local')
-    AA = s.cache_read(A0, 'shared', [OL])
-    BB = s.cache_read(B0, 'shared', [OL])
+    OL = s.cache_write(C, "local")
+    AA = s.cache_read(A0, "shared", [OL])
+    BB = s.cache_read(B0, "shared", [OL])
 
     b = s[bgemm].fuse(b1, b2)
     # Allow two different tiling strategies as both seem
@@ -500,7 +539,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     bgemm_scope, b = s[bgemm].split(b, nparts=1)
     bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
     by, vy, ty, yi = cfg["tile_y"].apply(s, C, z)
-    if cfg['unroll_axis'].val:
+    if cfg["unroll_axis"].val:
         bx, vx, tx, xi = cfg["tile_x"].apply(s, C, y)
     else:
         bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
@@ -514,7 +553,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     s[C].bind(ty, te.thread_axis("threadIdx.y"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
     s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
-    if cfg['unroll_axis'].val:
+    if cfg["unroll_axis"].val:
         s[C].unroll(x)
     else:
         s[C].unroll(y)
@@ -525,8 +564,8 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     y = s[OL].fuse(y1, y2)
     b = s[OL].fuse(b1, b2)
     rc, rz = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    rzo, rzi = cfg['tile_rz'].apply(s, OL, rz)
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
+    rzo, rzi = cfg["tile_rz"].apply(s, OL, rz)
     s[OL].reorder(rco, rzo, rci, rzi, b, y, x)
 
     s[AA].compute_at(s[OL], rco)
@@ -542,15 +581,15 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
         s[load].bind(ty, te.thread_axis("threadIdx.y"))
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[C].pragma(bgemm_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[C].pragma(bgemm_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     # schedule inverse, output and fusion
     if output.op in s.outputs:
         OL = None
     else:
         OL = output
-        s[OL].set_scope('local')
+        s[OL].set_scope("local")
         output = s.outputs[0]
 
     m = alpha - 3 + 1
@@ -586,10 +625,12 @@ def conv3d_ncdhw_winograd(cfg, data, kernel, strides, padding, dilation, out_dty
     # Check if we can transform depth.
     if 2 < KD < 8 and KD == KH:
         return winograd_cuda(
-            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
+            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False
+        )
 
     return winograd_without_depth_cuda(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False
+    )
 
 
 @autotvm.register_topi_schedule("conv3d_ncdhw_winograd.cuda")
@@ -598,9 +639,9 @@ def schedule_conv3d_ncdhw_winograd(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
+        if "conv3d_ncdhw_winograd_without_depth" in op.tag:
             schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=False)
-        elif 'conv3d_ncdhw_winograd' in op.tag:
+        elif "conv3d_ncdhw_winograd" in op.tag:
             schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
 
     traverse_inline(s, outs[0].op, _callback)
@@ -608,16 +649,20 @@ def schedule_conv3d_ncdhw_winograd(cfg, outs):
 
 
 @autotvm.register_topi_compute("conv3d_ncdhw_winograd_without_weight_transform.cuda")
-def conv3d_ncdhw_winograd_without_weight_transform(cfg, data, kernel, strides, padding, dilation,
-                                                   out_dtype):
+def conv3d_ncdhw_winograd_without_weight_transform(
+    cfg, data, kernel, strides, padding, dilation, out_dtype
+):
+    """Conv3d NCDHW winograd without weight transform."""
     A, B, C, _, _ = get_const_tuple(kernel.shape)
     # Check if we can transform depth.
     if A == B == C:
         return winograd_cuda(
-            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
+            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True
+        )
 
     return winograd_without_depth_cuda(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True
+    )
 
 
 @autotvm.register_topi_schedule("conv3d_ncdhw_winograd_without_weight_transform.cuda")
@@ -626,9 +671,9 @@ def schedule_conv3d_ncdhw_winograd_without_weight_transform(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
+        if "conv3d_ncdhw_winograd_without_depth" in op.tag:
             schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=True)
-        elif 'conv3d_ncdhw_winograd' in op.tag:
+        elif "conv3d_ncdhw_winograd" in op.tag:
             schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
 
     traverse_inline(s, outs[0].op, _callback)
index dbaabb7..12f5644 100644 (file)
@@ -24,8 +24,9 @@ from ..util import traverse_inline
 
 
 @autotvm.register_topi_compute("correlation_nchw.cuda")
-def correlation_nchw(cfg, data1, data2, kernel_size, max_displacement, stride1, stride2, padding,
-                     is_multiply):
+def correlation_nchw(
+    cfg, data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply
+):
     """Correlation operator in NCHW layout.
 
     Parameters
@@ -62,8 +63,9 @@ def correlation_nchw(cfg, data1, data2, kernel_size, max_displacement, stride1,
         4-D with shape [batch, out_channel, out_height, out_width]
     """
     # pylint: disable=unused-argument
-    return nn.correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2,
-                               padding, is_multiply)
+    return nn.correlation_nchw(
+        data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply
+    )
 
 
 def _schedule_correlation_nchw(cfg, s, correlation):
@@ -81,7 +83,7 @@ def _schedule_correlation_nchw(cfg, s, correlation):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -93,9 +95,9 @@ def _schedule_correlation_nchw(cfg, s, correlation):
     s[padded_data2].compute_inline()
 
     # create cache stage
-    s[correlation].set_scope('local')
-    AA = s.cache_read(padded_data1, 'shared', [correlation])
-    BB = s.cache_read(padded_data2, 'shared', [correlation])
+    s[correlation].set_scope("local")
+    AA = s.cache_read(padded_data1, "shared", [correlation])
+    BB = s.cache_read(padded_data2, "shared", [correlation])
 
     output = s.outputs[0].output(0)
 
@@ -123,9 +125,9 @@ def _schedule_correlation_nchw(cfg, s, correlation):
     # tile reduction axes
     n, f, y, x = s[correlation].op.axis
     rc, ry, rx = s[correlation].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, correlation, rc)
-    ryo, ryi = cfg['tile_ry'].apply(s, correlation, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, correlation, rx)
+    rco, rci = cfg["tile_rc"].apply(s, correlation, rc)
+    ryo, ryi = cfg["tile_ry"].apply(s, correlation, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, correlation, rx)
     s[correlation].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
 
     s[AA].compute_at(s[correlation], rxo)
@@ -143,10 +145,8 @@ def _schedule_correlation_nchw(cfg, s, correlation):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                     cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit',
-                     cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
 
 @autotvm.register_topi_schedule("correlation_nchw.cuda")
@@ -171,7 +171,7 @@ def schedule_correlation_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'correlation_nchw':
+        if op.tag == "correlation_nchw":
             _schedule_correlation_nchw(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
index d97d501..365fde5 100644 (file)
@@ -24,10 +24,14 @@ from ..util import traverse_inline
 
 
 @autotvm.register_topi_compute("deformable_conv2d_nchw.cuda")
-def deformable_conv2d_nchw(cfg, data, offset, kernel, strides, padding, dilation,
-                           deformable_groups, groups, out_dtype):
-    return nn.deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation,
-                                     deformable_groups, groups, out_dtype)
+def deformable_conv2d_nchw(
+    cfg, data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
+):
+    """Deformable Conv2d."""
+    return nn.deformable_conv2d_nchw(
+        data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
+    )
+
 
 @autotvm.register_topi_schedule("deformable_conv2d_nchw.cuda")
 def schedule_deformable_conv2d_nchw(cfg, outs):
@@ -51,7 +55,7 @@ def schedule_deformable_conv2d_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'deformable_conv2d_nchw':
+        if op.tag == "deformable_conv2d_nchw":
             _schedule_direct_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
@@ -71,7 +75,7 @@ def _schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -79,20 +83,20 @@ def _schedule_direct_cuda(cfg, s, conv):
     data_deform, kernel = s[conv].op.input_tensors
 
     s[data_deform].compute_inline()
-    if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
         s[kernel].compute_inline()
 
     if conv.op in s.outputs:
         output = conv
-        OL = s.cache_write(conv, 'local')
+        OL = s.cache_write(conv, "local")
     else:
         output = s.outputs[0].output(0)
-        s[conv].set_scope('local')
+        s[conv].set_scope("local")
         OL = conv
 
     # create cache stage
-    AA = s.cache_read(data_deform, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
+    AA = s.cache_read(data_deform, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
 
     # tile and bind spatial axes
     n, f, y, x = s[output].op.axis
@@ -118,9 +122,9 @@ def _schedule_direct_cuda(cfg, s, conv):
     # tile reduction axes
     n, f, y, x = s[OL].op.axis
     rc, ry, rx = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
+    ryo, ryi = cfg["tile_ry"].apply(s, OL, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, OL, rx)
     s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
     cfg.define_reorder("reorder_inner", [rco, ryo, rxo], "all")
     cfg["reorder_inner"].apply(s, OL, [rco, ryo, rxo])
@@ -141,5 +145,5 @@ def _schedule_direct_cuda(cfg, s, conv):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
index f5b6563..727992d 100644 (file)
@@ -27,13 +27,13 @@ from .. import tag
 from .. import generic
 from ..util import traverse_inline, get_const_tuple
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
+
 
 @autotvm.register_topi_compute("dense_cublas.cuda")
 def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator on CUDA with CUBLAS"""
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim dense"
+    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
@@ -44,9 +44,9 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
     matmul = cublas.matmul(data, weight, False, True)
     cfg.add_flop(batch * in_dim * out_dim * 2)
     if bias is not None:
-        matmul = te.compute((batch, out_dim),
-                            lambda i, j: matmul[i, j] + bias[j],
-                            tag=tag.BROADCAST)
+        matmul = te.compute(
+            (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
+        )
     return matmul
 
 
@@ -69,20 +69,21 @@ def schedule_dense_small_batch(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense':
+        if op.tag == "dense":
             _schedule_dense_small_batch(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def _schedule_dense_small_batch(cfg, s, C):
     A, _ = C.op.input_tensors
     _, in_dim = get_const_tuple(A.shape)
-    cfg.define_split('tile_k', in_dim, num_outputs=2)
+    cfg.define_split("tile_k", in_dim, num_outputs=2)
     if cfg.is_fallback:
         cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
 
-    _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
+    _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
     CF = s.rfactor(C, kf)
 
     if C.op in s.outputs:
@@ -114,7 +115,7 @@ def schedule_dense_large_batch(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense':
+        if op.tag == "dense":
             _schedule_dense_large_batch(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
@@ -131,39 +132,50 @@ def _schedule_dense_large_batch(cfg, s, C):
     # create tuning space
     try:
         block_cand = [64, 128]
-        vthread_cand = [2**x for x in range(1, 7)]
-        n_thread_cand = [2**x for x in range(3, 7)]
-        cfg.define_split('tile_x', batch, num_outputs=4,
-                         filter=lambda x: (x.size[1] in vthread_cand and
-                                           x.size[2] in n_thread_cand and
-                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
-        cfg.define_split('tile_y', out_dim, num_outputs=4,
-                         filter=lambda x: (x.size[1] in vthread_cand and
-                                           x.size[2] in n_thread_cand and
-                                           (x.size[1] * x.size[2] * x.size[3]) in block_cand))
-        cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
+        vthread_cand = [2 ** x for x in range(1, 7)]
+        n_thread_cand = [2 ** x for x in range(3, 7)]
+        cfg.define_split(
+            "tile_x",
+            batch,
+            num_outputs=4,
+            filter=lambda x: (
+                x.size[1] in vthread_cand
+                and x.size[2] in n_thread_cand
+                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
+            ),
+        )
+        cfg.define_split(
+            "tile_y",
+            out_dim,
+            num_outputs=4,
+            filter=lambda x: (
+                x.size[1] in vthread_cand
+                and x.size[2] in n_thread_cand
+                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
+            ),
+        )
+        cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
     except IndexError:
         # Index error happens when no entities left after filtering, which was designed
         # to prune tuning space for better search efficiency.
-        logger.debug(
-            'Tuning space was created without pruning due to unfit shapes')
-        cfg.define_split('tile_x', batch, num_outputs=4)
-        cfg.define_split('tile_y', out_dim, num_outputs=4)
-        cfg.define_split('tile_k', in_dim, num_outputs=3)
+        logger.debug("Tuning space was created without pruning due to unfit shapes")
+        cfg.define_split("tile_x", batch, num_outputs=4)
+        cfg.define_split("tile_y", out_dim, num_outputs=4)
+        cfg.define_split("tile_k", in_dim, num_outputs=3)
 
     if cfg.is_fallback:
         if batch > 1:
-            cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
+            cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
         else:
-            cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
+            cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
         if out_dim > 1:
-            cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
+            cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
         else:
-            cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
+            cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
         if in_dim > 8:
-            cfg['tile_k'] = SplitEntity([-1, 8, 1])
+            cfg["tile_k"] = SplitEntity([-1, 8, 1])
         else:
-            cfg['tile_k'] = SplitEntity([-1, 1, 1])
+            cfg["tile_k"] = SplitEntity([-1, 1, 1])
 
     # Explicit memory access
     AA = s.cache_read(A, "shared", [C])
@@ -178,8 +190,8 @@ def _schedule_dense_large_batch(cfg, s, C):
         C = s.outputs[0].output(0)
 
     # Split and reorder computation
-    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
-    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
+    bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
+    by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
     s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
     s[CC].compute_at(s[C], tx)
 
@@ -193,7 +205,7 @@ def _schedule_dense_large_batch(cfg, s, C):
 
     # Split reduction
     yo, xo = CC.op.axis
-    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
+    ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
     s[CC].reorder(ko, kt, ki, yo, xo)
     s[AA].compute_at(s[CC], ko)
     s[BB].compute_at(s[CC], ko)
@@ -202,7 +214,7 @@ def _schedule_dense_large_batch(cfg, s, C):
     s[BL].compute_at(s[CC], kt)
 
     # Schedule for A's shared memory load
-    num_thread_x = cfg['tile_x'].size[2]
+    num_thread_x = cfg["tile_x"].size[2]
     ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
     _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
     tx, xi = s[AA].split(xi, nparts=num_thread_x)
@@ -211,7 +223,7 @@ def _schedule_dense_large_batch(cfg, s, C):
     s[AA].double_buffer()
 
     # Schedule for B' shared memory load
-    num_thread_y = cfg['tile_y'].size[2]
+    num_thread_y = cfg["tile_y"].size[2]
     ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
     _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
     tx, xi = s[BB].split(xi, nparts=num_thread_y)
@@ -228,19 +240,24 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
 
     batch, in_dim = get_const_tuple(data.shape)
     out_dim, _ = get_const_tuple(weight.shape)
-    k = te.reduce_axis((0, in_dim), name='k')
+    k = te.reduce_axis((0, in_dim), name="k")
 
-    matmul = te.compute((batch, out_dim),
-                        lambda i, j: te.sum(data[i, k].astype(out_dtype) *
-                                            weight[j, k].astype(out_dtype), axis=[k]),
-                        tag="dense_int8")
+    matmul = te.compute(
+        (batch, out_dim),
+        lambda i, j: te.sum(
+            data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=[k]
+        ),
+        tag="dense_int8",
+    )
 
     cfg.add_flop(batch * in_dim * out_dim * 2)
 
     if bias is not None:
-        matmul = te.compute((batch, out_dim),
-                            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
-                            tag=tag.BROADCAST)
+        matmul = te.compute(
+            (batch, out_dim),
+            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+        )
         cfg.add_flop(batch * out_dim)
 
     return matmul
@@ -255,11 +272,13 @@ def schedule_dense_int8(cfg, outs):
     def _callback(op):
         if "dense_int8" in op.tag:
             _schedule_dense_int8(cfg, s, op.output(0))
+
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-_dp4a = dp4a('shared', 'shared', 'local')
+_dp4a = dp4a("shared", "shared", "local")
+
 
 def _schedule_dense_int8(cfg, s, output):
     data, weight = s[output].op.input_tensors
@@ -276,12 +295,12 @@ def _schedule_dense_int8(cfg, s, output):
     cfg.define_split("tile_y", batch, num_outputs=4)
     cfg.define_split("tile_x", out_dim, num_outputs=4)
     cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2)
-    cfg.define_knob('auto_unroll_max_step', [0, 512, 1500])
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     # create cache stage
-    AA = s.cache_read(data, 'shared', [output])
-    WW = s.cache_read(weight, 'shared', [output])
-    CC = s.cache_write(output, 'local')
+    AA = s.cache_read(data, "shared", [output])
+    WW = s.cache_read(weight, "shared", [output])
+    CC = s.cache_write(output, "local")
 
     # handle bias
     if output.op not in s.outputs:
@@ -295,20 +314,20 @@ def _schedule_dense_int8(cfg, s, output):
 
     ko = CC.op.reduce_axis[0]
     ko, ki = s[CC].split(ko, factor=4)
-    ko, kt = cfg['tile_k'].apply(s, CC, ko)
+    ko, kt = cfg["tile_k"].apply(s, CC, ko)
     s[CC].tensorize(ki, _dp4a)
-    by, vy, ty, yi = cfg['tile_y'].apply(s, output, n)
-    bx, vx, tx, xi = cfg['tile_x'].apply(s, output, x)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, n)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
     s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi)
-    s[output].bind(by, te.thread_axis('blockIdx.y'))
-    s[output].bind(bx, te.thread_axis('blockIdx.x'))
-    s[output].bind(vy, te.thread_axis('vthread'))
-    s[output].bind(vx, te.thread_axis('vthread'))
-    s[output].bind(ty, te.thread_axis('threadIdx.y'))
-    s[output].bind(tx, te.thread_axis('threadIdx.x'))
-    n_ty = cfg['tile_y'].size[2]
-    n_tx = cfg['tile_x'].size[2]
+    s[output].bind(by, te.thread_axis("blockIdx.y"))
+    s[output].bind(bx, te.thread_axis("blockIdx.x"))
+    s[output].bind(vy, te.thread_axis("vthread"))
+    s[output].bind(vx, te.thread_axis("vthread"))
+    s[output].bind(ty, te.thread_axis("threadIdx.y"))
+    s[output].bind(tx, te.thread_axis("threadIdx.x"))
+    n_ty = cfg["tile_y"].size[2]
+    n_tx = cfg["tile_x"].size[2]
 
     s[CC].compute_at(s[output], tx)
     yo, xo = CC.op.axis[:2]
@@ -324,9 +343,9 @@ def _schedule_dense_int8(cfg, s, output):
 
         fused, tx = s[load].split(fused, factor=n_tx)
         fused, ty = s[load].split(fused, factor=n_ty)
-        s[load].bind(tx, te.thread_axis('threadIdx.x'))
-        s[load].bind(ty, te.thread_axis('threadIdx.y'))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
 
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', False)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", False)
     return s
index bb51c40..8c7d7cc 100644 (file)
@@ -22,8 +22,12 @@ from tvm import te
 import tvm.autotvm as autotvm
 from .. import tag
 from ..util import traverse_inline, get_const_tuple
-from .tensor_intrin import intrin_wmma_load_matrix_A, \
-        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+from .tensor_intrin import (
+    intrin_wmma_load_matrix_A,
+    intrin_wmma_load_matrix_W,
+    intrin_wmma_store_matrix,
+    intrin_wmma_gemm,
+)
 
 
 @autotvm.register_topi_compute("dense_tensorcore.cuda")
@@ -40,38 +44,47 @@ def schedule_dense_tensorcore(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense_tensorcore':
+        if op.tag == "dense_tensorcore":
             _schedule_dense_tensorcore(cfg, s, op.output(0))
+
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
 def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None):
     """Dense tensorcore operator on CUDA"""
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim dense"
+    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
     batch, in_dim = get_const_tuple(data.shape)
     out_dim, _ = get_const_tuple(weight.shape)
-    assert ((batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) or \
-            (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) or \
-            (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)), \
-            "The shape of (batch, in_dim, out_dim) "\
-             "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
-    k = te.reduce_axis((0, in_dim), name='k')
-    data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype('float16'))
-    weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype('float16'))
-    matmul = te.compute((batch, out_dim), \
-                         lambda i, j: te.sum(data_16[i, k].astype(out_dtype) * \
-                                              weight_16[j, k].astype(out_dtype), axis=k), \
-                         name='T_dense', tag='dense_tensorcore')
+    assert (
+        (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0)
+        or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0)
+        or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)
+    ), (
+        "The shape of (batch, in_dim, out_dim) "
+        "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+    )
+    k = te.reduce_axis((0, in_dim), name="k")
+    data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16"))
+    weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16"))
+    matmul = te.compute(
+        (batch, out_dim),
+        lambda i, j: te.sum(
+            data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k
+        ),
+        name="T_dense",
+        tag="dense_tensorcore",
+    )
     if bias is not None:
-        matmul = te.compute((batch, out_dim), \
-                             lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
-                             tag=tag.BROADCAST)
+        matmul = te.compute(
+            (batch, out_dim),
+            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+        )
     return matmul
 
 
@@ -84,18 +97,19 @@ def _schedule_dense_tensorcore(cfg, s, C):
     s[B].compute_inline()
 
     # Explicit memory access
-    AS = s.cache_read(A, 'shared', [C])
-    BS = s.cache_read(B, 'shared', [C])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
-    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
-    CF = s.cache_write(C, 'wmma.accumulator')
-    CS = s.cache_read(CF, 'shared', [C])
+    AS = s.cache_read(A, "shared", [C])
+    BS = s.cache_read(B, "shared", [C])
+    AF = s.cache_read(AS, "wmma.matrix_a", [C])
+    BF = s.cache_read(BS, "wmma.matrix_b", [C])
+    CF = s.cache_write(C, "wmma.accumulator")
+    CS = s.cache_read(CF, "shared", [C])
 
     # fallback support
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.kind.name, target.model, 'dense_tensorcore.cuda')
+            target.kind.name, target.model, "dense_tensorcore.cuda"
+        )
         cfg.fallback_with_reference_log(ref_log)
 
     # Deal with op fusion, such as bias and relu
@@ -113,12 +127,12 @@ def _schedule_dense_tensorcore(cfg, s, C):
     cfg.define_knob("offsetCS", [0, 8])
     cfg.define_knob("vec", [1, 2, 4, 8])
 
-    #Ensure that the default parameters are applicable when autotvm is not in use
-    if (batch % 32 == 0 and out_dim % 8 == 0):
+    # Ensure that the default parameters are applicable when autotvm is not in use
+    if batch % 32 == 0 and out_dim % 8 == 0:
         cfg.define_knob("wmma_m", [32, 16, 8])
-    elif (batch%16 == 0 and out_dim % 16 == 0):
+    elif batch % 16 == 0 and out_dim % 16 == 0:
         cfg.define_knob("wmma_m", [16, 8, 32])
-    elif (batch % 8 == 0 and out_dim % 32 == 0):
+    elif batch % 8 == 0 and out_dim % 32 == 0:
         cfg.define_knob("wmma_m", [8, 16, 32])
 
     warp_size = 32
@@ -140,7 +154,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
     elif wmma_m == 32:
         wmma_n = 8
 
-    #Define the stride of intrin functions
+    # Define the stride of intrin functions
     AS_align = chunk * wmma_k + offset
     BS_align = chunk * wmma_k + offset
     CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
@@ -151,13 +165,13 @@ def _schedule_dense_tensorcore(cfg, s, C):
     CF_stride = [warp_col_tiles * wmma_n, 1]
     CS_stride = [CS_align, 1]
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
-    #Schedule for dense computation
+    # Schedule for dense computation
     block_factor_b = wmma_m * warp_row_tiles * block_row_warps
     block_factor_o = wmma_n * warp_col_tiles * block_col_warps
     b, o = C.op.axis
@@ -176,7 +190,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
     s[C].bind(tx, thread_x)
     s[C].vectorize(vi)
 
-    #Schedule for wmma store
+    # Schedule for wmma store
     s[CS].compute_at(s[C], block_j)
     bb, oo = CS.op.axis
     s[CS].storage_align(bb, CS_align - 1, CS_align)
@@ -186,31 +200,31 @@ def _schedule_dense_tensorcore(cfg, s, C):
     oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
     s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
 
-    #Schedule for wmma computation
+    # Schedule for wmma computation
     s[CF].compute_at(s[CS], oo)
     warp_i, warp_j = CF.op.axis
     warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
     warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
-    k, = CF.op.reduce_axis
+    (k,) = CF.op.reduce_axis
     k, _k = s[CF].split(k, factor=wmma_k)
     ko, ki = s[CF].split(k, factor=chunk)
     s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
 
-    #Schedule for  wmma_matrix_a load
+    # Schedule for  wmma_matrix_a load
     s[AF].compute_at(s[CF], ki)
     b, i = AF.op.axis
     b, b_ii = s[AF].split(b, factor=wmma_m)
     i, i_jj = s[AF].split(i, factor=wmma_k)
     s[AF].reorder(b, i, b_ii, i_jj)
 
-    #Schedule for  wmma_matrix_b load
+    # Schedule for  wmma_matrix_b load
     s[BF].compute_at(s[CF], ki)
     o, i = BF.op.axis
     o, o_ii = s[BF].split(o, factor=wmma_n)
     i, i_ii = s[BF].split(i, factor=wmma_k)
     s[BF].reorder(o, i, o_ii, i_ii)
 
-    #Schedule for A's(B's) shared memory load
+    # Schedule for A's(B's) shared memory load
     def shared_shedule(stage, strides):
         s[stage].compute_at(s[CF], ko)
         xo, yo = stage.op.axis
@@ -229,24 +243,39 @@ def _schedule_dense_tensorcore(cfg, s, C):
     shared_shedule(BS, BS_align)
 
     shape = (wmma_m, wmma_n, wmma_k)
-    in_dtype = 'float16'
-    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
-    BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype)
-    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
-    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
-                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *\
-                                   BL_gemm[jj, k_gemm].astype(out_dtype),\
-                                   axis=k_gemm), name='CL_compute')
-
-    #lower the computation loops down to TensorCore hardware intrinsics
-    #by mapping the dense tensorcore to tensor intrinsics
-    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \
-            AF_stride, AS_stride, shape, "row_major",\
-            (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
-    s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \
-            BF_stride, BS_stride, shape, "col_major",\
-            (wmma_n, wmma_k), (wmma_n, wmma_k), 'float16'))
-    s[CF].tensorize(_ii, intrin_wmma_gemm( \
-            AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape))
-    s[CS].tensorize(bbi, intrin_wmma_store_matrix( \
-            CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)))
+    in_dtype = "float16"
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
+    CL_compute = te.compute(
+        (wmma_m, wmma_n),
+        lambda ii, jj: te.sum(
+            AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype),
+            axis=k_gemm,
+        ),
+        name="CL_compute",
+    )
+
+    # lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the dense tensorcore to tensor intrinsics
+    s[AF].tensorize(
+        b_ii,
+        intrin_wmma_load_matrix_A(
+            AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16"
+        ),
+    )
+    s[BF].tensorize(
+        o_ii,
+        intrin_wmma_load_matrix_W(
+            BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16"
+        ),
+    )
+    s[CF].tensorize(
+        _ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape)
+    )
+    s[CS].tensorize(
+        bbi,
+        intrin_wmma_store_matrix(
+            CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)
+        ),
+    )
index f2f7a04..2908439 100644 (file)
@@ -29,6 +29,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty
     """Compute depthwise_conv2d with NCHW layout."""
     return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
+
 @autotvm.register_topi_schedule("depthwise_conv2d_nchw.cuda")
 def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule for depthwise_conv2d nchw forward.
@@ -48,7 +49,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'depthwise_conv2d_nchw':
+        if op.tag == "depthwise_conv2d_nchw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -61,7 +62,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -69,29 +70,30 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.kind.name, target.model, 'depthwise_conv2d_nchw.cuda')
+                    target.kind.name, target.model, "depthwise_conv2d_nchw.cuda"
+                )
                 cfg.fallback_with_reference_log(ref_log)
                 # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
-                cfg['unroll_explicit'].val = 0
+                cfg["unroll_explicit"].val = 0
             ##### space definition end #####
 
             s[pad_data].compute_inline()
-            if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            AA = s.cache_read(pad_data, 'shared', [OL])
-            WW = s.cache_read(kernel, 'shared', [OL])
-            AL = s.cache_read(AA, 'local', [OL])
-            WL = s.cache_read(WW, 'local', [OL])
+            AA = s.cache_read(pad_data, "shared", [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
+            AL = s.cache_read(AA, "local", [OL])
+            WL = s.cache_read(WW, "local", [OL])
 
             # tile and bind spatial axes
             n, f, y, x = s[output].op.axis
@@ -128,12 +130,13 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
                 s[load].bind(ty, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def schedule_depthwise_conv2d_nhwc(outs):
     """Schedule for depthwise_conv2d nhwc forward.
 
@@ -203,10 +206,10 @@ def schedule_depthwise_conv2d_nhwc(outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule depthwise_conv2d
-        if OP.tag == 'depthwise_conv2d_nhwc':
+        if OP.tag == "depthwise_conv2d_nhwc":
             PaddedInput = OP.input_tensors[0]
             Filter = OP.input_tensors[1]
-            if isinstance(Filter.op, tvm.te.ComputeOp) and 'dilate' in Filter.op.tag:
+            if isinstance(Filter.op, tvm.te.ComputeOp) and "dilate" in Filter.op.tag:
                 s[Filter].compute_inline()
             DepthwiseConv2d = OP.output(0)
             _schedule(PaddedInput, Filter, DepthwiseConv2d)
@@ -250,7 +253,7 @@ def schedule_depthwise_conv2d_backward_input_nhwc(outs):
 
     def traverse(OP):
         # inline all one-to-one-mapping operators except the last stage (output)
-        if OP.tag == 'depthwise_conv2d_backward_input_nhwc':
+        if OP.tag == "depthwise_conv2d_backward_input_nhwc":
             Padded_out_grad = OP.input_tensors[0]
             Dilated_out_grad = Padded_out_grad.op.input_tensors[0]
             s[Dilated_out_grad].compute_inline()
@@ -262,6 +265,7 @@ def schedule_depthwise_conv2d_backward_input_nhwc(outs):
     traverse(outs[0].op)
     return s
 
+
 def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
     """Schedule for depthwise_conv2d nhwc backward wrt weight.
 
@@ -303,7 +307,7 @@ def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
 
     def traverse(OP):
         # inline all one-to-one-mapping operators except the last stage (output)
-        if OP.tag == 'depthwise_conv2d_backward_weight_nhwc':
+        if OP.tag == "depthwise_conv2d_backward_weight_nhwc":
             Padded_in = OP.input_tensors[1]
             s[Padded_in].compute_inline()
             Weight_grad = OP.output(0)
index ab7db66..35d5119 100644 (file)
@@ -29,8 +29,7 @@ from .. import nn
 
 
 @autotvm.register_topi_compute("group_conv2d_nchw.cuda")
-def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups,
-                      out_dtype='float32'):
+def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups, out_dtype="float32"):
     return nn.group_conv2d_nchw(data, kernel, stride, padding, dilation, groups, out_dtype)
 
 
@@ -83,7 +82,7 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.kind.name in ['nvptx', 'rocm']:
+    if target.kind.name in ["nvptx", "rocm"]:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -94,15 +93,15 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
 
     if conv.op in s.outputs:
         output = conv
-        OL = s.cache_write(conv, 'local')
+        OL = s.cache_write(conv, "local")
     else:
         output = s.outputs[0].output(0)
-        s[conv].set_scope('local')
+        s[conv].set_scope("local")
         OL = conv
 
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
+    AA = s.cache_read(pad_data, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
 
     # tile and bind spatial axes
     n, f, y, x = s[output].op.axis
@@ -151,9 +150,9 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
     # tile reduction axes
     n, f, y, x = s[OL].op.axis
     rc, ry, rx = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
-    rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
+    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
+    ryo, ryi = cfg["tile_rx"].apply(s, OL, ry)
+    rxo, rxi = cfg["tile_ry"].apply(s, OL, rx)
     s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
 
     s[AA].compute_at(s[OL], rxo)
@@ -171,8 +170,8 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     N, CO, OH, OW = get_const_tuple(output.shape)
     _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
@@ -180,8 +179,9 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
 
 
 @autotvm.register_topi_compute("group_conv2d_NCHWc_int8.cuda")
-def group_conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, groups,
-                            out_dtype='float32'):
+def group_conv2d_NCHWc_int8(
+    cfg, data, kernel, stride, padding, dilation, groups, out_dtype="float32"
+):
     """Group convolution operator for 'group_conv2d_NCHWc_int8'.
 
     Parameters
@@ -221,46 +221,57 @@ def group_conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, groups
     pre_computed = len(kernel.shape) == 6
     if not pre_computed:
         batch, channels, height, width = get_const_tuple(data.shape)
-        out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(
-            kernel.shape)
+        out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape)
 
         assert channels % groups == 0, "input channels must divide group size"
         assert out_channels % groups == 0, "output channels must divide group size"
-        assert channels % ic_block_factor == 0, \
-            "Number of input channels per group must divide {}".format(ic_block_factor)
-        assert out_channels % oc_block_factor == 0, \
-            "Number of output channels per group must divide {}".format(oc_block_factor)
-
-        packed_data = te.compute((batch, channels // ic_block_factor, height, width,
-                                  ic_block_factor),
-                                 lambda n, c, h, w, vc: data[n, c*ic_block_factor + vc, h, w],
-                                 name="packed_data")
+        assert (
+            channels % ic_block_factor == 0
+        ), "Number of input channels per group must divide {}".format(ic_block_factor)
+        assert (
+            out_channels % oc_block_factor == 0
+        ), "Number of output channels per group must divide {}".format(oc_block_factor)
+
+        packed_data = te.compute(
+            (batch, channels // ic_block_factor, height, width, ic_block_factor),
+            lambda n, c, h, w, vc: data[n, c * ic_block_factor + vc, h, w],
+            name="packed_data",
+        )
         packed_kernel = te.compute(
-            (out_channels // oc_block_factor, in_channels // ic_block_factor, kernel_h, kernel_w,
-             oc_block_factor, ic_block_factor),
-            lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block:
-            kernel[oc_chunk * oc_block_factor + oc_block,
-                   ic_chunk * ic_block_factor + ic_block, kh, kw],
-            name="packed_kernel")
+            (
+                out_channels // oc_block_factor,
+                in_channels // ic_block_factor,
+                kernel_h,
+                kernel_w,
+                oc_block_factor,
+                ic_block_factor,
+            ),
+            lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block: kernel[
+                oc_chunk * oc_block_factor + oc_block, ic_chunk * ic_block_factor + ic_block, kh, kw
+            ],
+            name="packed_kernel",
+        )
     else:
         packed_data = data
         packed_kernel = kernel
 
-    batch, ic_chunk, in_height, in_width, _ = get_const_tuple(
-        packed_data.shape)
-    oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
-        packed_kernel.shape)
+    batch, ic_chunk, in_height, in_width, _ = get_const_tuple(packed_data.shape)
+    oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(packed_kernel.shape)
 
     # TODO(kumasento): these assertions ensure that the number of groups
     # should be smaller or equal to the number of blocks, so that each
     # group will have at least one block.
     # Shall we pad the channels to avoid raising assertions?
-    assert groups <= oc_chunk, \
-        ('Number of groups {} should be less than '
-         'output channel chunk size {}'.format(groups, oc_chunk))
-    assert groups <= ic_chunk, \
-        ('Number of groups {} should be less than '
-         'input channel chunk size {}'.format(groups, ic_chunk))
+    assert (
+        groups <= oc_chunk
+    ), "Number of groups {} should be less than " "output channel chunk size {}".format(
+        groups, oc_chunk
+    )
+    assert (
+        groups <= ic_chunk
+    ), "Number of groups {} should be less than " "input channel chunk size {}".format(
+        groups, ic_chunk
+    )
 
     if isinstance(stride, int):
         stride_h = stride_w = stride
@@ -273,24 +284,21 @@ def group_conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, groups
         dilation_h, dilation_w = dilation
 
     # pad the input data
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (kernel_h, kernel_w))
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
     pad_before = [0, 0, pad_top, pad_left, 0]
     pad_after = [0, 0, pad_down, pad_right, 0]
     pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
 
     # compute the output shape
-    out_height = (in_height - (kernel_h - 1) * dilation_h -
-                  1 + pad_top + pad_down) // stride_h + 1
-    out_width = (in_width - (kernel_w - 1) * dilation_w -
-                 1 + pad_left + pad_right) // stride_w + 1
+    out_height = (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1
+    out_width = (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1
 
     oshape = (batch, oc_chunk, out_height, out_width, oc_block)
 
-    icc = te.reduce_axis((0, ic_chunk // groups), name='ic_chunk')
-    icb = te.reduce_axis((0, ic_block_factor), name='ic_block')
-    kh = te.reduce_axis((0, kernel_h), name='kh')
-    kw = te.reduce_axis((0, kernel_w), name='kw')
+    icc = te.reduce_axis((0, ic_chunk // groups), name="ic_chunk")
+    icb = te.reduce_axis((0, ic_block_factor), name="ic_block")
+    kh = te.reduce_axis((0, kernel_h), name="kh")
+    kw = te.reduce_axis((0, kernel_w), name="kw")
 
     # NOTE(kumasento): explanation of this snippet -
     # oc_chunk//groups and ic_chunk//groups give you the number of blocks,
@@ -304,19 +312,38 @@ def group_conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, groups
     # Compared with a normal convolution, group convolution only sums
     # input channels from the group that an output channel resides in.
     conv = te.compute(
-        oshape, lambda n, occ, oh, ow, ocb:
-        te.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
-                        oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
-               .astype('int32') *
-               packed_kernel[occ, icc, kh, kw, ocb, icb].astype('int32'),
-               axis=[icc, kh, kw, icb]))
+        oshape,
+        lambda n, occ, oh, ow, ocb: te.sum(
+            pad_data[
+                n,
+                occ // (oc_chunk // groups) * (ic_chunk // groups) + icc,
+                oh * stride_h + kh * dilation_h,
+                ow * stride_w + kw * dilation_w,
+                icb,
+            ].astype("int32")
+            * packed_kernel[occ, icc, kh, kw, ocb, icb].astype("int32"),
+            axis=[icc, kh, kw, icb],
+        ),
+    )
 
     # Type conversion
-    output = te.compute(oshape, lambda *index: conv(*index).astype(out_dtype),
-                        tag='group_conv2d_NCHWc_int8')
-
-    num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
-        ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
+    output = te.compute(
+        oshape, lambda *index: conv(*index).astype(out_dtype), tag="group_conv2d_NCHWc_int8"
+    )
+
+    num_flop = (
+        batch
+        * oc_chunk
+        * oc_block
+        * out_height
+        * out_width
+        * ic_chunk
+        * ic_block
+        * kernel_h
+        * kernel_w
+        * 2
+        // groups
+    )
     cfg.add_flop(num_flop)
 
     return output
@@ -351,7 +378,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, outs):
     return s
 
 
-_dp4a = dp4a('shared', 'shared', 'local')
+_dp4a = dp4a("shared", "shared", "local")
 
 
 def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
@@ -372,11 +399,9 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
         # skip this part during tuning to make records accurate
         # this part will be pre-computed during NNVM's pre-compute optimization pass
         s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
-        s[packed_kernel].pragma(
-            s[packed_kernel].op.axis[0], "debug_skip_region")
+        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
     else:
-        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and \
-                packed_kernel.name == 'packed_kernel':
+        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
             # data and kernel are not pre-computed, schedule layout transform here
             schedule_injective_from_existing(s, packed_data)
             schedule_injective_from_existing(s, packed_kernel)
@@ -385,10 +410,10 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
         s[pad_data].compute_inline()
 
     # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [conv])
-    WW = s.cache_read(packed_kernel, 'shared', [conv])
+    AA = s.cache_read(pad_data, "shared", [conv])
+    WW = s.cache_read(packed_kernel, "shared", [conv])
 
-    s[conv].set_scope('local')
+    s[conv].set_scope("local")
 
     # handle bias
     if output.op not in s.outputs:
@@ -408,15 +433,14 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
     kernel_scope, n = s[output].split(n, nparts=1)
 
     g, f = s[output].split(f, nparts=groups)
-    s[output].bind(n, te.thread_axis('blockIdx.z'))
+    s[output].bind(n, te.thread_axis("blockIdx.z"))
     bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
     bg, vg = cfg["tile_g"].apply(s, output, g)
     bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
     by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
     bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
-    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy,
-                      vx, tn, tf, ty, tx, ni, fi, yi, xi)
+    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
     s[output].bind(bn, te.thread_axis("blockIdx.z"))
     s[output].bind(s[output].fuse(bg, bf), te.thread_axis("blockIdx.y"))
     s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
@@ -455,9 +479,9 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
     cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
     cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
     cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
-    rco, rci = cfg['tile_rc'].apply(s, conv, rc)
-    ryo, ryi = cfg['tile_ry'].apply(s, conv, ry)
-    rxo, rxi = cfg['tile_rx'].apply(s, conv, rx)
+    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
 
     s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)
     _, rc_block = s[conv].split(rc_block, factor=4)
@@ -482,17 +506,16 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # double buffer
-    cfg.define_knob('AA_double_buffer', [0, 1])
-    cfg.define_knob('WW_double_buffer', [0, 1])
-    if cfg['AA_double_buffer'].val:
+    cfg.define_knob("AA_double_buffer", [0, 1])
+    cfg.define_knob("WW_double_buffer", [0, 1])
+    if cfg["AA_double_buffer"].val:
         s[AA].double_buffer()
-    if cfg['WW_double_buffer'].val:
+    if cfg["WW_double_buffer"].val:
         s[WW].double_buffer()
 
     # unroll
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
-                     cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', False)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", False)
 
     return s
index bd3e01d..8a5f618 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from .. import util
 
+
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -66,6 +67,7 @@ def schedule_injective_from_existing(sch, out):
 
     return sch
 
+
 def schedule_injective(outs):
     """Schedule for injective op.
 
@@ -89,5 +91,6 @@ def schedule_injective(outs):
             schedule_injective_from_existing(s, out)
     return s
 
+
 schedule_elemwise = schedule_injective
 schedule_broadcast = schedule_injective
index 4772080..2041f4c 100644 (file)
@@ -34,25 +34,29 @@ def cuda_atomic_add_rule(op):
         return tvm.tir.call_pure_extern("int32", "atomicAdd", op.args[0], op.args[1])
     raise RuntimeError("only support int32, float32 and float64")
 
+
 def opencl_atomic_add_rule(op):
     if op.dtype == "int32":
         return tvm.tir.call_pure_extern("int32", "atomic_add", op.args[0], op.args[1])
     raise RuntimeError("only support int32")
 
-tvm.target.intrin.register_intrin_rule(
-    "cuda", "atomic_add", cuda_atomic_add_rule, override=True)
+
+tvm.target.intrin.register_intrin_rule("cuda", "atomic_add", cuda_atomic_add_rule, override=True)
 
 tvm.target.intrin.register_intrin_rule(
-    "opencl", "atomic_add", opencl_atomic_add_rule, override=True)
+    "opencl", "atomic_add", opencl_atomic_add_rule, override=True
+)
 
 tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
 
+
 def atomic_add(x, y):
     return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
-def get_valid_counts_ir(data, valid_count, out, out_indices,
-                        score_threshold, id_index, score_index):
+def get_valid_counts_ir(
+    data, valid_count, out, out_indices, score_threshold, id_index, score_index
+):
     """Low level IR to get valid count of bounding boxes
     given a score threshold. Also prepares to move valid boxes to the
     top of input data.
@@ -94,16 +98,15 @@ def get_valid_counts_ir(data, valid_count, out, out_indices,
     out = ib.buffer_ptr(out)
     out_indices = ib.buffer_ptr(out_indices)
     atomic_add_return = ib.allocate(
-        valid_count.dtype, (1,), name='atomic_add_return', scope='local')
+        valid_count.dtype, (1,), name="atomic_add_return", scope="local"
+    )
     one_count = tvm.tir.const(1, dtype=valid_count.dtype)
     one = tvm.tir.const(1, dtype=out.dtype)
-    score_threshold = tvm.ir.make_node(
-        "FloatImm", dtype="float32", value=score_threshold)
+    score_threshold = tvm.ir.make_node("FloatImm", dtype="float32", value=score_threshold)
     id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.ir.make_node("IntImm", dtype="int32", value=score_index)
 
-    max_threads = int(tvm.target.Target.current(
-        allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors // max_threads + 1
     tx = te.thread_axis("threadIdx.x")
@@ -119,10 +122,14 @@ def get_valid_counts_ir(data, valid_count, out, out_indices,
     with ib.if_scope(tid < batch_size * num_anchors):
         i = idxd(tid, num_anchors)
         with ib.if_scope(
-                tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
-                            tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
-            atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of",
-                                                                       valid_count[i]), one_count)
+            tvm.tir.all(
+                data[tid * elem_length + score_index] > score_threshold,
+                tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0),
+            )
+        ):
+            atomic_add_return[0] = atomic_add(
+                tvm.tir.call_intrin("handle", "tir.address_of", valid_count[i]), one_count
+            )
             with ib.for_range(0, elem_length) as k:
                 out[tid * elem_length + k] = data[tid * elem_length + k]
                 out_indices[tid + k] = tid + k
@@ -162,31 +169,45 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
-    data_buf = tvm.tir.decl_buffer(
-        data.shape, data.dtype, "data_buf", data_alignment=8)
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
     valid_count_buf = tvm.tir.decl_buffer(
-        (batch_size,), "int32", "valid_count_buf", data_alignment=8)
-    out_buf = tvm.tir.decl_buffer(
-        data.shape, data.dtype, "out_buf", data_alignment=8)
+        (batch_size,), "int32", "valid_count_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
     out_indices_buf = tvm.tir.decl_buffer(
-        (batch_size, num_anchors), "int32", "out_buf", data_alignment=8)
-
-    valid_count, out, out_indices = \
-        te.extern([(batch_size,), data.shape, (batch_size, num_anchors)], [data],
-                  lambda ins, outs: get_valid_counts_ir(
-            ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index),
-            dtype=["int32", data.dtype],
-            in_buffers=[data_buf],
-            out_buffers=[valid_count_buf, out_buf, out_indices_buf],
-            name="get_valid_counts",
-            tag="get_valid_counts_gpu")
+        (batch_size, num_anchors), "int32", "out_buf", data_alignment=8
+    )
+
+    valid_count, out, out_indices = te.extern(
+        [(batch_size,), data.shape, (batch_size, num_anchors)],
+        [data],
+        lambda ins, outs: get_valid_counts_ir(
+            ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index
+        ),
+        dtype=["int32", data.dtype],
+        in_buffers=[data_buf],
+        out_buffers=[valid_count_buf, out_buf, out_indices_buf],
+        name="get_valid_counts",
+        tag="get_valid_counts_gpu",
+    )
 
     return [valid_count, out, out_indices]
 
 
-def nms_ir(data, sorted_index, valid_count, out, box_indices,
-           max_output_size, iou_threshold, force_suppress,
-           top_k, coord_start, id_index, score_index):
+def nms_ir(
+    data,
+    sorted_index,
+    valid_count,
+    out,
+    box_indices,
+    max_output_size,
+    iou_threshold,
+    force_suppress,
+    top_k,
+    coord_start,
+    id_index,
+    score_index,
+):
     """Low level IR routing for transform location in multibox_detection operator.
 
     Parameters
@@ -230,18 +251,27 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     stmt : Stmt
         The result IR statement.
     """
+
     def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
-        """Calculate overlap of two boxes.
-        """
-        w = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
-                       - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
-        h = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
-                       - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
+        """Calculate overlap of two boxes."""
+        w = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
+            - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]),
+        )
+        h = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
+            - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]),
+        )
         i = w * h
-        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
-            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
-            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
-            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
+        u = (
+            (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx])
+            * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1])
+            + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx])
+            * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1])
+            - i
+        )
         return tvm.tir.Select(u <= 0.0, 0.0, i / u)
 
     batch_size = data.shape[0]
@@ -255,11 +285,9 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     valid_count = ib.buffer_ptr(valid_count)
     out = ib.buffer_ptr(out)
     box_indices = ib.buffer_ptr(box_indices)
-    num_valid_boxes = ib.allocate(
-        "int32", (1,), name="num_valid_boxes", scope="local")
+    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
 
-    max_threads = int(
-        tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = num_anchors // max_threads + 1
     tx = te.thread_axis("threadIdx.x")
@@ -268,29 +296,26 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     j = bx * max_threads + tx
 
-    iou_threshold = tvm.ir.make_node(
-        "FloatImm", dtype="float32", value=iou_threshold)
+    iou_threshold = tvm.ir.make_node("FloatImm", dtype="float32", value=iou_threshold)
     top_k = tvm.ir.make_node("IntImm", dtype="int32", value=top_k)
     coord_start = tvm.ir.make_node("IntImm", dtype="int32", value=coord_start)
     id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.ir.make_node("IntImm", dtype="int32", value=score_index)
-    force_suppress = tvm.ir.make_node(
-        "IntImm", dtype="int32", value=1 if force_suppress else 0)
+    force_suppress = tvm.ir.make_node("IntImm", dtype="int32", value=1 if force_suppress else 0)
 
     with ib.for_range(0, batch_size, for_type="unroll") as i:
         base_idx = i * num_anchors * box_data_length
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Reorder output
             nkeep = if_then_else(
-                tvm.tir.all(top_k > 0, top_k < valid_count[i]),
-                top_k, valid_count[i])
+                tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
+            )
             with ib.if_scope(j < nkeep):
                 with ib.for_range(0, box_data_length) as k:
-                    out[(base_idx + j * box_data_length + k)] = \
-                        data[(base_idx + sorted_index[i * num_anchors + j]
-                              * box_data_length + k)]
-                box_indices[i * num_anchors +
-                            j] = sorted_index[i * num_anchors + j]
+                    out[(base_idx + j * box_data_length + k)] = data[
+                        (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
+                    ]
+                box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
             with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
                 with ib.if_scope(j < valid_count[i] - nkeep):
                     with ib.for_range(0, box_data_length) as k:
@@ -300,22 +325,31 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
             with ib.for_range(0, valid_count[i]) as k:
                 offset_k = k * box_data_length
                 with ib.if_scope(
-                        tvm.tir.all(out[base_idx + offset_k + score_index] > 0,
-                                    tvm.tir.any(id_index < 0, out[base_idx +
-                                                                  offset_k + id_index] >= 0))):
+                    tvm.tir.all(
+                        out[base_idx + offset_k + score_index] > 0,
+                        tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
+                    )
+                ):
                     with ib.if_scope(j < valid_count[i]):
                         offset_j = j * box_data_length
                         with ib.if_scope(
-                                tvm.tir.all(j > k,
-                                            out[base_idx + offset_j +
-                                                score_index] > 0,
-                                            tvm.tir.any(id_index < 0,
-                                                        out[base_idx + offset_j + id_index] >= 0),
-                                            tvm.tir.any(force_suppress > 0, id_index < 0,
-                                                    out[base_idx + offset_k + id_index] ==
-                                                        out[base_idx + offset_j + id_index]))):
-                            iou = calculate_overlap(out, base_idx + offset_j + coord_start,
-                                                    base_idx + offset_k + coord_start)
+                            tvm.tir.all(
+                                j > k,
+                                out[base_idx + offset_j + score_index] > 0,
+                                tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0),
+                                tvm.tir.any(
+                                    force_suppress > 0,
+                                    id_index < 0,
+                                    out[base_idx + offset_k + id_index]
+                                    == out[base_idx + offset_j + id_index],
+                                ),
+                            )
+                        ):
+                            iou = calculate_overlap(
+                                out,
+                                base_idx + offset_j + coord_start,
+                                base_idx + offset_k + coord_start,
+                            )
                             with ib.if_scope(iou >= iou_threshold):
                                 out[base_idx + offset_j + score_index] = -1.0
                                 with ib.if_scope(id_index >= 0):
@@ -325,14 +359,12 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
             with ib.if_scope(j < valid_count[i]):
                 offset_j = j * box_data_length
                 with ib.for_range(0, box_data_length) as k:
-                    out[(base_idx + offset_j + k)
-                        ] = data[base_idx + offset_j + k]
+                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
                 box_indices[i * num_anchors + j] = j
         # Set invalid entry to be -1
         with ib.if_scope(j < num_anchors - valid_count[i]):
             with ib.for_range(0, box_data_length) as k:
-                out[base_idx + (j + valid_count[i]) *
-                    box_data_length + k] = -1.0
+                out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
             box_indices[i * num_anchors + j + valid_count[i]] = -1
         # Only return max_output_size number of valid boxes
         num_valid_boxes[0] = 0
@@ -350,10 +382,20 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     return ib.get()
 
 
-def non_max_suppression(data, valid_count, indices, max_output_size=-1,
-                        iou_threshold=0.5, force_suppress=False, top_k=-1,
-                        coord_start=2, score_index=1, id_index=0,
-                        return_indices=True, invalid_to_bottom=False):
+def non_max_suppression(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
     """Non-maximum suppression operator for object detection.
 
     Parameters
@@ -435,36 +477,49 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
     num_anchors = data.shape[1]
 
     valid_count_dtype = "int32"
-    valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count_dtype,
-                                          "valid_count_buf", data_alignment=4)
+    valid_count_buf = tvm.tir.decl_buffer(
+        valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4
+    )
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = te.compute(
-        score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
+    score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
     if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
         sort_tensor = argsort_thrust(
-            score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype)
+            score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
+        )
     else:
         sort_tensor = argsort(
-            score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype)
-
-    sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
-                                          "sort_tensor_buf", data_alignment=8)
-
-    data_buf = tvm.tir.decl_buffer(
-        data.shape, data.dtype, "data_buf", data_alignment=8)
-
-    out, box_indices = \
-        te.extern([data.shape, score_shape],
-                  [data, sort_tensor, valid_count],
-                  lambda ins, outs: nms_ir(
-            ins[0], ins[1], ins[2], outs[0], outs[1],
-            max_output_size, iou_threshold, force_suppress,
-            top_k, coord_start, id_index, score_index),
-            dtype=[data.dtype, "int32"],
-            in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
-            name="nms",
-            tag="nms")
+            score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
+        )
+
+    sort_tensor_buf = tvm.tir.decl_buffer(
+        sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
+    )
+
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+
+    out, box_indices = te.extern(
+        [data.shape, score_shape],
+        [data, sort_tensor, valid_count],
+        lambda ins, outs: nms_ir(
+            ins[0],
+            ins[1],
+            ins[2],
+            outs[0],
+            outs[1],
+            max_output_size,
+            iou_threshold,
+            force_suppress,
+            top_k,
+            coord_start,
+            id_index,
+            score_index,
+        ),
+        dtype=[data.dtype, "int32"],
+        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+        name="nms",
+        tag="nms",
+    )
     # TODO(yongwww): Update cuda nms to be consistent with cpu version
     if return_indices:
         return box_indices
index 4460f7b..0de3777 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
 
 from .. import cpp
 
+
 def schedule_lrn(outs):
     """Schedule for LRN
 
index 9839984..a3caf5f 100644 (file)
@@ -22,7 +22,7 @@ from .. import tag
 from ..util import traverse_inline
 
 
-def schedule_adaptive_pool(outs, layout='NCHW'):
+def schedule_adaptive_pool(outs, layout="NCHW"):
     """Schedule for adaptive_pool.
 
     Parameters
@@ -53,7 +53,7 @@ def schedule_adaptive_pool(outs, layout='NCHW'):
             s[Pool].set_scope("local")
 
         by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
-        if layout == 'NHWC':
+        if layout == "NHWC":
             bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread)
         else:
             bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
@@ -79,7 +79,7 @@ def schedule_adaptive_pool(outs, layout='NCHW'):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule global_pool
-        elif OP.tag.startswith('adaptive_pool'):
+        elif OP.tag.startswith("adaptive_pool"):
             Pool = OP.output(0)
             _schedule(Pool)
         else:
@@ -110,6 +110,7 @@ def schedule_pool(outs, layout):
     """
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
+
     def _schedule(PaddedInput, Pool):
         if isinstance(PaddedInput.op, tvm.te.ComputeOp):
             s[PaddedInput].compute_inline()
@@ -141,7 +142,7 @@ def schedule_pool(outs, layout):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule pool
-        elif OP.tag.startswith('pool'):
+        elif OP.tag.startswith("pool"):
             PaddedInput = OP.input_tensors[0]
             Pool = OP.output(0)
             _schedule(PaddedInput, Pool)
@@ -194,7 +195,7 @@ def schedule_pool_grad(outs):
             s[op].compute_at(s[out], tx)
 
     def _callback(op):
-        if op.tag.startswith('pool_grad'):
+        if op.tag.startswith("pool_grad"):
             _schedule_pool_grad(op)
 
     traverse_inline(s, outs[0].op, _callback)
index 1414384..119b7bd 100644 (file)
@@ -23,8 +23,17 @@ from ...vision.rcnn import generate_anchor, reg_bbox, reg_iou
 from ...util import get_const_tuple, get_const_int
 
 
-def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
-                    feature_stride, rpn_min_size, iou_loss):
+def predict_bbox_ir(
+    cls_prob_buf,
+    bbox_pred_buf,
+    im_info_buf,
+    out_buf,
+    scales,
+    ratios,
+    feature_stride,
+    rpn_min_size,
+    iou_loss,
+):
     """Predict bounding boxes based on anchors, scores and deltas.
 
     Parameters
@@ -100,8 +109,10 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
             x2 = anchor[2] + w * feature_stride
             y2 = anchor[3] + h * feature_stride
 
-            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
-                     for i in range(4)]
+            delta = [
+                p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
+                for i in range(4)
+            ]
             regression_func = reg_iou if iou_loss else reg_bbox
             pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
 
@@ -110,16 +121,17 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
             pred_x2 = tvm.te.max(tvm.te.min(pred_x2, im_width - 1.0), 0.0)
             pred_y2 = tvm.te.max(tvm.te.min(pred_y2, im_height - 1.0), 0.0)
 
-            real_height = (im_height / feature_stride).astype('int32')
-            real_width = (im_width / feature_stride).astype('int32')
+            real_height = (im_height / feature_stride).astype("int32")
+            real_width = (im_width / feature_stride).astype("int32")
 
             bbox_w = pred_x2 - pred_x1 + 1.0
             bbox_h = pred_y2 - pred_y1 + 1.0
             min_size = p_im_info[b * 3 + 2] * rpn_min_size
 
             pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
-            pred_score = tvm.tir.Select(tvm.tir.any(h >= real_height, w >= real_width),
-                                        -1.0, pred_score)
+            pred_score = tvm.tir.Select(
+                tvm.tir.any(h >= real_height, w >= real_width), -1.0, pred_score
+            )
             p_out[out_index * 5 + 0] = pred_x1
             p_out[out_index * 5 + 1] = pred_y1
             p_out[out_index * 5 + 2] = pred_x2
@@ -178,15 +190,15 @@ def argsort_ir(data_buf, out_index_buf):
         with ib.for_range(0, num_bbox) as k:
             offset = start + 2 * tid + idxm(k, 2)
             with ib.if_scope(
-                    tvm.tir.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
+                tvm.tir.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])
+            ):
                 temp_data[0] = p_data[offset]
                 p_data[offset] = p_data[offset + 1]
                 p_data[offset + 1] = temp_data[0]
                 temp_index[0] = index_out[offset]
                 index_out[offset] = index_out[offset + 1]
                 index_out[offset + 1] = temp_index[0]
-            ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                 tvm.runtime.convert(['shared'])))
+            ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
     return ib.get()
 
 
@@ -210,18 +222,29 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
     stmt : Stmt
         The result IR statement.
     """
+
     def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
-        """Calculate overlap of two boxes.
-        """
-        w = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
-                       - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
-        h = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
-                       - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
+        """Calculate overlap of two boxes."""
+        w = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
+            - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx])
+            + 1.0,
+        )
+        h = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
+            - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])
+            + 1.0,
+        )
         i = w * h
-        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
-            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
-            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
-            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
+        u = (
+            (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0)
+            * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0)
+            + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0)
+            * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0)
+            - i
+        )
         return i / u
 
     batch, num_bbox = get_const_tuple(out_buf.shape)
@@ -245,8 +268,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
                 iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
                 with ib.if_scope(iou > nms_threshold):
                     p_out[base_idx + i] = True
-        ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                             tvm.runtime.convert(['shared'])))
+        ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
     return ib.get()
 
 
@@ -277,29 +299,31 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     tx = te.thread_axis("threadIdx.x")
     ib = tvm.tir.ir_builder.create()
     ib.scope_attr(tx, "thread_extent", nthread_tx)
-    i = ib.allocate('int32', (1,), 'i', scope='local')
+    i = ib.allocate("int32", (1,), "i", scope="local")
     i[0] = 0
     p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
     p_remove = ib.buffer_ptr(remove_mask_buf)
     p_out = ib.buffer_ptr(out_buf)
     b = tx
 
-    nkeep = ib.allocate('int32', (1,), 'nkeep', scope='local')
-    nkeep[0] = 0 # number of bbox after nms
+    nkeep = ib.allocate("int32", (1,), "nkeep", scope="local")
+    nkeep[0] = 0  # number of bbox after nms
 
     with ib.for_range(0, num_bbox) as j:
         with ib.if_scope(p_remove[b * num_bbox + j] == False):
             nkeep[0] += 1
     with ib.if_scope(nkeep[0] > 0):
-        with ib.for_range(0, te.ceil(
-                tvm.tir.const(rpn_post_nms_top_n, 'float32') / nkeep[0]).astype('int32')):
+        with ib.for_range(
+            0, te.ceil(tvm.tir.const(rpn_post_nms_top_n, "float32") / nkeep[0]).astype("int32")
+        ):
             with ib.for_range(0, num_bbox) as j:
                 offset_j = (b * num_bbox + j) * 5
                 offset_i = (b * rpn_post_nms_top_n + i[0]) * 5
-                with ib.if_scope(tvm.tir.all(i[0] < rpn_post_nms_top_n,
-                                             p_remove[(b*num_bbox+j)] == False)):
-                    p_out[offset_i] = tvm.tir.Cast('float32', b)
-                    with ib.for_range(0, 4, for_type='unroll') as k:
+                with ib.if_scope(
+                    tvm.tir.all(i[0] < rpn_post_nms_top_n, p_remove[(b * num_bbox + j)] == False)
+                ):
+                    p_out[offset_i] = tvm.tir.Cast("float32", b)
+                    with ib.for_range(0, 4, for_type="unroll") as k:
                         p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
                     i[0] = i[0] + 1
 
@@ -307,8 +331,19 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     return body
 
 
-def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
-             rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
+def proposal(
+    cls_prob,
+    bbox_pred,
+    im_info,
+    scales,
+    ratios,
+    feature_stride,
+    threshold,
+    rpn_pre_nms_top_n,
+    rpn_post_nms_top_n,
+    rpn_min_size,
+    iou_loss,
+):
     """Proposal operator.
 
     Parameters
@@ -359,20 +394,33 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres
     num_bbox = height * width * num_anchors
     rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
 
-    bbox = te.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
-                     predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
-                                     feature_stride, rpn_min_size, iou_loss),
-                     dtype=bbox_pred.dtype)
-    score = te.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
-    sorted_index = te.extern([score.shape], [score],
-                             lambda ins, outs: argsort_ir(ins[0], outs[0]),
-                             dtype='int32')
-    sorted_bbox = te.compute((batch, rpn_pre_nms_top_n, 5),
-                             lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
-    nms_remove_mask = te.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
-                                lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
-                                dtype='bool')
-    nms_out = te.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
-                        lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
-                        dtype=sorted_bbox.dtype)
+    bbox = te.extern(
+        (batch, num_bbox, 5),
+        [cls_prob, bbox_pred, im_info],
+        lambda ins, outs: predict_bbox_ir(
+            ins[0], ins[1], ins[2], outs[0], scales, ratios, feature_stride, rpn_min_size, iou_loss
+        ),
+        dtype=bbox_pred.dtype,
+    )
+    score = te.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag="bbox_score")
+    sorted_index = te.extern(
+        [score.shape], [score], lambda ins, outs: argsort_ir(ins[0], outs[0]), dtype="int32"
+    )
+    sorted_bbox = te.compute(
+        (batch, rpn_pre_nms_top_n, 5),
+        lambda b, i, j: bbox[b, sorted_index[b, i], j],
+        tag="sorted_bbox",
+    )
+    nms_remove_mask = te.extern(
+        (batch, rpn_pre_nms_top_n),
+        [sorted_bbox],
+        lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
+        dtype="bool",
+    )
+    nms_out = te.extern(
+        (batch * rpn_post_nms_top_n, 5),
+        [sorted_bbox, nms_remove_mask],
+        lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
+        dtype=sorted_bbox.dtype,
+    )
     return nms_out
index 664ea44..ee868ac 100644 (file)
@@ -22,6 +22,7 @@ from tvm import te
 from .. import tag
 from .injective import schedule_injective_from_existing
 
+
 def _schedule_reduce(op, sch, is_idx_reduce=False):
     if is_idx_reduce:
         data_out = op.input_tensors[0]
@@ -49,8 +50,9 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
 
     # Fuse and refactor the reduce axis
-    fused_reduce = sch[data_out].fuse(*[sch[data_out].op.reduce_axis[i]
-                                        for i in range(len(sch[data_out].op.reduce_axis))])
+    fused_reduce = sch[data_out].fuse(
+        *[sch[data_out].op.reduce_axis[i] for i in range(len(sch[data_out].op.reduce_axis))]
+    )
     ko, ki = sch[data_out].split(fused_reduce, factor=num_thread)
     if is_idx_reduce:
         data_out_rf, _ = sch.rfactor(data_out, ki)
@@ -67,8 +69,9 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         real_output = data_out
     if not all_reduce:
         # Fuse and split the axis
-        fused_outer = sch[real_output].fuse(*[sch[real_output].op.axis[i]
-                                              for i in range(len(sch[real_output].op.axis))])
+        fused_outer = sch[real_output].fuse(
+            *[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis))]
+        )
         bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread)
 
         # Bind the axes to threads and blocks
@@ -81,10 +84,8 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         if is_idx_reduce:
             spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis))
             sch[real_output].bind(spatial_axis, te.thread_axis("blockIdx.x"))
-            sch[temp_idx_input].compute_at(sch[real_output],
-                                           spatial_axis)
-            sch[temp_val_input].compute_at(sch[real_output],
-                                           spatial_axis)
+            sch[temp_idx_input].compute_at(sch[real_output], spatial_axis)
+            sch[temp_val_input].compute_at(sch[real_output], spatial_axis)
     sch[real_output].set_store_predicate(thread_x.equal(0))
     return sch
 
@@ -128,12 +129,12 @@ def schedule_reduce(outs):
                 schedule_injective_from_existing(sch, operator.output(0))
             for tensor in operator.input_tensors:
                 traverse_after_reduce(tensor.op)
-        elif operator.tag == 'comm_reduce':
+        elif operator.tag == "comm_reduce":
             _schedule_reduce(operator, sch, is_idx_reduce=False)
             for tensor in operator.input_tensors:
                 if tensor.op not in scheduled_ops:
                     traverse_before_reduce(tensor.op)
-        elif operator.tag == 'comm_reduce_idx':
+        elif operator.tag == "comm_reduce_idx":
             _schedule_reduce(operator, sch, is_idx_reduce=True)
             input_tensors = operator.input_tensors[0].op.input_tensors
             for tensor in input_tensors:
index dbd0325..99fbdd0 100644 (file)
@@ -43,17 +43,21 @@ def schedule_softmax(outs):
     tgt = Target.current(allow_none=False)
 
     op_tag = softmax.op.tag
-    if op_tag == 'softmax_output':
+    if op_tag == "softmax_output":
         expsum = softmax.op.input_tensors[1]
         exp = softmax.op.input_tensors[0]
         max_elem = s[exp].op.input_tensors[1]
-    elif op_tag == 'log_softmax_output':
+    elif op_tag == "log_softmax_output":
         exp = None
         max_elem = softmax.op.input_tensors[1]
         expsum = softmax.op.input_tensors[2]
     else:
-        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
-                         Got {0}'.format(op_tag))
+        raise ValueError(
+            "Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}".format(
+                op_tag
+            )
+        )
 
     # The nvptx and rocm backends only supports 32-bits warp shuffle
     # instructions.
index a8d1572..465299a 100644 (file)
@@ -24,9 +24,11 @@ from ..math import identity
 from ..transform import strided_slice, transpose
 from .. import tag
 
+
 def swap(arr, axis):
     """ swap arr[axis] and arr[-1] """
-    return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
+    return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
+
 
 def _schedule_sort(outs):
     """Schedule for argsort operator.
@@ -53,10 +55,12 @@ def _schedule_sort(outs):
             if tensor.op.input_tensors and tensor.op not in scheduled_ops:
                 traverse(tensor.op)
         scheduled_ops.append(op)
+
     for out in outs:
         traverse(out.op)
     return s
 
+
 def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
 
@@ -113,10 +117,10 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             with ib.if_scope(tid < shape[axis]):
                 values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
                 if indices_out is not None:
-                    indices_out[base_idx + tid * axis_mul_after] = \
-                        tvm.tir.generic.cast(tid, indices_out.dtype)
-    ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                         tvm.runtime.convert(['shared'])))
+                    indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast(
+                        tid, indices_out.dtype
+                    )
+    ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
 
@@ -129,11 +133,15 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                 with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
                     offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
                     if is_ascend:
-                        cond = tvm.tir.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                                           values_out[offset] > values_out[offset + axis_mul_after])
+                        cond = tvm.tir.all(
+                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
+                            values_out[offset] > values_out[offset + axis_mul_after],
+                        )
                     else:
-                        cond = tvm.tir.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                                           values_out[offset] < values_out[offset + axis_mul_after])
+                        cond = tvm.tir.all(
+                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
+                            values_out[offset] < values_out[offset + axis_mul_after],
+                        )
                     with ib.if_scope(cond):
                         temp_data[0] = values_out[offset]
                         values_out[offset] = values_out[offset + axis_mul_after]
@@ -142,8 +150,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                             temp_index[0] = indices_out[offset]
                             indices_out[offset] = indices_out[offset + axis_mul_after]
                             indices_out[offset + axis_mul_after] = temp_index[0]
-                ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                     tvm.runtime.convert(['shared'])))
+                ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
 
     return ib.get()
 
@@ -215,29 +222,37 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
             with ib.for_range(0, current_sort_num) as k:
                 with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
                     offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
-                    with ib.if_scope(tvm.tir.all(is_ascend == 1, \
-                                                 2 * tid + idxm(k, 2) + 1 < current_sort_num, \
-                                                 data[offset] > data[offset + axis_mul_after])):
+                    with ib.if_scope(
+                        tvm.tir.all(
+                            is_ascend == 1,
+                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
+                            data[offset] > data[offset + axis_mul_after],
+                        )
+                    ):
                         temp_data[0] = data[offset]
                         data[offset] = data[offset + axis_mul_after]
                         data[offset + axis_mul_after] = temp_data[0]
                         temp_index[0] = output[offset]
                         output[offset] = output[offset + axis_mul_after]
                         output[offset + axis_mul_after] = temp_index[0]
-                    with ib.if_scope(tvm.tir.all(is_ascend == 0, \
-                                                 2 * tid + idxm(k, 2) + 1 < current_sort_num, \
-                                                 data[offset] < data[offset + axis_mul_after])):
+                    with ib.if_scope(
+                        tvm.tir.all(
+                            is_ascend == 0,
+                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
+                            data[offset] < data[offset + axis_mul_after],
+                        )
+                    ):
                         temp_data[0] = data[offset]
                         data[offset] = data[offset + axis_mul_after]
                         data[offset + axis_mul_after] = temp_data[0]
                         temp_index[0] = output[offset]
                         output[offset] = output[offset + axis_mul_after]
                         output[offset + axis_mul_after] = temp_index[0]
-                ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                     tvm.runtime.convert(['shared'])))
+                ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
 
     return ib.get()
 
+
 def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
@@ -272,23 +287,26 @@ def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32")
         axes = swap(list(range(ndim)), axis)
         data = transpose(data, axes)
 
-    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
-                                   data_alignment=8)
-    valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
-                                          "valid_count_buf", data_alignment=4)
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    valid_count_buf = tvm.tir.decl_buffer(
+        valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4
+    )
     out_bufs = [
         tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
-        tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
+        tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8),
     ]
-    out = te.extern([data.shape, data.shape],
-                    [data, valid_count],
-                    lambda ins, outs: tvm.tir.call_packed(
-                        "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
-                    in_buffers=[data_buf, valid_count_buf],
-                    out_buffers=out_bufs,
-                    dtype=[data.dtype, "int32"],
-                    name="nms_argsort_gpu",
-                    tag="nms_argsort_gpu")
+    out = te.extern(
+        [data.shape, data.shape],
+        [data, valid_count],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend
+        ),
+        in_buffers=[data_buf, valid_count_buf],
+        out_buffers=out_bufs,
+        dtype=[data.dtype, "int32"],
+        name="nms_argsort_gpu",
+        tag="nms_argsort_gpu",
+    )
 
     if axis != ndim - 1:
         axes = swap(list(range(ndim)), axis)
@@ -296,6 +314,7 @@ def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32")
 
     return out[1]
 
+
 def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
@@ -324,32 +343,37 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """
     if valid_count is not None:
         sorted_data = identity(data)
-        sorted_data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf",
-                                              data_alignment=8)
-        valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
-                                              "valid_count_buf", data_alignment=4)
+        sorted_data_buf = tvm.tir.decl_buffer(
+            data.shape, data.dtype, "sorted_data_buf", data_alignment=8
+        )
+        valid_count_buf = tvm.tir.decl_buffer(
+            valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4
+        )
         out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
-        out = te.extern([data.shape],
-                        [sorted_data, valid_count],
-                        lambda ins, outs: sort_nms_ir(
-                            ins[0], ins[1], outs[0], axis, is_ascend),
-                        dtype="int32",
-                        in_buffers=[sorted_data_buf, valid_count_buf],
-                        out_buffers=[out_buf],
-                        name="argsort_nms_gpu",
-                        tag="argsort_nms_gpu")
+        out = te.extern(
+            [data.shape],
+            [sorted_data, valid_count],
+            lambda ins, outs: sort_nms_ir(ins[0], ins[1], outs[0], axis, is_ascend),
+            dtype="int32",
+            in_buffers=[sorted_data_buf, valid_count_buf],
+            out_buffers=[out_buf],
+            name="argsort_nms_gpu",
+            tag="argsort_nms_gpu",
+        )
     else:
         value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
         indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
-        out = te.extern([data.shape, data.shape],
-                        [data],
-                        lambda ins, outs: sort_ir(
-                            ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
-                        out_buffers=[value_buf, indices_buf],
-                        name="argsort_gpu",
-                        tag="argsort_gpu")[1]
+        out = te.extern(
+            [data.shape, data.shape],
+            [data],
+            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
+            out_buffers=[value_buf, indices_buf],
+            name="argsort_gpu",
+            tag="argsort_gpu",
+        )[1]
     return out
 
+
 def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
@@ -399,6 +423,7 @@ def schedule_argsort(outs):
     """
     return _schedule_sort(outs)
 
+
 def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     """Get the top k elements in an input tensor along the given axis.
 
@@ -437,21 +462,23 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
     indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
     if ret_type == "values":
-        output = te.extern([data.shape],
-                           [data],
-                           lambda ins, outs: sort_ir(
-                               ins[0], outs[0], axis, is_ascend),
-                           out_buffers=[values_buf],
-                           name="topk_gpu",
-                           tag="topk_gpu")
+        output = te.extern(
+            [data.shape],
+            [data],
+            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend),
+            out_buffers=[values_buf],
+            name="topk_gpu",
+            tag="topk_gpu",
+        )
     else:
-        output = te.extern([data.shape, data.shape],
-                           [data],
-                           lambda ins, outs: sort_ir(
-                               ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
-                           out_buffers=[values_buf, indices_buf],
-                           name="topk_gpu",
-                           tag="topk_gpu")
+        output = te.extern(
+            [data.shape, data.shape],
+            [data],
+            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
+            out_buffers=[values_buf, indices_buf],
+            name="topk_gpu",
+            tag="topk_gpu",
+        )
     if k < 1:
         if ret_type == "indices":
             return output[1]
@@ -470,7 +497,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
         output = [values_out, indices_out]
     elif ret_type == "values":
         output = [strided_slice(output, beg, end)]
-    else: # ret_type == "indices"
+    else:  # ret_type == "indices"
         indices_out = output[1]
         output = [strided_slice(indices_out, beg, end)]
     return output
@@ -519,17 +546,20 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
     out_bufs = [
         tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
-        tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
+        tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8),
     ]
 
-    out = te.extern([data.shape, data.shape],
-                    [data],
-                    lambda ins, outs: tvm.tir.call_packed(
-                        "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend),
-                    in_buffers=[data_buf],
-                    out_buffers=out_bufs,
-                    name="topk_gpu",
-                    tag="topk_gpu")
+    out = te.extern(
+        [data.shape, data.shape],
+        [data],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
+        ),
+        in_buffers=[data_buf],
+        out_buffers=out_bufs,
+        name="topk_gpu",
+        tag="topk_gpu",
+    )
 
     if k > 0:
         beg = [0] * ndim
index 5b57000..d1d31a6 100644 (file)
@@ -63,6 +63,7 @@ def schedule_sparse_dense(cfg, outs):
     """Create schedule for sparse dense"""
     # pylint:disable=invalid-name
     s = te.create_schedule([x.op for x in outs])
+
     def _callback(op):
         if op.tag == "sparse_dense_bsrmm":
             y_bsrmm = op.input_tensors[0]
@@ -85,7 +86,7 @@ def schedule_sparse_dense(cfg, outs):
             cfg.define_split("tile_c", c, num_outputs=2)
             if cfg.is_fallback:
                 cfg["tile_c"] = SplitEntity([-1, 8])
-            _, ci = cfg['tile_c'].apply(s, y_bsrmm, c)
+            _, ci = cfg["tile_c"].apply(s, y_bsrmm, c)
 
             y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
             tx = s[y_bsrmm].op.reduce_axis[0]
index 541af06..dbff63c 100644 (file)
@@ -54,8 +54,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
     stmt : Stmt
         The result IR statement.
     """
-    max_threads = int(math.sqrt(
-        tvm.target.Target.current(allow_none=False).max_num_threads))
+    max_threads = int(math.sqrt(tvm.target.Target.current(allow_none=False).max_num_threads))
     tx = te.thread_axis("threadIdx.x")
     ty = te.thread_axis("threadIdx.y")
     bx = te.thread_axis("blockIdx.x")
@@ -89,15 +88,25 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
             center_w = (j + offset_w) * steps_w
 
             for k in range(num_sizes + num_ratios - 1):
-                w = if_then_else(k < num_sizes,
-                                 float(size_ratio_concat[k]) * in_height / in_width / 2.0,
-                                 float(size_ratio_concat[0]) * in_height / in_width *
-                                 math.sqrt(size_ratio_concat[k + 1]) / 2.0)
+                w = if_then_else(
+                    k < num_sizes,
+                    float(size_ratio_concat[k]) * in_height / in_width / 2.0,
+                    float(size_ratio_concat[0])
+                    * in_height
+                    / in_width
+                    * math.sqrt(size_ratio_concat[k + 1])
+                    / 2.0,
+                )
                 h = if_then_else(
-                    k < num_sizes, size_ratio_concat[k] / 2.0,
-                    size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
-                count = (i * in_width * (num_sizes + num_ratios - 1) +
-                         j * (num_sizes + num_ratios - 1) + k) * 4
+                    k < num_sizes,
+                    size_ratio_concat[k] / 2.0,
+                    size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0,
+                )
+                count = (
+                    i * in_width * (num_sizes + num_ratios - 1)
+                    + j * (num_sizes + num_ratios - 1)
+                    + k
+                ) * 4
                 p_out[count] = center_w - w
                 p_out[count + 1] = center_h - h
                 p_out[count + 2] = center_w + w
@@ -107,8 +116,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
     return body
 
 
-def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
-                   offsets=(0.5, 0.5), clip=False):
+def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
     """Generate prior(anchor) boxes from data, sizes and ratios.
 
     Parameters
@@ -138,16 +146,18 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
     """
     num_sizes = len(sizes)
     num_ratios = len(ratios)
-    oshape = (
-        1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
-    out = te.extern(oshape, [data], lambda ins, outs:
-                    multibox_prior_ir(
-                        ins[0], outs[0], sizes, ratios, steps, offsets),
-                    tag="multibox_prior")
+    oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
+    out = te.extern(
+        oshape,
+        [data],
+        lambda ins, outs: multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
+        tag="multibox_prior",
+    )
     if clip:
         out = topi.clip(out, 0, 1)
     return out
 
+
 def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold):
     """Low level IR routing for transform location data preparation.
 
@@ -192,7 +202,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
 
     max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
-    nthread_bx = (batch_size *  num_anchors) // max_threads + 1
+    nthread_bx = (batch_size * num_anchors) // max_threads + 1
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
@@ -221,14 +231,26 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
         with ib.if_scope(tid < batch_size):
             with ib.for_range(0, num_anchors) as k:
                 with ib.if_scope(k > 0):
-                    temp_valid_count[tid * num_anchors + k] += \
-                        temp_valid_count[tid * num_anchors + k - 1]
+                    temp_valid_count[tid * num_anchors + k] += temp_valid_count[
+                        tid * num_anchors + k - 1
+                    ]
             valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1]
 
     return ib.get()
 
-def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \
-                     clip, variances, batch_size, num_anchors):
+
+def transform_loc_ir(
+    loc_pred,
+    anchor,
+    temp_valid_count,
+    temp_cls_id,
+    temp_score,
+    out,
+    clip,
+    variances,
+    batch_size,
+    num_anchors,
+):
     """Low level IR routing for transform location in multibox_detection operator.
 
     Parameters
@@ -268,9 +290,9 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
     stmt : Stmt
         The result IR statement.
     """
+
     def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh):
-        """Transform prior anchor box to output box through location predictions.
-        """
+        """Transform prior anchor box to output box through location predictions."""
         al = anchor[anchor_base_idx]
         at = anchor[anchor_base_idx + 1]
         ar = anchor[anchor_base_idx + 2]
@@ -287,10 +309,12 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
         oy = py * vy * ah + ay
         ow = exp(pw * vw) * aw / 2.0
         oh = exp(ph * vh) * ah / 2.0
-        return tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, ox - ow)), ox - ow), \
-            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, oy - oh)), oy - oh), \
-            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, ox + ow)), ox + ow), \
-            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, oy + oh)), oy + oh)
+        return (
+            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, ox - ow)), ox - ow),
+            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, oy - oh)), oy - oh),
+            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, ox + ow)), ox + ow),
+            tvm.tir.if_then_else(clip, tvm.te.max(0.0, tvm.te.min(1.0, oy + oh)), oy + oh),
+        )
 
     ib = tvm.tir.ir_builder.create()
 
@@ -322,26 +346,49 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
                 out_base_idx = i * num_anchors * 6
                 out_loc[out_base_idx] = cls_id[tid] - 1.0
                 out_loc[out_base_idx + 1] = score[tid]
-                out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
-                    out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
-                                                              anchor, j * 4, clip, variances[0],
-                                                              variances[1], variances[2],
-                                                              variances[3])
+                (
+                    out_loc[out_base_idx + 2],
+                    out_loc[out_base_idx + 3],
+                    out_loc[out_base_idx + 4],
+                    out_loc[out_base_idx + 5],
+                ) = transform_loc(
+                    loc_pred,
+                    tid * 4,
+                    anchor,
+                    j * 4,
+                    clip,
+                    variances[0],
+                    variances[1],
+                    variances[2],
+                    variances[3],
+                )
             with ib.else_scope():
                 out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6
                 out_loc[out_base_idx] = cls_id[tid] - 1.0
                 out_loc[out_base_idx + 1] = score[tid]
-                out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
-                    out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
-                                                              anchor, j * 4, clip, variances[0],
-                                                              variances[1], variances[2],
-                                                              variances[3])
+                (
+                    out_loc[out_base_idx + 2],
+                    out_loc[out_base_idx + 3],
+                    out_loc[out_base_idx + 4],
+                    out_loc[out_base_idx + 5],
+                ) = transform_loc(
+                    loc_pred,
+                    tid * 4,
+                    anchor,
+                    j * 4,
+                    clip,
+                    variances[0],
+                    variances[1],
+                    variances[2],
+                    variances[3],
+                )
 
     return ib.get()
 
 
-def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, \
-                           threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
+def multibox_transform_loc(
+    cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
+):
     """Location transformation for multibox detection
 
     Parameters
@@ -381,46 +428,105 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, \
     valid_count_dtype = "int32"
     out_loc_dtype = loc_pred.dtype
 
-    valid_count_buf = tvm.tir.decl_buffer((batch_size,), valid_count_dtype,
-                                          "valid_count_buf", data_alignment=4)
-    loc_pred_buf = tvm.tir.decl_buffer(loc_pred.shape, loc_pred.dtype,
-                                       "loc_pred_buf", data_alignment=8)
-    anchor_buf = tvm.tir.decl_buffer(anchor.shape, anchor.dtype,
-                                     "anchor_buf", data_alignment=8)
+    valid_count_buf = tvm.tir.decl_buffer(
+        (batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4
+    )
+    loc_pred_buf = tvm.tir.decl_buffer(
+        loc_pred.shape, loc_pred.dtype, "loc_pred_buf", data_alignment=8
+    )
+    anchor_buf = tvm.tir.decl_buffer(anchor.shape, anchor.dtype, "anchor_buf", data_alignment=8)
 
     temp_valid_count_buf = tvm.tir.decl_buffer(
-        (batch_size, num_anchors,), valid_count_dtype, "temp_valid_count", data_alignment=8)
+        (
+            batch_size,
+            num_anchors,
+        ),
+        valid_count_dtype,
+        "temp_valid_count",
+        data_alignment=8,
+    )
     temp_cls_id_buf = tvm.tir.decl_buffer(
-        (batch_size, num_anchors,), valid_count_dtype, "temp_cls_id", data_alignment=8)
+        (
+            batch_size,
+            num_anchors,
+        ),
+        valid_count_dtype,
+        "temp_cls_id",
+        data_alignment=8,
+    )
     temp_score_buf = tvm.tir.decl_buffer(
-        (batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8)
-
-    valid_count, temp_valid_count, temp_cls_id, temp_score = \
-        te.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \
-                   (batch_size, num_anchors,)], [cls_prob],
-                  lambda ins, outs: transform_loc_pre(
-                      ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
-                  dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype],
-                  out_buffers=[valid_count_buf, temp_valid_count_buf, \
-                               temp_cls_id_buf, temp_score_buf],
-                  tag="multibox_transform_loc_phase_one")
-
-    out_loc = \
-        te.extern([oshape],
-                  [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score],
-                  lambda ins, outs: transform_loc_ir(
-                      ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \
-                      batch_size, num_anchors),
-                  in_buffers=[loc_pred_buf, anchor_buf, temp_valid_count_buf, \
-                              temp_cls_id_buf, temp_score_buf],
-                  dtype=[out_loc_dtype],
-                  tag="multibox_transform_loc")
+        (
+            batch_size,
+            num_anchors,
+        ),
+        cls_prob.dtype,
+        "temp_score",
+        data_alignment=8,
+    )
+
+    valid_count, temp_valid_count, temp_cls_id, temp_score = te.extern(
+        [
+            (batch_size,),
+            (
+                batch_size,
+                num_anchors,
+            ),
+            (
+                batch_size,
+                num_anchors,
+            ),
+            (
+                batch_size,
+                num_anchors,
+            ),
+        ],
+        [cls_prob],
+        lambda ins, outs: transform_loc_pre(ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
+        dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype],
+        out_buffers=[valid_count_buf, temp_valid_count_buf, temp_cls_id_buf, temp_score_buf],
+        tag="multibox_transform_loc_phase_one",
+    )
+
+    out_loc = te.extern(
+        [oshape],
+        [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score],
+        lambda ins, outs: transform_loc_ir(
+            ins[0],
+            ins[1],
+            ins[2],
+            ins[3],
+            ins[4],
+            outs[0],
+            clip,
+            variances,
+            batch_size,
+            num_anchors,
+        ),
+        in_buffers=[
+            loc_pred_buf,
+            anchor_buf,
+            temp_valid_count_buf,
+            temp_cls_id_buf,
+            temp_score_buf,
+        ],
+        dtype=[out_loc_dtype],
+        tag="multibox_transform_loc",
+    )
 
     return [out_loc, valid_count]
 
 
-def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
-                       force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
+def multibox_detection(
+    cls_prob,
+    loc_pred,
+    anchor,
+    clip=True,
+    threshold=0.01,
+    nms_threshold=0.5,
+    force_suppress=False,
+    variances=(0.1, 0.1, 0.2, 0.2),
+    nms_topk=-1,
+):
     """Convert multibox detection predictions.
 
     Parameters
@@ -457,9 +563,15 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
     out : tvm.te.Tensor
         3-D tensor with shape (batch_size, num_anchors, 6)
     """
-    inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
-                                       clip, threshold, variances)
-    out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1,
-                              iou_threshold=nms_threshold, force_suppress=force_suppress,
-                              top_k=nms_topk, return_indices=False)
+    inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances)
+    out = non_max_suppression(
+        inter_out[0],
+        inter_out[1],
+        inter_out[1],
+        max_output_size=-1,
+        iou_threshold=nms_threshold,
+        force_suppress=force_suppress,
+        top_k=nms_topk,
+        return_indices=False,
+    )
     return out
index c2b7d25..499f4b3 100644 (file)
@@ -20,7 +20,7 @@ import tvm
 from tvm import te
 
 
-def dp4a(x_scope='local', y_scope='local', z_scope='local'):
+def dp4a(x_scope="local", y_scope="local", z_scope="local"):
     """
     Int8 dot product reduced by every 4 elements using __dp4a
 
@@ -40,13 +40,12 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'):
     """
 
     n = 4  # dp4a requires operands packed by 4
-    x = te.placeholder((n,), name='x', dtype='int8')
-    y = te.placeholder((n,), name='y', dtype='int8')
+    x = te.placeholder((n,), name="x", dtype="int8")
+    y = te.placeholder((n,), name="y", dtype="int8")
 
-    k = te.reduce_axis((0, n), name='rc')
+    k = te.reduce_axis((0, n), name="rc")
 
-    z = te.compute((1,), lambda i: te.sum(
-        x[k].astype('int32') * y[k].astype('int32'), axis=[k]))
+    z = te.compute((1,), lambda i: te.sum(x[k].astype("int32") * y[k].astype("int32"), axis=[k]))
 
     def _intrin_func(ins, outs):
         def _instr(index):
@@ -58,40 +57,48 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'):
 
             ib = tvm.tir.ir_builder.create()
 
-            vec_x = xx.vload(0, dtype='int8x4')
-            vec_y = yy.vload(0, dtype='int8x4')
+            vec_x = xx.vload(0, dtype="int8x4")
+            vec_y = yy.vload(0, dtype="int8x4")
             prev_z = 0 if index == 0 else zz.vload(0)
 
-            new_z = tvm.tir.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z)
+            new_z = tvm.tir.call_pure_extern("int32", "__dp4a", vec_x, vec_y, prev_z)
             ib.emit(zz.vstore(0, new_z))
 
             return ib.get()
 
-        return _instr(0), _instr(1), _instr(2) # body, reset, update
+        return _instr(0), _instr(1), _instr(2)  # body, reset, update
 
-    default_buffer_params = {
-        "data_alignment": 4, "offset_factor": 1
-    }
+    default_buffer_params = {"data_alignment": 4, "offset_factor": 1}
     scopes = {x: x_scope, y: y_scope, z: z_scope}
-    binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
-                                    scope=scopes[t], **default_buffer_params) for t in [x, y, z]}
+    binds = {
+        t: tvm.tir.decl_buffer(
+            t.shape, t.dtype, t.op.name, scope=scopes[t], **default_buffer_params
+        )
+        for t in [x, y, z]
+    }
 
     return te.decl_tensor_intrin(
-        z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params)
+        z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params
+    )
 
 
 def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
     """Intrin function for loading data from shared memory to wmma.matrix_a"""
     wmma_m, wmma_n, wmma_k = shape
 
-    A = te.placeholder(A_shape, name='A', dtype=in_dtype)
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype,
-                             scope='shared', strides=strides_from,
-                             data_alignment=32, offset_factor=8)
-    C = te.compute(C_shape, lambda *i: A(*i), name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype,
-                             scope="wmma.matrix_a", strides=strides_dst,
-                             data_alignment=32, offset_factor=8)
+    A = te.placeholder(A_shape, name="A", dtype=in_dtype)
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, scope="shared", strides=strides_from, data_alignment=32, offset_factor=8
+    )
+    C = te.compute(C_shape, lambda *i: A(*i), name="C")
+    BC = tvm.tir.decl_buffer(
+        C.shape,
+        C.dtype,
+        scope="wmma.matrix_a",
+        strides=strides_dst,
+        data_alignment=32,
+        offset_factor=8,
+    )
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
@@ -100,9 +107,20 @@ def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape,
         BC = outs[0]
         row = wmma_m * wmma_k
         warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
-                                    BC.data, wmma_m, wmma_n, wmma_k, warp_index,
-                                    BA.access_ptr('r'), strides_from[0], layout))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_load_matrix_sync",
+                BC.data,
+                wmma_m,
+                wmma_n,
+                wmma_k,
+                warp_index,
+                BA.access_ptr("r"),
+                strides_from[0],
+                layout,
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
@@ -112,14 +130,19 @@ def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape,
     """Intrin function for loading data from shared memory to wmma.matrix_b"""
     wmma_m, wmma_n, wmma_k = shape
 
-    A = te.placeholder(A_shape, name='A', dtype=in_dtype)
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype,
-                             scope='shared', strides=strides_from,
-                             data_alignment=32, offset_factor=8)
-    C = te.compute(C_shape, lambda *i: A(*i), name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype,
-                             scope="wmma.matrix_b", strides=strides_dst,
-                             data_alignment=32, offset_factor=8)
+    A = te.placeholder(A_shape, name="A", dtype=in_dtype)
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, scope="shared", strides=strides_from, data_alignment=32, offset_factor=8
+    )
+    C = te.compute(C_shape, lambda *i: A(*i), name="C")
+    BC = tvm.tir.decl_buffer(
+        C.shape,
+        C.dtype,
+        scope="wmma.matrix_b",
+        strides=strides_dst,
+        data_alignment=32,
+        offset_factor=8,
+    )
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
@@ -128,9 +151,20 @@ def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape,
         BC = outs[0]
         row = wmma_n * wmma_k
         warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
-                                    BC.data, wmma_m, wmma_n, wmma_k, warp_index,
-                                    BA.access_ptr('r'), strides_from[0], layout))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_load_matrix_sync",
+                BC.data,
+                wmma_m,
+                wmma_n,
+                wmma_k,
+                warp_index,
+                BA.access_ptr("r"),
+                strides_from[0],
+                layout,
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
@@ -139,15 +173,19 @@ def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape,
 def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shape, C_shape):
     """Intrin function for storing the results from wmma.accumulator to shared"""
     wmma_m, wmma_n, wmma_k = shape
-    A = te.placeholder(A_shape, name='A', dtype=out_dtype)
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype,
-                             scope='wmma.accumulator',
-                             strides=strides_from, data_alignment=32,
-                             offset_factor=8)
-    C = te.compute(C_shape, lambda *i: A(*i), name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype,
-                             scope='shared', strides=strides_dst,
-                             data_alignment=32, offset_factor=8)
+    A = te.placeholder(A_shape, name="A", dtype=out_dtype)
+    BA = tvm.tir.decl_buffer(
+        A.shape,
+        A.dtype,
+        scope="wmma.accumulator",
+        strides=strides_from,
+        data_alignment=32,
+        offset_factor=8,
+    )
+    C = te.compute(C_shape, lambda *i: A(*i), name="C")
+    BC = tvm.tir.decl_buffer(
+        C.shape, C.dtype, scope="shared", strides=strides_dst, data_alignment=32, offset_factor=8
+    )
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
@@ -156,16 +194,26 @@ def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shap
         BC = outs[0]
         row = wmma_m * wmma_n
         warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
-                                    BA.data, wmma_m, wmma_n, wmma_k, warp_index,
-                                    BC.access_ptr('w'), strides_dst[0], 'row_major'))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_store_matrix_sync",
+                BA.data,
+                wmma_m,
+                wmma_n,
+                wmma_k,
+                warp_index,
+                BC.access_ptr("w"),
+                strides_dst[0],
+                "row_major",
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
 
 
-def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A,
-                     strides_W, strides_Conv, shape):
+def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, strides_W, strides_Conv, shape):
     """Intrin for wmma fill_fragment and mma_sync
 
     Parameters
@@ -182,19 +230,37 @@ def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A,
     B = WL_gemm
     C = CL_compute
 
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA',
-                             scope='wmma.matrix_a', data_alignment=32,
-                             offset_factor=8, strides=strides_A)
-    BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB',
-                             scope='wmma.matrix_b', data_alignment=32,
-                             offset_factor=8, strides=strides_W)
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC',
-                             scope='wmma.accumulator', data_alignment=32,
-                             offset_factor=8, strides=strides_Conv)
+    BA = tvm.tir.decl_buffer(
+        A.shape,
+        A.dtype,
+        name="BA",
+        scope="wmma.matrix_a",
+        data_alignment=32,
+        offset_factor=8,
+        strides=strides_A,
+    )
+    BB = tvm.tir.decl_buffer(
+        B.shape,
+        B.dtype,
+        name="BB",
+        scope="wmma.matrix_b",
+        data_alignment=32,
+        offset_factor=8,
+        strides=strides_W,
+    )
+    BC = tvm.tir.decl_buffer(
+        C.shape,
+        C.dtype,
+        name="BC",
+        scope="wmma.accumulator",
+        data_alignment=32,
+        offset_factor=8,
+        strides=strides_Conv,
+    )
 
     def intrin_func(ins, outs):
         BA, BB = ins
-        BC, = outs
+        (BC,) = outs
 
         def warp_idnex(offset, row, col):
             row = row * col
@@ -207,18 +273,35 @@ def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A,
         def init():
             ib = tvm.tir.ir_builder.create()
             ib.emit(
-                tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment',
-                                    BC.data, wmma_m, wmma_n, wmma_k,
-                                    warp_index_C, 0.0))
+                tvm.tir.call_intrin(
+                    "handle",
+                    "tir.tvm_fill_fragment",
+                    BC.data,
+                    wmma_m,
+                    wmma_n,
+                    wmma_k,
+                    warp_index_C,
+                    0.0,
+                )
+            )
             return ib.get()
 
         def update():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
-                                        BC.data, warp_index_C,
-                                        BA.data, warp_index_A,
-                                        BB.data, warp_index_B,
-                                        BC.data, warp_index_C))
+            ib.emit(
+                tvm.tir.call_intrin(
+                    "handle",
+                    "tir.tvm_mma_sync",
+                    BC.data,
+                    warp_index_C,
+                    BA.data,
+                    warp_index_A,
+                    BB.data,
+                    warp_index_B,
+                    BC.data,
+                    warp_index_C,
+                )
+            )
             return ib.get()
 
         return update(), init(), update()
index 7536654..73b24de 100644 (file)
@@ -24,21 +24,25 @@ from .. import tag
 from .pooling import schedule_pool
 from .injective import schedule_injective_from_existing
 
+
 def _default_schedule(outs):
     """Default schedule for gpu."""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
     scheduled_ops = []
+
     def traverse(op):
-        if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
+        if tag.is_broadcast(op.tag) or op.tag in ["bbox_score", "sorted_bbox"]:
             schedule_injective_from_existing(s, op.output(0))
         for tensor in op.input_tensors:
             if tensor.op.input_tensors and tensor.op not in scheduled_ops:
                 traverse(tensor.op)
         scheduled_ops.append(op)
+
     traverse(outs[0].op)
     return s
 
+
 def schedule_reorg(outs):
     """Schedule for reorg operator.
     Parameters
@@ -56,6 +60,7 @@ def schedule_reorg(outs):
     cpp_target = cpp.TEST_create_target(target.kind.name)
     return cpp.cuda.schedule_injective(cpp_target, outs)
 
+
 def schedule_nms(outs):
     """Schedule for non-maximum suppression
 
@@ -72,6 +77,7 @@ def schedule_nms(outs):
     """
     return _default_schedule(outs)
 
+
 def schedule_multibox_prior(outs):
     """Schedule for multibox_prior operator.
 
@@ -88,6 +94,7 @@ def schedule_multibox_prior(outs):
     """
     return _default_schedule(outs)
 
+
 def schedule_multibox_transform_loc(outs):
     """Schedule for multibox_transform_loc
 
@@ -105,6 +112,7 @@ def schedule_multibox_transform_loc(outs):
     """
     return _default_schedule(outs)
 
+
 def schedule_multibox_detection(outs):
     """Schedule for multibox_detection operator.
 
@@ -121,11 +129,14 @@ def schedule_multibox_detection(outs):
     """
     return _default_schedule(outs)
 
+
 def schedule_roi_align(outs):
-    return schedule_pool(outs, 'NCHW')
+    return schedule_pool(outs, "NCHW")
+
 
 def schedule_roi_pool(outs):
-    return schedule_pool(outs, 'NCHW')
+    return schedule_pool(outs, "NCHW")
+
 
 def schedule_proposal(outs):
     """Schedule for proposal operator.
@@ -143,6 +154,7 @@ def schedule_proposal(outs):
     """
     return _default_schedule(outs)
 
+
 def schedule_get_valid_counts(outs):
     """Schedule for get_valid_counts operator.
 
index 2d9f78b..122d7d2 100644 (file)
@@ -22,6 +22,7 @@ from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from ..util import get_const_tuple
 
+
 def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
     """Fallback schedule for conv2d int8 on cpu.
     Normally the inner most pattern takes two int8/uint8 tensors
@@ -41,10 +42,14 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
     HSTR, WSTR = wkl.hstride, wkl.wstride
     out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
 
-    assert wkl.out_filter % int32_lanes == 0, \
-        "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
-    assert wkl.in_filter % num_int8_elements == 0, \
-        "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
+    assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
+        wkl.out_filter,
+        int32_lanes,
+    )
+    assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % (
+        wkl.in_filter,
+        num_int8_elements,
+    )
 
     oc_bn = int32_lanes
     ic_bn = 1
@@ -85,10 +90,14 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
     out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
     out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
 
-    assert wkl.out_filter % int32_lanes == 0, \
-        "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
-    assert wkl.in_filter % num_int8_elements == 0, \
-        "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
+    assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
+        wkl.out_filter,
+        int32_lanes,
+    )
+    assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % (
+        wkl.in_filter,
+        num_int8_elements,
+    )
 
     oc_bn = int32_lanes
     ic_bn = 1
@@ -109,8 +118,9 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
     raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
 
 
-def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
-                                        last, int32_lanes=16, intrin=None):
+def schedule_conv_NCHWc_cpu_common_int8(
+    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+):
     """
     Defines the schedule for INT8 for Intel and ARM machines
     Uses the Intel/ARM intrinsics to use INT8 operations
@@ -122,8 +132,7 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
     _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
 
     # schedule pad
-    if isinstance(s[data_vec].op, te.tensor.ComputeOp) \
-            and "pad" in data_vec.op.tag:
+    if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag:
         batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
         s[data_vec].parallel(parallel_axis)
@@ -135,8 +144,7 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
         # this part will be folded during Relay fold_constant pass.
         s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
         s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
-    elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and \
-            kernel_vec.name == 'kernel_vec':
+    elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and kernel_vec.name == "kernel_vec":
         # data and kernel are not pre-computed, schedule layout transform here.
         # this should only be used by x86 conv2d_nchw, which is for
         # testing purpose.
@@ -155,7 +163,7 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
 
     # schedule 5-D NCHW[x]c conv
     C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
@@ -177,12 +185,34 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
     oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
 
     if unroll_kw:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
-                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+        s[CC].reorder(
+            oc_chunk,
+            oh,
+            ow_chunk,
+            ic_outer,
+            kh,
+            ic_f_inner,
+            kw,
+            ow_block,
+            oc_f_inner,
+            oc_s_inner,
+            ic_s_inner,
+        )
         s[CC].unroll(kw)
     else:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
-                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+        s[CC].reorder(
+            oc_chunk,
+            oh,
+            ow_chunk,
+            ic_outer,
+            kh,
+            kw,
+            ic_f_inner,
+            ow_block,
+            oc_f_inner,
+            oc_s_inner,
+            ic_s_inner,
+        )
 
     if intrin is not None:
         s[CC].tensorize(oc_s_inner, intrin)
@@ -213,8 +243,10 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
 
     return s
 
-def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
-                                     last, int32_lanes=16, intrin=None):
+
+def schedule_conv_NCHWc_cpu_1x1_int8(
+    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+):
     """
     Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
     Uses the Intel/ARM intrinsics to use INT8 operations
@@ -226,8 +258,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
     _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
 
     # schedule pad
-    if isinstance(s[data_vec].op, te.tensor.ComputeOp) \
-            and "pad" in data_vec.op.tag:
+    if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag:
         batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
         s[data_vec].parallel(parallel_axis)
@@ -239,8 +270,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
         # this part will be folded during Relay fold_constant pass.
         s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
         s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
-    elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and \
-            kernel_vec.name == 'kernel_vec':
+    elif isinstance(kernel_vec.op, te.tensor.ComputeOp) and kernel_vec.name == "kernel_vec":
         # data and kernel are not pre-computed, schedule layout transform here.
         # this should only be used by x86 conv2d_nchw, which is for
         # testing purpose.
@@ -258,7 +288,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
         s[kernel_vec].parallel(parallel_axis)
 
     C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
@@ -282,8 +312,20 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
     oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
     ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
 
-    s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
-                  ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
+    s[CC].reorder(
+        oc_chunk,
+        oh_outer,
+        ow_outer,
+        kh,
+        kw,
+        ic_outer,
+        ic_f_inner,
+        oh_inner,
+        ow_inner,
+        oc_f_inner,
+        oc_s_inner,
+        ic_s_inner,
+    )
     s[CC].fuse(oc_chunk, oh_outer)
 
     if intrin is not None:
index 3b4feb7..cd6fd7a 100644 (file)
@@ -19,6 +19,7 @@
 import tvm
 from .. import cpp
 
+
 def schedule_extern(outs):
     """Schedule for an extern op followed by injective operations.
 
index 6360f8b..6b81098 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import te
 
+
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -39,6 +40,7 @@ def schedule_injective_from_existing(sch, out):
     sch[out].fuse(*sch[out].op.axis)
     return sch
 
+
 def schedule_injective(outs):
     """Schedule for injective op.
 
@@ -63,5 +65,6 @@ def schedule_injective(outs):
     schedule_injective_from_existing(s, x)
     return s
 
+
 schedule_elemwise = schedule_injective
 schedule_broadcast = schedule_injective
index 7645588..2b56249 100644 (file)
@@ -293,6 +293,7 @@ def schedule_conv3d_ncdhw(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_conv3d_ndhwc(outs):
     """Schedule for conv3d_ndhwc
 
index 1d5a30d..16e6a5d 100644 (file)
@@ -36,6 +36,7 @@ def schedule_argsort(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_topk(outs):
     """Schedule for topk operator.
 
index d0855a0..e7518b1 100644 (file)
@@ -40,6 +40,7 @@ def schedule_reorg(outs):
     cpp_target = cpp.TEST_create_target(target.kind.name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
+
 def schedule_get_valid_counts(outs):
     """Schedule for get_valid_counts
 
@@ -56,6 +57,7 @@ def schedule_get_valid_counts(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_nms(outs):
     """Schedule for non-maximum suppression
 
@@ -72,6 +74,7 @@ def schedule_nms(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_multibox_prior(outs):
     """Schedule for multibox_prior
 
@@ -88,6 +91,7 @@ def schedule_multibox_prior(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_multibox_transform_loc(outs):
     """Schedule for multibox_transform_loc
 
@@ -105,6 +109,7 @@ def schedule_multibox_transform_loc(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_multibox_detection(outs):
     """Schedule for multibox_detection
 
@@ -121,6 +126,7 @@ def schedule_multibox_detection(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_roi_align(outs):
     """Schedule for roi_align
 
@@ -137,6 +143,7 @@ def schedule_roi_align(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_roi_pool(outs):
     """Schedule for roi_align
 
@@ -153,6 +160,7 @@ def schedule_roi_pool(outs):
     """
     return _default_schedule(outs, False)
 
+
 def schedule_proposal(outs):
     """Schedule for proposal operator.
 
index f4695d3..661e24d 100644 (file)
@@ -82,6 +82,7 @@ def _make_bop(broadcast_bop, orig_bop):
         if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor):
             return orig_bop(lhs, rhs)
         return broadcast_bop(lhs, rhs)
+
     _tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name)
     return _tensor_bop_impl
 
@@ -98,4 +99,5 @@ def _bind_generic_ops():
         tvm.tir.generic.divide = _make_bop(_broadcast.divide, tvm.tir.generic.divide)
         tvm.tir.generic.cast = _math.cast
 
+
 _bind_generic_ops()
index 4c1fdf4..9319359 100644 (file)
@@ -19,6 +19,7 @@
 import tvm
 from tvm import te
 
+
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -39,6 +40,7 @@ def schedule_injective_from_existing(sch, out):
     sch[out].bind(px, te.thread_axis("pipeline"))
     return sch
 
+
 def schedule_injective(outs):
     """Schedule for injective op.
 
@@ -60,5 +62,6 @@ def schedule_injective(outs):
         schedule_injective_from_existing(s, out)
     return s
 
+
 schedule_elemwise = schedule_injective
 schedule_broadcast = schedule_injective
index 3d7ff82..b9053fe 100644 (file)
@@ -152,6 +152,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
     """
     return _schedule_conv2d(outs)
 
+
 def schedule_bitserial_conv2d_nchw(outs):
     """Schedule for bitserial_conv2d_nchw
 
@@ -253,17 +254,21 @@ def schedule_softmax(outs):
     softmax = outs[0]
 
     op_tag = softmax.op.tag
-    if op_tag == 'softmax_output':
+    if op_tag == "softmax_output":
         expsum = softmax.op.input_tensors[1]
         exp = softmax.op.input_tensors[0]
         max_elem = s[exp].op.input_tensors[1]
-    elif op_tag == 'log_softmax_output':
+    elif op_tag == "log_softmax_output":
         exp = None
         max_elem = softmax.op.input_tensors[1]
         expsum = softmax.op.input_tensors[2]
     else:
-        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
-                         Got {0}'.format(op_tag))
+        raise ValueError(
+            "Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}".format(
+                op_tag
+            )
+        )
 
     if exp is not None:
         s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
@@ -304,7 +309,7 @@ def schedule_dense(outs):
                 if isinstance(tensor.op, tvm.te.ComputeOp):
                     traverse(tensor.op)
         # schedule dense
-        elif OP.tag == 'dense':
+        elif OP.tag == "dense":
             Dense = OP.output(0)
             if not Dense.op in s.outputs:
                 Out = outs[0].op.output(0)
@@ -347,7 +352,7 @@ def schedule_pool(outs, layout):
                 if isinstance(tensor.op, tvm.te.ComputeOp):
                     traverse(tensor.op)
         # schedule pool
-        elif OP.tag.startswith('pool'):
+        elif OP.tag.startswith("pool"):
             Pool = OP.output(0)
             if not Pool.op in s.outputs:
                 Out = outs[0].op.output(0)
@@ -390,7 +395,7 @@ def schedule_adaptive_pool(outs):
                 if isinstance(tensor.op, tvm.te.ComputeOp):
                     traverse(tensor.op)
         # schedule global_pool
-        elif OP.tag.startswith('adaptive_pool'):
+        elif OP.tag.startswith("adaptive_pool"):
             Pool = OP.output(0)
             if not Pool.op in s.outputs:
                 Out = outs[0].op.output(0)
index dd16a21..b388782 100644 (file)
@@ -68,14 +68,16 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
 
     batch, in_channel, in_height, in_width = input.shape
     channel, kernel_h, kernel_w = filter.shape
-    assert in_channel.value == channel.value, \
-        "For Dilation2D input and filter channels should be same."
+    assert (
+        in_channel.value == channel.value
+    ), "For Dilation2D input and filter channels should be same."
 
     # compute the output shape
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
 
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
@@ -83,16 +85,20 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
     pad_before = [0, 0, pad_top, pad_left]
     pad_after = [0, 0, pad_down, pad_right]
     temp = pad(input, pad_before, pad_after, name="pad_temp")
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     return te.compute(
         (batch, in_channel, out_height, out_width),
         lambda nn, ff, yy, xx: te.max(
-            temp[nn, ff, yy * stride_h + ry * dilation_h,
-                 xx * stride_w + rx * dilation_w].astype(out_dtype) +
-            filter[ff, ry, rx].astype(out_dtype),
-            axis=[ry, rx]), tag="dilation2d_nchw")
+            temp[nn, ff, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
+                out_dtype
+            )
+            + filter[ff, ry, rx].astype(out_dtype),
+            axis=[ry, rx],
+        ),
+        tag="dilation2d_nchw",
+    )
 
 
 def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
@@ -139,27 +145,33 @@ def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
 
     batch, in_height, in_width, in_channel = input.shape
     kernel_h, kernel_w, channel = filter.shape
-    assert in_channel.value == channel.value, \
-        "For Dilation2D input and filter channels should be same."
+    assert (
+        in_channel.value == channel.value
+    ), "For Dilation2D input and filter channels should be same."
 
     # compute the output shape
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
 
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
     pad_before = [0, pad_top, pad_left, 0]
     pad_after = [0, pad_down, pad_right, 0]
     padded_input = pad(input, pad_before, pad_after, name="padded_input")
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     return te.compute(
         (batch, out_height, out_width, in_channel),
         lambda nn, yy, xx, ff: te.max(
-            padded_input[nn, yy * stride_h + ry * dilation_h,
-                         xx * stride_w + rx * dilation_w, ff].astype(out_dtype) +
-            filter[ry, rx, ff].astype(out_dtype),
-            axis=[ry, rx]), tag="dilation2d_nhcw")
+            padded_input[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ff
+            ].astype(out_dtype)
+            + filter[ry, rx, ff].astype(out_dtype),
+            axis=[ry, rx],
+        ),
+        tag="dilation2d_nhcw",
+    )
index 32b6112..19a69ef 100644 (file)
@@ -41,12 +41,13 @@ def affine_grid(data, target_shape):
     """
     assert target_shape is not None
     assert len(target_shape) == 2
-    assert target_shape[0] > 1 and target_shape[1] > 1, \
-        "target height/width should be greater than 1"
+    assert (
+        target_shape[0] > 1 and target_shape[1] > 1
+    ), "target height/width should be greater than 1"
 
     dtype = data.dtype
-    y_step = tir.const((2.0 - 1e-7)/ (target_shape[0] - 1), dtype=dtype)
-    x_step = tir.const((2.0 - 1e-7)/ (target_shape[1] - 1), dtype=dtype)
+    y_step = tir.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype)
+    x_step = tir.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype)
     start = tir.const(-1.0, dtype=dtype)
 
     def _compute(n, dim, i, j):
@@ -55,10 +56,10 @@ def affine_grid(data, target_shape):
         return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
 
     oshape = (data.shape[0], len(target_shape), *target_shape)
-    return te.compute(oshape, _compute, tag='affine_grid')
+    return te.compute(oshape, _compute, tag="affine_grid")
 
 
-def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+def grid_sample(data, grid, method="bilinear", layout="NCHW"):
     """Applies bilinear sampling to input feature map.
 
     Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by
@@ -99,26 +100,32 @@ def grid_sample(data, grid, method='bilinear', layout='NCHW'):
     """
     batch, in_channel, in_height, in_width = data.shape
     out_height, out_width = grid.shape[2:]
-    assert method == 'bilinear', "Only bilinear is supported"
+    assert method == "bilinear", "Only bilinear is supported"
     assert layout == "NCHW", "Only NCHW is supported"
 
     def _get_pixel_value(n, c, h, w):
-        return te.if_then_else(te.all(h >= 0, w >= 0, h < in_height, w < in_width),
-                               data[n, c, h, w], tir.const(0.0, dtype=data.dtype))
+        return te.if_then_else(
+            te.all(h >= 0, w >= 0, h < in_height, w < in_width),
+            data[n, c, h, w],
+            tir.const(0.0, dtype=data.dtype),
+        )
 
     def _bilinear_sample(n, c, h, w):
         x = grid[n, 0, h, w]
         y = grid[n, 1, h, w]
         y = (y + 1) * (in_height - 1) / 2
         x = (x + 1) * (in_width - 1) / 2
-        x0 = te.floor(x).astype('int32')
-        y0 = te.floor(y).astype('int32')
-        x1 = x0 + tir.const(1, 'int32')
-        y1 = y0 + tir.const(1, 'int32')
-        return _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) \
-            + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) \
-            + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0)) \
+        x0 = te.floor(x).astype("int32")
+        y0 = te.floor(y).astype("int32")
+        x1 = x0 + tir.const(1, "int32")
+        y1 = y0 + tir.const(1, "int32")
+        return (
+            _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0))
+            + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0)
+            + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0))
             + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0)
+        )
 
-    return te.compute((batch, in_channel, out_height, out_width), _bilinear_sample,
-                      tag='grid_sample')
+    return te.compute(
+        (batch, in_channel, out_height, out_width), _bilinear_sample, tag="grid_sample"
+    )
index b159723..ca99044 100644 (file)
@@ -22,13 +22,14 @@ from tvm import te
 from tvm.topi.util import nchw_pack_layout, nchw_xc_layout
 from .. import tag
 
-def get_2d_indices(indices, layout='NCHW'):
+
+def get_2d_indices(indices, layout="NCHW"):
     """ Get 2d indices """
     (cc, inum, ic) = (0, 0, 0)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         n, y, x, c = indices
         cc = None
-    elif layout == 'NCHW':
+    elif layout == "NCHW":
         n, c, y, x = indices
         cc = None
     elif nchw_pack_layout(layout):
@@ -40,27 +41,38 @@ def get_2d_indices(indices, layout='NCHW'):
 
     return n, c, y, x, cc, inum, ic
 
+
 def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic):
     """ Get 2d pixel """
     if boxes is None:
         y = tvm.te.max(tvm.te.min(y, image_height - 1), 0)
         x = tvm.te.max(tvm.te.min(x, image_width - 1), 0)
-    if layout == 'NHWC':
-        return data(n, y, x, c).astype('float')
-    if layout == 'NCHW':
-        return data(n, c, y, x).astype('float')
+    if layout == "NHWC":
+        return data(n, y, x, c).astype("float")
+    if layout == "NCHW":
+        return data(n, c, y, x).astype("float")
     if nchw_pack_layout(layout):
-        return data(n, c, y, x, ib, ic).astype('float')
+        return data(n, c, y, x, ib, ic).astype("float")
 
     # else must be NCHWxc
     assert nchw_xc_layout(layout)
-    return data(n, c, y, x, cc).astype('float')
-
-def resize_nearest_neighbor(indices, data, image_height, image_width,
-                            target_height, target_width, boxes=None,
-                            box_indices=None, extrapolation_value=None, layout='NCHW',
-                            coordinate_transformation_mode="align_corners",
-                            out_dtype=None):
+    return data(n, c, y, x, cc).astype("float")
+
+
+def resize_nearest_neighbor(
+    indices,
+    data,
+    image_height,
+    image_width,
+    target_height,
+    target_width,
+    boxes=None,
+    box_indices=None,
+    extrapolation_value=None,
+    layout="NCHW",
+    coordinate_transformation_mode="align_corners",
+    out_dtype=None,
+):
 
     """Perform resize operation with nearest neighbor method on the data.
     For details about Nearest-neighbor interpolation please refer to
@@ -132,21 +144,24 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
 
         in_h = (image_height - 1) * (y2 - y1)
         in_w = (image_width - 1) * (x2 - x1)
-        h_scale = in_h.astype('float') / (target_height - 1)
-        w_scale = in_w.astype('float') / (target_width - 1)
+        h_scale = in_h.astype("float") / (target_height - 1)
+        w_scale = in_w.astype("float") / (target_width - 1)
 
         in_y = y1 * (image_height - 1) + h_scale * y
         in_x = x1 * (image_width - 1) + w_scale * x
     else:
         if coordinate_transformation_mode == "align_corners":
-            h_scale = (image_height - 1).astype('float') / (target_height - 1)
-            w_scale = (image_width - 1).astype('float') / (target_width - 1)
+            h_scale = (image_height - 1).astype("float") / (target_height - 1)
+            w_scale = (image_width - 1).astype("float") / (target_width - 1)
         elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
-            h_scale = image_height.astype('float') / target_height
-            w_scale = image_width.astype('float') / target_width
+            h_scale = image_height.astype("float") / target_height
+            w_scale = image_width.astype("float") / target_width
         else:
-            raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-                coordinate_transformation_mode))
+            raise ValueError(
+                "Unsupported coordinate_transformation_mode: {}".format(
+                    coordinate_transformation_mode
+                )
+            )
         in_y = h_scale * y
         in_x = w_scale * x
 
@@ -156,32 +171,53 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
     else:
         # Add epsilon to floor to prevent gpu rounding errors.
         epsilon = 1e-5
-        closest_y_index = te.floor(in_y + epsilon).astype('int32')
-        closest_x_index = te.floor(in_x + epsilon).astype('int32')
-
-    value = get_2d_pixel(data, layout, boxes, image_height, image_width,
-                         box_idx, c, closest_y_index, closest_x_index, cc, inum, ic)
+        closest_y_index = te.floor(in_y + epsilon).astype("int32")
+        closest_x_index = te.floor(in_x + epsilon).astype("int32")
+
+    value = get_2d_pixel(
+        data,
+        layout,
+        boxes,
+        image_height,
+        image_width,
+        box_idx,
+        c,
+        closest_y_index,
+        closest_x_index,
+        cc,
+        inum,
+        ic,
+    )
 
     if extrapolation_value is not None:
-        out = tvm.tir.if_then_else(in_y < 0,
-                                   extrapolation_value,
-                                   tvm.tir.if_then_else(in_y > image_height - 1,
-                                                        extrapolation_value,
-                                                        value))
+        out = tvm.tir.if_then_else(
+            in_y < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value),
+        )
         # use extrapolation_value if in_x is out of boundary
-        value = tvm.tir.if_then_else(in_x < 0,
-                                     extrapolation_value,
-                                     tvm.tir.if_then_else(in_x > image_width - 1,
-                                                          extrapolation_value,
-                                                          out))
+        value = tvm.tir.if_then_else(
+            in_x < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out),
+        )
     return _cast_output(value, data.dtype, out_dtype=out_dtype)
 
 
-def resize_bilinear(indices, data, image_height, image_width,
-                    target_height, target_width, boxes=None,
-                    box_indices=None, extrapolation_value=None, layout='NCHW',
-                    coordinate_transformation_mode="align_corners",
-                    out_dtype=None):
+def resize_bilinear(
+    indices,
+    data,
+    image_height,
+    image_width,
+    target_height,
+    target_width,
+    boxes=None,
+    box_indices=None,
+    extrapolation_value=None,
+    layout="NCHW",
+    coordinate_transformation_mode="align_corners",
+    out_dtype=None,
+):
 
     """Perform resize operation with bilinear method on the data.
     For details about Bilinear interpolation please refer to
@@ -257,21 +293,24 @@ def resize_bilinear(indices, data, image_height, image_width,
 
         in_h = (image_height - 1) * (y2 - y1)
         in_w = (image_width - 1) * (x2 - x1)
-        h_scale = in_h.astype('float') / (target_height - 1)
-        w_scale = in_w.astype('float') / (target_width - 1)
+        h_scale = in_h.astype("float") / (target_height - 1)
+        w_scale = in_w.astype("float") / (target_width - 1)
 
         in_y = y1 * (image_height - 1) + h_scale * y
         in_x = x1 * (image_width - 1) + w_scale * x
     else:
         if coordinate_transformation_mode == "align_corners":
-            h_scale = (image_height - 1).astype('float') / (target_height - 1)
-            w_scale = (image_width - 1).astype('float') / (target_width - 1)
+            h_scale = (image_height - 1).astype("float") / (target_height - 1)
+            w_scale = (image_width - 1).astype("float") / (target_width - 1)
         elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
-            h_scale = image_height.astype('float') / target_height
-            w_scale = image_width.astype('float') / target_width
+            h_scale = image_height.astype("float") / target_height
+            w_scale = image_width.astype("float") / target_width
         else:
-            raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-                coordinate_transformation_mode))
+            raise ValueError(
+                "Unsupported coordinate_transformation_mode: {}".format(
+                    coordinate_transformation_mode
+                )
+            )
 
         if coordinate_transformation_mode == "half_pixel":
             in_y = h_scale * (y + 0.5) - 0.5
@@ -280,22 +319,70 @@ def resize_bilinear(indices, data, image_height, image_width,
             in_y = h_scale * y
             in_x = w_scale * x
 
-    top_y_index = te.floor(in_y).astype('int32')
-    bottom_y_index = te.ceil(in_y).astype('int32')
+    top_y_index = te.floor(in_y).astype("int32")
+    bottom_y_index = te.ceil(in_y).astype("int32")
     y_lerp = in_y - top_y_index
 
-    left_x_index = te.floor(in_x).astype('int32')
-    right_x_index = te.ceil(in_x).astype('int32')
+    left_x_index = te.floor(in_x).astype("int32")
+    right_x_index = te.ceil(in_x).astype("int32")
     x_lerp = in_x - left_x_index
 
-    top_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
-                            box_idx, c, top_y_index, left_x_index, cc, inum, ic)
-    top_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
-                             box_idx, c, top_y_index, right_x_index, cc, inum, ic)
-    bottom_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
-                               box_idx, c, bottom_y_index, left_x_index, cc, inum, ic)
-    bottom_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
-                                box_idx, c, bottom_y_index, right_x_index, cc, inum, ic)
+    top_left = get_2d_pixel(
+        data,
+        layout,
+        boxes,
+        image_height,
+        image_width,
+        box_idx,
+        c,
+        top_y_index,
+        left_x_index,
+        cc,
+        inum,
+        ic,
+    )
+    top_right = get_2d_pixel(
+        data,
+        layout,
+        boxes,
+        image_height,
+        image_width,
+        box_idx,
+        c,
+        top_y_index,
+        right_x_index,
+        cc,
+        inum,
+        ic,
+    )
+    bottom_left = get_2d_pixel(
+        data,
+        layout,
+        boxes,
+        image_height,
+        image_width,
+        box_idx,
+        c,
+        bottom_y_index,
+        left_x_index,
+        cc,
+        inum,
+        ic,
+    )
+    bottom_right = get_2d_pixel(
+        data,
+        layout,
+        boxes,
+        image_height,
+        image_width,
+        box_idx,
+        c,
+        bottom_y_index,
+        right_x_index,
+        cc,
+        inum,
+        ic,
+    )
 
     top = _lerp(top_left, top_right, x_lerp)
     bottom = _lerp(bottom_left, bottom_right, x_lerp)
@@ -303,24 +390,33 @@ def resize_bilinear(indices, data, image_height, image_width,
 
     # use extrapolation_value if in_y/in_x is out of boundary
     if extrapolation_value is not None:
-        out = tvm.tir.if_then_else(in_y < 0,
-                                   extrapolation_value,
-                                   tvm.tir.if_then_else(in_y > image_height - 1,
-                                                        extrapolation_value,
-                                                        value))
-        value = tvm.tir.if_then_else(in_x < 0,
-                                     extrapolation_value,
-                                     tvm.tir.if_then_else(in_x > image_width - 1,
-                                                          extrapolation_value,
-                                                          out))
+        out = tvm.tir.if_then_else(
+            in_y < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value),
+        )
+        value = tvm.tir.if_then_else(
+            in_x < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out),
+        )
     return _cast_output(value, data.dtype, out_dtype=out_dtype)
 
 
-def resize_bicubic(indices, data, image_height, image_width,
-                   target_height, target_width, boxes=None,
-                   box_indices=None, extrapolation_value=None, layout='NCHW',
-                   coordinate_transformation_mode="align_corners",
-                   out_dtype=None):
+def resize_bicubic(
+    indices,
+    data,
+    image_height,
+    image_width,
+    target_height,
+    target_width,
+    boxes=None,
+    box_indices=None,
+    extrapolation_value=None,
+    layout="NCHW",
+    coordinate_transformation_mode="align_corners",
+    out_dtype=None,
+):
     """Perform resize operation with bicubic method on the data.
     More details about Bicubic interpolation please refer to
     https://en.wikipedia.org/wiki/Bicubic_interpolation.
@@ -399,21 +495,24 @@ def resize_bicubic(indices, data, image_height, image_width,
 
         in_h = (image_height - 1) * (y2 - y1)
         in_w = (image_width - 1) * (x2 - x1)
-        h_scale = in_h.astype('float') / (target_height - 1)
-        w_scale = in_w.astype('float') / (target_width - 1)
+        h_scale = in_h.astype("float") / (target_height - 1)
+        w_scale = in_w.astype("float") / (target_width - 1)
 
         in_y = y1 * (image_height - 1) + h_scale * y
         in_x = x1 * (image_width - 1) + w_scale * x
     else:
         if coordinate_transformation_mode == "align_corners":
-            h_scale = (image_height - 1).astype('float') / (target_height - 1)
-            w_scale = (image_width - 1).astype('float') / (target_width - 1)
+            h_scale = (image_height - 1).astype("float") / (target_height - 1)
+            w_scale = (image_width - 1).astype("float") / (target_width - 1)
         elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
-            h_scale = image_height.astype('float') / target_height
-            w_scale = image_width.astype('float') / target_width
+            h_scale = image_height.astype("float") / target_height
+            w_scale = image_width.astype("float") / target_width
         else:
-            raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-                coordinate_transformation_mode))
+            raise ValueError(
+                "Unsupported coordinate_transformation_mode: {}".format(
+                    coordinate_transformation_mode
+                )
+            )
 
         if coordinate_transformation_mode == "half_pixel":
             in_y = h_scale * (y + 0.5) - 0.5
@@ -422,51 +521,67 @@ def resize_bicubic(indices, data, image_height, image_width,
             in_y = h_scale * y
             in_x = w_scale * x
 
-    xint = te.floor(in_x).astype('int32')
+    xint = te.floor(in_x).astype("int32")
     xfract = in_x - te.floor(in_x)
 
-    yint = te.floor(in_y).astype('int32')
+    yint = te.floor(in_y).astype("int32")
     yfract = in_y - te.floor(in_y)
 
     # 1st row
-    p00 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint - 1, xint - 1, cc, inum, ic)
-    p10 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint - 1, xint + 0, cc, inum, ic)
-    p20 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint - 1, xint + 1, cc, inum, ic)
-    p30 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint - 1, xint + 2, cc, inum, ic)
+    p00 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint - 1, cc, inum, ic
+    )
+    p10 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 0, cc, inum, ic
+    )
+    p20 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 1, cc, inum, ic
+    )
+    p30 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 2, cc, inum, ic
+    )
 
     # 2nd row
-    p01 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 0, xint - 1, cc, inum, ic)
-    p11 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 0, xint + 0, cc, inum, ic)
-    p21 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 0, xint + 1, cc, inum, ic)
-    p31 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 0, xint + 2, cc, inum, ic)
+    p01 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint - 1, cc, inum, ic
+    )
+    p11 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 0, cc, inum, ic
+    )
+    p21 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 1, cc, inum, ic
+    )
+    p31 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 2, cc, inum, ic
+    )
 
     # 3rd row
-    p02 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 1, xint - 1, cc, inum, ic)
-    p12 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 1, xint + 0, cc, inum, ic)
-    p22 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 1, xint + 1, cc, inum, ic)
-    p32 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 1, xint + 2, cc, inum, ic)
+    p02 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint - 1, cc, inum, ic
+    )
+    p12 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 0, cc, inum, ic
+    )
+    p22 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 1, cc, inum, ic
+    )
+    p32 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 2, cc, inum, ic
+    )
 
     # 4th row
-    p03 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 2, xint - 1, cc, inum, ic)
-    p13 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 2, xint + 0, cc, inum, ic)
-    p23 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 2, xint + 1, cc, inum, ic)
-    p33 = _get_pixel(data, layout, boxes, image_height, image_width,
-                     box_idx, c, yint + 2, xint + 2, cc, inum, ic)
+    p03 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint - 1, cc, inum, ic
+    )
+    p13 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 0, cc, inum, ic
+    )
+    p23 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 1, cc, inum, ic
+    )
+    p33 = _get_pixel(
+        data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 2, cc, inum, ic
+    )
 
     # Interpolate bicubically
     col0 = _cubic_kernel(p00, p10, p20, p30, xfract)
@@ -477,21 +592,28 @@ def resize_bicubic(indices, data, image_height, image_width,
 
     # use extrapolation_value if in_y/in_x is out of boundary
     if extrapolation_value is not None:
-        out = tvm.tir.if_then_else(in_y < 0,
-                                   extrapolation_value,
-                                   tvm.tir.if_then_else(in_y > image_height - 1,
-                                                        extrapolation_value,
-                                                        value))
-        value = tvm.tir.if_then_else(in_x < 0,
-                                     extrapolation_value,
-                                     tvm.tir.if_then_else(in_x > image_width - 1,
-                                                          extrapolation_value,
-                                                          out))
+        out = tvm.tir.if_then_else(
+            in_y < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value),
+        )
+        value = tvm.tir.if_then_else(
+            in_x < 0,
+            extrapolation_value,
+            tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out),
+        )
     return _cast_output(value, data.dtype, out_dtype=out_dtype)
 
 
-def resize(data, size, layout="NCHW", method="bilinear",
-           coordinate_transformation_mode="half_pixel", out_dtype=None, output_shape=None):
+def resize(
+    data,
+    size,
+    layout="NCHW",
+    method="bilinear",
+    coordinate_transformation_mode="half_pixel",
+    out_dtype=None,
+    output_shape=None,
+):
     """Perform resize operation on the data.
 
     Parameters
@@ -532,49 +654,67 @@ def resize(data, size, layout="NCHW", method="bilinear",
     """
     method = method.lower()
     if method == "nearest_neighbor" and coordinate_transformation_mode != "asymmetric":
-        raise ValueError('Topi Resize does not support the combination of method %s ' \
-                         'and coordinate_transformation_mode %s' %
-                         (method, coordinate_transformation_mode))
-    if layout == 'NHWC':
+        raise ValueError(
+            "Topi Resize does not support the combination of method %s "
+            "and coordinate_transformation_mode %s" % (method, coordinate_transformation_mode)
+        )
+    if layout == "NHWC":
         in_n, in_h, in_w, in_c = data.shape
         if output_shape is None:
             output_shape = [in_n, size[0], size[1], in_c]
-    elif layout == 'NCHW':
+    elif layout == "NCHW":
         in_n, in_c, in_h, in_w = data.shape
         if output_shape is None:
             output_shape = [in_n, in_c, size[0], size[1]]
-    elif nchw_pack_layout(layout):# for NCHWinic
+    elif nchw_pack_layout(layout):  # for NCHWinic
         in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape
         if output_shape is None:
             output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
-    elif nchw_xc_layout(layout):# for NCHWxc
+    elif nchw_xc_layout(layout):  # for NCHWxc
         in_n, in_c, in_h, in_w, in_cc = data.shape
         if output_shape is None:
             output_shape = [in_n, in_c, size[0], size[1], in_cc]
     else:
-        raise ValueError('%s layout is not supported.' % layout)
-
+        raise ValueError("%s layout is not supported." % layout)
 
     def _nearest_neighbor(*indices):
-        return resize_nearest_neighbor(indices, data, in_h, in_w,
-                                       size[0], size[1], layout=layout,
-                                       coordinate_transformation_mode= \
-                                       coordinate_transformation_mode,
-                                       out_dtype=out_dtype)
+        return resize_nearest_neighbor(
+            indices,
+            data,
+            in_h,
+            in_w,
+            size[0],
+            size[1],
+            layout=layout,
+            coordinate_transformation_mode=coordinate_transformation_mode,
+            out_dtype=out_dtype,
+        )
 
     def _bilinear(*indices):
-        return resize_bilinear(indices, data, in_h, in_w,
-                               size[0], size[1], layout=layout,
-                               coordinate_transformation_mode= \
-                               coordinate_transformation_mode,
-                               out_dtype=out_dtype)
+        return resize_bilinear(
+            indices,
+            data,
+            in_h,
+            in_w,
+            size[0],
+            size[1],
+            layout=layout,
+            coordinate_transformation_mode=coordinate_transformation_mode,
+            out_dtype=out_dtype,
+        )
 
     def _bicubic(*indices):
-        return resize_bicubic(indices, data, in_h, in_w,
-                              size[0], size[1], layout,
-                              coordinate_transformation_mode= \
-                              coordinate_transformation_mode,
-                              out_dtype=out_dtype)
+        return resize_bicubic(
+            indices,
+            data,
+            in_h,
+            in_w,
+            size[0],
+            size[1],
+            layout,
+            coordinate_transformation_mode=coordinate_transformation_mode,
+            out_dtype=out_dtype,
+        )
 
     # Determine which interpolation method to use then run it.
     if method == "nearest_neighbor":
@@ -584,13 +724,21 @@ def resize(data, size, layout="NCHW", method="bilinear",
     elif method == "bicubic":
         compute_func = _bicubic
     else:
-        raise ValueError('%s method is not supported.' % method)
+        raise ValueError("%s method is not supported." % method)
 
-    return te.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE)
+    return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
 
 
-def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW",
-                    method="bilinear", extrapolation_value=0, out_dtype=None):
+def crop_and_resize(
+    data,
+    boxes,
+    box_indices,
+    crop_size,
+    layout="NCHW",
+    method="bilinear",
+    extrapolation_value=0,
+    out_dtype=None,
+):
     """Perform crop and resize operation on the data.
 
     Parameters
@@ -633,31 +781,56 @@ def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW",
     target_h = crop_size[0]
     target_w = crop_size[1]
 
-    if layout == 'NHWC':
+    if layout == "NHWC":
         output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]]
         image_h = data.shape[1].astype("int32")
         image_w = data.shape[2].astype("int32")
-    elif layout == 'NCHW':
+    elif layout == "NCHW":
         output_shape = [box_indices.shape[0], data.shape[1], crop_size[0], crop_size[1]]
         image_h = data.shape[2].astype("int32")
         image_w = data.shape[3].astype("int32")
-    elif layout.startswith("NCHW"):# for NCHWxc
-        output_shape = [box_indices.shape[0], data.shape[1],
-                        crop_size[0], crop_size[1], data.shape[4]]
+    elif layout.startswith("NCHW"):  # for NCHWxc
+        output_shape = [
+            box_indices.shape[0],
+            data.shape[1],
+            crop_size[0],
+            crop_size[1],
+            data.shape[4],
+        ]
         image_h = data.shape[2].astype("int32")
         image_w = data.shape[3].astype("int32")
     else:
-        raise ValueError('%s layout is not supported.' % layout)
+        raise ValueError("%s layout is not supported." % layout)
 
     def _bilinear(*indices):
-        return resize_bilinear(indices, data, image_h, image_w, target_h,
-                               target_w, boxes, box_indices, extrapolation_value,
-                               layout, out_dtype=out_dtype)
+        return resize_bilinear(
+            indices,
+            data,
+            image_h,
+            image_w,
+            target_h,
+            target_w,
+            boxes,
+            box_indices,
+            extrapolation_value,
+            layout,
+            out_dtype=out_dtype,
+        )
 
     def _nearest_neighbor(*indices):
-        return resize_nearest_neighbor(indices, data, image_h, image_w, target_h,
-                                       target_w, boxes, box_indices, extrapolation_value,
-                                       layout, out_dtype=out_dtype)
+        return resize_nearest_neighbor(
+            indices,
+            data,
+            image_h,
+            image_w,
+            target_h,
+            target_w,
+            boxes,
+            box_indices,
+            extrapolation_value,
+            layout,
+            out_dtype=out_dtype,
+        )
 
     # Determine which interpolation method to use then run it.
     if method == "nearest_neighbor":
@@ -665,14 +838,19 @@ def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW",
     elif method == "bilinear":
         compute_func = _bilinear
     else:
-        raise ValueError('%s method is not supported.' % method)
-
-    return te.compute(output_shape, compute_func, name='crop_and_resize', tag=tag.INJECTIVE)
+        raise ValueError("%s method is not supported." % method)
 
+    return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE)
 
 
-def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
-             coordinate_transformation_mode="align_corners", out_dtype=None):
+def resize3d(
+    data,
+    size,
+    layout="NCDHW",
+    method="nearest_neighbor",
+    coordinate_transformation_mode="align_corners",
+    out_dtype=None,
+):
     """Perform resize operation on the data.
 
     Parameters
@@ -710,10 +888,10 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
     """
     method = method.lower()
 
-    if layout == 'NDHWC':
+    if layout == "NDHWC":
         in_n, in_d, in_h, in_w, in_c = data.shape
         output_shape = [in_n, size[0], size[1], size[2], in_c]
-    elif layout == 'NCDHW':
+    elif layout == "NCDHW":
         in_n, in_c, in_d, in_h, in_w = data.shape
         output_shape = [in_n, in_c, size[0], size[1], size[2]]
     # Otherwise layout must be NCHWxc
@@ -722,33 +900,34 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
         output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc]
 
     if coordinate_transformation_mode == "align_corners":
-        z_ratio = (in_d - 1).astype('float') / (size[0] - 1)
-        y_ratio = (in_h - 1).astype('float') / (size[1] - 1)
-        x_ratio = (in_w - 1).astype('float') / (size[2] - 1)
+        z_ratio = (in_d - 1).astype("float") / (size[0] - 1)
+        y_ratio = (in_h - 1).astype("float") / (size[1] - 1)
+        x_ratio = (in_w - 1).astype("float") / (size[2] - 1)
     elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
-        z_ratio = (in_d).astype('float') / (size[0])
-        y_ratio = (in_h).astype('float') / (size[1])
-        x_ratio = (in_w).astype('float') / (size[2])
+        z_ratio = (in_d).astype("float") / (size[0])
+        y_ratio = (in_h).astype("float") / (size[1])
+        x_ratio = (in_w).astype("float") / (size[2])
     else:
-        raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-            coordinate_transformation_mode))
+        raise ValueError(
+            "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode)
+        )
 
     def _get_pixel(n, c, z, y, x, cc):
         z = tvm.te.max(tvm.te.min(z, in_d - 1), 0)
         y = tvm.te.max(tvm.te.min(y, in_h - 1), 0)
         x = tvm.te.max(tvm.te.min(x, in_w - 1), 0)
-        if layout == 'NDHWC':
-            return data(n, z, y, x, c).astype('float')
-        if layout == 'NCDHW':
-            return data(n, c, z, y, x).astype('float')
+        if layout == "NDHWC":
+            return data(n, z, y, x, c).astype("float")
+        if layout == "NCDHW":
+            return data(n, c, z, y, x).astype("float")
         # else must be NCDHWxc
-        return data(n, c, z, y, x, cc).astype('float')
+        return data(n, c, z, y, x, cc).astype("float")
 
     def _get_indices(*indices):
-        if layout == 'NDHWC':
+        if layout == "NDHWC":
             n, z, y, x, c = indices
             cc = None
-        elif layout == 'NCDHW':
+        elif layout == "NCDHW":
             n, c, z, y, x = indices
             cc = None
         else:
@@ -772,18 +951,21 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
         in_x = x_ratio * x
 
         if coordinate_transformation_mode == "align_corners":
-            zint = te.round(in_z).astype('int32')
-            yint = te.round(in_y).astype('int32')
-            xint = te.round(in_x).astype('int32')
+            zint = te.round(in_z).astype("int32")
+            yint = te.round(in_y).astype("int32")
+            xint = te.round(in_x).astype("int32")
         elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
             # Add epsilon to floor to prevent gpu rounding errors.
             epsilon = 1e-5
-            zint = te.floor(in_z + epsilon).astype('int32')
-            yint = te.floor(in_y + epsilon).astype('int32')
-            xint = te.floor(in_x + epsilon).astype('int32')
+            zint = te.floor(in_z + epsilon).astype("int32")
+            yint = te.floor(in_y + epsilon).astype("int32")
+            xint = te.floor(in_x + epsilon).astype("int32")
         else:
-            raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-                coordinate_transformation_mode))
+            raise ValueError(
+                "Unsupported coordinate_transformation_mode: {}".format(
+                    coordinate_transformation_mode
+                )
+            )
 
         return _cast_output(_get_pixel(n, c, zint, yint, xint, cc))
 
@@ -803,13 +985,13 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
             in_y = y_ratio * y
             in_x = x_ratio * x
 
-        zint = te.floor(in_z).astype('int32')
+        zint = te.floor(in_z).astype("int32")
         zfract = in_z - te.floor(in_z)
 
-        xint = te.floor(in_x).astype('int32')
+        xint = te.floor(in_x).astype("int32")
         xfract = in_x - te.floor(in_x)
 
-        yint = te.floor(in_y).astype('int32')
+        yint = te.floor(in_y).astype("int32")
         yfract = in_y - te.floor(in_y)
 
         p000 = _get_pixel(n, c, zint, yint, xint, cc)
@@ -836,6 +1018,6 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
     elif method == "trilinear":
         compute_func = _trilinear
     else:
-        raise ValueError('%s method is not supported.' % method)
+        raise ValueError("%s method is not supported." % method)
 
-    return te.compute(output_shape, compute_func, name='resize3d', tag=tag.INJECTIVE)
+    return te.compute(output_shape, compute_func, name="resize3d", tag=tag.INJECTIVE)
index e4ea196..5bd8581 100644 (file)
@@ -152,23 +152,25 @@ def _pack_data(data, kernel, ic_bn, oc_bn):
     ic_chunk = ic // ic_bn
     oc_chunk = oc // oc_bn
 
-    data = te.compute((n, ic_chunk, ih, iw, ic_bn),
-                      lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
-                      name="data_vec")
+    data = te.compute(
+        (n, ic_chunk, ih, iw, ic_bn),
+        lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
+        name="data_vec",
+    )
 
     kernel = te.compute(
         (oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn),
-        lambda occ, icc, k_h, k_w, icb, ocb:
-        kernel[occ * oc_bn + ocb,
-               icc * ic_bn + icb, k_h, k_w],
-        name="kernel_vec")
+        lambda occ, icc, k_h, k_w, icb, ocb: kernel[occ * oc_bn + ocb, icc * ic_bn + icb, k_h, k_w],
+        name="kernel_vec",
+    )
 
     return data, kernel
 
 
 @autotvm.register_topi_compute("conv2d_NCHWc.intel_graphics")
-def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
-                 out_layout, out_dtype='float32'):
+def conv2d_NCHWc(
+    cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype="float32"
+):
     """Conv2D operator for Intel Graphics backend.
 
     Parameters
@@ -204,7 +206,8 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
 
     dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
     pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(
-        padding, (kernel_height, kernel_width))
+        padding, (kernel_height, kernel_width)
+    )
     assert (dh, dw) == (1, 1), "Does not support dilation"
     if isinstance(strides, (tuple, list)):
         stride_h, stride_w = strides
@@ -216,10 +219,16 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
     _create_schedule_template(cfg, data_shape, kernel_shape, strides, padding, dilation)
 
     if cfg.is_fallback:
-        _get_default_config(cfg, te.placeholder((batch, in_channel, ih, iw), dtype=data.dtype),
-                            te.placeholder((num_filter, in_channel, kernel_height, kernel_width),
-                                           dtype=kernel.dtype),
-                            strides, padding, out_dtype)
+        _get_default_config(
+            cfg,
+            te.placeholder((batch, in_channel, ih, iw), dtype=data.dtype),
+            te.placeholder(
+                (num_filter, in_channel, kernel_height, kernel_width), dtype=kernel.dtype
+            ),
+            strides,
+            padding,
+            out_dtype,
+        )
 
     ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
     oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
@@ -233,9 +242,9 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
     out_width = simplify((iw - kernel_width + pad_left + pad_right) // stride_w + 1)
     oshape = (batch, out_channel // oc_bn, out_height, out_width, oc_bn)
 
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_height), name='ry')
-    rx = te.reduce_axis((0, kernel_width), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_height), name="ry")
+    rx = te.reduce_axis((0, kernel_width), name="rx")
 
     block_h = cfg["block_oh"].val
     block_w = cfg["block_ow"].val
@@ -252,11 +261,14 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
     cshape = (batch, out_channel // oc_bn, c_h, c_w, oc_bn)
 
     pad_before = [0, 0, pad_top, pad_left, 0]
-    pad_after = [0, 0, pad_down + c_h - out_height, pad_right + \
-                 c_w - out_width, 0]
-    DOPAD = (pad_top != 0 or pad_left != 0 or pad_down + c_h - out_height != 0 \
-             or pad_right + c_w - out_width != 0)
-    DOUNPACK = (c_h - out_height != 0 or c_w - out_width != 0)
+    pad_after = [0, 0, pad_down + c_h - out_height, pad_right + c_w - out_width, 0]
+    DOPAD = (
+        pad_top != 0
+        or pad_left != 0
+        or pad_down + c_h - out_height != 0
+        or pad_right + c_w - out_width != 0
+    )
+    DOUNPACK = c_h - out_height != 0 or c_w - out_width != 0
     if DOPAD:
         temp = nn.pad(data, pad_before, pad_after, name="pad_temp")
     else:
@@ -264,19 +276,24 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
 
     conv = te.compute(
         cshape,
-        lambda nn, ff, yy, xx, ff_v: \
-        te.sum(
-            temp[nn, rc//ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc%ic_bn]. \
-            astype(out_dtype) *
-            kernel[ff, rc//ic_bn, ry, rx, rc%ic_bn, ff_v].astype(out_dtype),
-            axis=[rc, ry, rx]), tag="conv2d_NCHWc", name='conv2d_NCHWc')
+        lambda nn, ff, yy, xx, ff_v: te.sum(
+            temp[nn, rc // ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc % ic_bn].astype(
+                out_dtype
+            )
+            * kernel[ff, rc // ic_bn, ry, rx, rc % ic_bn, ff_v].astype(out_dtype),
+            axis=[rc, ry, rx],
+        ),
+        tag="conv2d_NCHWc",
+        name="conv2d_NCHWc",
+    )
 
     if DOUNPACK:
         output = te.compute(
             oshape,
-            lambda nn, ff, yy, xx, ff_v:
-            conv[nn][ff][yy][xx][ff_v],
-            name='output_unpack', tag="conv2d_NCHWc_unpack")
+            lambda nn, ff, yy, xx, ff_v: conv[nn][ff][yy][xx][ff_v],
+            name="output_unpack",
+            tag="conv2d_NCHWc_unpack",
+        )
     else:
         output = conv
 
@@ -324,7 +341,7 @@ def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
             s[output].compute_inline()
             conv = s.outputs[0]
         SCHEDULE_OUTPUT = False
-    else: # conv2d_NCHWc_unpack
+    else:  # conv2d_NCHWc_unpack
         conv = op.input_tensors[0]
         temp = s[conv].op.input_tensors[0]
         kernel = s[conv].op.input_tensors[1]
@@ -420,7 +437,7 @@ def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
         tile_and_bind3d(s, out, w, h, vc, 4, 8, 8)
 
 
-def conv2d_nchw(data, kernel, stride, padding, dilation, out_dtype='float32'):
+def conv2d_nchw(data, kernel, stride, padding, dilation, out_dtype="float32"):
     """Conv2D operator for Intel Graphics backend.
 
     Parameters
@@ -462,14 +479,14 @@ def schedule_conv2d_nchw(outs):
 
     def _callback(op):
         """inline all one-to-one-mapping operators except the last stage (output)"""
-        if 'conv2d' in op.tag:
+        if "conv2d" in op.tag:
             _schedule_cl_spatialpack(s, op)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype='float16'):
+def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype="float16"):
     batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape]
     num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape]
     pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(padding, (kernel_h, kernel_w))
@@ -484,9 +501,9 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype='float16'):
     out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
     oshape = (batch, out_channel, out_height, out_width)
 
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     if stride_h == 2:
         if num_filter + kernel_h == 515:
@@ -508,7 +525,7 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype='float16'):
     else:
         block_h = 1
         block_w = 16
-    attrs = {'block_h': block_h, 'block_w' : block_w}
+    attrs = {"block_h": block_h, "block_w": block_w}
     c_h = out_height
     c_w = out_width
 
@@ -531,23 +548,26 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype='float16'):
     kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv)
 
     kernel_vec = te.compute(
-        kvshape,
-        lambda co, ci, kh, kw, vc:
-        kernel[co*nv + vc][ci][kh][kw], name='kernel_vec')
+        kvshape, lambda co, ci, kh, kw, vc: kernel[co * nv + vc][ci][kh][kw], name="kernel_vec"
+    )
 
     conv = te.compute(
         cshape,
-        lambda nn, ff, yy, xx, vc: \
-        te.sum(
-            temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
-            kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype),
-            axis=[rc, ry, rx]), name='conv', attrs=attrs)
+        lambda nn, ff, yy, xx, vc: te.sum(
+            temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype)
+            * kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype),
+            axis=[rc, ry, rx],
+        ),
+        name="conv",
+        attrs=attrs,
+    )
 
     output = te.compute(
         oshape,
-        lambda nn, ff, yy, xx:
-        conv[nn][ff//nv][yy][xx][ff%nv],
-        name='output_unpack', tag='conv2d')
+        lambda nn, ff, yy, xx: conv[nn][ff // nv][yy][xx][ff % nv],
+        name="output_unpack",
+        tag="conv2d",
+    )
 
     return output
 
@@ -567,8 +587,8 @@ def _schedule_cl_spatialpack(s, op):
     _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape]
 
     attrs = s[conv].op.attrs
-    OUTPUT_BLOCK_HEIGHT = attrs['block_h']
-    OUTPUT_BLOCK_WIDTH = attrs['block_w']
+    OUTPUT_BLOCK_HEIGHT = attrs["block_h"]
+    OUTPUT_BLOCK_WIDTH = attrs["block_w"]
 
     # schedule conv
     z_factor = 1
index bbe5e7f..46802bb 100644 (file)
@@ -36,7 +36,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         workload = cfg.workload
     else:
         _, outs = relay.backend.compile_engine.select_implementation(
-            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+        )
         workload = autotvm.task.get_workload(outs)
         if workload is None:
             # The best implementation is not an AutoTVM template,
@@ -45,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         cfg = dispatch_ctx.query(target, workload)
 
     topi_tmpl = workload[0]
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
 
     padding = attrs.get_int_tuple("padding")
     strides = attrs.get_int_tuple("strides")
@@ -60,28 +61,39 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     if topi_tmpl == "conv2d_NCHWc.intel_graphics":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         if cfg.is_fallback:
-            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                out_dtype, False)
+            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False)
         batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
         out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
         ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
         oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
 
         # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        new_attrs["channels"] = out_channel
+        new_attrs["data_layout"] = "NCHW%dc" % ic_bn
         # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-        new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+        new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
+        new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
         # Store altered operator's config
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
-                                     kh, kw, ic_bn, oc_bn), dtype=kernel_dtype)
+        new_data = te.placeholder(
+            (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+        )
+        new_kernel = te.placeholder(
+            (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn, oc_bn), dtype=kernel_dtype
+        )
         new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
-             new_attrs["out_layout"], out_dtype], "conv2d_NCHWc.intel_graphics")
+            [
+                new_data,
+                new_kernel,
+                strides,
+                padding,
+                dilation,
+                new_attrs["data_layout"],
+                new_attrs["out_layout"],
+                out_dtype,
+            ],
+            "conv2d_NCHWc.intel_graphics",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
 
index ffeb9af..e236779 100644 (file)
@@ -49,7 +49,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'depthwise_conv2d_nchw':
+        if op.tag == "depthwise_conv2d_nchw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
@@ -62,7 +62,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.kind.name in ['nvptx', 'rocm']:
+            if target.kind.name in ["nvptx", "rocm"]:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -70,28 +70,29 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.kind.name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
+                    target.kind.name, target.model, "depthwise_conv2d_nchw.intel_graphics"
+                )
                 cfg.fallback_with_reference_log(ref_log)
-                cfg['unroll_explicit'].val = 0
+                cfg["unroll_explicit"].val = 0
             ##### space definition end #####
 
             s[pad_data].compute_inline()
-            if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag:
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
             if conv.op in s.outputs:
                 output = conv
-                OL = s.cache_write(conv, 'local')
+                OL = s.cache_write(conv, "local")
             else:
                 output = s.outputs[0].output(0)
-                s[conv].set_scope('local')
+                s[conv].set_scope("local")
                 OL = conv
 
             # create cache stage
-            AA = s.cache_read(pad_data, 'shared', [OL])
-            WW = s.cache_read(kernel, 'shared', [OL])
-            AL = s.cache_read(AA, 'local', [OL])
-            WL = s.cache_read(WW, 'local', [OL])
+            AA = s.cache_read(pad_data, "shared", [OL])
+            WW = s.cache_read(kernel, "shared", [OL])
+            AL = s.cache_read(AA, "local", [OL])
+            WL = s.cache_read(WW, "local", [OL])
 
             # tile and bind spatial axes
             n, f, y, x = s[output].op.axis
@@ -128,8 +129,8 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
                 s[load].bind(ty, te.thread_axis("threadIdx.y"))
                 s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
-            s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-            s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+            s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+            s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
@@ -204,10 +205,10 @@ def schedule_depthwise_conv2d_nhwc(outs):
                 if tensor.op.input_tensors and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule depthwise_conv2d
-        if OP.tag == 'depthwise_conv2d_nhwc':
+        if OP.tag == "depthwise_conv2d_nhwc":
             PaddedInput = OP.input_tensors[0]
             Filter = OP.input_tensors[1]
-            if isinstance(Filter.op, tvm.te.ComputeOp) and 'dilate' in Filter.op.tag:
+            if isinstance(Filter.op, tvm.te.ComputeOp) and "dilate" in Filter.op.tag:
                 s[Filter].compute_inline()
             DepthwiseConv2d = OP.output(0)
             _schedule(PaddedInput, Filter, DepthwiseConv2d)
@@ -251,7 +252,7 @@ def schedule_depthwise_conv2d_backward_input_nhwc(outs):
 
     def traverse(OP):
         # inline all one-to-one-mapping operators except the last stage (output)
-        if OP.tag == 'depthwise_conv2d_backward_input_nhwc':
+        if OP.tag == "depthwise_conv2d_backward_input_nhwc":
             Padded_out_grad = OP.input_tensors[0]
             Dilated_out_grad = Padded_out_grad.op.input_tensors[0]
             s[Dilated_out_grad].compute_inline()
@@ -263,6 +264,7 @@ def schedule_depthwise_conv2d_backward_input_nhwc(outs):
     traverse(outs[0].op)
     return s
 
+
 def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
     """Schedule for depthwise_conv2d nhwc backward wrt weight.
 
@@ -304,7 +306,7 @@ def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
 
     def traverse(OP):
         # inline all one-to-one-mapping operators except the last stage (output)
-        if OP.tag == 'depthwise_conv2d_backward_weight_nhwc':
+        if OP.tag == "depthwise_conv2d_backward_weight_nhwc":
             Padded_in = OP.input_tensors[1]
             s[Padded_in].compute_inline()
             Weight_grad = OP.output(0)
@@ -315,6 +317,7 @@ def schedule_depthwise_conv2d_backward_weight_nhwc(outs):
     traverse(outs[0].op)
     return s
 
+
 @depthwise_conv2d_infer_layout.register("intel_graphics")
 def _depthwise_conv2d_infer_layout(workload, _):
     """Infer input/output shapes and layouts from a workload and cfg.
index f2b26ee..0ccf1e6 100644 (file)
@@ -64,8 +64,10 @@ def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_
     output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                    dilation, out_dtype, num_tile=3)
+    return conv2d_spatial_pack_nchw(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=3
+    )
+
 
 @autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.mali")
 def schedule_conv2d_nchw_spatial_pack(cfg, outs):
@@ -88,7 +90,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
 
     def _callback(op):
         # schedule conv2d
-        if 'spatial_conv2d_output' in op.tag:
+        if "spatial_conv2d_output" in op.tag:
             output = op.output(0)
             conv = op.input_tensors[0]
 
@@ -97,7 +99,7 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs):
             s[data_pad].compute_inline()
 
             kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
+            if kernel_vec.op.name == "kernel_vec":
                 kernel = kernel_vec.op.input_tensors[0]
             else:
                 kernel = kernel_vec
@@ -127,7 +129,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
         s[data_pad].compute_inline()
 
     # schedule data packing
-    if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == 'data_vec_undilated':
+    if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated":
         _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis
     else:
         _, h, w, ci, vh, vw = s[data_vec].op.axis
@@ -137,7 +139,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
     if vw.dom.extent.value < max_unroll:
         s[data_vec].unroll(vw)
 
-    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec':
+    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec":
         if not autotvm.GLOBAL_SCOPE.in_tuning:
             max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
             co, ci, kh, kw, vc = s[kernel_vec].op.axis
@@ -156,16 +158,23 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
     cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc])
     tile_and_bind3d(s, conv, c, h, w, TC, TH, TW)
 
-    cfg["ann_reduce"].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kernel_vec.shape[2]),
-                                       get_const_int(kernel_vec.shape[3])],
-                            max_unroll=max_unroll)
-
-    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
-                             axis_lens=[VH, VW, VC],
-                             max_unroll=max_unroll,
-                             vec_size=vec_size,
-                             cfg=cfg)
+    cfg["ann_reduce"].apply(
+        s,
+        conv,
+        [kh, kw],
+        axis_lens=[get_const_int(kernel_vec.shape[2]), get_const_int(kernel_vec.shape[3])],
+        max_unroll=max_unroll,
+    )
+
+    cfg["ann_spatial"].apply(
+        s,
+        conv,
+        [vh, vw, vc],
+        axis_lens=[VH, VW, VC],
+        max_unroll=max_unroll,
+        vec_size=vec_size,
+        cfg=cfg,
+    )
 
     # schedule output
     if output.op not in s.outputs:  # has bias
@@ -177,6 +186,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
 
     return s
 
+
 ##### WINOGRAD TEMPLATE #####
 def _pick_tile_size(data, kernel):
     N, CI, H, W = get_const_tuple(data.shape)
@@ -190,8 +200,7 @@ def _pick_tile_size(data, kernel):
 @autotvm.register_topi_compute("conv2d_nchw_winograd.mali")
 def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     tile_size = _pick_tile_size(data, kernel)
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype,
-                          tile_size)
+    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size)
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd.mali")
@@ -199,7 +208,7 @@ def schedule_conv2d_nchw_winograd(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
+        if "winograd_conv2d_output" in op.tag:
             _schedule_winograd(cfg, s, op)
 
     traverse_inline(s, outs[0].op, _callback)
@@ -237,43 +246,48 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
 
     H = (IH + pt + pb - 3) // HSTR + 1
     W = (IW + pl + pr - 3) // WSTR + 1
-    nH, nW = (H + m-1) // m, (W + m-1) // m
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
 
     ##### space definition begin #####
     tile_bna_candidates = [1, 2, 4, 8, 16]
     factors = get_factors(CO)
-    cfg.define_knob('tile_bna', [x for x in tile_bna_candidates if x in factors])
-    cfg.define_knob('tile_bnb', [1, 2, 4, 8, 16])
-    cfg.define_split('tile_t1', CI, num_outputs=2, max_factor=128)
-    cfg.define_split('tile_t2', CO, num_outputs=2, max_factor=128)
-    cfg.define_split('c_unroll', CI, num_outputs=2, max_factor=8)
-    cfg.define_knob('yt', [1, 2, 4, 8, 16, 32])
+    cfg.define_knob("tile_bna", [x for x in tile_bna_candidates if x in factors])
+    cfg.define_knob("tile_bnb", [1, 2, 4, 8, 16])
+    cfg.define_split("tile_t1", CI, num_outputs=2, max_factor=128)
+    cfg.define_split("tile_t2", CO, num_outputs=2, max_factor=128)
+    cfg.define_split("c_unroll", CI, num_outputs=2, max_factor=8)
+    cfg.define_knob("yt", [1, 2, 4, 8, 16, 32])
     ##### space definition end #####
 
     if cfg.is_fallback:
-        cfg['tile_bnb'].val = 4
-        cfg['tile_bna'].val = 4
-        while CO % cfg['tile_bna'].val != 0:
-            cfg['tile_bna'].val //= 2
-        cfg['yt'].val = 8
-        cfg.fallback_split('tile_t1', [-1, 128])
-        cfg.fallback_split('tile_t2', [-1, 128])
-        cfg.fallback_split('c_unroll', [-1, 8])
-
-    bna = cfg['tile_bna'].val
-    bnb = cfg['tile_bnb'].val
+        cfg["tile_bnb"].val = 4
+        cfg["tile_bna"].val = 4
+        while CO % cfg["tile_bna"].val != 0:
+            cfg["tile_bna"].val //= 2
+        cfg["yt"].val = 8
+        cfg.fallback_split("tile_t1", [-1, 128])
+        cfg.fallback_split("tile_t2", [-1, 128])
+        cfg.fallback_split("c_unroll", [-1, 8])
+
+    bna = cfg["tile_bna"].val
+    bnb = cfg["tile_bnb"].val
 
     P_round = (P + bnb - 1) // bnb * bnb
     assert CO % bna == 0 and P_round % bnb == 0
 
     # pack input tile
     input_tile = te.compute(
-        (CI, P_round // bnb, alpha, alpha, bnb), lambda ci, b, eps, nu, bb: \
-        tvm.tir.if_then_else(
+        (CI, P_round // bnb, alpha, alpha, bnb),
+        lambda ci, b, eps, nu, bb: tvm.tir.if_then_else(
             b * bnb + bb < P,
-            data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps]
-            [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d')
+            data_pad[(b * bnb + bb) // (nH * nW)][ci][(b * bnb + bb) // nW % nH * m + eps][
+                (b * bnb + bb) % nW * m + nu
+            ],
+            tvm.tir.const(0, data_pad.dtype),
+        ),
+        name="d",
+    )
 
     if autotvm.GLOBAL_SCOPE.in_tuning:
         kvshape = (alpha, alpha, CO // bna, CI, bna)
@@ -283,50 +297,70 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
         if pre_computed:
             U = kernel
         else:
-            r_kh = te.reduce_axis((0, KH), 'r_kh')
-            r_kw = te.reduce_axis((0, KW), 'r_kw')
-            U = te.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco:
-                           te.sum(kernel[co * bna + vco][ci][r_kh][r_kw] *
-                                  G[eps][r_kh] * G[nu][r_kw],
-                                  axis=[r_kh, r_kw]), name='U')
+            r_kh = te.reduce_axis((0, KH), "r_kh")
+            r_kw = te.reduce_axis((0, KW), "r_kw")
+            U = te.compute(
+                (alpha, alpha, CO // bna, CI, bna),
+                lambda eps, nu, co, ci, vco: te.sum(
+                    kernel[co * bna + vco][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kh, r_kw],
+                ),
+                name="U",
+            )
 
     # transform image
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    V = te.compute((alpha, alpha, P_round // bnb, CI, bnb), lambda eps, nu, p, ci, vp:
-                   te.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu],
-                          axis=[r_a, r_b]), name='V')
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    V = te.compute(
+        (alpha, alpha, P_round // bnb, CI, bnb),
+        lambda eps, nu, p, ci, vp: te.sum(
+            input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
+        ),
+        name="V",
+    )
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
     # batch gemm
-    ci = te.reduce_axis((0, CI), name='c')
-    M = te.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p:
-                   te.sum(U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)] *
-                          V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)], axis=ci), name='M')
-
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    Y = te.compute((CO, P, m, m), lambda co, p, vh, vw:
-                   te.sum(M[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw],
-                          axis=[r_a, r_b]), name='Y')
+    ci = te.reduce_axis((0, CI), name="c")
+    M = te.compute(
+        (alpha, alpha, CO, P_round),
+        lambda eps, nu, co, p: te.sum(
+            U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)]
+            * V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)],
+            axis=ci,
+        ),
+        name="M",
+    )
+
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    Y = te.compute(
+        (CO, P, m, m),
+        lambda co, p, vh, vw: te.sum(M[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]),
+        name="Y",
+    )
 
     # unpack output
-    output = te.compute((N, CO, H, W), lambda n, co, h, w:
-                        Y[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
-                          idxmod(h, m), idxmod(w, m)]
-                        # The following hack term is used to make the padding in batch gemm ("M")
-                        # effective, otherwise the padding will be eliminated by bound inference.
-                        # Use `tvm.tir.Mul` instead of `*` to avoid issues in const folding.
-                        + tvm.tir.Mul(tvm.tir.const(0, out_dtype),
-                                      M[alpha-1][alpha-1][CO-1][P_round-1]),
-                        name='output', tag='winograd_conv2d_output')
+    output = te.compute(
+        (N, CO, H, W),
+        lambda n, co, h, w: Y[
+            co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
+        ]
+        # The following hack term is used to make the padding in batch gemm ("M")
+        # effective, otherwise the padding will be eliminated by bound inference.
+        # Use `tvm.tir.Mul` instead of `*` to avoid issues in const folding.
+        + tvm.tir.Mul(tvm.tir.const(0, out_dtype), M[alpha - 1][alpha - 1][CO - 1][P_round - 1]),
+        name="output",
+        tag="winograd_conv2d_output",
+    )
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * CO * H * W * KH * KW * CI)
     return output
 
+
 def _schedule_winograd(cfg, s, op):
     """schedule winograd fast convolution F(2x2, 3x3) for conv2d"""
     # get ops and tensors
@@ -345,7 +379,13 @@ def _schedule_winograd(cfg, s, op):
     if isinstance(U.op, tvm.te.ComputeOp):
         kernel, G = s[U].op.input_tensors
         s[G].compute_inline()
-        eps, nu, co, ci, vco, = s[U].op.axis
+        (
+            eps,
+            nu,
+            co,
+            ci,
+            vco,
+        ) = s[U].op.axis
         if not autotvm.GLOBAL_SCOPE.in_tuning:
             r_kh, r_kw = s[U].op.reduce_axis
             s[U].reorder(co, ci, eps, nu, r_kh, r_kw, vco)
@@ -359,7 +399,7 @@ def _schedule_winograd(cfg, s, op):
 
     # transform image
     s[B].compute_inline()
-    VL = s.cache_write(V, 'local')
+    VL = s.cache_write(V, "local")
 
     eps, nu, p, ci, vp = s[V].op.axis
     s[V].reorder(p, ci, eps, nu, vp)
@@ -368,9 +408,9 @@ def _schedule_winograd(cfg, s, op):
     s[V].vectorize(vp)
     fused = s[V].fuse(p, ci)
 
-    bb, tt = cfg['tile_t1'].apply(s, V, fused)
-    s[V].bind(bb, te.thread_axis('blockIdx.x'))
-    s[V].bind(tt, te.thread_axis('threadIdx.x'))
+    bb, tt = cfg["tile_t1"].apply(s, V, fused)
+    s[V].bind(bb, te.thread_axis("blockIdx.x"))
+    s[V].bind(tt, te.thread_axis("threadIdx.x"))
 
     eps, nu, p, ci, vp = s[VL].op.axis
     r_a, r_b = s[VL].op.reduce_axis
@@ -381,20 +421,20 @@ def _schedule_winograd(cfg, s, op):
     s[VL].compute_at(s[V], tt)
 
     # batch gemm
-    bna = cfg['tile_bna'].val
-    bnb = cfg['tile_bnb'].val
+    bna = cfg["tile_bna"].val
+    bnb = cfg["tile_bnb"].val
 
     eps, nu, k, b = s[M].op.axis
     alpha = eps.dom.extent
     c = s[M].op.reduce_axis[0]
     yo, xo, yi, xi = s[M].tile(k, b, bna, bnb)
-    c, c_unroll = cfg['c_unroll'].apply(s, M, c)
+    c, c_unroll = cfg["c_unroll"].apply(s, M, c)
     s[M].reorder(yo, xo, c, c_unroll, yi, xi)
     s[M].unroll(c_unroll)
     s[M].unroll(yi)
     s[M].vectorize(xi)
     z = s[M].fuse(eps, nu)
-    tile_and_bind3d(s, M, z, yo, xo, 1, cfg['yt'].val, 1)
+    tile_and_bind3d(s, M, z, yo, xo, 1, cfg["yt"].val, 1)
 
     # inverse transform
     s[A].compute_inline()
@@ -414,9 +454,9 @@ def _schedule_winograd(cfg, s, op):
     s[output].unroll(hi)
     s[output].unroll(wi)
     fused = s[output].fuse(n, co, h, w)
-    bb, tt = cfg['tile_t2'].apply(s, output, fused)
-    s[output].bind(bb, te.thread_axis('blockIdx.x'))
-    s[output].bind(tt, te.thread_axis('threadIdx.x'))
+    bb, tt = cfg["tile_t2"].apply(s, output, fused)
+    s[output].bind(bb, te.thread_axis("blockIdx.x"))
+    s[output].bind(tt, te.thread_axis("threadIdx.x"))
 
     s[Y].compute_at(s[output], tt)
 
@@ -428,7 +468,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     dispatch_ctx = autotvm.task.DispatchContext.current
 
     _, outs = relay.backend.compile_engine.select_implementation(
-        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+    )
     workload = autotvm.task.get_workload(outs)
     if workload is None:
         # The best implementation is not an AutoTVM template,
@@ -456,15 +497,16 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
-        VC = cfg['tile_co'].size[-1]
+        VC = cfg["tile_co"].size[-1]
 
-        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+        new_attrs["kernel_layout"] = "OIHW%do" % VC
 
         new_data = data
         new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            "conv2d_nchw_spatial_pack.mali")
+            "conv2d_nchw_spatial_pack.mali",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.conv2d(*inputs, **new_attrs)
@@ -473,31 +515,32 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         N, CI, H, W = get_const_tuple(data.shape)
         CO, _, KH, KW = get_const_tuple(kernel.shape)
         tile_size = _pick_tile_size(data, kernel)
-        VC = cfg['tile_bna'].val
+        VC = cfg["tile_bna"].val
 
         weight_expr = inputs[1]
         weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
-            weight_expr, tile_size=tile_size)
-        weight_expr = relay.reshape(weight_expr,
-                                    newshape=(KH + tile_size - 1,
-                                              KW + tile_size - 1,
-                                              idxd(CO, VC), VC, CI))
+            weight_expr, tile_size=tile_size
+        )
+        weight_expr = relay.reshape(
+            weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), VC, CI)
+        )
         weight_expr = relay.transpose(weight_expr, axes=[0, 1, 2, 4, 3])
 
-        new_attrs['tile_size'] = tile_size
+        new_attrs["tile_size"] = tile_size
 
         new_data = data
-        new_kernel = te.placeholder((KH + tile_size - 1,
-                                     KW + tile_size -1,
-                                     idxd(CO, VC), CI, VC),
-                                    kernel.dtype)
+        new_kernel = te.placeholder(
+            (KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), CI, VC), kernel.dtype
+        )
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            'conv2d_nchw_winograd.mali')
+            "conv2d_nchw_winograd.mali",
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_winograd_without_weight_transform(
-            inputs[0], weight_expr, **new_attrs)
+            inputs[0], weight_expr, **new_attrs
+        )
     else:
         return None
 
index 8ec5d19..7605ace 100644 (file)
@@ -23,14 +23,13 @@ from .. import nn
 from ..util import traverse_inline
 
 
-
-@autotvm.register_topi_compute('dense.mali')
+@autotvm.register_topi_compute("dense.mali")
 def dense(_, data, weight, bias=None, out_dtype=None):
     """Dense operator on Mali"""
     return nn.dense(data, weight, bias, out_dtype)
 
 
-@autotvm.register_topi_schedule('dense.mali')
+@autotvm.register_topi_schedule("dense.mali")
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -51,7 +50,7 @@ def schedule_dense(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense':
+        if op.tag == "dense":
             vec_size = [1, 2, 4, 8, 16]
             max_unroll = 32
 
@@ -62,42 +61,41 @@ def schedule_dense(cfg, outs):
             c = s[dense_out].op.reduce_axis[0]
 
             ##### space definition begin #####
-            cfg.define_split('tile_y', y, num_outputs=3)
-            cfg.define_split('tile_x', x, num_outputs=3)
-            cfg.define_split('c_unroll', c, num_outputs=2, max_factor=64)
+            cfg.define_split("tile_y", y, num_outputs=3)
+            cfg.define_split("tile_x", x, num_outputs=3)
+            cfg.define_split("c_unroll", c, num_outputs=2, max_factor=64)
 
             # fallback support
             if cfg.is_fallback:
-                ref_log = autotvm.tophub.load_reference_log(
-                    'mali', 'rk3399', 'dense.mali')
+                ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "dense.mali")
                 cfg.fallback_with_reference_log(ref_log)
             ##### space definition end #####
 
             if dense_out.op in s.outputs:
-                dense_out = s.cache_write(output, 'local')
+                dense_out = s.cache_write(output, "local")
 
-            by, ty, yi = cfg['tile_y'].apply(s, output, y)
-            bx, tx, xi = cfg['tile_x'].apply(s, output, x)
+            by, ty, yi = cfg["tile_y"].apply(s, output, y)
+            bx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
-            s[output].bind(by, te.thread_axis('blockIdx.y'))
-            s[output].bind(bx, te.thread_axis('blockIdx.x'))
-            s[output].bind(ty, te.thread_axis('threadIdx.y'))
-            s[output].bind(tx, te.thread_axis('threadIdx.x'))
+            s[output].bind(by, te.thread_axis("blockIdx.y"))
+            s[output].bind(bx, te.thread_axis("blockIdx.x"))
+            s[output].bind(ty, te.thread_axis("threadIdx.y"))
+            s[output].bind(tx, te.thread_axis("threadIdx.x"))
 
-            if cfg['tile_y'].size[-1] < max_unroll:
+            if cfg["tile_y"].size[-1] < max_unroll:
                 s[output].unroll(yi)
-            if cfg['tile_x'].size[-1] in vec_size:
+            if cfg["tile_x"].size[-1] in vec_size:
                 s[output].vectorize(xi)
             s[dense_out].compute_at(s[output], tx)
 
             k = s[dense_out].op.reduce_axis[0]
             y, x = s[dense_out].op.axis
-            k, k_unroll = cfg['c_unroll'].apply(s, dense_out, k)
+            k, k_unroll = cfg["c_unroll"].apply(s, dense_out, k)
             s[dense_out].reorder(k, k_unroll, y, x)
             s[dense_out].unroll(k_unroll)
-            if cfg['tile_y'].size[-1] < max_unroll:
+            if cfg["tile_y"].size[-1] < max_unroll:
                 s[dense_out].unroll(y)
-            if cfg['tile_x'].size[-1] in vec_size:
+            if cfg["tile_x"].size[-1] in vec_size:
                 s[dense_out].vectorize(x)
 
     traverse_inline(s, outs[0].op, _callback)
index 785128c..b64135c 100644 (file)
@@ -59,18 +59,18 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
         ##### space definition begin #####
         n, c, y, x = s[conv].op.axis
         bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
-        by, ty, yi = cfg.define_split('tile_y', y, num_outputs=3)
+        by, ty, yi = cfg.define_split("tile_y", y, num_outputs=3)
         bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3)
-        cfg.define_annotate('ann_spatial', [ci, yi, xi], policy='try_unroll_vec')
+        cfg.define_annotate("ann_spatial", [ci, yi, xi], policy="try_unroll_vec")
 
         # fallback support
         if cfg.is_fallback:
             ref_log = autotvm.tophub.load_reference_log(
-                'mali', 'rk3399', 'depthwise_conv2d_nchw.mali')
+                "mali", "rk3399", "depthwise_conv2d_nchw.mali"
+            )
             cfg.fallback_with_reference_log(ref_log)
         ###### space definition end ######
 
-
         # schedule padding
         n, c, y, x = s[pad_data].op.axis
         tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1)
@@ -81,17 +81,17 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
 
         # schedule conv
         if conv.op not in s.outputs:
-            s[conv].set_scope('local')
+            s[conv].set_scope("local")
             OL = conv
             output = s.outputs[0].output(0)
         else:
-            OL = s.cache_write(conv, 'local')
+            OL = s.cache_write(conv, "local")
             output = conv
 
         n, c, y, x = s[output].op.axis
-        bc, tc, ci = cfg['tile_c'].apply(s, output, c)
-        by, ty, yi = cfg['tile_y'].apply(s, output, y)
-        bx, tx, xi = cfg['tile_x'].apply(s, output, x)
+        bc, tc, ci = cfg["tile_c"].apply(s, output, c)
+        by, ty, yi = cfg["tile_y"].apply(s, output, y)
+        bx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
         bc = s[output].fuse(n, bc)
         s[output].bind(bc, te.thread_axis("blockIdx.z"))
@@ -108,17 +108,20 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
         s[OL].compute_at(s[output], tx)
         n, ci, yi, xi = s[OL].op.axis
 
-        cfg["ann_spatial"].apply(s, OL, [ci, yi, xi],
-                                 axis_lens=[cfg['tile_c'].size[2], cfg['tile_y'].size[2],
-                                            cfg['tile_x'].size[2]],
-                                 max_unroll=max_unroll,
-                                 vec_size=vec_size,
-                                 cfg=cfg)
+        cfg["ann_spatial"].apply(
+            s,
+            OL,
+            [ci, yi, xi],
+            axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]],
+            max_unroll=max_unroll,
+            vec_size=vec_size,
+            cfg=cfg,
+        )
 
     def _callback(op):
         """traverse to find op to schedule"""
         # schedule depthwise_conv2d
-        if op.tag == 'depthwise_conv2d_nchw':
+        if op.tag == "depthwise_conv2d_nchw":
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
             conv = op.output(0)
index 046b103..6d71348 100644 (file)
@@ -278,6 +278,7 @@ def atan(x):
     """
     return te.compute(x.shape, lambda *i: te.atan(x(*i)))
 
+
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
 def atanh(x):
     """Take atanh of input x.
@@ -294,6 +295,7 @@ def atanh(x):
     """
     return te.compute(x.shape, lambda *i: te.atanh(x(*i)))
 
+
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
 def floor(x):
     """Take floor of input x.
@@ -605,13 +607,16 @@ def clip(x, a_min, a_max):
     y : tvm.te.Tensor
         The result.
     """
+
     def _compute(*indices):
         value = x(*indices)
         const_min = tvm.tir.const(a_min, value.dtype)
         const_max = tvm.tir.const(a_max, value.dtype)
         return tvm.te.max(tvm.te.min(value, const_max), const_min)
+
     return te.compute(x.shape, _compute)
 
+
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
 def fixed_point_multiply(x, multiplier, shift):
     """Fixed point multiplication between data and a fixed point
@@ -632,14 +637,19 @@ def fixed_point_multiply(x, multiplier, shift):
     y : tvm.te.Tensor
         The result.
     """
+
     def _compute(*indices):
         value = x(*indices)
-        return tvm.tir.q_multiply_shift(value,
-                                        tvm.tir.const(multiplier, 'int32'),
-                                        tvm.tir.const(31, 'int32'),
-                                        tvm.tir.const(shift, 'int32'))
+        return tvm.tir.q_multiply_shift(
+            value,
+            tvm.tir.const(multiplier, "int32"),
+            tvm.tir.const(31, "int32"),
+            tvm.tir.const(shift, "int32"),
+        )
+
     return te.compute(x.shape, _compute)
 
+
 def cast(x, dtype):
     """Cast input to specified data type.
 
@@ -657,10 +667,10 @@ def cast(x, dtype):
         The result.
     """
     if isinstance(x, te.tensor.Tensor):
-        return te.compute(
-            x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
+        return te.compute(x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
     # pylint: disable=import-outside-toplevel
     from tvm.tir import _ffi_api
+
     return _ffi_api._cast(dtype, x)
 
 
index 0d9f351..7c8fead 100644 (file)
@@ -19,6 +19,7 @@
 from tvm import te
 from ..util import get_const_tuple
 
+
 def batch_matmul(x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
@@ -43,7 +44,7 @@ def batch_matmul(x, y):
     assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
     batch, M, K = x.shape
     N = y.shape[1]
-    k = te.reduce_axis((0, K), name='k')
-    return te.compute((batch, M, N),
-                      lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k),
-                      tag='batch_matmul')
+    k = te.reduce_axis((0, K), name="k")
+    return te.compute(
+        (batch, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
+    )
index e1a7697..d104519 100644 (file)
@@ -24,8 +24,18 @@ from .util import get_pad_tuple
 from .bitserial_util import bitpack
 from ..util import get_const_tuple
 
-def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
-                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+
+def bitserial_conv2d_nchw(
+    data,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     """Bitserial Conv2D operator.
 
     Parameters
@@ -88,35 +98,67 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
     out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
     out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
 
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
-    b1 = te.reduce_axis((0, activation_bits), name='b1')
-    b2 = te.reduce_axis((0, weight_bits), name='b2')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    b1 = te.reduce_axis((0, activation_bits), name="b1")
+    b2 = te.reduce_axis((0, weight_bits), name="b2")
 
     if unipolar:
+
         def _conv(nn, ff, yy, xx):
-            b1b2 = (b1+b2).astype(out_dtype)
+            b1b2 = (b1 + b2).astype(out_dtype)
             return te.sum(
-                ((tvm.tir.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
-                                   Filter_q[ff, rc, ry, rx, b2]) -
-                  tvm.tir.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
-                                   ~Filter_q[ff, rc, ry, rx, b2]))
-                 << (b1b2)).astype(out_dtype),
-                axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
-    else:
-        def _conv(nn, ff, yy, xx):
-            b1b2 = (b1+b2).astype(out_dtype)
-            return te.sum((tvm.tir.popcount(
-                PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
-                Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype),
-                          axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
+                (
+                    (
+                        tvm.tir.popcount(
+                            PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx]
+                            & Filter_q[ff, rc, ry, rx, b2]
+                        )
+                        - tvm.tir.popcount(
+                            PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx]
+                            & ~Filter_q[ff, rc, ry, rx, b2]
+                        )
+                    )
+                    << (b1b2)
+                ).astype(out_dtype),
+                axis=[rc, ry, rx, b2, b1],
+            ).astype(out_dtype)
 
-    return te.compute((batch, out_channel, out_height, out_width), _conv,
-                      name="Conv2dOutput", tag="bitserial_conv2d_nchw")
+    else:
 
-def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
-                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+        def _conv(nn, ff, yy, xx):
+            b1b2 = (b1 + b2).astype(out_dtype)
+            return te.sum(
+                (
+                    tvm.tir.popcount(
+                        PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx]
+                        & Filter_q[ff, rc, ry, rx, b2]
+                    )
+                    << (b1b2)
+                ).astype(out_dtype),
+                axis=[rc, ry, rx, b2, b1],
+            ).astype(out_dtype)
+
+    return te.compute(
+        (batch, out_channel, out_height, out_width),
+        _conv,
+        name="Conv2dOutput",
+        tag="bitserial_conv2d_nchw",
+    )
+
+
+def bitserial_conv2d_nhwc(
+    data,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     """Bitserial Conv2D operator.
 
     Parameters
@@ -180,36 +222,58 @@ def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight
     out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
     PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput")
 
-    rc = te.reduce_axis((0, in_channel_q), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
-    b1 = te.reduce_axis((0, activation_bits), name='b1')
-    b2 = te.reduce_axis((0, weight_bits), name='b2')
+    rc = te.reduce_axis((0, in_channel_q), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    b1 = te.reduce_axis((0, activation_bits), name="b1")
+    b2 = te.reduce_axis((0, weight_bits), name="b2")
 
     if unipolar:
+
         def _conv(nn, yy, xx, ff):
-            b1b2 = (b1+b2).astype(out_dtype)
+            b1b2 = (b1 + b2).astype(out_dtype)
             return te.sum(
-                ((tvm.tir.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
-                                   Filter_q[ry, rx, rc, ff, b2]) -
-                  tvm.tir.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
-                                   ~Filter_q[ry, rx, rc, ff, b2]))
-                 << b1b2).astype(out_dtype),
-                axis=[rc, ry, rx, b2, b1])
+                (
+                    (
+                        tvm.tir.popcount(
+                            PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1]
+                            & Filter_q[ry, rx, rc, ff, b2]
+                        )
+                        - tvm.tir.popcount(
+                            PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1]
+                            & ~Filter_q[ry, rx, rc, ff, b2]
+                        )
+                    )
+                    << b1b2
+                ).astype(out_dtype),
+                axis=[rc, ry, rx, b2, b1],
+            )
 
     else:
-        def _conv(nn, yy, xx, ff):
-            b1b2 = (b1+b2).astype(out_dtype)
-            return te.sum((tvm.tir.popcount(
-                PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
-                Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
-                          axis=[rc, ry, rx, b2, b1])
 
-    conv = te.compute((batch, out_height, out_width, out_channel), _conv,
-                      name="Conv2dOutput", tag="bitserial_conv2d_nhwc")
+        def _conv(nn, yy, xx, ff):
+            b1b2 = (b1 + b2).astype(out_dtype)
+            return te.sum(
+                (
+                    tvm.tir.popcount(
+                        PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1]
+                        & Filter_q[ry, rx, rc, ff, b2]
+                    )
+                    << b1b2
+                ).astype(out_dtype),
+                axis=[rc, ry, rx, b2, b1],
+            )
+
+    conv = te.compute(
+        (batch, out_height, out_width, out_channel),
+        _conv,
+        name="Conv2dOutput",
+        tag="bitserial_conv2d_nhwc",
+    )
 
     return conv
 
+
 @tvm.target.generic_func
 def bitserial_conv2d_legalize(attrs, inputs, types):
     """Legalizes Bitserial Conv2D op.
index 97d1fb2..0b86e2e 100644 (file)
@@ -22,8 +22,10 @@ from tvm import te
 from tvm.topi.util import get_const_tuple
 from .bitserial_util import bitpack
 
-def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
-                    out_dtype='int16', unipolar=True):
+
+def bitserial_dense(
+    data, weight, data_bits, weight_bits, pack_dtype="uint32", out_dtype="int16", unipolar=True
+):
     """The default implementation of bitserial dense in topi.
 
     Parameters
@@ -47,20 +49,32 @@ def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
     X, WB, _ = get_const_tuple(weight_packed.shape)
 
     oshape = (Y, X)
-    k = te.reduce_axis((0, K), name='k')
-    db = te.reduce_axis((0, DB), name='db')
-    wb = te.reduce_axis((0, WB), name='wb')
-
-    matmul_unipolar = te.compute(oshape, lambda i, j: te.sum(
-        (tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) -
-         tvm.tir.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]),
-                                 tag='bitserial_dense_unipolar')
+    k = te.reduce_axis((0, K), name="k")
+    db = te.reduce_axis((0, DB), name="db")
+    wb = te.reduce_axis((0, WB), name="wb")
 
-    matmul = te.compute(oshape, lambda i, j: te.sum(
-        tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
+    matmul_unipolar = te.compute(
+        oshape,
+        lambda i, j: te.sum(
+            (
+                tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k])
+                - tvm.tir.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])
+            ).astype(out_dtype)
+            << (db + wb).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense_unipolar",
+    )
 
+    matmul = te.compute(
+        oshape,
+        lambda i, j: te.sum(
+            tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
+            << (db + wb).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense",
+    )
 
     if unipolar:
         return matmul_unipolar
index 2b320b8..ae43668 100644 (file)
@@ -22,6 +22,7 @@ from tvm import te
 from tvm.topi.transform import concatenate
 from ..util import get_const_int
 
+
 def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
     """Packs data into format necessary for bitserial computation
 
@@ -34,20 +35,20 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
     """
     ishape = data.shape
     n = len(ishape)
-    if pack_type == 'uint8':
+    if pack_type == "uint8":
         data_width = 8
-    elif pack_type == 'uint16':
+    elif pack_type == "uint16":
         data_width = 16
-    elif pack_type == 'uint32':
+    elif pack_type == "uint32":
         data_width = 32
-    elif pack_type == 'uint64':
+    elif pack_type == "uint64":
         data_width = 64
 
     # Data must be in multiples of the data_width
     assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
 
     shape_vec = list(ishape)
-    shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
+    shape_vec[pack_axis] = shape_vec[pack_axis] // data_width
     shape_vec.insert(bit_axis, 1)
     bitserial_oshape = tuple(shape_vec)
     masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
@@ -62,7 +63,7 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
             # Translate indices for packed data back to original
             idx = [0] * n
             j = 0
-            for i in range(n+1):
+            for i in range(n + 1):
                 if i == bit_axis:
                     continue
                 if i == pack_axis:
@@ -73,9 +74,10 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
 
             element = data(*idx)
             for b in range(bits):
-                extracted_bit = (
-                    (element & tvm.tir.const(masks[b], "int32")) >> b).astype(pack_type)
-                packed_data[b] = (packed_data[b] | extracted_bit)
+                extracted_bit = ((element & tvm.tir.const(masks[b], "int32")) >> b).astype(
+                    pack_type
+                )
+                packed_data[b] = packed_data[b] | extracted_bit
                 if k < data_width - 1:
                     packed_data[b] = packed_data[b] << 1
 
@@ -83,14 +85,15 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
                 return tuple(packed_data)
         return tuple(packed_data)
 
-    output_tuple = te.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
+    output_tuple = te.compute(bitserial_oshape, _bitpack, name=name, tag="bitpack")
 
     if bits > 1:
         return concatenate(output_tuple, axis=bit_axis)
     return output_tuple
 
+
 def binary_op_multiplier(pack_dtype):
-    """"Returns number of bits packed into
+    """ "Returns number of bits packed into
     pack_dtype: string
         pack type for the operator (must be a uint)"""
     return int(pack_dtype[4:])
index d7355fb..6c36b37 100644 (file)
@@ -47,22 +47,21 @@ def binarize_pack(data, axis=None, name="PackedInput"):
         axis = len(ishape) - 1
     assert get_const_int(ishape[axis]) % 32 == 0
     n = len(ishape)
-    oshape = tuple(simplify(ishape[i] // 32) if i == axis \
-                   else ishape[i] for i in range(n))
+    oshape = tuple(simplify(ishape[i] // 32) if i == axis else ishape[i] for i in range(n))
 
     def _binarize_pack(*indices):
         start_idx = [indices[i] * 32 if i == axis else indices[i] for i in range(n)]
-        packed = tvm.tir.const(0, 'uint32')
+        packed = tvm.tir.const(0, "uint32")
         for j in range(32):
             idx = [start_idx[i] + j if i == axis else start_idx[i] for i in range(n)]
             sign = (data(*idx) >= 0).astype("uint32")
-            packed = (packed | sign)
+            packed = packed | sign
             if j == 31:
                 return packed
             packed = packed << 1
         raise RuntimeError("not resach")
 
-    return te.compute(oshape, _binarize_pack, name=name, tag='binarize_pack')
+    return te.compute(oshape, _binarize_pack, name=name, tag="binarize_pack")
 
 
 def binary_dense(data, weight):
@@ -81,17 +80,19 @@ def binary_dense(data, weight):
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim], dtype is float32.
     """
-    assert data.dtype == 'uint32' and weight.dtype == 'uint32', \
-        "dtype of data and weight should be uint32"
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim binary dense"
+    assert (
+        data.dtype == "uint32" and weight.dtype == "uint32"
+    ), "dtype of data and weight should be uint32"
+    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim binary dense"
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
-    k = te.reduce_axis((0, in_dim), name='k')
-    matmul = te.compute((batch, out_dim), lambda i, j: \
-                        te.sum(tvm.tir.popcount(data[i, k] ^ weight[j, k]), axis=k), \
-                        tag='binary_dense')
+    k = te.reduce_axis((0, in_dim), name="k")
+    matmul = te.compute(
+        (batch, out_dim),
+        lambda i, j: te.sum(tvm.tir.popcount(data[i, k] ^ weight[j, k]), axis=k),
+        tag="binary_dense",
+    )
 
-    return te.compute((batch, out_dim), lambda i, j: \
-                      32 * in_dim - 2. * matmul(i, j), \
-                      tag=tag.ELEMWISE)
+    return te.compute(
+        (batch, out_dim), lambda i, j: 32 * in_dim - 2.0 * matmul(i, j), tag=tag.ELEMWISE
+    )
index 8049dff..cffed66 100644 (file)
@@ -22,14 +22,8 @@ from ..util import simplify
 from .util import get_pad_tuple1d
 
 
-def conv1d(data,
-           kernel,
-           strides=1,
-           padding='VALID',
-           dilation=1,
-           layout='NCW',
-           out_dtype=None):
-    """ 1D convolution forward operator.
+def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", out_dtype=None):
+    """1D convolution forward operator.
 
     Parameters
     ----------
@@ -63,20 +57,15 @@ def conv1d(data,
     if isinstance(dilation, (tuple, list)):
         dilation = dilation[0]
 
-    if layout == 'NCW':
+    if layout == "NCW":
         return conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
-    if layout == 'NWC':
+    if layout == "NWC":
         return conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
     raise ValueError("This layout is not yet supported: {}".format(layout))
 
 
-def conv1d_ncw(data,
-               kernel,
-               strides=1,
-               padding='VALID',
-               dilation=1,
-               out_dtype=None):
-    """ 1D convolution forward operator for NCW layout.
+def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
+    """1D convolution forward operator for NCW layout.
 
     Parameters
     ----------
@@ -111,36 +100,32 @@ def conv1d_ncw(data,
 
     # Compute the output shape
     dilated_kernel_size = (kernel_size - 1) * dilation + 1
-    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, ))
+    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
     out_channels = simplify(out_channels)
-    out_width = simplify(
-        (data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
+    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
 
     # Apply padding
     pad_before = [0, 0, pad_left]
     pad_after = [0, 0, pad_right]
-    temp = pad(data, pad_before, pad_after, name='pad_temp')
+    temp = pad(data, pad_before, pad_after, name="pad_temp")
 
     # Compute graph
-    rc = te.reduce_axis((0, in_channels), name='rc')
-    rw = te.reduce_axis((0, kernel_size), name='rw')
+    rc = te.reduce_axis((0, in_channels), name="rc")
+    rw = te.reduce_axis((0, kernel_size), name="rw")
 
     return te.compute(
         (batch, out_channels, out_width),
         lambda b, c, w: te.sum(
             temp[b, rc, w * strides + rw * dilation].astype(out_dtype)
             * kernel[c, rc, rw].astype(out_dtype),
-            axis=[rc, rw]),
-        tag="conv1d_ncw")
+            axis=[rc, rw],
+        ),
+        tag="conv1d_ncw",
+    )
 
 
-def conv1d_nwc(data,
-               kernel,
-               strides=1,
-               padding='VALID',
-               dilation=1,
-               out_dtype=None):
-    """ 1D convolution forward operator for NWC layout.
+def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
+    """1D convolution forward operator for NWC layout.
 
     Parameters
     ----------
@@ -175,24 +160,25 @@ def conv1d_nwc(data,
 
     # Compute the output shape
     dilated_kernel_size = (kernel_size - 1) * dilation + 1
-    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, ))
+    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
     out_channels = simplify(out_channels)
-    out_width = simplify(
-        (data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
+    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
 
     # Apply padding
     pad_before = [0, pad_left, 0]
     pad_after = [0, pad_right, 0]
-    temp = pad(data, pad_before, pad_after, name='pad_temp')
+    temp = pad(data, pad_before, pad_after, name="pad_temp")
 
     # Compute graph
-    rc = te.reduce_axis((0, in_channels), name='rc')
-    rw = te.reduce_axis((0, kernel_size), name='rw')
+    rc = te.reduce_axis((0, in_channels), name="rc")
+    rw = te.reduce_axis((0, kernel_size), name="rw")
 
     return te.compute(
         (batch, out_width, out_channels),
         lambda b, w, c: te.sum(
             temp[b, w * strides + rw * dilation, rc].astype(out_dtype)
             * kernel[rw, rc, c].astype(out_dtype),
-            axis=[rc, rw]),
-        tag="conv1d_nwc")
+            axis=[rc, rw],
+        ),
+        tag="conv1d_nwc",
+    )
index b5b55d2..813377e 100644 (file)
@@ -23,8 +23,7 @@ from ..util import simplify
 from .util import get_pad_tuple1d
 
 
-def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype,
-                         output_padding):
+def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, output_padding):
     """Transposed 1D convolution ncw forward operator.
 
     Parameters
@@ -64,27 +63,31 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype,
     _, channels_out, kernel_width = kernel.shape
     assert output_padding < stride
     channels_out = simplify(channels_out)
-    data = dilate(data, [1, 1, stride], name='data_dilate')
+    data = dilate(data, [1, 1, stride], name="data_dilate")
     pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,))
     pad_left = kernel_width - 1 - pad_left
     pad_right = kernel_width - 1 - pad_right + output_padding
-    data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name='data_pad')
+    data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name="data_pad")
 
     # transpose kernel, switch kernel layout to IOW
-    kernel = te.compute((channels_out, channels_in, kernel_width), \
-                        lambda o, i, w: kernel[i][o][kernel_width-1-w],\
-                        name='kernel')
+    kernel = te.compute(
+        (channels_out, channels_in, kernel_width),
+        lambda o, i, w: kernel[i][o][kernel_width - 1 - w],
+        name="kernel",
+    )
 
     # convolution
     _, _, data_width = data.shape
     out_w = simplify(data_width - kernel_width + 1)
-    dc = te.reduce_axis((0, channels_in), name='dc')
-    dw = te.reduce_axis((0, kernel_width), name='dw')
+    dc = te.reduce_axis((0, channels_in), name="dc")
+    dw = te.reduce_axis((0, kernel_width), name="dw")
     output = te.compute(
         (batch, channels_out, out_w),
         lambda b, c, w: te.sum(
-            data[b, dc, w+dw].astype(out_dtype) *
-            kernel[c, dc, dw].astype(out_dtype),
-            axis=[dc, dw]), tag="conv1d_transpose_ncw")
+            data[b, dc, w + dw].astype(out_dtype) * kernel[c, dc, dw].astype(out_dtype),
+            axis=[dc, dw],
+        ),
+        tag="conv1d_transpose_ncw",
+    )
 
     return output
index d3be6bb..5245584 100644 (file)
@@ -28,11 +28,27 @@ from ..util import simplify, get_const_tuple, get_const_int, tag
 from .winograd_util import winograd_transform_matrices
 
 # workload description of conv2d
-Workload = namedtuple('Workload',
-                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
-                       'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
-
-def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
+Workload = namedtuple(
+    "Workload",
+    [
+        "in_dtype",
+        "out_dtype",
+        "height",
+        "width",
+        "in_filter",
+        "groups",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
+
+
+def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=None):
     """Conv2D operator.
 
     Parameters
@@ -64,11 +80,11 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
     """
     # search platform specific declaration first
     # default declaration
-    if layout == 'NCHW':
+    if layout == "NCHW":
         return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
-    if layout == 'HWCN':
+    if layout == "HWCN":
         return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
     raise ValueError("not support this layout {} yet".format(layout))
 
@@ -117,6 +133,7 @@ def conv2d_alter_layout(attrs, inputs, tinfos, out_type):
     # not to change by default
     return None
 
+
 @tvm.target.generic_func
 def conv2d_infer_layout(workload, cfg):
     """Infer input/output shapes and layouts from a workload and cfg.
@@ -137,19 +154,18 @@ def conv2d_infer_layout(workload, cfg):
     raise ValueError("missing register for topi.nn.conv2d_infer_layout")
 
 
-
-def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
+def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
     """ Get the workload structure. """
-    if data_layout == 'NCHW':
+    if data_layout == "NCHW":
         _, CI, IH, IW = get_const_tuple(data.shape)
-    elif data_layout == 'NHWC':
+    elif data_layout == "NHWC":
         _, IH, IW, CI = get_const_tuple(data.shape)
-    elif data_layout == 'HWCN':
+    elif data_layout == "HWCN":
         IH, IW, CI, _ = get_const_tuple(data.shape)
     else:
         raise ValueError("not support this layout {} yet".format(data_layout))
 
-    if data_layout == 'NCHW':
+    if data_layout == "NCHW":
         CO, CIG, KH, KW = get_const_tuple(kernel.shape)
     else:
         KH, KW, CIG, CO = get_const_tuple(kernel.shape)
@@ -160,9 +176,12 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
-        "Do not support inputs with different data types now. ' \
-        '{} vs. {}".format(data.dtype, kernel.dtype)
+    assert (data.dtype == kernel.dtype) or (
+        data.dtype == "uint8" and kernel.dtype == "int8"
+    ), "Do not support inputs with different data types now. ' \
+        '{} vs. {}".format(
+        data.dtype, kernel.dtype
+    )
     return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
@@ -213,7 +232,8 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
@@ -221,16 +241,20 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
     pad_before = [0, 0, pad_top, pad_left]
     pad_after = [0, 0, pad_down, pad_right]
     temp = pad(Input, pad_before, pad_after, name="pad_temp")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     return te.compute(
         (batch, out_channel, out_height, out_width),
         lambda nn, ff, yy, xx: te.sum(
-            temp[nn, rc, yy * stride_h + ry * dilation_h,
-                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
-            Filter[ff, rc, ry, rx].astype(out_dtype),
-            axis=[rc, ry, rx]), tag="conv2d_nchw")
+            temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
+                out_dtype
+            )
+            * Filter[ff, rc, ry, rx].astype(out_dtype),
+            axis=[rc, ry, rx],
+        ),
+        tag="conv2d_nchw",
+    )
 
 
 def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
@@ -281,27 +305,33 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
     pad_before = [pad_top, pad_left, 0, 0]
     pad_after = [pad_down, pad_right, 0, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     Output = te.compute(
         (out_height, out_width, out_channel, batch),
         lambda yy, xx, ff, nn: te.sum(
-            PaddedInput[yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w,
-                        rc, nn].astype(out_dtype) *
-            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
-        name="Conv2dOutput", tag="conv2d_hwcn")
+            PaddedInput[
+                yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc, nn
+            ].astype(out_dtype)
+            * Filter[ry, rx, rc, ff].astype(out_dtype),
+            axis=[ry, rx, rc],
+        ),
+        name="Conv2dOutput",
+        tag="conv2d_hwcn",
+    )
     return Output
 
 
-def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
+def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"):
     """Convolution operator in NHWC layout.
 
     Parameters
@@ -347,27 +377,33 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
     pad_before = [0, pad_top, pad_left, 0]
     pad_after = [0, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     Output = te.compute(
         (batch, out_height, out_width, out_channel),
         lambda nn, yy, xx, ff: te.sum(
-            PaddedInput[nn, yy * stride_h + ry * dilation_h,
-                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
-        name="Conv2dOutput", tag="conv2d_nhwc")
+            PaddedInput[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc
+            ].astype(out_dtype)
+            * Filter[ry, rx, rc, ff].astype(out_dtype),
+            axis=[ry, rx, rc],
+        ),
+        name="Conv2dOutput",
+        tag="conv2d_nhwc",
+    )
     return Output
 
 
-def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'):
+def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32"):
     """Conv2D operator for nChw[x]c layout.
 
     Parameters
@@ -409,14 +445,14 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
     HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
-    dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
-        else (dilation, dilation)
+    dilation_h, dilation_w = (
+        dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    )
 
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
     target = tvm.target.Target.current(allow_none=False)
-    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
-        get_const_tuple(kernel.shape)
+    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
     num_filter = oc_chunk * oc_bn
     groups = ic_chunk // ic_chunk_group
 
@@ -424,7 +460,8 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
     dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
 
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HPAD = pad_top + pad_down
     WPAD = pad_left + pad_right
 
@@ -436,37 +473,40 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
     pad_after = (0, 0, pad_down, pad_right, 0)
 
     # DOPAD
-    DOPAD = (HPAD != 0 or WPAD != 0)
+    DOPAD = HPAD != 0 or WPAD != 0
     if DOPAD:
         data_pad = pad(data, pad_before, pad_after, name="data_pad")
     else:
         data_pad = data
 
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    return te.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                      te.sum(data_pad[n,
-                                      idxdiv(ic, ic_bn),
-                                      oh * HSTR + kh * dilation_h,
-                                      ow * WSTR + kw * dilation_w,
-                                      idxmod(ic, ic_bn)].astype(out_dtype)
-                             * kernel[oc_chunk,
-                                      idxdiv(ic, ic_bn),
-                                      kh,
-                                      kw,
-                                      idxmod(ic, ic_bn),
-                                      oc_block],
-                             axis=[ic, kh, kw]),
-                      name='conv2d_NCHWc', tag="conv2d_NCHWc")
-
-
-def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layout,
-                      out_dtype='int32'):
+    return te.compute(
+        oshape,
+        lambda n, oc_chunk, oh, ow, oc_block: te.sum(
+            data_pad[
+                n,
+                idxdiv(ic, ic_bn),
+                oh * HSTR + kh * dilation_h,
+                ow * WSTR + kw * dilation_w,
+                idxmod(ic, ic_bn),
+            ].astype(out_dtype)
+            * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block],
+            axis=[ic, kh, kw],
+        ),
+        name="conv2d_NCHWc",
+        tag="conv2d_NCHWc",
+    )
+
+
+def conv2d_NCHWc_int8(
+    data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
+):
     """Conv2D operator for nChw[x]c layout.
 
     Parameters
@@ -508,22 +548,24 @@ def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layou
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
     HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
-    dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
-        else (dilation, dilation)
+    dilation_h, dilation_w = (
+        dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    )
 
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
-    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
-        get_const_tuple(kernel.shape)
+    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(
+        kernel.shape
+    )
     num_filter = oc_chunk * oc_bn
     groups = ic_chunk // ic_chunk_group
 
     dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
 
-
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HPAD = pad_top + pad_down
     WPAD = pad_left + pad_right
 
@@ -535,59 +577,62 @@ def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layou
     pad_after = (0, 0, pad_down, pad_right, 0)
 
     # DOPAD
-    DOPAD = (HPAD != 0 or WPAD != 0)
+    DOPAD = HPAD != 0 or WPAD != 0
     if DOPAD:
         data_pad = pad(data, pad_before, pad_after, name="data_pad")
     else:
         data_pad = data
 
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
 
     if groups == 1:
         n_elems = 4
-        ic_outer = te.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
-        ic_f_inner = te.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
-        ic_s_inner = te.reduce_axis((0, n_elems), name='ic_s_inner')
-        return te.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                          te.sum(data_pad[n,
-                                          ic_outer,
-                                          oh * HSTR + kh * dilation_h,
-                                          ow * WSTR + kw * dilation_w,
-                                          ic_f_inner * n_elems + ic_s_inner].astype(out_dtype)
-                                 * kernel[oc_chunk,
-                                          ic_outer,
-                                          kh,
-                                          kw,
-                                          ic_f_inner,
-                                          oc_block,
-                                          ic_s_inner].astype(out_dtype),
-                                 axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
-                          name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
+        ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer")
+        ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
+        ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner")
+        return te.compute(
+            oshape,
+            lambda n, oc_chunk, oh, ow, oc_block: te.sum(
+                data_pad[
+                    n,
+                    ic_outer,
+                    oh * HSTR + kh * dilation_h,
+                    ow * WSTR + kw * dilation_w,
+                    ic_f_inner * n_elems + ic_s_inner,
+                ].astype(out_dtype)
+                * kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(
+                    out_dtype
+                ),
+                axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner],
+            ),
+            name="conv2d_NCHWc_int8",
+            tag="conv2d_NCHWc_int8",
+        )
     # for int8 group conv support
     n_elems = 4
-    ic_chunk = in_channel//ic_bn
-    ic_outer = te.reduce_axis((0, ic_chunk//groups), name='ic_outer')
-    ic_f_inner = te.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
-    ic_s_inner = te.reduce_axis((0, n_elems), name='ic_s_inner')
+    ic_chunk = in_channel // ic_bn
+    ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer")
+    ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
+    ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner")
     oshape = (n, oc_chunk, out_height, out_width, oc_bn)
-    return te.compute(oshape, lambda n, occ, oh, ow, oc_block:
-                      te.sum(data_pad[n,
-                                      (occ * oc_bn // (oc_chunk * oc_bn // groups))
-                                      * (ic_chunk // groups) + ic_outer,
-                                      oh * HSTR + kh,
-                                      ow * WSTR + kw,
-                                      ic_f_inner * n_elems +  ic_s_inner].astype(out_dtype)
-                             * kernel[occ,
-                                      ic_outer,
-                                      kh,
-                                      kw,
-                                      ic_f_inner,
-                                      oc_block,
-                                      ic_s_inner].astype(out_dtype),
-                             axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
-                      name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
+    return te.compute(
+        oshape,
+        lambda n, occ, oh, ow, oc_block: te.sum(
+            data_pad[
+                n,
+                (occ * oc_bn // (oc_chunk * oc_bn // groups)) * (ic_chunk // groups) + ic_outer,
+                oh * HSTR + kh,
+                ow * WSTR + kw,
+                ic_f_inner * n_elems + ic_s_inner,
+            ].astype(out_dtype)
+            * kernel[occ, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype),
+            axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner],
+        ),
+        name="conv2d_NCHWc_int8",
+        tag="conv2d_NCHWc_int8",
+    )
 
 
 def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
@@ -611,9 +656,9 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
     K = KH * KW * IC
     N = OC
 
-    kernel_flat = te.compute((K, N), lambda x, y:
-                             kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y],
-                             'weight_flatten')
+    kernel_flat = te.compute(
+        (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten"
+    )
 
     pad_K = 0
     pad_N = 0
@@ -628,15 +673,15 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
     K_padded = K + pad_K
 
     if pad_K != 0 or pad_N != 0:
-        kernel_flat = pad(kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N),
-                          name='weight_padding')
+        kernel_flat = pad(
+            kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
+        )
 
-    return te.compute((N_padded // tile_rows,
-                       K_padded // tile_cols,
-                       tile_rows,
-                       tile_cols), lambda x, y, z, w:
-                      kernel_flat[w + tile_cols * y, z + tile_rows * x],
-                      name='weight_block_reshape')
+    return te.compute(
+        (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols),
+        lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x],
+        name="weight_block_reshape",
+    )
 
 
 def conv2d_winograd_weight_transform(kernel, tile_size):
@@ -663,12 +708,15 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
 
     _, _, G = winograd_transform_matrices(tile_size, K, kernel.dtype)
 
-    r_kh = te.reduce_axis((0, K), name='r_kh')
-    r_kw = te.reduce_axis((0, K), name='r_kw')
-    return te.compute(shape, lambda eps, nu, co, ci:
-                      te.sum(kernel[co][ci][r_kh][r_kw] *
-                             G[eps][r_kh] * G[nu][r_kw],
-                             axis=[r_kh, r_kw]), name='transform_weight')
+    r_kh = te.reduce_axis((0, K), name="r_kh")
+    r_kw = te.reduce_axis((0, K), name="r_kw")
+    return te.compute(
+        shape,
+        lambda eps, nu, co, ci: te.sum(
+            kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+        ),
+        name="transform_weight",
+    )
 
 
 def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
@@ -688,8 +736,10 @@ def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_d
     """
     # pylint: disable=import-outside-toplevel
     from tvm.contrib import nnpack
+
     return nnpack.convolution_inference_weight_transform(
-        kernel, algorithm=convolution_algorithm, dtype=out_dtype)
+        kernel, algorithm=convolution_algorithm, dtype=out_dtype
+    )
 
 
 def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
@@ -745,29 +795,36 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
     assert in_channel % groups == 0, "input channels must divide group size"
     assert num_filter % groups == 0, "output channels must divide group size"
 
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (kernel_h, kernel_w))
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
     # compute the output shape
     out_channel = num_filter
     out_height = simplify(
-        (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1)
+        (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1
+    )
     out_width = simplify(
-        (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1)
+        (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1
+    )
     # compute graph
     pad_before = [0, 0, pad_top, pad_left]
     pad_after = [0, 0, pad_down, pad_right]
     temp = pad(Input, pad_before, pad_after, name="pad_temp")
-    rc = te.reduce_axis((0, in_channel // groups), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel // groups), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     return te.compute(
         (batch, out_channel, out_height, out_width),
         lambda nn, ff, yy, xx: te.sum(
-            temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc,
-                 yy * stride_h + ry * dilation_h,
-                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
-            Filter[ff, rc, ry, rx].astype(out_dtype),
-            axis=[rc, ry, rx]), tag='group_conv2d_nchw')
+            temp[
+                nn,
+                ff // (num_filter // groups) * (in_channel // groups) + rc,
+                yy * stride_h + ry * dilation_h,
+                xx * stride_w + rx * dilation_w,
+            ].astype(out_dtype)
+            * Filter[ff, rc, ry, rx].astype(out_dtype),
+            axis=[rc, ry, rx],
+        ),
+        tag="group_conv2d_nchw",
+    )
 
 
 def unpack_NCHWc_to_nchw(packed_out, out_dtype):
@@ -792,11 +849,12 @@ def unpack_NCHWc_to_nchw(packed_out, out_dtype):
     idxdiv = tvm.tir.indexdiv
 
     oshape = (n, oc_chunk * oc_bn, oh, ow)
-    unpacked_out = \
-        te.compute(oshape,
-                   lambda n, c, h, w:
-                   packed_out[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
-                   .astype(out_dtype),
-                   name='output_unpack',
-                   tag=tag.INJECTIVE+",unpack_nchwc")
+    unpacked_out = te.compute(
+        oshape,
+        lambda n, c, h, w: packed_out[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)].astype(
+            out_dtype
+        ),
+        name="output_unpack",
+        tag=tag.INJECTIVE + ",unpack_nchwc",
+    )
     return unpacked_out
index 1fe981d..d1edbaa 100644 (file)
@@ -25,9 +25,7 @@ from .util import get_pad_tuple
 from ..util import simplify
 
 
-
-def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype,
-                          output_padding):
+def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -55,41 +53,44 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype,
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype,
-                                             output_padding=output_padding)
+    return declaration_conv2d_transpose_impl(
+        Input, Filter, strides, padding, out_dtype, output_padding=output_padding
+    )
 
 
 def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding):
     """Preprocess data and kernel to make the compute pattern
-       of conv2d_transpose the same as conv2d"""
+    of conv2d_transpose the same as conv2d"""
     batch, in_c, in_h, in_w = data.shape
     _, out_c, filter_h, filter_w = kernel.shape
     stride_h, stride_w = strides
     opad_h, opad_w = output_padding
     assert opad_h < stride_h and opad_w < stride_w
     # dilate data
-    data_dilate = dilate(data, [1, 1, stride_h, stride_w], name='data_dilate')
+    data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate")
     # pad data
     fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
     bpad_top = filter_h - 1 - fpad_top
     bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = filter_w - 1 - fpad_right + opad_w
-    data_pad = pad(data_dilate, \
-                   [0, 0, bpad_top, bpad_left], \
-                   [0, 0, bpad_bottom, bpad_right], \
-                   name='data_pad')
+    data_pad = pad(
+        data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad"
+    )
     # transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
-    kernel_transform = te.compute((out_c, in_c, filter_h, filter_w), \
-                                  lambda o, i, h, w: kernel[i][o][filter_h-1-h][filter_w-1-w], \
-                                  name='kernel_transform')
+    kernel_transform = te.compute(
+        (out_c, in_c, filter_h, filter_w),
+        lambda o, i, h, w: kernel[i][o][filter_h - 1 - h][filter_w - 1 - w],
+        name="kernel_transform",
+    )
     return data_pad, kernel_transform
 
 
 def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding):
     """Implementation of conv2d transpose"""
-    data_pad, kernel_transform = \
-        conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding)
+    data_pad, kernel_transform = conv2d_transpose_nchw_preprocess(
+        data, kernel, strides, padding, out_dtype, output_padding
+    )
     batch, in_c, in_h, in_w = data_pad.shape
     out_c, _, filter_h, filter_w = kernel_transform.shape
 
@@ -98,16 +99,19 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype,
 
     out_h = simplify(in_h - filter_h + 1 + output_padding[0])
     out_w = simplify(in_w - filter_w + 1 + output_padding[1])
-    dc = tvm.reduce_axis((0, in_c), name='dc')
-    dh = tvm.reduce_axis((0, filter_h), name='dh')
-    dw = tvm.reduce_axis((0, filter_w), name='dw')
+    dc = tvm.reduce_axis((0, in_c), name="dc")
+    dh = tvm.reduce_axis((0, filter_h), name="dh")
+    dw = tvm.reduce_axis((0, filter_w), name="dw")
 
     Output = te.compute(
         (batch, out_c, out_h, out_w),
         lambda b, c, h, w: te.sum(
-            data_pad[b, dc, h+dh, w+dw].astype(out_dtype) *
-            kernel_transform[c, dc, dh, dw].astype(out_dtype),
-            axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
+            data_pad[b, dc, h + dh, w + dw].astype(out_dtype)
+            * kernel_transform[c, dc, dh, dw].astype(out_dtype),
+            axis=[dc, dh, dw],
+        ),
+        tag="conv2d_transpose_nchw",
+    )
 
     return Output
 
@@ -130,24 +134,24 @@ def conv2d_transpose_legalize(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-    if attrs['data_layout'] == 'NHWC':
+    if attrs["data_layout"] == "NHWC":
         data, kernel = inputs
-        kernel_layout = attrs['kernel_layout']
+        kernel_layout = attrs["kernel_layout"]
         # Convert Kernel layout to IOHW
         # kernel_layout is different from input kernel layout - IO is swapped
-        if kernel_layout == 'HWIO':
+        if kernel_layout == "HWIO":
             # input kernel layout is swapped to HWOI
             # output kernel layout will be IOHW
             kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
-        elif kernel_layout == 'HWOI':
+        elif kernel_layout == "HWOI":
             # input kernel layout is swapped to HWIO
             # output kernel layout will be IOHW
             kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
-        elif kernel_layout == 'IOHW':
+        elif kernel_layout == "IOHW":
             # input kernel layout is swapped to OIHW
             # output kernel layout will be IOHW
             kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
-        elif kernel_layout == 'OIHW':
+        elif kernel_layout == "OIHW":
             # input kernel layout is swapped to IOHW
             # output kernel layout will be IOHW
             pass
@@ -157,9 +161,9 @@ def conv2d_transpose_legalize(attrs, inputs, types):
 
         # Set new attrs for conv2d_transpose.
         new_attrs = {k: attrs[k] for k in attrs.keys()}
-        new_attrs['data_layout'] = 'NCHW'
+        new_attrs["data_layout"] = "NCHW"
         # layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
-        new_attrs['kernel_layout'] = 'OIHW'
+        new_attrs["kernel_layout"] = "OIHW"
 
         # Convert data to NCHW.
         data = relay.transpose(data, axes=(0, 3, 1, 2))
index 2bac284..1696ac6 100644 (file)
@@ -72,7 +72,8 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
@@ -81,21 +82,29 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
     pad_before = [0, 0, pad_front, pad_top, pad_left]
     pad_after = [0, 0, pad_back, pad_down, pad_right]
     temp = pad(Input, pad_before, pad_after, name="pad_temp")
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    rz = te.reduce_axis((0, kernel_d), name='rz')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    rz = te.reduce_axis((0, kernel_d), name="rz")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     return te.compute(
         (batch, out_channel, out_depth, out_height, out_width),
         lambda nn, ff, zz, yy, xx: te.sum(
-            temp[nn, rc, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
-                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
-            Filter[ff, rc, rz, ry, rx].astype(out_dtype),
-            axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw")
-
-
-def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
+            temp[
+                nn,
+                rc,
+                zz * stride_d + rz * dilation_d,
+                yy * stride_h + ry * dilation_h,
+                xx * stride_w + rx * dilation_w,
+            ].astype(out_dtype)
+            * Filter[ff, rc, rz, ry, rx].astype(out_dtype),
+            axis=[rc, rz, ry, rx],
+        ),
+        tag="conv3d_ncdhw",
+    )
+
+
+def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"):
     """Convolution operator in NDHWC layout.
 
     Parameters
@@ -141,7 +150,8 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
 
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
@@ -149,17 +159,26 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     pad_before = [0, pad_front, pad_top, pad_left, 0]
     pad_after = [0, pad_back, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
-    rd = te.reduce_axis((0, kernel_d), name='rd')
-    rh = te.reduce_axis((0, kernel_h), name='rh')
-    rw = te.reduce_axis((0, kernel_w), name='rw')
-    rc = te.reduce_axis((0, in_channel), name='rc')
+    rd = te.reduce_axis((0, kernel_d), name="rd")
+    rh = te.reduce_axis((0, kernel_h), name="rh")
+    rw = te.reduce_axis((0, kernel_w), name="rw")
+    rc = te.reduce_axis((0, in_channel), name="rc")
     Output = te.compute(
         (batch, out_depth, out_height, out_width, out_channel),
         lambda nn, dd, hh, ww, cc: te.sum(
-            PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h,
-                        ww * stride_w + rw * dilation_w, rc].astype(out_dtype) *
-            Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
-        name="Conv3dOutput", tag="conv3d_ndhwc")
+            PaddedInput[
+                nn,
+                dd * stride_d + rd * dilation_d,
+                hh * stride_h + rh * dilation_h,
+                ww * stride_w + rw * dilation_w,
+                rc,
+            ].astype(out_dtype)
+            * Filter[rd, rh, rw, rc, cc].astype(out_dtype),
+            axis=[rd, rh, rw, rc],
+        ),
+        name="Conv3dOutput",
+        tag="conv3d_ndhwc",
+    )
     return Output
 
 
@@ -189,26 +208,29 @@ def conv3d_winograd_weight_transform(kernel, tile_size):
 
     r = tile_size + KH - 1
 
-    r_kh = te.reduce_axis((0, KH), name='r_kh')
-    r_kw = te.reduce_axis((0, KW), name='r_kw')
+    r_kh = te.reduce_axis((0, KH), name="r_kh")
+    r_kw = te.reduce_axis((0, KW), name="r_kw")
     _, _, G = winograd_transform_matrices(tile_size, KH, kernel.dtype)
     if depth_transform:
         shape = (r, r, r, CO, CI)
-        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        r_kd = te.reduce_axis((0, KD), name="r_kd")
         return te.compute(
             shape,
             lambda omg, eps, nu, co, ci: te.sum(
                 kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
-                axis=[r_kd, r_kh, r_kw]),
-            name='transform_weight')
+                axis=[r_kd, r_kh, r_kw],
+            ),
+            name="transform_weight",
+        )
     else:
         shape = (r, r, KD, CO, CI)
         return te.compute(
             shape,
             lambda eps, nu, d, co, ci: te.sum(
-                kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
-            name='transform_weight')
-
+                kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
+            ),
+            name="transform_weight",
+        )
 
 
 @tvm.target.generic_func
index cd57264..9a8828f 100644 (file)
@@ -53,45 +53,51 @@ def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype, output_pa
     Output : tvm.te.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    return declaration_conv3d_transpose_impl(Input, Filter, strides, padding,
-                                             out_dtype, output_padding)
+    return declaration_conv3d_transpose_impl(
+        Input, Filter, strides, padding, out_dtype, output_padding
+    )
 
 
 def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding):
     """Preprocess data and kernel to make the compute pattern
-       of conv3d_transpose the same as conv3d"""
+    of conv3d_transpose the same as conv3d"""
     batch, in_c, in_d, in_h, in_w = data.shape
     _, out_c, filter_d, filter_h, filter_w = kernel.shape
     stride_d, stride_h, stride_w = strides
     opad_d, opad_h, opad_w = output_padding
     assert opad_d < stride_d and opad_h < stride_h and opad_w < stride_w
     # dilate data
-    data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name='data_dilate')
+    data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name="data_dilate")
     # pad data
     fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d(
-        padding, (filter_d, filter_h, filter_w))
+        padding, (filter_d, filter_h, filter_w)
+    )
     bpad_front = filter_d - 1 - fpad_front
     bpad_back = filter_d - 1 - fpad_back + opad_d
     bpad_top = filter_h - 1 - fpad_top
     bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = filter_w - 1 - fpad_right + opad_w
-    data_pad = pad(data_dilate, \
-                   [0, 0, bpad_front, bpad_top, bpad_left], \
-                   [0, 0, bpad_back, bpad_bottom, bpad_right], \
-                   name='data_pad')
+    data_pad = pad(
+        data_dilate,
+        [0, 0, bpad_front, bpad_top, bpad_left],
+        [0, 0, bpad_back, bpad_bottom, bpad_right],
+        name="data_pad",
+    )
     # transform kernel layout from IODHW to OIDHW, and rotate kernel by 180 degrees
-    kernel_transform = te.compute((out_c, in_c, filter_d, filter_h, filter_w), \
-                                  lambda o, i, d, h, w: kernel[i][o][filter_d-1-d] \
-                                        [filter_h-1-h][filter_w-1-w], \
-                                  name='kernel_transform')
+    kernel_transform = te.compute(
+        (out_c, in_c, filter_d, filter_h, filter_w),
+        lambda o, i, d, h, w: kernel[i][o][filter_d - 1 - d][filter_h - 1 - h][filter_w - 1 - w],
+        name="kernel_transform",
+    )
     return data_pad, kernel_transform
 
 
 def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding):
     """Implementation of conv3d transpose"""
-    data_pad, kernel_transform = \
-        conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding)
+    data_pad, kernel_transform = conv3d_transpose_ncdhw_preprocess(
+        data, kernel, strides, padding, out_dtype, output_padding
+    )
     batch, in_c, in_d, in_h, in_w = data_pad.shape
     out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape
     stride_d, stride_h, stride_w = strides
@@ -101,17 +107,20 @@ def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype,
     out_d = simplify(in_d - filter_d + 1)
     out_h = simplify(in_h - filter_h + 1)
     out_w = simplify(in_w - filter_w + 1)
-    dc = te.reduce_axis((0, in_c), name='dc')
-    dd = te.reduce_axis((0, filter_d), name='dd')
-    dh = te.reduce_axis((0, filter_h), name='dh')
-    dw = te.reduce_axis((0, filter_w), name='dw')
+    dc = te.reduce_axis((0, in_c), name="dc")
+    dd = te.reduce_axis((0, filter_d), name="dd")
+    dh = te.reduce_axis((0, filter_h), name="dh")
+    dw = te.reduce_axis((0, filter_w), name="dw")
 
     Output = te.compute(
         (batch, out_c, out_d, out_h, out_w),
         lambda b, c, d, h, w: te.sum(
-            data_pad[b, dc, d+dd, h+dh, w+dw].astype(out_dtype) *
-            kernel_transform[c, dc, dd, dh, dw].astype(out_dtype),
-            axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw")
+            data_pad[b, dc, d + dd, h + dh, w + dw].astype(out_dtype)
+            * kernel_transform[c, dc, dd, dh, dw].astype(out_dtype),
+            axis=[dc, dd, dh, dw],
+        ),
+        tag="conv3d_transpose_ncdhw",
+    )
 
     return Output
 
@@ -134,24 +143,24 @@ def conv3d_transpose_legalize(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-    if attrs['data_layout'] == 'NDHWC':
+    if attrs["data_layout"] == "NDHWC":
         data, kernel = inputs
-        kernel_layout = attrs['kernel_layout']
+        kernel_layout = attrs["kernel_layout"]
         # Convert Kernel layout to IODHW
         # kernel_layout is different from input kernel layout - IO is swapped
-        if kernel_layout == 'DHWIO':
+        if kernel_layout == "DHWIO":
             # input kernel layout is swapped to DHWOI
             # output kernel layout will be IODHW
             kernel = relay.transpose(kernel, axes=(4, 3, 0, 1, 2))
-        elif kernel_layout == 'DHWOI':
+        elif kernel_layout == "DHWOI":
             # input kernel layout is swapped to DHWIO
             # output kernel layout will be IODHW
             kernel = relay.transpose(kernel, axes=(3, 4, 0, 1, 2))
-        elif kernel_layout == 'IODHW':
+        elif kernel_layout == "IODHW":
             # input kernel layout is swapped to OIDHW
             # output kernel layout will be IODHW
             kernel = relay.transpose(kernel, axes=(1, 0, 2, 3, 4))
-        elif kernel_layout == 'OIDHW':
+        elif kernel_layout == "OIDHW":
             # input kernel layout is swapped to IODHW
             # output kernel layout will be IODHW
             pass
@@ -161,9 +170,9 @@ def conv3d_transpose_legalize(attrs, inputs, types):
 
         # Set new attrs for conv3d_transpose.
         new_attrs = {k: attrs[k] for k in attrs.keys()}
-        new_attrs['data_layout'] = 'NCDHW'
+        new_attrs["data_layout"] = "NCDHW"
         # layout of kernel should be IODHW, but kernel_layout should be swapped - OIDHW
-        new_attrs['kernel_layout'] = 'OIDHW'
+        new_attrs["kernel_layout"] = "OIDHW"
 
         # Convert data to NCDHW.
         data = relay.transpose(data, axes=(0, 4, 1, 2, 3))
index 94aea55..583002e 100644 (file)
@@ -21,8 +21,9 @@ from .pad import pad
 from ..util import get_const_tuple
 
 
-def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2, padding,
-                     is_multiply):
+def correlation_nchw(
+    data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply
+):
     """Correlation operator in NCHW layout.
 
     Parameters
@@ -92,9 +93,9 @@ def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, strid
     out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1
     out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1
 
-    rc = te.reduce_axis((0, channel), name='rc')
-    ry = te.reduce_axis((0, kernel_size), name='ry')
-    rx = te.reduce_axis((0, kernel_size), name='rx')
+    rc = te.reduce_axis((0, channel), name="rc")
+    ry = te.reduce_axis((0, kernel_size), name="ry")
+    rx = te.reduce_axis((0, kernel_size), name="rx")
 
     if is_multiply:
         corr_func = lambda x, y: x * y
@@ -108,9 +109,14 @@ def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, strid
         # location in data2
         y2 = y1 + (te.indexdiv(q, displacement_size) - displacement_radius) * stride2
         x2 = x1 + (te.indexmod(q, displacement_size) - displacement_radius) * stride2
-        return te.sum(corr_func(padded_data1[n, rc, y1 + ry, x1 + rx],
-                                padded_data2[n, rc, y2 + ry, x2 + rx]), axis=[rc, ry, rx])
-
-    correlation = te.compute((batch, out_channel, out_height, out_width), lambda n, q, i, j:
-                             _compute_correlation(n, q, i, j), tag="correlation_nchw")
+        return te.sum(
+            corr_func(padded_data1[n, rc, y1 + ry, x1 + rx], padded_data2[n, rc, y2 + ry, x2 + rx]),
+            axis=[rc, ry, rx],
+        )
+
+    correlation = te.compute(
+        (batch, out_channel, out_height, out_width),
+        lambda n, q, i, j: _compute_correlation(n, q, i, j),
+        tag="correlation_nchw",
+    )
     return correlation / (kernel_size * kernel_size * channel)
index 39be6d6..3d2b7ce 100644 (file)
@@ -23,8 +23,10 @@ from .util import get_pad_tuple
 from ..util import get_const_tuple
 from ..cpp.util import bilinear_sample_nchw
 
-def deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation, deformable_groups,
-                           groups, out_dtype):
+
+def deformable_conv2d_nchw(
+    data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
+):
     """Deformable conv2D operator in NCHW layout.
 
     The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
@@ -84,11 +86,10 @@ def deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation, def
 
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
-    pad_top, pad_left, _, _ = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
 
     zero = tvm.tir.const(0.0, data.dtype)
 
@@ -97,19 +98,35 @@ def deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation, def
         val = bilinear_sample_nchw(data, (n, c, h, w), in_height - 1, in_width - 1)
         return tvm.tir.if_then_else(outside, zero, val)
 
-    data_deform = \
-        te.compute((batch, in_channel, kernel_h, kernel_w, out_height, out_width),
-                   lambda n, c, kh, kw, y, x:
-                   _bilinear(n, c,
-                             y * stride_h - pad_top + kh * dilation_h +
-                             offset[n, c // ic_per_dgroup * (kernel_w*kernel_h*2) +
-                                    (kh * kernel_w + kw) * 2, y, x],
-                             x * stride_w - pad_left + kw * dilation_w +
-                             offset[n, c // ic_per_dgroup * (kernel_w*kernel_h*2) +
-                                    (kh * kernel_w + kw) * 2 + 1, y, x]), tag="data_deform")
+    data_deform = te.compute(
+        (batch, in_channel, kernel_h, kernel_w, out_height, out_width),
+        lambda n, c, kh, kw, y, x: _bilinear(
+            n,
+            c,
+            y * stride_h
+            - pad_top
+            + kh * dilation_h
+            + offset[
+                n, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2, y, x
+            ],
+            x * stride_w
+            - pad_left
+            + kw * dilation_w
+            + offset[
+                n,
+                c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1,
+                y,
+                x,
+            ],
+        ),
+        tag="data_deform",
+    )
     return te.compute(
         (batch, out_channel, out_height, out_width),
         lambda n, f, y, x: te.sum(
-            data_deform[n, rc, ry, rx, y, x].astype(out_dtype) *
-            kernel[f, rc, ry, rx].astype(out_dtype),
-            axis=[rc, ry, rx]), tag="deformable_conv2d_nchw")
+            data_deform[n, rc, ry, rx, y, x].astype(out_dtype)
+            * kernel[f, rc, ry, rx].astype(out_dtype),
+            axis=[rc, ry, rx],
+        ),
+        tag="deformable_conv2d_nchw",
+    )
index 7d7ef6c..0ce0f9e 100644 (file)
@@ -18,6 +18,7 @@
 from tvm import te
 from .. import tag
 
+
 def dense(data, weight, bias=None, out_dtype=None):
     """The default implementation of dense in topi.
 
@@ -40,21 +41,24 @@ def dense(data, weight, bias=None, out_dtype=None):
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim dense"
+    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
-    k = te.reduce_axis((0, in_dim), name='k')
-    matmul = te.compute((batch, out_dim), \
-                        lambda i, j: te.sum(data[i, k].astype(out_dtype) * \
-                                            weight[j, k].astype(out_dtype), axis=k), \
-                        name='T_dense', tag='dense')
+    k = te.reduce_axis((0, in_dim), name="k")
+    matmul = te.compute(
+        (batch, out_dim),
+        lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
+        name="T_dense",
+        tag="dense",
+    )
     if bias is not None:
-        matmul = te.compute((batch, out_dim), \
-                            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
-                            tag=tag.BROADCAST)
+        matmul = te.compute(
+            (batch, out_dim),
+            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+            tag=tag.BROADCAST,
+        )
     return matmul
index a9fbfea..bf970a4 100644 (file)
@@ -22,7 +22,7 @@ from tvm import te
 from .. import tag
 
 
-def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
+def depth_to_space(data, block_size, layout="NCHW", mode="DCR"):
     """Perform depth to space transformation on the data
 
     Parameters
@@ -46,23 +46,21 @@ def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
     output : tvm.te.Tensor
         Output of shape [N, C / block_size**2, H * block_size, W * block_size]
     """
-    if layout == 'NCHW':
+    if layout == "NCHW":
         in_n, in_c, in_h, in_w = data.shape
         channel_factor = tvm.tir.truncdiv(in_c, (block_size * block_size))
-        output_shape = [in_n, channel_factor,
-                        in_h * block_size, in_w * block_size]
-    elif layout == 'NHWC':
+        output_shape = [in_n, channel_factor, in_h * block_size, in_w * block_size]
+    elif layout == "NHWC":
         in_n, in_h, in_w, in_c = data.shape
         channel_factor = tvm.tir.truncdiv(in_c, (block_size * block_size))
-        output_shape = [in_n, in_h * block_size,
-                        in_w * block_size, channel_factor]
+        output_shape = [in_n, in_h * block_size, in_w * block_size, channel_factor]
     else:
         raise ValueError("Only NCHW and NHWC layouts are currently supported.")
 
     def _get_indices(*indices):
-        if layout == 'NCHW':
+        if layout == "NCHW":
             n, c, y, x = indices
-        elif layout == 'NHWC':
+        elif layout == "NHWC":
             n, y, x, c = indices
         return n, c, y, x
 
@@ -76,7 +74,7 @@ def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
         else:
             channel_idx = (c * block_size * block_size) + ((block_size * idx_y) + idx_x)
 
-        if layout == 'NCHW':
+        if layout == "NCHW":
             output = data(n, channel_idx, block_y, block_x)
         else:
             output = data(n, block_y, block_x, channel_idx)
@@ -86,4 +84,4 @@ def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
         n, c, y, x = _get_indices(*indices)
         return _get_pixel(n, c, y, x)
 
-    return te.compute(output_shape, _compute, name='depth_to_space', tag=tag.INJECTIVE)
+    return te.compute(output_shape, _compute, name="depth_to_space", tag=tag.INJECTIVE)
index 32a9258..c863a15 100644 (file)
@@ -27,9 +27,24 @@ from .util import get_pad_tuple
 from ..util import simplify
 
 # workload description of depthwise-conv2d
-Workload = namedtuple('Workload',
-                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+Workload = namedtuple(
+    "Workload",
+    [
+        "in_dtype",
+        "out_dtype",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
+
 
 def _get_workload(data, kernel, stride, padding, out_dtype):
     """ Get the workload structure. """
@@ -41,11 +56,26 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
-        "Do not support inputs with different data types now. ' \
-        '{} vs. {}".format(data.dtype, kernel.dtype)
-    return Workload(data.dtype, out_dtype, height, width, in_channel,
-                    out_channel, kh, kw, HPAD, WPAD, HSTR, WSTR)
+    assert (data.dtype == kernel.dtype) or (
+        data.dtype == "uint8" and kernel.dtype == "int8"
+    ), "Do not support inputs with different data types now. ' \
+        '{} vs. {}".format(
+        data.dtype, kernel.dtype
+    )
+    return Workload(
+        data.dtype,
+        out_dtype,
+        height,
+        width,
+        in_channel,
+        out_channel,
+        kh,
+        kw,
+        HPAD,
+        WPAD,
+        HSTR,
+        WSTR,
+    )
 
 
 def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
@@ -95,7 +125,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
     dilated_kernel_h = (filter_height - 1) * dilation_h + 1
     dilated_kernel_w = (filter_width - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = simplify(in_channel * channel_multiplier)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
@@ -107,17 +138,27 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
     # depthconv stage
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
-    di = te.reduce_axis((0, filter_height), name='di')
-    dj = te.reduce_axis((0, filter_width), name='dj')
+    di = te.reduce_axis((0, filter_height), name="di")
+    dj = te.reduce_axis((0, filter_width), name="dj")
     Output = te.compute(
         (batch, out_channel, out_height, out_width),
         lambda b, c, i, j: te.sum(
-            (PaddedInput[b, idxdiv(c, channel_multiplier), i*stride_h+di*dilation_h,
-                         j*stride_w+dj*dilation_w].astype(out_dtype) *
-             Filter[idxdiv(c, channel_multiplier),
-                    idxmod(c, channel_multiplier), di, dj].astype(out_dtype)),
-            axis=[di, dj]),
-        name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
+            (
+                PaddedInput[
+                    b,
+                    idxdiv(c, channel_multiplier),
+                    i * stride_h + di * dilation_h,
+                    j * stride_w + dj * dilation_w,
+                ].astype(out_dtype)
+                * Filter[
+                    idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier), di, dj
+                ].astype(out_dtype)
+            ),
+            axis=[di, dj],
+        ),
+        name="DepthwiseConv2d",
+        tag="depthwise_conv2d_nchw",
+    )
     return Output
 
 
@@ -168,7 +209,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
     dilated_kernel_h = (filter_height - 1) * dilation_h + 1
     dilated_kernel_w = (filter_width - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = simplify(in_channel * channel_multiplier)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
@@ -181,20 +223,30 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    di = te.reduce_axis((0, filter_height), name='di')
-    dj = te.reduce_axis((0, filter_width), name='dj')
+    di = te.reduce_axis((0, filter_height), name="di")
+    dj = te.reduce_axis((0, filter_width), name="dj")
     Output = te.compute(
         (batch, out_height, out_width, out_channel),
         lambda b, i, j, c: te.sum(
-            (PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
-                         idxdiv(c, channel_multiplier)].astype(out_dtype) *
-             Filter[di, dj,
+            (
+                PaddedInput[
+                    b,
+                    i * stride_h + di * dilation_h,
+                    j * stride_w + dj * dilation_w,
                     idxdiv(c, channel_multiplier),
-                    idxmod(c, channel_multiplier)].astype(out_dtype)),
-            axis=[di, dj]),
-        name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
+                ].astype(out_dtype)
+                * Filter[
+                    di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)
+                ].astype(out_dtype)
+            ),
+            axis=[di, dj],
+        ),
+        name="DepthwiseConv2d",
+        tag="depthwise_conv2d_nhwc",
+    )
     return Output
 
+
 def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, stride, padding):
     """Depthwise convolution nhwc backward wrt input operator.
 
@@ -225,7 +277,7 @@ def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, strid
     else:
         stride_h, stride_w = stride
 
-    dilated_out_grad = dilate(Out_grad, [1, stride_h, stride_w, 1], name='dilated_out_grad')
+    dilated_out_grad = dilate(Out_grad, [1, stride_h, stride_w, 1], name="dilated_out_grad")
 
     # padding params in forward propagation
     fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
@@ -235,20 +287,26 @@ def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, strid
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
 
-    padded_out_grad = pad(dilated_out_grad, \
-                          [0, bpad_top, bpad_left, 0], \
-                          [0, bpad_bottom, bpad_right, 0], \
-                          name='padded_out_grad')
+    padded_out_grad = pad(
+        dilated_out_grad,
+        [0, bpad_top, bpad_left, 0],
+        [0, bpad_bottom, bpad_right, 0],
+        name="padded_out_grad",
+    )
 
-    dh = te.reduce_axis((0, filter_h), name='dh')
-    dw = te.reduce_axis((0, filter_w), name='dw')
-    dc = te.reduce_axis((0, channel_multiplier), name='dc')
+    dh = te.reduce_axis((0, filter_h), name="dh")
+    dw = te.reduce_axis((0, filter_w), name="dw")
+    dc = te.reduce_axis((0, channel_multiplier), name="dc")
 
     In_grad = te.compute(
         (batch, in_h, in_w, in_c),
-        lambda b, h, w, c: te.sum(padded_out_grad[b, h+dh, w+dw, c*channel_multiplier + dc] * \
-                                  Filter[filter_h-1-dh, filter_w-1-dw, c, dc],
-                                  axis=[dh, dw, dc]), tag='depthwise_conv2d_backward_input_nhwc')
+        lambda b, h, w, c: te.sum(
+            padded_out_grad[b, h + dh, w + dw, c * channel_multiplier + dc]
+            * Filter[filter_h - 1 - dh, filter_w - 1 - dw, c, dc],
+            axis=[dh, dw, dc],
+        ),
+        tag="depthwise_conv2d_backward_input_nhwc",
+    )
 
     return In_grad
 
@@ -285,29 +343,32 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
 
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_h, filter_w))
 
-    padded_in = pad(Input, \
-                    [0, pad_top, pad_left, 0], \
-                    [0, pad_bottom, pad_right, 0], \
-                    name='padded_in')
+    padded_in = pad(
+        Input, [0, pad_top, pad_left, 0], [0, pad_bottom, pad_right, 0], name="padded_in"
+    )
 
-    dh = te.reduce_axis((0, Out_grad.shape[1].value), name='dh')
-    dw = te.reduce_axis((0, Out_grad.shape[2].value), name='dw')
-    db = te.reduce_axis((0, batch), name='db')
+    dh = te.reduce_axis((0, Out_grad.shape[1].value), name="dh")
+    dw = te.reduce_axis((0, Out_grad.shape[2].value), name="dw")
+    db = te.reduce_axis((0, batch), name="db")
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
     Weight_grad = te.compute(
         (filter_h, filter_w, in_c, channel_multiplier),
         lambda fh, fw, c, m: te.sum(
-            Out_grad[db, dh, dw, c*channel_multiplier+idxmod(m, channel_multiplier)] *
-            padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
-        tag='depthwise_conv2d_backward_weight_nhwc')
+            Out_grad[db, dh, dw, c * channel_multiplier + idxmod(m, channel_multiplier)]
+            * padded_in[db, fh + dh * stride_h, fw + dw * stride_w, c],
+            axis=[db, dh, dw],
+        ),
+        tag="depthwise_conv2d_backward_weight_nhwc",
+    )
 
     return Weight_grad
 
 
-def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
-                           layout, out_layout, out_dtype=None):
+def depthwise_conv2d_NCHWc(
+    Input, Filter, stride, padding, dilation, layout, out_layout, out_dtype=None
+):
     """Depthwise convolution NCHW[x]c forward operator.
 
     Parameters
@@ -345,6 +406,7 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
     """
     raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
 
+
 @tvm.target.generic_func
 def depthwise_conv2d_infer_layout(workload, cfg):
     """Infer input/output shapes and layouts from a workload and cfg.
index ebcf478..836e29a 100644 (file)
@@ -21,7 +21,8 @@ from tvm import te
 from .. import util
 from .. import tag
 
-@te.tag_scope(tag=tag.INJECTIVE+",dilate")
+
+@te.tag_scope(tag=tag.INJECTIVE + ",dilate")
 def dilate(data, strides, name="DilatedInput"):
     """Dilate data with zeros.
 
@@ -43,11 +44,9 @@ def dilate(data, strides, name="DilatedInput"):
     """
     n = len(data.shape)
     if len(strides) != n:
-        raise ValueError("data dimension and strides size dismatch : %d vs %d" % (
-            n, len(strides)))
+        raise ValueError("data dimension and strides size dismatch : %d vs %d" % (n, len(strides)))
     ana = tvm.arith.Analyzer()
-    out_shape = tuple(
-        ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
+    out_shape = tuple(ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
 
     def _dilate(*indices):
         not_zero = []
@@ -63,7 +62,8 @@ def dilate(data, strides, name="DilatedInput"):
         if not_zero:
             not_zero = tvm.tir.all(*not_zero)
             return tvm.tir.if_then_else(
-                not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype))
+                not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype)
+            )
         return data(*index_tuple)
 
     return te.compute(out_shape, _dilate, name=name)
index e851c64..03fffc7 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 from .. import tag
 from ..util import get_const_int
 
+
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
 def relu(x):
     """Take relu of input x.
@@ -55,12 +56,15 @@ def leaky_relu(x, alpha):
     y : tvm.te.Tensor
         The result.
     """
+
     def _compute(*indices):
         value = x(*indices)
         calpha = tvm.tir.const(alpha, value.dtype)
         return tvm.tir.Select(value > 0, value, value * calpha)
+
     return te.compute(x.shape, _compute)
 
+
 @tvm.te.tag_scope(tag=tag.BROADCAST)
 def prelu(x, slope, axis=1):
     """PReLU.
@@ -97,4 +101,5 @@ def prelu(x, slope, axis=1):
     def _compute_channelwise(*indices):
         xval = x(*indices)
         return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis]))
+
     return te.compute(x.shape, _compute_channelwise)
index de283e0..0f12d9f 100644 (file)
@@ -22,7 +22,8 @@ from tvm import te
 from .. import tag
 from ..transform import concatenate, strided_slice
 
-@tvm.te.tag_scope(tag=tag.INJECTIVE+",fifo_buffer")
+
+@tvm.te.tag_scope(tag=tag.INJECTIVE + ",fifo_buffer")
 def fifo_buffer(data, buffer, axis):
     """
     FIFO buffer to enable computation reuse in CNNs with sliding indow input
@@ -55,11 +56,12 @@ def fifo_buffer(data, buffer, axis):
     result : tvm.te.Tensor
         Updated value for the buffer
     """
-    assert len(data.shape) == len(buffer.shape), \
-        'buffer and data must have same number of dimensions, ' + \
-        'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape)
-    assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported'
-    assert 0 <= axis < len(buffer.shape), 'buffer axis out of range'
+    assert len(data.shape) == len(buffer.shape), (
+        "buffer and data must have same number of dimensions, "
+        + "buffer.shape = {}, data.shape = {}".format(buffer.shape, data.shape)
+    )
+    assert len(buffer.shape) >= 1, "Zero-dimension tensor not supported"
+    assert 0 <= axis < len(buffer.shape), "buffer axis out of range"
     for i in range(len(data.shape)):
         if i == axis:
             assert int(str(data.shape[i])) <= int(str(buffer.shape[i]))
@@ -71,81 +73,109 @@ def fifo_buffer(data, buffer, axis):
 
     # Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher
     if len(buffer.shape) == 1:
-        return te.compute(buffer.shape,
-                          lambda i:
-                          tvm.tir.if_then_else(i < buflen - data_size,
-                                               buffer[i + data_size],
-                                               data[i - buflen + data_size]),
-                          name='new_buffer')
+        return te.compute(
+            buffer.shape,
+            lambda i: tvm.tir.if_then_else(
+                i < buflen - data_size, buffer[i + data_size], data[i - buflen + data_size]
+            ),
+            name="new_buffer",
+        )
     if len(buffer.shape) == 2:
         if axis == 0:
-            return te.compute(buffer.shape,
-                              lambda i, j:
-                              tvm.tir.if_then_else(i < buflen - data_size,
-                                                   buffer[i + data_size, j],
-                                                   data[i - buflen + data_size, j]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j: tvm.tir.if_then_else(
+                    i < buflen - data_size,
+                    buffer[i + data_size, j],
+                    data[i - buflen + data_size, j],
+                ),
+                name="new_buffer",
+            )
         if axis == 1:
-            return te.compute(buffer.shape,
-                              lambda i, j:
-                              tvm.tir.if_then_else(j < buflen - data_size,
-                                                   buffer[i, j + data_size],
-                                                   data[i, j - buflen + data_size]),
-                              name='new_buffer')
-        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
+            return te.compute(
+                buffer.shape,
+                lambda i, j: tvm.tir.if_then_else(
+                    j < buflen - data_size,
+                    buffer[i, j + data_size],
+                    data[i, j - buflen + data_size],
+                ),
+                name="new_buffer",
+            )
+        assert False, "Invalid value for axis; it should be at most {}".format(len(buffer.shape))
     elif len(buffer.shape) == 3:
         if axis == 0:
-            return te.compute(buffer.shape,
-                              lambda i, j, k:
-                              tvm.tir.if_then_else(i < buflen - data_size,
-                                                   buffer[i + data_size, j, k],
-                                                   data[i - buflen + data_size, j, k]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k: tvm.tir.if_then_else(
+                    i < buflen - data_size,
+                    buffer[i + data_size, j, k],
+                    data[i - buflen + data_size, j, k],
+                ),
+                name="new_buffer",
+            )
         if axis == 1:
-            return te.compute(buffer.shape,
-                              lambda i, j, k:
-                              tvm.tir.if_then_else(j < buflen - data_size,
-                                                   buffer[i, j + data_size, k],
-                                                   data[i, j - buflen + data_size, k]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k: tvm.tir.if_then_else(
+                    j < buflen - data_size,
+                    buffer[i, j + data_size, k],
+                    data[i, j - buflen + data_size, k],
+                ),
+                name="new_buffer",
+            )
         if axis == 2:
-            return te.compute(buffer.shape,
-                              lambda i, j, k:
-                              tvm.tir.if_then_else(k < buflen - data_size,
-                                                   buffer[i, j, k + data_size],
-                                                   data[i, j, k - buflen + data_size]),
-                              name='new_buffer')
-        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k: tvm.tir.if_then_else(
+                    k < buflen - data_size,
+                    buffer[i, j, k + data_size],
+                    data[i, j, k - buflen + data_size],
+                ),
+                name="new_buffer",
+            )
+        assert False, "Invalid value for axis; it should be at most {}".format(len(buffer.shape))
     elif len(buffer.shape) == 4:
         if axis == 0:
-            return te.compute(buffer.shape,
-                              lambda i, j, k, l:
-                              tvm.tir.if_then_else(i < buflen - data_size,
-                                                   buffer[i + data_size, j, k, l],
-                                                   data[i - buflen + data_size, j, k, l]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k, l: tvm.tir.if_then_else(
+                    i < buflen - data_size,
+                    buffer[i + data_size, j, k, l],
+                    data[i - buflen + data_size, j, k, l],
+                ),
+                name="new_buffer",
+            )
         if axis == 1:
-            return te.compute(buffer.shape,
-                              lambda i, j, k, l:
-                              tvm.tir.if_then_else(j < buflen - data_size,
-                                                   buffer[i, j + data_size, k, l],
-                                                   data[i, j - buflen + data_size, k, l]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k, l: tvm.tir.if_then_else(
+                    j < buflen - data_size,
+                    buffer[i, j + data_size, k, l],
+                    data[i, j - buflen + data_size, k, l],
+                ),
+                name="new_buffer",
+            )
         if axis == 2:
-            return te.compute(buffer.shape,
-                              lambda i, j, k, l:
-                              tvm.tir.if_then_else(k < buflen - data_size,
-                                                   buffer[i, j, k + data_size, l],
-                                                   data[i, j, k - buflen + data_size, l]),
-                              name='new_buffer')
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k, l: tvm.tir.if_then_else(
+                    k < buflen - data_size,
+                    buffer[i, j, k + data_size, l],
+                    data[i, j, k - buflen + data_size, l],
+                ),
+                name="new_buffer",
+            )
         if axis == 3:
-            return te.compute(buffer.shape,
-                              lambda i, j, k, l:
-                              tvm.tir.if_then_else(l < buflen - data_size,
-                                                   buffer[i, j, k, l + data_size],
-                                                   data[i, j, k, l - buflen + data_size]),
-                              name='new_buffer')
-        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
+            return te.compute(
+                buffer.shape,
+                lambda i, j, k, l: tvm.tir.if_then_else(
+                    l < buflen - data_size,
+                    buffer[i, j, k, l + data_size],
+                    data[i, j, k, l - buflen + data_size],
+                ),
+                name="new_buffer",
+            )
+        assert False, "Invalid value for axis; it should be at most {}".format(len(buffer.shape))
     else:
         # Implement FIFO buffer as combination of concat and slice
         begin = [0] * len(buffer.shape)
index 11fe0d8..4d8e92f 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from .. import tag
 
+
 @tvm.te.tag_scope(tag=tag.INJECTIVE)
 def flatten(data):
     """Flattens the input array into a 2-D array by collapsing the higher dimensions.
index 35c76d2..06a6f65 100644 (file)
@@ -19,6 +19,7 @@
 from __future__ import absolute_import
 from .. import cpp
 
+
 def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
     """Perform the across channels local response normalisation
     on the input data.
index 12558a8..c048fc8 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from tvm import te
 from .. import tag
 
+
 @tvm.te.tag_scope(tag=tag.BROADCAST)
 def scale_shift_nchw(Input, Scale, Shift):
     """Batch normalization operator in inference.
@@ -41,7 +42,9 @@ def scale_shift_nchw(Input, Scale, Shift):
     Output : tvm.te.Tensor
         Output tensor, layout is NCHW
     """
-    return te.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift')
+    return te.compute(
+        Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name="ScaleShift"
+    )
 
 
 @tvm.te.tag_scope(tag=tag.BROADCAST)
@@ -64,4 +67,6 @@ def scale_shift_nhwc(Input, Scale, Shift):
     Output : tvm.te.Tensor
         Output tensor, layout is NHWC
     """
-    return te.compute(Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name='ScaleShift')
+    return te.compute(
+        Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name="ScaleShift"
+    )
index b298a0a..2998d1d 100644 (file)
@@ -21,7 +21,8 @@ from tvm import te
 from ..util import equal_const_int
 from .. import tag
 
-@tvm.te.tag_scope(tag=tag.INJECTIVE+",pad")
+
+@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
 def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
     """Pad Input with zeros.
 
@@ -50,16 +51,19 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
     n = len(data.shape)
     pad_after = pad_after if pad_after else pad_before
     if len(pad_before) != n:
-        raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % (
-            n, len(pad_before)))
+        raise ValueError(
+            "Input dimension and pad_before dismatch : %d vs %d" % (n, len(pad_before))
+        )
     if len(pad_after) != n:
-        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (
-            n, len(pad_before)))
+        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before)))
     ana = tvm.arith.Analyzer()
-    out_shape = tuple(
-        ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
-    pad_value = (pad_value if isinstance(pad_value, tvm.tir.PrimExpr)
-                 else tvm.tir.const(pad_value, data.dtype))
+    out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
+    pad_value = (
+        pad_value
+        if isinstance(pad_value, tvm.tir.PrimExpr)
+        else tvm.tir.const(pad_value, data.dtype)
+    )
+
     def _pad(*indices):
         not_zero = []
         index_tuple = []
@@ -74,15 +78,12 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
             not_zero = tvm.tir.all(*not_zero)
             return tvm.tir.if_then_else(not_zero, data(*index_tuple), pad_value)
         return data(*index_tuple)
+
     return te.compute(out_shape, _pad, name=name)
 
 
 @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
-def mirror_pad(data,
-               pad_before,
-               pad_after=None,
-               mode='SYMMETRIC',
-               name="MirrorPadInput"):
+def mirror_pad(data, pad_before, pad_after=None, mode="SYMMETRIC", name="MirrorPadInput"):
     """Pad Input with mirroring either symmetric or reflected.
 
     Parameters
@@ -110,25 +111,22 @@ def mirror_pad(data,
     n = len(data.shape)
     pad_after = pad_after if pad_after else pad_before
     if len(pad_before) != n:
-        raise ValueError("Input dimension and pad_before dismatch : %d vs %d" %
-                         (n, len(pad_before)))
+        raise ValueError(
+            "Input dimension and pad_before dismatch : %d vs %d" % (n, len(pad_before))
+        )
     if len(pad_after) != n:
-        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" %
-                         (n, len(pad_before)))
+        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before)))
     ana = tvm.arith.Analyzer()
-    out_shape = tuple(
-        ana.simplify(data.shape[i] + pad_before[i] + pad_after[i])
-        for i in range(n))
-    assert mode in ('SYMMETRIC', 'REFLECT')
-    mode = int(mode == 'SYMMETRIC')
+    out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
+    assert mode in ("SYMMETRIC", "REFLECT")
+    mode = int(mode == "SYMMETRIC")
 
     def _pad(*indices):
         index_tuple = []
         above = []
         below = []
         for i in range(n):
-            if equal_const_int(pad_before[i], 0) and equal_const_int(
-                    pad_after[i], 0):
+            if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0):
                 index_tuple.append(indices[i])
                 above.append(False)
                 below.append(False)
@@ -140,7 +138,8 @@ def mirror_pad(data,
         for i, axis in enumerate(index_tuple):
             mapped_axis = tvm.tir.if_then_else(below[i], -axis - mode, axis)
             mapped_axis = tvm.tir.if_then_else(
-                above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis)
+                above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis
+            )
             mapped_tuple.append(mapped_axis)
         return data(*mapped_tuple)
 
index 52317c2..8c4be5a 100644 (file)
 from __future__ import absolute_import
 from .. import cpp
 
-POOL_TYPE_CODE = {
-    "avg": 0,
-    "max": 1
-}
+POOL_TYPE_CODE = {"avg": 0, "max": 1}
+
 
 def global_pool(data, pool_type, layout="NCHW"):
     """Perform global pooling on height and width dimension of data.
@@ -58,14 +56,9 @@ def global_pool(data, pool_type, layout="NCHW"):
     return cpp.nn.global_pool(data, POOL_TYPE_CODE[pool_type], layout)
 
 
-def pool(data,
-         kernel,
-         stride,
-         padding,
-         pool_type,
-         ceil_mode=False,
-         layout="NCHW",
-         count_include_pad=True):
+def pool(
+    data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW", count_include_pad=True
+):
     """Perform pooling on height and width dimension of data.
        It decides the height and width dimension according to the layout string,
        in which 'W' and 'H' means width and height respectively.
@@ -111,18 +104,29 @@ def pool(data,
     output : tvm.te.Tensor
         n-D in the same layout
     """
-    return cpp.nn.pool(data, kernel, stride, padding,
-                       POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
-
-def pool_grad(grads,
-              data,
-              kernel,
-              stride,
-              padding,
-              pool_type,
-              ceil_mode=False,
-              layout="NCHW",
-              count_include_pad=True):
+    return cpp.nn.pool(
+        data,
+        kernel,
+        stride,
+        padding,
+        POOL_TYPE_CODE[pool_type],
+        ceil_mode,
+        layout,
+        count_include_pad,
+    )
+
+
+def pool_grad(
+    grads,
+    data,
+    kernel,
+    stride,
+    padding,
+    pool_type,
+    ceil_mode=False,
+    layout="NCHW",
+    count_include_pad=True,
+):
     """Gradient of pooling on height and width dimension of data.
        It decides the height and width dimension according to the layout string,
        in which 'W' and 'H' means width and height respectively.
@@ -171,15 +175,20 @@ def pool_grad(grads,
     output : tvm.te.Tensor
         n-D in the same layout
     """
-    return cpp.nn.pool_grad(grads, data, kernel,
-                            stride, padding, POOL_TYPE_CODE[pool_type],
-                            ceil_mode, layout, count_include_pad)
-
-
-def adaptive_pool(data,
-                  output_size,
-                  pool_type,
-                  layout="NCHW"):
+    return cpp.nn.pool_grad(
+        grads,
+        data,
+        kernel,
+        stride,
+        padding,
+        POOL_TYPE_CODE[pool_type],
+        ceil_mode,
+        layout,
+        count_include_pad,
+    )
+
+
+def adaptive_pool(data, output_size, pool_type, layout="NCHW"):
     """Perform pooling on height and width dimension of data.
        The pooling kernel and stride sizes are automatically chosen for desired
        output sizes.
@@ -218,24 +227,16 @@ def adaptive_pool(data,
     return cpp.nn.adaptive_pool(data, output_size, POOL_TYPE_CODE[pool_type], layout)
 
 
-def adaptive_pool3d(data,
-                    output_size,
-                    pool_type,
-                    layout="NCDHW"):
+def adaptive_pool3d(data, output_size, pool_type, layout="NCDHW"):
     """Perform pooling on three dimensional data.
-       See the two dimensional version above for details.
+    See the two dimensional version above for details.
     """
     return cpp.nn.adaptive_pool3d(data, output_size, POOL_TYPE_CODE[pool_type], layout)
 
 
-def pool1d(data,
-           kernel,
-           stride,
-           padding,
-           pool_type,
-           ceil_mode=False,
-           layout="NCW",
-           count_include_pad=True):
+def pool1d(
+    data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCW", count_include_pad=True
+):
     """Perform pooling on width dimension of data.
        Width axis is determined according to the layout string.
        in which 'w' means width.
@@ -282,21 +283,35 @@ def pool1d(data,
         n-D in the same layout
     """
     if isinstance(kernel, int):
-        kernel = [kernel, ]
+        kernel = [
+            kernel,
+        ]
     if isinstance(stride, int):
-        stride = [stride, ]
-    return cpp.nn.pool1d(data, kernel, stride, padding,
-                         POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
-
-
-def pool3d(data,
-           kernel,
-           stride,
-           padding,
-           pool_type,
-           ceil_mode=False,
-           layout="NCDHW",
-           count_include_pad=True):
+        stride = [
+            stride,
+        ]
+    return cpp.nn.pool1d(
+        data,
+        kernel,
+        stride,
+        padding,
+        POOL_TYPE_CODE[pool_type],
+        ceil_mode,
+        layout,
+        count_include_pad,
+    )
+
+
+def pool3d(
+    data,
+    kernel,
+    stride,
+    padding,
+    pool_type,
+    ceil_mode=False,
+    layout="NCDHW",
+    count_include_pad=True,
+):
     """Perform pooling on depth, height and width dimension of data.
        It decides the depth, height and width dimension according to the layout string,
        in which 'D', 'W' and 'H' means depth, width and height respectively.
@@ -342,5 +357,13 @@ def pool3d(data,
     output : tvm.te.Tensor
         n-D in the same layout
     """
-    return cpp.nn.pool3d(data, kernel, stride, padding,
-                         POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
+    return cpp.nn.pool3d(
+        data,
+        kernel,
+        stride,
+        padding,
+        POOL_TYPE_CODE[pool_type],
+        ceil_mode,
+        layout,
+        count_include_pad,
+    )
index fb51384..f6f20d7 100644 (file)
@@ -20,7 +20,8 @@ from __future__ import absolute_import
 import tvm
 from tvm import te
 
-@tvm.te.tag_scope(tag='softmax_output')
+
+@tvm.te.tag_scope(tag="softmax_output")
 def softmax(x, axis=-1):
     """Perform softmax activation on the data
 
@@ -43,8 +44,8 @@ def softmax(x, axis=-1):
     if axis >= len(shape):
         ValueError("axis parameter should be less than input dim")
 
-    k1 = te.reduce_axis((0, shape[axis]), name='k')
-    k2 = te.reduce_axis((0, shape[axis]), name='k')
+    k1 = te.reduce_axis((0, shape[axis]), name="k")
+    k2 = te.reduce_axis((0, shape[axis]), name="k")
 
     def insert_reduce_index(indices, reduce_index):
         return indices[:axis] + (reduce_index,) + indices[axis:]
@@ -69,15 +70,20 @@ def softmax(x, axis=-1):
         return exp[indices] / expsum[non_reduce_indices]
 
     reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
-    max_elem = te.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
-    exp = te.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
-                     name='T_softmax_exp')
-    expsum = te.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
-                        name='T_softmax_expsum')
-    return te.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
-                      name='T_softmax_norm', attrs={"axis" : axis})
-
-@tvm.te.tag_scope(tag='log_softmax_output')
+    max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem")
+    exp = te.compute(shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp")
+    expsum = te.compute(
+        reduced_shape, lambda *indices: _compute_expsum(exp, *indices), name="T_softmax_expsum"
+    )
+    return te.compute(
+        shape,
+        lambda *indices: _normalize(exp, expsum, *indices),
+        name="T_softmax_norm",
+        attrs={"axis": axis},
+    )
+
+
+@tvm.te.tag_scope(tag="log_softmax_output")
 def log_softmax(x):
     """Perform log softmax activation on the data
 
@@ -94,10 +100,8 @@ def log_softmax(x):
 
     assert len(x.shape) == 2, "only support 2-dim log softmax"
     m, n = x.shape
-    k = te.reduce_axis((0, n), name='k')
-    max_elem = te.compute((m, ), lambda i: tvm.te.max(x[i, k], axis=k))
-    k = te.reduce_axis((0, n), name='k')
-    expsum = te.compute(
-        (m, ), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k))
-    return te.compute(
-        x.shape, lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]))
+    k = te.reduce_axis((0, n), name="k")
+    max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k))
+    k = te.reduce_axis((0, n), name="k")
+    expsum = te.compute((m,), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k))
+    return te.compute(x.shape, lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]))
index b90bd11..aedbc4a 100644 (file)
@@ -22,7 +22,7 @@ from tvm import te
 from .. import tag
 
 
-def space_to_depth(data, block_size, layout='NCHW'):
+def space_to_depth(data, block_size, layout="NCHW"):
     """Perform space to depth transformation on the data
 
     Parameters
@@ -42,21 +42,29 @@ def space_to_depth(data, block_size, layout='NCHW'):
         Output of shape [N, C * block_size**2, H / block_size, W / block_size]
     """
 
-    if layout == 'NCHW':
+    if layout == "NCHW":
         in_n, in_c, in_h, in_w = data.shape
-        output_shape = [in_n, in_c * block_size * block_size,
-                        tvm.tir.truncdiv(in_h, block_size), tvm.tir.truncdiv(in_w, block_size)]
-    elif layout == 'NHWC':
+        output_shape = [
+            in_n,
+            in_c * block_size * block_size,
+            tvm.tir.truncdiv(in_h, block_size),
+            tvm.tir.truncdiv(in_w, block_size),
+        ]
+    elif layout == "NHWC":
         in_n, in_h, in_w, in_c = data.shape
-        output_shape = [in_n, tvm.tir.truncdiv(in_h, block_size), tvm.tir.truncdiv(
-            in_w, block_size), in_c * block_size * block_size]
+        output_shape = [
+            in_n,
+            tvm.tir.truncdiv(in_h, block_size),
+            tvm.tir.truncdiv(in_w, block_size),
+            in_c * block_size * block_size,
+        ]
     else:
         raise ValueError("Only NCHW and NHWC layouts are currently supported.")
 
     def _get_indices(*indices):
-        if layout == 'NCHW':
+        if layout == "NCHW":
             n, c, y, x = indices
-        elif layout == 'NHWC':
+        elif layout == "NHWC":
             n, y, x, c = indices
         return n, c, y, x
 
@@ -66,16 +74,14 @@ def space_to_depth(data, block_size, layout='NCHW'):
         x_idx = tvm.tir.truncmod(block_offset, block_size)
         y_idx = tvm.tir.truncdiv(block_offset, block_size)
 
-        if layout == 'NCHW':
-            output = data(n, channel_idx, y_idx +
-                          (y * block_size), x_idx + (x * block_size))
+        if layout == "NCHW":
+            output = data(n, channel_idx, y_idx + (y * block_size), x_idx + (x * block_size))
         else:
-            output = data(n, y_idx + (y * block_size), x_idx +
-                          (x * block_size), channel_idx)
+            output = data(n, y_idx + (y * block_size), x_idx + (x * block_size), channel_idx)
         return output
 
     def _compute(*indices):
         n, c, y, x = _get_indices(*indices)
         return _get_pixel(n, c, y, x)
 
-    return te.compute(output_shape, _compute, name='space_to_depth', tag=tag.INJECTIVE)
+    return te.compute(output_shape, _compute, name="space_to_depth", tag=tag.INJECTIVE)
index b24121b..e3c144a 100644 (file)
@@ -59,9 +59,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
 
 
 def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
-    oshape = (
-        get_const_tuple(data.shape)[0],
-        get_const_tuple(weight_indptr.shape)[0] - 1)
+    oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)
 
     def f(i, row):
         row_start = weight_indptr[row]
@@ -72,21 +70,21 @@ def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
         a_val = weight_data[elem]
         weight_val = data[i, weight_indices[elem]]
         return te.sum(a_val * weight_val, axis=elem_idx)
+
     return te.compute(oshape, f, tag="sparse_dense_csrmm")
 
 
 def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
     (m, _) = get_const_tuple(data.shape)
     (_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
-    (num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape)
+    (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
     num_blocks = num_blocks_plus_1 - 1
 
     def _compute_block(i, nb_j, j):
         row_start = weight_indptr[nb_j]
         row_end = weight_indptr[nb_j + 1]
         row_elems = row_end - row_start
-        elem_idx = te.reduce_axis(
-            (0, row_elems), name="elem_idx")
+        elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
         block_offset = row_start + elem_idx
         c = te.reduce_axis((0, bs_c), name="c")
         block_j = weight_indices[block_offset]
@@ -97,13 +95,12 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
 
-    bsrmm_block = te.compute(
-        (m, num_blocks, bs_r), _compute_block,
-        tag="sparse_dense_bsrmm_block")
+    bsrmm_block = te.compute((m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block")
     return te.compute(
         (m, num_blocks * bs_r),
         lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
-        tag="sparse_dense_bsrmm")
+        tag="sparse_dense_bsrmm",
+    )
 
 
 def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
@@ -140,18 +137,20 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
 
     nnz = get_const_tuple(sparse_data.shape)[0]
     n = get_const_tuple(sparse_indptr.shape)[0] - 1
-    output_shape = [(nnz,), (nnz,), (n+1,)]
+    output_shape = [(nnz,), (nnz,), (n + 1,)]
 
     # TODO: Add BSR transpose support
 
     output_data, output_indices, output_indptr = te.extern(
         shape=output_shape,
         inputs=[sparse_data, sparse_indices, sparse_indptr],
-        fcompute=lambda ins, outs:
-        _csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]),
+        fcompute=lambda ins, outs: _csr_transpose_ir(
+            ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]
+        ),
         tag="sparse_transpose_csr",
-        dtype=['float32', 'int32', 'int32'],
-        name='out')
+        dtype=["float32", "int32", "int32"],
+        name="out",
+    )
 
     return [output_data, output_indices, output_indptr]
 
@@ -171,26 +170,26 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
     n = get_const_tuple(indptr.shape)[0] - 1
     nnz = get_const_tuple(data.shape)[0]
 
-    with irb.for_range(0, n, for_type="parallel", name='col') as col:
+    with irb.for_range(0, n, for_type="parallel", name="col") as col:
         out_indptr_ptr[col] = 0
 
-    with irb.for_range(0, nnz, for_type="serial", name='nz_idx') as nz_idx:
+    with irb.for_range(0, nnz, for_type="serial", name="nz_idx") as nz_idx:
         out_indptr_ptr[indices_ptr[nz_idx]] += 1
 
-    cumsum = irb.allocate('int32', (1,), name='cumsum', scope='local')
-    temp = irb.allocate('int32', (1,), name='temp', scope='local')
+    cumsum = irb.allocate("int32", (1,), name="cumsum", scope="local")
+    temp = irb.allocate("int32", (1,), name="temp", scope="local")
     cumsum[0] = 0
-    with irb.for_range(0, n, for_type="serial", name='col') as col:
+    with irb.for_range(0, n, for_type="serial", name="col") as col:
         temp[0] = out_indptr_ptr[col]
         out_indptr_ptr[col] = cumsum[0]
         cumsum[0] += temp[0]
 
     out_indptr_ptr[n] = nnz
 
-    with irb.for_range(0, n, for_type="serial", name='row') as row:
+    with irb.for_range(0, n, for_type="serial", name="row") as row:
         offset = indptr_ptr[row]
-        diff = indptr_ptr[row+1] - indptr_ptr[row]
-        with irb.for_range(0, diff, for_type="serial", name='idx') as idx:
+        diff = indptr_ptr[row + 1] - indptr_ptr[row]
+        with irb.for_range(0, diff, for_type="serial", name="idx") as idx:
             real_idx = offset + idx
             col = indices_ptr[real_idx]
             dest = out_indptr_ptr[col]
@@ -199,8 +198,8 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
             out_data_ptr[dest] = data_ptr[real_idx]
             out_indptr_ptr[col] += 1
 
-    last = irb.allocate('int32', (1,), name='last', scope='local')
-    temp2 = irb.allocate('int32', (1,), name='temp2', scope='local')
+    last = irb.allocate("int32", (1,), name="last", scope="local")
+    temp2 = irb.allocate("int32", (1,), name="temp2", scope="local")
     last[0] = 0
     with irb.for_range(0, n, for_type="serial", name="col") as col:
         temp2[0] = out_indptr_ptr[col]
index 6c07cf4..b390b80 100644 (file)
@@ -20,8 +20,15 @@ from tvm import te
 from ..util import simplify
 
 
-def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
-               align_corners=False, output_shape=None):
+def upsampling(
+    data,
+    scale_h,
+    scale_w,
+    layout="NCHW",
+    method="nearest_neighbor",
+    align_corners=False,
+    output_shape=None,
+):
     """Perform upsampling on the data.
        Nearest neighbor and bilinear upsampling are supported.
 
@@ -56,34 +63,55 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
     """
     base_layout = layout[0:4]
     if base_layout == "NCHW":
-        if not output_shape: #static case
+        if not output_shape:  # static case
             scaled_h = data.shape[2] * scale_h
             scaled_w = data.shape[3] * scale_w
-            reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
-                            simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
-        else: #dynamic case -- we don't need to scale; already done in shape func
-            reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
-                            simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)))
+            reshape_size = (
+                simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
+                simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)),
+            )
+        else:  # dynamic case -- we don't need to scale; already done in shape func
+            reshape_size = (
+                simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
+                simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)),
+            )
     elif layout == "NHWC":
-        if not output_shape: #static case
+        if not output_shape:  # static case
             scaled_h = data.shape[1] * scale_h
             scaled_w = data.shape[2] * scale_w
-            reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
-                            simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)))
-        else: #dynamic case
-            reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
-                            simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)))
+            reshape_size = (
+                simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
+                simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)),
+            )
+        else:  # dynamic case
+            reshape_size = (
+                simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
+                simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
+            )
 
     else:
         raise ValueError("not support this layout {} yet".format(layout))
     coord_trans = "align_corners" if align_corners else "asymmetric"
-    return topi.image.resize(data, reshape_size, layout=layout,
-                             method=method, coordinate_transformation_mode=coord_trans,
-                             output_shape=output_shape)
-
-
-def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
-                 coordinate_transformation_mode="half_pixel", output_shape=None):
+    return topi.image.resize(
+        data,
+        reshape_size,
+        layout=layout,
+        method=method,
+        coordinate_transformation_mode=coord_trans,
+        output_shape=output_shape,
+    )
+
+
+def upsampling3d(
+    data,
+    scale_d,
+    scale_h,
+    scale_w,
+    layout="NCDHW",
+    method="nearest_neighbor",
+    coordinate_transformation_mode="half_pixel",
+    output_shape=None,
+):
     """Perform upsampling on the data.
        Nearest neighbor and bilinear upsampling are supported.
 
@@ -127,30 +155,43 @@ def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='neares
     """
     base_layout = layout[0:5]
     if base_layout == "NCDHW":
-        if not output_shape: # static case
+        if not output_shape:  # static case
             scaled_d = data.shape[2] * scale_d
             scaled_h = data.shape[3] * scale_h
             scaled_w = data.shape[4] * scale_w
-            resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[2].dtype)),
-                            simplify(topi.cast(te.round(scaled_h), data.shape[3].dtype)),
-                            simplify(topi.cast(te.round(scaled_w), data.shape[4].dtype)))
-        else: # dynamic case -- don't need to scale; already done in shape func
-            resize_shape = (simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
-                            simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)),
-                            simplify(topi.cast(te.round(output_shape[4]), data.shape[4].dtype)))
+            resize_shape = (
+                simplify(topi.cast(te.round(scaled_d), data.shape[2].dtype)),
+                simplify(topi.cast(te.round(scaled_h), data.shape[3].dtype)),
+                simplify(topi.cast(te.round(scaled_w), data.shape[4].dtype)),
+            )
+        else:  # dynamic case -- don't need to scale; already done in shape func
+            resize_shape = (
+                simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
+                simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)),
+                simplify(topi.cast(te.round(output_shape[4]), data.shape[4].dtype)),
+            )
     elif layout == "NDHWC":
-        if not output_shape: # static case
+        if not output_shape:  # static case
             scaled_d = data.shape[1] * scale_d
             scaled_h = data.shape[2] * scale_h
             scaled_w = data.shape[3] * scale_w
-            resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[1].dtype)),
-                            simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
-                            simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
-        else: # dynamic case
-            resize_shape = (simplify(topi.cast(te.round(output_shape[1]), data.shape[1].dtype)),
-                            simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
-                            simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)))
+            resize_shape = (
+                simplify(topi.cast(te.round(scaled_d), data.shape[1].dtype)),
+                simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
+                simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)),
+            )
+        else:  # dynamic case
+            resize_shape = (
+                simplify(topi.cast(te.round(output_shape[1]), data.shape[1].dtype)),
+                simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
+                simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)),
+            )
     else:
         raise ValueError("not support this layout {} yet".format(layout))
-    return topi.image.resize3d(data, resize_shape, layout=layout, method=method,
-                               coordinate_transformation_mode=coordinate_transformation_mode)
+    return topi.image.resize3d(
+        data,
+        resize_shape,
+        layout=layout,
+        method=method,
+        coordinate_transformation_mode=coordinate_transformation_mode,
+    )
index 5a9b49e..0894656 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import
 import tvm
 from ..util import get_const_int
 
+
 def infer_pad(data, data_pad):
     """Infer the padding from stages in reverse.
 
@@ -47,6 +48,7 @@ def infer_pad(data, data_pad):
     wpad = (TW - IW) // 2
     return get_const_int(hpad), get_const_int(wpad)
 
+
 def infer_pad3d(data, data_pad, layout):
     """Infer the padding from stages in reverse.
 
@@ -78,11 +80,12 @@ def infer_pad3d(data, data_pad, layout):
         _, _, TD, TH, TW = data_pad.shape
     else:
         raise ValueError("Layout {} is not supported".format(layout))
-    dpad = (TD - ID)
-    hpad = (TH - IH)
-    wpad = (TW - IW)
+    dpad = TD - ID
+    hpad = TH - IH
+    wpad = TW - IW
     return get_const_int(dpad), get_const_int(hpad), get_const_int(wpad)
 
+
 def infer_stride(data, kernel, out):
     """Infer the stride from stages in reverse.
 
@@ -199,8 +202,7 @@ def get_pad_tuple3d(padding, kernel):
             pad_h = padding[1] * 2
             pad_w = padding[2] * 2
         elif len(padding) == 6:
-            return padding[0], padding[1], padding[2], padding[3], \
-                padding[4], padding[5]
+            return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
         else:
             raise ValueError("Size of padding can only be 3 or 6")
     elif isinstance(padding, int):
@@ -245,7 +247,7 @@ def get_pad_tuple1d(padding, kernel):
         if len(padding) == 1:
             pad_w = padding[0] * 2
         elif len(padding) == 2:
-            return  padding[0], padding[1]
+            return padding[0], padding[1]
         else:
             raise ValueError("Size of padding can only be 2 or 4")
     elif isinstance(padding, int):
index d967431..d43586d 100644 (file)
@@ -34,34 +34,38 @@ def _cook_toom_convolution(a, n, r):
     """Compute Cook-Toom convolution A,B,G matrices"""
 
     def _F_m(a, n):
-        f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
-        F = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
+        f = lambda j, i: reduce(mul, ((a[i] - a[k] if k != i else 1) for k in range(0, n - 1)), 1)
+        F = np.fromfunction(np.vectorize(f), (1, n - 1), dtype=int)
         F = np.diagflat(F)
-        F = np.append(F, np.zeros((n-1, 1), dtype=int), axis=1)
-        f = lambda i, j: (1 if j == (n-1) else 0)
+        F = np.append(F, np.zeros((n - 1, 1), dtype=int), axis=1)
+        f = lambda i, j: (1 if j == (n - 1) else 0)
         z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
 
         return np.append(F, z, axis=0)
 
     def _A_m(a, m, n):
-        f = lambda i, j: a[i]**j
-        A = np.fromfunction(np.vectorize(f), (m-1, n), dtype=int)
-        f = lambda i, j: (1 if j == (n-1) else 0)
+        f = lambda i, j: a[i] ** j
+        A = np.fromfunction(np.vectorize(f), (m - 1, n), dtype=int)
+        f = lambda i, j: (1 if j == (n - 1) else 0)
         z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
 
         return np.append(A, z, axis=0)
 
     def _B_m(a, n):
-        f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
-        Ff = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
-        f = lambda i, nth: (reduce(mul, [(np.poly1d([1, -a[k]]) if k != i else 1) \
-                                         for k in range(0, n-1)], 1)).coef[n-1-nth-1]/Ff[0, i]
-        F = np.fromfunction(np.vectorize(f), (n-1, n-1), dtype=int)
-        f = lambda i, j: -a[i]**(n-1)
-        t = np.fromfunction(np.vectorize(f), (n-1, 1), dtype=int)
-        T = np.append(np.eye(n-1), t, axis=1)
-
-        return np.append(F.T.dot(T), np.array([np.eye(n)[n-1]]), axis=0)
+        f = lambda j, i: reduce(mul, ((a[i] - a[k] if k != i else 1) for k in range(0, n - 1)), 1)
+        Ff = np.fromfunction(np.vectorize(f), (1, n - 1), dtype=int)
+        f = (
+            lambda i, nth: (
+                reduce(mul, [(np.poly1d([1, -a[k]]) if k != i else 1) for k in range(0, n - 1)], 1)
+            ).coef[n - 1 - nth - 1]
+            / Ff[0, i]
+        )
+        F = np.fromfunction(np.vectorize(f), (n - 1, n - 1), dtype=int)
+        f = lambda i, j: -a[i] ** (n - 1)
+        t = np.fromfunction(np.vectorize(f), (n - 1, 1), dtype=int)
+        T = np.append(np.eye(n - 1), t, axis=1)
+
+        return np.append(F.T.dot(T), np.array([np.eye(n)[n - 1]]), axis=0)
 
     alpha = n + r - 1
 
@@ -80,6 +84,7 @@ def _cook_toom_convolution(a, n, r):
 
     return (A, B, G)
 
+
 def _interpolation_points(degree):
     """Propose filter points"""
 
@@ -96,47 +101,64 @@ def _interpolation_points(degree):
     in_pts = [
         #   {invalid}
         [],
-        #01 {E=4.63E-08 on conv2d  [1]}
+        # 01 {E=4.63E-08 on conv2d  [1]}
         [],
-        #02 {E=7.65E-08 on F( 2,3) [1]}
-        [0,   -1,    1],
-        #03 {E=2.35E-07 on F( 3,3) [1]}
-        [0,   -1,    1,  1/2],
-        #04 {E=3.29E-07 on F( 4,3) [1]}
-        [0,   -1,    1,  1/2,   -2],
-        #05 {E=6.81E-07 on F( 5,3) [1]}
-        [0,   -1,    1,  1/2,   -2, -1/2],
-        #06 {E=8.79E-07 on F( 6,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2],
-        #07 {E=3.71E-06 on F( 7,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4],
-        #08 {E=7.35E-06 on F( 8,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4],
-        #09 {E=2.20E-05 on F( 9,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,  3/4, -4/3],
-        #10 {E=3.22E-05 on F(10,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  3/4, -4/3],
-        #11 {E=1.09E-04 on F(11,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  3/4, -4/3,  1/4],
-        #12 {E=1.99E-04 on F(12,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  1/4, -3/4,  4/3,   -4],
-        #13 {E=5.54E-04 on F(13,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  1/4, -3/4,  4/3,  3/4, -4/3],
-        #14 {E=8.80E-04 on F(14,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  1/4, -3/4,  4/3,   -4,  3/4, -4/3],
-        #15 {E=1.07E-02 on F(15,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  1/4, -3/4,  4/3,   -4,  2/3, -3/2,  3/2],
-        #16 {E=1.93E-02 on F(16,3) [1]}
-        [0,   -1,    1,  1/2, -1/2,    2,   -2, -1/4,    4,  1/4, -3/4,  4/3,   -4,  2/3, -3/2, -2/3,  3/2]
-    ] # pylint: enable=bad-whitespace,line-too-long
-
-    return np.array(in_pts[degree-1], dtype=np.float64)
+        # 02 {E=7.65E-08 on F( 2,3) [1]}
+        [0, -1, 1],
+        # 03 {E=2.35E-07 on F( 3,3) [1]}
+        [0, -1, 1, 1 / 2],
+        # 04 {E=3.29E-07 on F( 4,3) [1]}
+        [0, -1, 1, 1 / 2, -2],
+        # 05 {E=6.81E-07 on F( 5,3) [1]}
+        [0, -1, 1, 1 / 2, -2, -1 / 2],
+        # 06 {E=8.79E-07 on F( 6,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2],
+        # 07 {E=3.71E-06 on F( 7,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4],
+        # 08 {E=7.35E-06 on F( 8,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4],
+        # 09 {E=2.20E-05 on F( 9,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 3 / 4, -4 / 3],
+        # 10 {E=3.22E-05 on F(10,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 3 / 4, -4 / 3],
+        # 11 {E=1.09E-04 on F(11,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 3 / 4, -4 / 3, 1 / 4],
+        # 12 {E=1.99E-04 on F(12,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 1 / 4, -3 / 4, 4 / 3, -4],
+        # 13 {E=5.54E-04 on F(13,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 1 / 4, -3 / 4, 4 / 3, 3 / 4, -4 / 3],
+        # 14 {E=8.80E-04 on F(14,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 1 / 4, -3 / 4, 4 / 3, -4, 3 / 4, -4 / 3],
+        # 15 {E=1.07E-02 on F(15,3) [1]}
+        [0, -1, 1, 1 / 2, -1 / 2, 2, -2, -1 / 4, 4, 1 / 4, -3 / 4, 4 / 3, -4, 2 / 3, -3 / 2, 3 / 2],
+        # 16 {E=1.93E-02 on F(16,3) [1]}
+        [
+            0,
+            -1,
+            1,
+            1 / 2,
+            -1 / 2,
+            2,
+            -2,
+            -1 / 4,
+            4,
+            1 / 4,
+            -3 / 4,
+            4 / 3,
+            -4,
+            2 / 3,
+            -3 / 2,
+            -2 / 3,
+            3 / 2,
+        ],
+    ]  # pylint: enable=bad-whitespace,line-too-long
+
+    return np.array(in_pts[degree - 1], dtype=np.float64)
 
 
 @memoize("topi.nn.winograd_matrices", save_at_exit=False)
 def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
-    """Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
-    """
+    """Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`."""
     if not 1 < tile_size < 9:
         raise ValueError("Unsupported tile size for Winograd: {}".format(tile_size))
     if not 2 < kernel_size < 8:
index 74ba688..77f9ad4 100644 (file)
@@ -19,6 +19,7 @@
 from __future__ import absolute_import as _abs
 from . import cpp
 
+
 def _get_real_axis(ndim, axis):
     if axis is None:
         real_axis = list(range(ndim))
@@ -33,7 +34,8 @@ def _get_real_axis(ndim, axis):
                 ele += ndim
             if ele >= ndim:
                 raise ValueError(
-                    "{} exceeds the maximum dimension {}. Received axis={}".format(ele, ndim, axis))
+                    "{} exceeds the maximum dimension {}. Received axis={}".format(ele, ndim, axis)
+                )
             real_axis.append(ele)
         real_axis.sort()
         real_axis = list(set(real_axis))  # Remove the duplicates
index bc5d5c3..0857d09 100644 (file)
@@ -23,9 +23,11 @@ from .. import generic
 from ..util import get_const_tuple
 from ..nn.util import get_pad_tuple
 
+
 @autotvm.register_topi_compute("conv2d_nchw_miopen.rocm")
-def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
-                       layout='NCHW', out_dtype='float32'):
+def conv2d_nchw_miopen(
+    cfg, data, kernel, strides, padding, dilation, layout="NCHW", out_dtype="float32"
+):
     """Conv2D operator for rocm backend.
 
     Parameters
@@ -59,7 +61,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
     CO, CI, KH, KW = get_const_tuple(kernel.shape)
     N, _, H, W = get_const_tuple(data.shape)
 
-    assert layout == 'NCHW'
+    assert layout == "NCHW"
 
     # handle dilation
     stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
@@ -69,19 +71,13 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
     assert (pt == pb) and (pl == pr)
     OH = (H + 2 * pad_h - KH) // stride_h + 1
     OW = (W + 2 * pad_w - KW) // stride_w + 1
-    cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
-                 ((KW - 1) * dilation_w + 1))
-
-    return miopen.conv2d_forward(data,
-                                 kernel,
-                                 stride_h,
-                                 stride_w,
-                                 pt,
-                                 pl,
-                                 dilation_h,
-                                 dilation_w,
-                                 conv_mode=0,
-                                 data_type=1)
+    cfg.add_flop(
+        2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1)
+    )
+
+    return miopen.conv2d_forward(
+        data, kernel, stride_h, stride_w, pt, pl, dilation_h, dilation_w, conv_mode=0, data_type=1
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_miopen.rocm")
index 989cc2a..4a771c6 100644 (file)
@@ -23,7 +23,8 @@ from .. import generic, nn
 from .. import tag
 from ..util import traverse_inline
 
-@autotvm.register_topi_compute('dense.rocm')
+
+@autotvm.register_topi_compute("dense.rocm")
 def dense(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for rocm backend.
 
@@ -46,8 +47,7 @@ def dense(cfg, data, weight, bias=None, out_dtype=None):
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim dense"
+    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
@@ -55,7 +55,7 @@ def dense(cfg, data, weight, bias=None, out_dtype=None):
     return nn.dense(data, weight, bias, out_dtype)
 
 
-@autotvm.register_topi_schedule('dense.rocm')
+@autotvm.register_topi_schedule("dense.rocm")
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -74,7 +74,7 @@ def schedule_dense(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == 'dense':
+        if op.tag == "dense":
             Dense = op.output(0)
             num_thread = 64
             k = Dense.op.reduce_axis[0]
@@ -100,7 +100,7 @@ def schedule_dense(cfg, outs):
     return s
 
 
-@autotvm.register_topi_compute('dense_rocblas.rocm')
+@autotvm.register_topi_compute("dense_rocblas.rocm")
 def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for rocm backend with cblas.
 
@@ -131,13 +131,13 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
     out_dim, _ = weight.shape
     cfg.add_flop(batch * in_dim * out_dim * 2)
     if bias is not None:
-        matmul = te.compute((batch, out_dim),
-                            lambda i, j: matmul[i, j] + bias[j],
-                            tag=tag.BROADCAST)
+        matmul = te.compute(
+            (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
+        )
     return matmul
 
 
-@autotvm.register_topi_schedule('dense_rocblas.rocm')
+@autotvm.register_topi_schedule("dense_rocblas.rocm")
 def schedule_dense_rocblas(_, outs):
     """Schedule for dense operator with rocm cblas"""
     return generic.schedule_extern(outs)
index 5f134cb..c963375 100644 (file)
@@ -19,5 +19,6 @@ from __future__ import absolute_import as _abs
 
 from .. import cpp
 
+
 def schedule_lrn(outs):
     return cpp.rocm.schedule_lrn(outs)
index e4e9886..3471057 100644 (file)
@@ -25,8 +25,7 @@ def _scatter_1d(data, indices, updates):
     for i in range(data.shape[0]):
         out[i] = data[i]
     for i in range(indices.shape[0]):
-        out[indices[i] if indices[i] >= 0 else indices[i] +
-            data.shape[0]] = updates[i]
+        out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
     return out
 
 
@@ -39,13 +38,15 @@ def _scatter_2d(data, indices, updates, axis):
     if axis == 0:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
-                out[indices[i, j] if indices[i, j] >=
-                    0 else indices[i, j] + data.shape[axis], j] = updates[i, j]
+                out[
+                    indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
+                ] = updates[i, j]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
-                out[i, indices[i, j] if indices[i, j] >=
-                    0 else indices[i, j] + data.shape[axis]] = updates[i, j]
+                out[
+                    i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
+                ] = updates[i, j]
 
     return out
 
@@ -61,20 +62,35 @@ def _scatter_3d(data, indices, updates, axis):
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis], j, k] = updates[i, j, k]
+                    out[
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                        j,
+                        k,
+                    ] = updates[i, j, k]
     elif axis == 1:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[i, indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis], k] = updates[i, j, k]
+                    out[
+                        i,
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                        k,
+                    ] = updates[i, j, k]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[i, j, indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis]] = updates[i, j, k]
+                    out[
+                        i,
+                        j,
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                    ] = updates[i, j, k]
 
     return out
 
@@ -93,36 +109,53 @@ def _scatter_4d(data, indices, updates, axis):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            j, k, l] = updates[i, j, k, l]
+                        out[
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            j,
+                            k,
+                            l,
+                        ] = updates[i, j, k, l]
     elif axis == 1:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            k, l] = updates[i, j, k, l]
+                        out[
+                            i,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            k,
+                            l,
+                        ] = updates[i, j, k, l]
     elif axis == 2:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i, j,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            l] = updates[i, j, k, l]
+                        out[
+                            i,
+                            j,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            l,
+                        ] = updates[i, j, k, l]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i, j, k,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis]
-                            ] = updates[i, j, k, l]
+                        out[
+                            i,
+                            j,
+                            k,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                        ] = updates[i, j, k, l]
 
     return out
 
index 046972b..4c77a07 100644 (file)
@@ -25,8 +25,7 @@ def _scatter_add_1d(data, indices, updates):
     for i in range(data.shape[0]):
         out[i] = data[i]
     for i in range(indices.shape[0]):
-        out[indices[i] if indices[i] >= 0 else indices[i] +
-            data.shape[0]] += updates[i]
+        out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] += updates[i]
     return out
 
 
@@ -39,13 +38,15 @@ def _scatter_add_2d(data, indices, updates, axis):
     if axis == 0:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
-                out[indices[i, j] if indices[i, j] >=
-                    0 else indices[i, j] + data.shape[axis], j] += updates[i, j]
+                out[
+                    indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
+                ] += updates[i, j]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
-                out[i, indices[i, j] if indices[i, j] >=
-                    0 else indices[i, j] + data.shape[axis]] += updates[i, j]
+                out[
+                    i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
+                ] += updates[i, j]
 
     return out
 
@@ -61,20 +62,35 @@ def _scatter_add_3d(data, indices, updates, axis):
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis], j, k] += updates[i, j, k]
+                    out[
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                        j,
+                        k,
+                    ] += updates[i, j, k]
     elif axis == 1:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[i, indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis], k] += updates[i, j, k]
+                    out[
+                        i,
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                        k,
+                    ] += updates[i, j, k]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
-                    out[i, j, indices[i, j, k] if indices[i, j, k] >=
-                        0 else indices[i, j, k] + data.shape[axis]] += updates[i, j, k]
+                    out[
+                        i,
+                        j,
+                        indices[i, j, k]
+                        if indices[i, j, k] >= 0
+                        else indices[i, j, k] + data.shape[axis],
+                    ] += updates[i, j, k]
 
     return out
 
@@ -93,36 +109,53 @@ def _scatter_add_4d(data, indices, updates, axis):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            j, k, l] += updates[i, j, k, l]
+                        out[
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            j,
+                            k,
+                            l,
+                        ] += updates[i, j, k, l]
     elif axis == 1:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            k, l] += updates[i, j, k, l]
+                        out[
+                            i,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            k,
+                            l,
+                        ] += updates[i, j, k, l]
     elif axis == 2:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i, j,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis],
-                            l] += updates[i, j, k, l]
+                        out[
+                            i,
+                            j,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                            l,
+                        ] += updates[i, j, k, l]
     else:
         for i in range(indices.shape[0]):
             for j in range(indices.shape[1]):
                 for k in const_range(indices.shape[2]):
                     for l in const_range(indices.shape[3]):
-                        out[i, j, k,
-                            indices[i, j, k, l] if indices[i, j, k, l] >=
-                            0 else indices[i, j, k, l] + data.shape[axis]
-                            ] += updates[i, j, k, l]
+                        out[
+                            i,
+                            j,
+                            k,
+                            indices[i, j, k, l]
+                            if indices[i, j, k, l] >= 0
+                            else indices[i, j, k, l] + data.shape[axis],
+                        ] += updates[i, j, k, l]
 
     return out
 
index f79eb52..86e2bad 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from .util import get_const_tuple
 
+
 def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array
     of indices having the same shape as an input array that index
@@ -69,33 +70,35 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
     if valid_count is not None:
         valid_count_buf = tvm.tir.decl_buffer(
-            valid_count.shape, valid_count.dtype,
-            "valid_count_buf", data_alignment=4)
+            valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4
+        )
         out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
-        out = \
-            te.extern(data.shape,
-                      [data, valid_count],
-                      lambda ins, outs: tvm.tir.call_packed(
-                          "tvm.contrib.sort.argsort_nms", ins[0], ins[1],
-                          outs[0], axis, is_ascend),
-                      dtype="int32",
-                      in_buffers=[data_buf, valid_count_buf],
-                      out_buffers=out_buf,
-                      name="argsort_nms_cpu",
-                      tag="argsort_nms_cpu")
+        out = te.extern(
+            data.shape,
+            [data, valid_count],
+            lambda ins, outs: tvm.tir.call_packed(
+                "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend
+            ),
+            dtype="int32",
+            in_buffers=[data_buf, valid_count_buf],
+            out_buffers=out_buf,
+            name="argsort_nms_cpu",
+            tag="argsort_nms_cpu",
+        )
     else:
         out_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
-        out = \
-            te.extern(data.shape,
-                      [data],
-                      lambda ins, outs: tvm.tir.call_packed(
-                          "tvm.contrib.sort.argsort", ins[0],
-                          outs[0], axis, is_ascend),
-                      dtype=dtype,
-                      in_buffers=[data_buf],
-                      out_buffers=out_buf,
-                      name="argsort_cpu",
-                      tag="argsort_cpu")
+        out = te.extern(
+            data.shape,
+            [data],
+            lambda ins, outs: tvm.tir.call_packed(
+                "tvm.contrib.sort.argsort", ins[0], outs[0], axis, is_ascend
+            ),
+            dtype=dtype,
+            in_buffers=[data_buf],
+            out_buffers=out_buf,
+            name="argsort_cpu",
+            tag="argsort_cpu",
+        )
     return out
 
 
@@ -146,12 +149,15 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     out_shapes = [out_shape] * len(out_bufs)
 
     kv = kvar if not isinstance(k, int) else k
-    out = te.extern(out_shapes,
-                    [data],
-                    lambda ins, outs: tvm.tir.call_packed(
-                        "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend),
-                    in_buffers=[data_buf],
-                    out_buffers=out_bufs,
-                    name="topk_cpu",
-                    tag="topk_cpu")
+    out = te.extern(
+        out_shapes,
+        [data],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend
+        ),
+        in_buffers=[data_buf],
+        out_buffers=out_bufs,
+        name="topk_cpu",
+        tag="topk_cpu",
+    )
     return out
index 8dc0894..954f9dd 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 from .. import tag
 from ..util import simplify
 
+
 def csrmm_default(data, indices, indptr, weight, bias=None):
     # pylint: disable=invalid-name
     """The default implementation of csrmm in topi.
@@ -47,14 +48,20 @@ def csrmm_default(data, indices, indptr, weight, bias=None):
     output : tvm.te.Tensor
         2-D with shape [m, n]
     """
-    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
-        and len(weight.shape) == 2, "only support 2-dim csrmm"
-    assert isinstance(weight, te.tensor.Tensor), \
-        "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
+    assert (
+        len(data.shape) == 1
+        and len(indices.shape) == 1
+        and len(indptr.shape) == 1
+        and len(weight.shape) == 2
+    ), "only support 2-dim csrmm"
+    assert isinstance(
+        weight, te.tensor.Tensor
+    ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
     if bias is not None:
         assert len(bias.shape) == 1
-    M = simplify(indptr.shape[0]-1)
+    M = simplify(indptr.shape[0] - 1)
     _, N = weight.shape
+
     def csrmm_default_ir(data, indices, indptr, weight, out):
         """define ir for csrmm"""
         irb = tvm.tir.ir_builder.create()
@@ -63,28 +70,33 @@ def csrmm_default(data, indices, indptr, weight, bias=None):
         indptr_ptr = irb.buffer_ptr(indptr)
         weight_ptr = irb.buffer_ptr(weight)
         out_ptr = irb.buffer_ptr(out)
-        M = simplify(indptr.shape[0]-1)
+        M = simplify(indptr.shape[0] - 1)
         _, N = weight.shape
-        with irb.for_range(0, N, for_type="vectorize", name='n') as n:
-            with irb.for_range(0, M, for_type="parallel", name='row') as row:
-                dot = irb.allocate('float32', (1,), name='dot', scope='local')
-                out_ptr[row*N+n] = 0.
-                dot[0] = 0.
+        with irb.for_range(0, N, for_type="vectorize", name="n") as n:
+            with irb.for_range(0, M, for_type="parallel", name="row") as row:
+                dot = irb.allocate("float32", (1,), name="dot", scope="local")
+                out_ptr[row * N + n] = 0.0
+                dot[0] = 0.0
                 row_start = indptr_ptr[row]
-                row_end = indptr_ptr[row+1]
-                row_elems = row_end-row_start
-                with irb.for_range(0, row_elems, name='idx') as idx:
-                    elem = row_start+idx
-                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]*N+n]
-                out_ptr[row*N+n] += dot[0]
+                row_end = indptr_ptr[row + 1]
+                row_elems = row_end - row_start
+                with irb.for_range(0, row_elems, name="idx") as idx:
+                    elem = row_start + idx
+                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem] * N + n]
+                out_ptr[row * N + n] += dot[0]
         return irb.get()
+
     oshape = (M, N)
-    matmul = te.extern(oshape, [data, indices, indptr, weight],
-                       lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
-                       tag="csrmm", dtype='float32', name='out')
+    matmul = te.extern(
+        oshape,
+        [data, indices, indptr, weight],
+        lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+        tag="csrmm",
+        dtype="float32",
+        name="out",
+    )
     if bias is not None:
-        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[i], \
-                            tag=tag.BROADCAST)
+        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[i], tag=tag.BROADCAST)
     return matmul
 
 
index c0aa1b4..afe3bc7 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from .. import tag
 
+
 def csrmv_default(data, indices, indptr, weight, bias=None):
     """The default implementation of csrmv in topi.
 
@@ -45,13 +46,14 @@ def csrmv_default(data, indices, indptr, weight, bias=None):
     output : tvm.te.Tensor
         2-D with shape [m, 1]
     """
-    assert len(data.shape) == 1 and len(weight.shape) == 2, \
-        "only support 2-dim csrmv"
-    assert isinstance(weight, te.tensor.Tensor), \
-        "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
+    assert len(data.shape) == 1 and len(weight.shape) == 2, "only support 2-dim csrmv"
+    assert isinstance(
+        weight, te.tensor.Tensor
+    ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
     if bias is not None:
         assert len(bias.shape) == 1
-    batch = indptr.shape[0]-1
+    batch = indptr.shape[0] - 1
+
     def csrmv_default_ir(data, indices, indptr, weight, out):
         """define ir for csrmv"""
         irb = tvm.tir.ir_builder.create()
@@ -60,26 +62,31 @@ def csrmv_default(data, indices, indptr, weight, bias=None):
         indptr_ptr = irb.buffer_ptr(indptr)
         weight_ptr = irb.buffer_ptr(weight)
         out_ptr = irb.buffer_ptr(out)
-        num_rows = indptr.shape[0]-1
-        with irb.for_range(0, num_rows, for_type="parallel", name='row') as row:
-            dot = irb.allocate('float32', (1,), name='dot', scope='local')
-            out_ptr[row] = 0.
-            dot[0] = 0.
+        num_rows = indptr.shape[0] - 1
+        with irb.for_range(0, num_rows, for_type="parallel", name="row") as row:
+            dot = irb.allocate("float32", (1,), name="dot", scope="local")
+            out_ptr[row] = 0.0
+            dot[0] = 0.0
             row_start = indptr_ptr[row]
-            row_end = indptr_ptr[row+1]
-            row_elems = row_end-row_start
-            with irb.for_range(0, row_elems, name='elemidx') as elemidx:
-                elem = row_start+elemidx
+            row_end = indptr_ptr[row + 1]
+            row_elems = row_end - row_start
+            with irb.for_range(0, row_elems, name="elemidx") as elemidx:
+                elem = row_start + elemidx
                 dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
             out_ptr[row] += dot[0]
         return irb.get()
+
     oshape = (batch, 1)
-    matmul = te.extern(oshape, [data, indices, indptr, weight],
-                       lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
-                       tag="csrmv", dtype='float32', name='csrmv')
+    matmul = te.extern(
+        oshape,
+        [data, indices, indptr, weight],
+        lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+        tag="csrmv",
+        dtype="float32",
+        name="csrmv",
+    )
     if bias is not None:
-        matmul = te.compute((batch, 1), lambda i, j: matmul[i, 0] + bias[i], \
-                            tag=tag.BROADCAST)
+        matmul = te.compute((batch, 1), lambda i, j: matmul[i, 0] + bias[i], tag=tag.BROADCAST)
     return matmul
 
 
index 9f01405..d86f5dd 100644 (file)
@@ -21,6 +21,7 @@ from tvm import te
 from .. import tag
 from ..util import simplify
 
+
 def dense_si(data, indices, indptr, weight, bias=None):
     # pylint: disable=invalid-name
     """The implementation of dense in topi, assuming sparse input.
@@ -47,15 +48,21 @@ def dense_si(data, indices, indptr, weight, bias=None):
     output : tvm.te.Tensor
         2-D with shape [m, n]
     """
-    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
-        and len(weight.shape) == 2, "only support 2-dim dense"
-    assert isinstance(weight, te.tensor.Tensor), \
-        "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
+    assert (
+        len(data.shape) == 1
+        and len(indices.shape) == 1
+        and len(indptr.shape) == 1
+        and len(weight.shape) == 2
+    ), "only support 2-dim dense"
+    assert isinstance(
+        weight, te.tensor.Tensor
+    ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight))
     if bias is not None:
         assert len(bias.shape) == 1
     dtype = data.dtype
-    M = simplify(indptr.shape[0]-1)
+    M = simplify(indptr.shape[0] - 1)
     N, _ = weight.shape
+
     def dense_default_ir(data, indices, indptr, weight, out):
         """Define IR for Dense"""
         dtype = data.dtype
@@ -65,27 +72,32 @@ def dense_si(data, indices, indptr, weight, bias=None):
         indptr_ptr = irb.buffer_ptr(indptr)
         weight_ptr = irb.buffer_ptr(weight)
         out_ptr = irb.buffer_ptr(out)
-        M = simplify(indptr.shape[0]-1)
+        M = simplify(indptr.shape[0] - 1)
         N, K = weight.shape
-        with irb.for_range(0, N, for_type="vectorize", name='n') as n:
-            with irb.for_range(0, M, for_type="parallel", name='m') as m:
-                dot = irb.allocate(dtype, (1,), name='dot', scope='local')
-                out_ptr[m*N+n] = tvm.tir.const(0, dtype)
+        with irb.for_range(0, N, for_type="vectorize", name="n") as n:
+            with irb.for_range(0, M, for_type="parallel", name="m") as m:
+                dot = irb.allocate(dtype, (1,), name="dot", scope="local")
+                out_ptr[m * N + n] = tvm.tir.const(0, dtype)
                 dot[0] = tvm.tir.const(0, dtype)
                 row_start = indptr_ptr[m]
-                row_elems = indptr_ptr[m+1]-row_start
-                with irb.for_range(0, row_elems, name='k') as k:
-                    elem = row_start+k
-                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]+n*K]
-                out_ptr[m*N+n] += dot[0]
+                row_elems = indptr_ptr[m + 1] - row_start
+                with irb.for_range(0, row_elems, name="k") as k:
+                    elem = row_start + k
+                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem] + n * K]
+                out_ptr[m * N + n] += dot[0]
         return irb.get()
+
     oshape = (M, N)
-    matmul = te.extern(oshape, [data, indices, indptr, weight],
-                       lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
-                       tag="dense", dtype=dtype, name='out')
+    matmul = te.extern(
+        oshape,
+        [data, indices, indptr, weight],
+        lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+        tag="dense",
+        dtype=dtype,
+        name="out",
+    )
     if bias is not None:
-        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
-                            tag=tag.BROADCAST)
+        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST)
     return matmul
 
 
@@ -115,15 +127,21 @@ def dense_sw(data, w_data, w_indices, w_indptr, bias=None):
     output : tvm.te.Tensor
         2-D with shape [m, n]
     """
-    assert len(w_data.shape) == 1 and len(w_indices.shape) == 1 and len(w_indptr.shape) == 1 \
-        and len(data.shape) == 2, "only support 2-dim dense"
-    assert isinstance(data, te.tensor.Tensor), \
-        "data matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(data))
+    assert (
+        len(w_data.shape) == 1
+        and len(w_indices.shape) == 1
+        and len(w_indptr.shape) == 1
+        and len(data.shape) == 2
+    ), "only support 2-dim dense"
+    assert isinstance(
+        data, te.tensor.Tensor
+    ), "data matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(data))
     if bias is not None:
         assert len(bias.shape) == 1
     dtype = data.dtype
     M, _ = data.shape
-    N = simplify(w_indptr.shape[0]-1)
+    N = simplify(w_indptr.shape[0] - 1)
+
     def dense_default_ir(data, w_data, w_indices, w_indptr, out):
         """Define IR for Dense"""
         dtype = data.dtype
@@ -134,26 +152,31 @@ def dense_sw(data, w_data, w_indices, w_indptr, bias=None):
         w_indptr_ptr = irb.buffer_ptr(w_indptr)
         out_ptr = irb.buffer_ptr(out)
         M, K = data.shape
-        N = simplify(w_indptr.shape[0]-1)
-        with irb.for_range(0, M, for_type="vectorize", name='m') as m:
-            with irb.for_range(0, N, for_type="parallel", name='n') as n:
-                dot = irb.allocate(dtype, (1,), name='dot', scope='local')
-                out_ptr[m*N+n] = tvm.tir.const(0, dtype)
+        N = simplify(w_indptr.shape[0] - 1)
+        with irb.for_range(0, M, for_type="vectorize", name="m") as m:
+            with irb.for_range(0, N, for_type="parallel", name="n") as n:
+                dot = irb.allocate(dtype, (1,), name="dot", scope="local")
+                out_ptr[m * N + n] = tvm.tir.const(0, dtype)
                 dot[0] = tvm.tir.const(0, dtype)
                 row_start = w_indptr_ptr[n]
-                row_elems = w_indptr_ptr[n+1]-row_start
-                with irb.for_range(0, row_elems, name='k') as k:
-                    elem = row_start+k
-                    dot[0] += w_data_ptr[elem] * data_ptr[w_indices_ptr[elem]+m*K]
-                out_ptr[m*N+n] += dot[0]
+                row_elems = w_indptr_ptr[n + 1] - row_start
+                with irb.for_range(0, row_elems, name="k") as k:
+                    elem = row_start + k
+                    dot[0] += w_data_ptr[elem] * data_ptr[w_indices_ptr[elem] + m * K]
+                out_ptr[m * N + n] += dot[0]
         return irb.get()
+
     oshape = (M, N)
-    matmul = te.extern(oshape, [data, w_data, w_indices, w_indptr],
-                       lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
-                       tag="dense", dtype=dtype, name='out')
+    matmul = te.extern(
+        oshape,
+        [data, w_data, w_indices, w_indptr],
+        lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+        tag="dense",
+        dtype=dtype,
+        name="out",
+    )
     if bias is not None:
-        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
-                            tag=tag.BROADCAST)
+        matmul = te.compute(oshape, lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST)
     return matmul
 
 
@@ -178,13 +201,21 @@ def dense(data, weight, bias=None):
         2-D with shape [batch, out_dim]
     """
     ret = None
-    if isinstance(data, tvm.contrib.sparse.CSRPlaceholderOp) and \
-       isinstance(weight, te.tensor.Tensor):
+    if isinstance(data, tvm.contrib.sparse.CSRPlaceholderOp) and isinstance(
+        weight, te.tensor.Tensor
+    ):
         ret = dense_si(data.data, data.indices, data.indptr, weight, bias)
-    elif isinstance(data, te.tensor.Tensor) and \
-            isinstance(weight, tvm.contrib.sparse.CSRPlaceholderOp):
+    elif isinstance(data, te.tensor.Tensor) and isinstance(
+        weight, tvm.contrib.sparse.CSRPlaceholderOp
+    ):
         ret = dense_sw(data, weight.data, weight.indices, weight.indptr, bias)
     else:
-        raise NotImplementedError("implementation for %s as data and %s as weights, "
-                                  "is not supported yet." % (type(data), type(weight), ))
+        raise NotImplementedError(
+            "implementation for %s as data and %s as weights, "
+            "is not supported yet."
+            % (
+                type(data),
+                type(weight),
+            )
+        )
     return ret
index 57cd411..0e1330d 100644 (file)
@@ -83,6 +83,4 @@ def is_injective(tag):
     """
     if tag in (ELEMWISE, BROADCAST, INJECTIVE):
         return True
-    return (tag.startswith(ELEMWISE) or
-            tag.startswith(BROADCAST) or
-            tag.startswith(INJECTIVE))
+    return tag.startswith(ELEMWISE) or tag.startswith(BROADCAST) or tag.startswith(INJECTIVE)
index 0071242..31ebe86 100644 (file)
@@ -19,6 +19,7 @@
 from __future__ import absolute_import as _abs
 from . import cpp
 
+
 def elemwise_sum(xs):
     """Perform element-wise sum on inputs
 
index ce0554f..5b23e8f 100644 (file)
@@ -56,8 +56,14 @@ from .one_hot import one_hot
 from .depth_to_space import depth_to_space_python
 from .space_to_depth import space_to_depth_python
 from .crop_and_resize_python import crop_and_resize_python
-from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \
-    get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
+from .common import (
+    get_injective_schedule,
+    get_reduce_schedule,
+    get_broadcast_schedule,
+    get_elemwise_schedule,
+    get_conv2d_nchw_implement,
+    dispatch,
+)
 from .adaptive_pool_python import adaptive_pool
 from .grid_sample_python import affine_grid_python, grid_sample_nchw_python
 from .matrix_set_diag import matrix_set_diag
index 3f464ce..79f42c8 100644 (file)
@@ -85,11 +85,11 @@ def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op):
     for i in range(n):
         for j in range(c):
             if len(out_size) == 2:
-                np_out[i, :, :, j] = pool_op(ishape[1:-1], out_size,
-                                             np_data[i, :, :, j], np_op)
+                np_out[i, :, :, j] = pool_op(ishape[1:-1], out_size, np_data[i, :, :, j], np_op)
             else:
-                np_out[i, :, :, :, j] = pool_op(ishape[1:-1], out_size,
-                                                np_data[i, :, :, :, j], np_op)
+                np_out[i, :, :, :, j] = pool_op(
+                    ishape[1:-1], out_size, np_data[i, :, :, :, j], np_op
+                )
 
     return np_out
 
index c864f8d..0a991f6 100644 (file)
@@ -18,6 +18,7 @@
 """Batch matmul in python"""
 import numpy as np
 
+
 def batch_matmul(x, y):
     """batch_matmul operator implemented in numpy.
 
index c43fd2c..8d78f13 100644 (file)
@@ -20,12 +20,13 @@ import math
 import numpy as np
 from tvm.topi.util import nchw_pack_layout
 
+
 def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"):
     """ Bilinear scaling using python"""
     (new_h, new_w) = out_size
     (ib, ic) = (1, 1)
 
-    if layout == 'NHWC':
+    if layout == "NHWC":
         (batch, h, w, channel) = image.shape
         scaled_image = np.ones((batch, new_h, new_w, channel))
     # NCHWinic
@@ -37,8 +38,8 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
         scaled_image = np.ones((batch, channel, new_h, new_w))
 
     if coordinate_transformation_mode == "align_corners":
-        height_scale = np.float32(h-1) / np.float32(out_size[0]-1)
-        width_scale = np.float32(w-1) / np.float32(out_size[1]-1)
+        height_scale = np.float32(h - 1) / np.float32(out_size[0] - 1)
+        width_scale = np.float32(w - 1) / np.float32(out_size[1] - 1)
     else:
         height_scale = np.float32(h) / np.float32(out_size[0])
         width_scale = np.float32(w) / np.float32(out_size[1])
@@ -67,7 +68,7 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
                 x0 = max(x0, 0)
                 x_lerp = in_x - math.floor(in_x)
 
-                if layout == 'NHWC':
+                if layout == "NHWC":
                     A = image[b][y0][x0][i]
                     B = image[b][y0][x1][i]
                     C = image[b][y1][x0][i]
@@ -88,7 +89,7 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
 
                 pixel = np.float32(_lerp(top, bottom, y_lerp))
 
-                if layout == 'NHWC':
+                if layout == "NHWC":
                     scaled_image[b][j][k][i] = pixel
                 elif nchw_pack_layout(layout):
                     scaled_image[b][i][j][k][m][n] = pixel
index 3a7c605..51ea19a 100644 (file)
@@ -32,9 +32,10 @@ _reduce_schedule = {
     "generic": topi.generic.schedule_reduce,
     "cpu": topi.x86.schedule_reduce,
     "gpu": topi.cuda.schedule_reduce,
-    "hls": topi.cuda.schedule_reduce
+    "hls": topi.cuda.schedule_reduce,
 }
 
+
 def dispatch(target, dispatch_map):
     if isinstance(target, str):
         target = tvm.target.Target(target)
@@ -44,29 +45,35 @@ def dispatch(target, dispatch_map):
             return dispatch_map[key]
     return dispatch_map["generic"]
 
+
 def get_injective_schedule(target):
     return dispatch(target, _injective_schedule)
 
+
 def get_reduce_schedule(target):
     return dispatch(target, _reduce_schedule)
 
+
 get_broadcast_schedule = get_injective_schedule
 get_elemwise_schedule = get_injective_schedule
 
 _conv2d_nchw_implement = {
     "generic": (topi.nn.conv2d_nchw, topi.generic.schedule_conv2d_nchw),
     "cpu": (topi.x86.conv2d_nchw, topi.x86.schedule_conv2d_nchw),
-    "arm_cpu": (topi.arm_cpu.conv2d_nchw_spatial_pack,
-                topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
+    "arm_cpu": (
+        topi.arm_cpu.conv2d_nchw_spatial_pack,
+        topi.arm_cpu.schedule_conv2d_nchw_spatial_pack,
+    ),
     "gpu": (topi.cuda.conv2d_nchw, topi.cuda.schedule_conv2d_nchw),
-    "mali": (topi.mali.conv2d_nchw_spatial_pack,
-             topi.mali.schedule_conv2d_nchw_spatial_pack),
-    "bifrost": (topi.bifrost.conv2d_nchw_spatial_pack,
-                topi.bifrost.schedule_conv2d_nchw_spatial_pack),
-    "intel_graphics": (topi.intel_graphics.conv2d_nchw,
-                       topi.intel_graphics.schedule_conv2d_nchw),
-    "hls": (topi.nn.conv2d_nchw, topi.hls.schedule_conv2d_nchw)
+    "mali": (topi.mali.conv2d_nchw_spatial_pack, topi.mali.schedule_conv2d_nchw_spatial_pack),
+    "bifrost": (
+        topi.bifrost.conv2d_nchw_spatial_pack,
+        topi.bifrost.schedule_conv2d_nchw_spatial_pack,
+    ),
+    "intel_graphics": (topi.intel_graphics.conv2d_nchw, topi.intel_graphics.schedule_conv2d_nchw),
+    "hls": (topi.nn.conv2d_nchw, topi.hls.schedule_conv2d_nchw),
 }
 
+
 def get_conv2d_nchw_implement(target):
     return dispatch(target, _conv2d_nchw_implement)
index 84a463f..1405adb 100644 (file)
@@ -21,7 +21,7 @@ from tvm.topi.nn.util import get_pad_tuple1d
 
 
 def dilate_np(x, dilation):
-    """ 1D dilation using numpy
+    """1D dilation using numpy
 
     Parameters
     ----------
@@ -38,7 +38,7 @@ def dilate_np(x, dilation):
     """
     irange = range(len(x) - 1)
     for d in range(dilation - 1):
-        indices = [(d + 1)*(i + 1) for i in irange]
+        indices = [(d + 1) * (i + 1) for i in irange]
         x = np.insert(x, indices, 0)
     return x
 
@@ -81,13 +81,14 @@ def conv1d_ncw_python(a_np, w_np, stride, padding, dilation):
     out_w = ((in_w - dilated_filter_w + pad_left + pad_right) // stride) + 1
 
     padded_a_np = np.zeros((batch, in_c, in_w + pad_left + pad_right))
-    padded_a_np[:, :, pad_left:(in_w + pad_left)] = a_np
+    padded_a_np[:, :, pad_left : (in_w + pad_left)] = a_np
 
     b_np = np.zeros((batch, out_c, out_w))
     for n in range(batch):
         for f in range(out_c):
             for c in range(in_c):
                 out = np.convolve(
-                    padded_a_np[n, c], np.flip(dilate_np(w_np[f, c], dilation)), mode='valid')
+                    padded_a_np[n, c], np.flip(dilate_np(w_np[f, c], dilation)), mode="valid"
+                )
                 b_np[n, f] += out[::stride]
     return b_np
index 0a5d22c..3a1bc61 100644 (file)
@@ -21,6 +21,7 @@ import scipy
 import tvm.topi.testing
 from tvm.topi.nn.util import get_pad_tuple1d
 
+
 def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding):
     """Transposed 1D convolution operator in NCW layout.
 
@@ -64,15 +65,14 @@ def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding):
     # padding stage
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = filter_w - 1 - fpad_right + opad
-    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_left+bpad_right))
-    padded_a_np[:, :, bpad_left:dilated_a_np.shape[2]+bpad_left] = dilated_a_np
+    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2] + bpad_left + bpad_right))
+    padded_a_np[:, :, bpad_left : dilated_a_np.shape[2] + bpad_left] = dilated_a_np
     # convolution stage
     out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad
     b_np = np.zeros((batch, out_c, out_w))
     for n in range(batch):
         for f in range(out_c):
             for c in range(in_c):
-                out = scipy.signal.convolve(
-                    padded_a_np[n, c], w_np[c, f], mode='valid')
+                out = scipy.signal.convolve(padded_a_np[n, c], w_np[c, f], mode="valid")
                 b_np[n, f] += out
     return b_np
index fd5d9a7..9a06edd 100644 (file)
@@ -69,10 +69,9 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_height + pad_h, in_width + pad_w))
-                    apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c]
+                    apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = at[n, c]
                 else:
                     apad = at[n, c]
-                out = scipy.signal.convolve2d(
-                    apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
+                out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(wt[f, c])), mode="valid")
                 bt[n, f] += out[::stride, ::stride]
     return bt.transpose((2, 3, 1, 0))
index cb855a4..38bed4a 100644 (file)
@@ -65,11 +65,10 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_height + pad_h, in_width + pad_w))
-                    apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = a_np[n, c]
+                    apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = a_np[n, c]
                 else:
                     apad = a_np[n, c]
-                out = scipy.signal.convolve2d(
-                    apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
+                out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np[f, c])), mode="valid")
                 b_np[n, f] += out[::stride_h, ::stride_w]
     return b_np
 
@@ -103,7 +102,9 @@ def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
     """
     a_slices = np.array_split(a_np, groups, axis=1)
     w_slices = np.array_split(w_np, groups, axis=0)
-    b_slices = [_conv2d_nchw_python(a_slice, w_slice, stride, padding)
-                for a_slice, w_slice in zip(a_slices, w_slices)]
+    b_slices = [
+        _conv2d_nchw_python(a_slice, w_slice, stride, padding)
+        for a_slice, w_slice in zip(a_slices, w_slices)
+    ]
     b_np = np.concatenate(b_slices, axis=1)
     return b_np
index 17d072a..136fb6b 100644 (file)
@@ -68,14 +68,14 @@ def _conv2d_nhwc_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_height + pad_h, in_width + pad_w))
-                    apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c]
+                    apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = at[n, c]
                 else:
                     apad = at[n, c]
-                out = scipy.signal.convolve2d(
-                    apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
+                out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(wt[f, c])), mode="valid")
                 bt[n, f] += out[::stride_h, ::stride_w]
     return bt.transpose((0, 2, 3, 1))
 
+
 def conv2d_nhwc_python(a_np, w_np, stride, padding, groups=1):
     """Convolution operator in NHWC layout.
 
@@ -106,7 +106,9 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding, groups=1):
 
     a_slices = np.array_split(a_np, groups, axis=3)
     w_slices = np.array_split(w_np, groups, axis=3)
-    b_slices = [_conv2d_nhwc_python(a_slice, w_slice, stride, padding)
-                for a_slice, w_slice in zip(a_slices, w_slices)]
+    b_slices = [
+        _conv2d_nhwc_python(a_slice, w_slice, stride, padding)
+        for a_slice, w_slice in zip(a_slices, w_slices)
+    ]
     b_np = np.concatenate(b_slices, axis=3)
     return b_np
index 47f9cf1..04e60a7 100644 (file)
@@ -66,10 +66,20 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
     bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = filter_w - 1 - fpad_right + opad_w
-    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
-                            dilated_a_np.shape[3]+bpad_left+bpad_right))
-    padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
-                bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
+    padded_a_np = np.zeros(
+        (
+            batch,
+            in_c,
+            dilated_a_np.shape[2] + bpad_top + bpad_bottom,
+            dilated_a_np.shape[3] + bpad_left + bpad_right,
+        )
+    )
+    padded_a_np[
+        :,
+        :,
+        bpad_top : dilated_a_np.shape[2] + bpad_top,
+        bpad_left : dilated_a_np.shape[3] + bpad_left,
+    ] = dilated_a_np
     # convolution stage
     out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
     out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
@@ -77,14 +87,14 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
     for n in range(batch):
         for f in range(out_c):
             for c in range(in_c):
-                out = scipy.signal.convolve2d(
-                    padded_a_np[n, c], w_np[c, f], mode='valid')
+                out = scipy.signal.convolve2d(padded_a_np[n, c], w_np[c, f], mode="valid")
                 b_np[n, f] += out
     return b_np
 
 
-def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding,
-                                 output_padding=(0, 0)):
+def conv2d_transpose_nhwc_python(
+    a_nhwc, weight, weight_format, stride, padding, output_padding=(0, 0)
+):
     """Transposed convolution operator in NHWC layout.
 
     Parameters
@@ -115,18 +125,19 @@ def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding,
     a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2))
 
     # conv2d_transpose_nchw_python needs kernel layout to be IOHW
-    if weight_format == 'HWIO':
+    if weight_format == "HWIO":
         w_iohw = np.transpose(weight, (2, 3, 0, 1))
-    elif weight_format == 'HWOI':
+    elif weight_format == "HWOI":
         w_iohw = np.transpose(weight, (3, 2, 0, 1))
-    elif weight_format == 'OIHW':
+    elif weight_format == "OIHW":
         w_iohw = np.transpose(weight, (1, 0, 2, 3))
-    elif weight_format == 'IOHW':
+    elif weight_format == "IOHW":
         w_iohw = weight
     else:
-        raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')
+        raise ValueError("Valid weight_formats are HWIO, HWOI, OIHW or IOHW")
 
-    res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding,
-                                            output_padding=output_padding)
+    res_nchw = conv2d_transpose_nchw_python(
+        a_nchw, w_iohw, stride, padding, output_padding=output_padding
+    )
     res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
     return res_nhwc
index 85b124a..11b0e23 100644 (file)
@@ -29,8 +29,9 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
     else:
         stride_d, stride_h, stride_w = stride
 
-    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
-        get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
+        padding, (kernel_d, kernel_h, kernel_w)
+    )
     pad_d = pad_front + pad_back
     pad_h = pad_top + pad_bottom
     pad_w = pad_left + pad_right
@@ -47,12 +48,14 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_d > 0 or pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
-                    apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
-                         pad_left:pad_left + in_width] = a_np[n, c]
+                    apad[
+                        pad_front : pad_front + in_depth,
+                        pad_top : pad_top + in_height,
+                        pad_left : pad_left + in_width,
+                    ] = a_np[n, c]
                 else:
                     apad = a_np[n, c]
-                out = scipy.signal.convolve(
-                    apad, np.flip(w_np[f, c]), mode='valid')
+                out = scipy.signal.convolve(apad, np.flip(w_np[f, c]), mode="valid")
                 b_np[n, f] += out[::stride_d, ::stride_h, ::stride_w]
     return b_np
 
@@ -84,7 +87,9 @@ def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1):
     """
     a_slices = np.array_split(a_np, groups, axis=1)
     w_slices = np.array_split(w_np, groups, axis=0)
-    b_slices = [_conv3d_ncdhw_python(a_slice, w_slice, stride, padding)
-                for a_slice, w_slice in zip(a_slices, w_slices)]
+    b_slices = [
+        _conv3d_ncdhw_python(a_slice, w_slice, stride, padding)
+        for a_slice, w_slice in zip(a_slices, w_slices)
+    ]
     b_np = np.concatenate(b_slices, axis=1)
     return b_np
index b9330ec..52974d4 100644 (file)
@@ -52,8 +52,9 @@ def conv3d_ndhwc_python(a_np, w_np, stride, padding):
     else:
         stride_d, stride_h, stride_w = stride
 
-    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
-        get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
+        padding, (kernel_d, kernel_h, kernel_w)
+    )
     pad_d = pad_front + pad_back
     pad_h = pad_top + pad_bottom
     pad_w = pad_left + pad_right
@@ -72,11 +73,13 @@ def conv3d_ndhwc_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_d > 0 or pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
-                    apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
-                         pad_left:pad_left + in_width] = at[n, c]
+                    apad[
+                        pad_front : pad_front + in_depth,
+                        pad_top : pad_top + in_height,
+                        pad_left : pad_left + in_width,
+                    ] = at[n, c]
                 else:
                     apad = at[n, c]
-                out = scipy.signal.convolve(
-                    apad, np.flip(wt[f, c]), mode='valid')
+                out = scipy.signal.convolve(apad, np.flip(wt[f, c]), mode="valid")
                 bt[n, f] += out[::stride_d, ::stride_h, ::stride_w]
     return bt.transpose((0, 2, 3, 4, 1))
index 711f04b..779371a 100644 (file)
@@ -63,7 +63,8 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
 
     # padding stage
     fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d(
-        padding, (filter_d, filter_h, filter_w))
+        padding, (filter_d, filter_h, filter_w)
+    )
 
     bpad_front = filter_d - 1 - fpad_front
     bpad_back = filter_d - 1 - fpad_back + opad_d
@@ -72,16 +73,23 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
     bpad_left = filter_w - 1 - fpad_left
     bpad_right = filter_w - 1 - fpad_right + opad_w
 
-    padded_a_np = np.zeros((batch,
-                            in_c,
-                            dilated_a_np.shape[2]+bpad_front+bpad_back,
-                            dilated_a_np.shape[3]+bpad_top+bpad_bottom,
-                            dilated_a_np.shape[4]+bpad_left+bpad_right))
-
-    padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_front,
-                bpad_top:dilated_a_np.shape[3]+bpad_top,
-                bpad_left:dilated_a_np.shape[4]+bpad_left] = dilated_a_np
-
+    padded_a_np = np.zeros(
+        (
+            batch,
+            in_c,
+            dilated_a_np.shape[2] + bpad_front + bpad_back,
+            dilated_a_np.shape[3] + bpad_top + bpad_bottom,
+            dilated_a_np.shape[4] + bpad_left + bpad_right,
+        )
+    )
+
+    padded_a_np[
+        :,
+        :,
+        bpad_front : dilated_a_np.shape[2] + bpad_front,
+        bpad_top : dilated_a_np.shape[3] + bpad_top,
+        bpad_left : dilated_a_np.shape[4] + bpad_left,
+    ] = dilated_a_np
 
     # convolution stage
     out_d = (in_d - 1) * stride_d - bpad_front - bpad_back + filter_d
@@ -89,6 +97,8 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
     out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
 
     w_np = np.flip(w_np, axis=[2, 3, 4]).transpose((1, 0, 2, 3, 4))
-    b_np = tvm.topi.testing.conv3d_ncdhw_python(padded_a_np, w_np, stride=(1, 1, 1), padding=(0, 0, 0))
+    b_np = tvm.topi.testing.conv3d_ncdhw_python(
+        padded_a_np, w_np, stride=(1, 1, 1), padding=(0, 0, 0)
+    )
 
     return b_np
index f053656..ac12e81 100644 (file)
@@ -19,7 +19,9 @@
 import numpy as np
 
 
-def correlation_nchw_python(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply):
+def correlation_nchw_python(
+    data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply
+):
     """Correlationn operator in NCHW layout.
 
     Parameters
@@ -65,15 +67,15 @@ def correlation_nchw_python(data1, data2, kernel_size, max_displacement, stride1
     out_channel = neighborhood_grid_width * neighborhood_grid_width
 
     out = np.zeros((data1.shape[0], out_channel, out_height, out_width))
-    pad_data1 = np.zeros((data1.shape[0], data1.shape[1],
-                          pad_data_height, pad_data_width))
-    pad_data2 = np.zeros((data1.shape[0], data1.shape[1],
-                          pad_data_height, pad_data_width))
+    pad_data1 = np.zeros((data1.shape[0], data1.shape[1], pad_data_height, pad_data_width))
+    pad_data2 = np.zeros((data1.shape[0], data1.shape[1], pad_data_height, pad_data_width))
 
-    pad_data1[:, :, padding:padding + data1.shape[2],
-              padding:padding + data1.shape[3]] = data1[:, :, :, :]
-    pad_data2[:, :, padding:padding + data2.shape[2],
-              padding:padding + data2.shape[3]] = data2[:, :, :, :]
+    pad_data1[:, :, padding : padding + data1.shape[2], padding : padding + data1.shape[3]] = data1[
+        :, :, :, :
+    ]
+    pad_data2[:, :, padding : padding + data2.shape[2], padding : padding + data2.shape[3]] = data2[
+        :, :, :, :
+    ]
 
     if is_multiply:
         corr_func = lambda x, y: x * y
@@ -96,8 +98,10 @@ def correlation_nchw_python(data1, data2, kernel_size, max_displacement, stride1
                     for h in range(kernel_size):
                         for w in range(kernel_size):
                             for channel in range(data1.shape[1]):
-                                out[nbatch, q, i, j] += corr_func(pad_data1[nbatch, channel, y1 + h, x1 + w],
-                                                                  pad_data2[nbatch, channel, y2 + h, x2 + w])
+                                out[nbatch, q, i, j] += corr_func(
+                                    pad_data1[nbatch, channel, y1 + h, x1 + w],
+                                    pad_data2[nbatch, channel, y2 + h, x2 + w],
+                                )
 
-    out /= float(kernel_size** 2 *data1.shape[1])
+    out /= float(kernel_size ** 2 * data1.shape[1])
     return out
index a5f2cc0..1796f99 100644 (file)
 import math
 import numpy as np
 
-def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
-                           method='bilinear', extrapolation_value=0):
+
+def crop_and_resize_python(
+    image, boxes, box_indices, crop_size, layout, method="bilinear", extrapolation_value=0
+):
     """Crop and resize using python"""
     (target_h, target_w) = crop_size
 
-    if layout == 'NHWC':
+    if layout == "NHWC":
         batch = boxes.shape[0]
         image_height, image_width, channel = image.shape[1], image.shape[2], image.shape[3]
         scaled_image = np.ones((batch, target_h, target_w, channel))
@@ -40,8 +42,8 @@ def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
 
         in_h = (image_height - 1) * (y2 - y1)
         in_w = (image_width - 1) * (x2 - x1)
-        h_scale = np.float32(in_h)/np.float32(target_h - 1)
-        w_scale = np.float32(in_w)/np.float32(target_w - 1)
+        h_scale = np.float32(in_h) / np.float32(target_h - 1)
+        w_scale = np.float32(in_w) / np.float32(target_w - 1)
 
         for y in range(target_h):
 
@@ -50,13 +52,13 @@ def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
             if in_y < 0 or in_y > image_height - 1:
                 for x in range(target_w):
                     for d in range(channel):
-                        if layout == 'NHWC':
+                        if layout == "NHWC":
                             scaled_image[n][y][x][d] = extrapolation_value
                         else:
                             scaled_image[n][d][y][x] = extrapolation_value
                 continue
 
-            if method == 'bilinear':
+            if method == "bilinear":
                 top_y_index = math.floor(in_y)
                 bottom_y_index = math.ceil(in_y)
                 y_lerp = in_y - top_y_index
@@ -65,7 +67,7 @@ def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
                     in_x = x1 * (image_width - 1) + x * w_scale
                     if in_x < 0 or in_x > image_width - 1:
                         for d in range(channel):
-                            if layout == 'NHWC':
+                            if layout == "NHWC":
                                 scaled_image[n][y][x][d] = extrapolation_value
                             else:
                                 scaled_image[n][d][y][x] = extrapolation_value
@@ -93,12 +95,12 @@ def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
                             bottom = bottom_left + (bottom_right - bottom_left) * x_lerp
                             scaled_image[n][d][y][x] = top + (bottom - top) * y_lerp
 
-            elif method == 'nearest_neighbor':
+            elif method == "nearest_neighbor":
                 for x in range(target_w):
                     in_x = x1 * (image_width - 1) + x * w_scale
                     if in_x < 0 or in_x > image_width - 1:
                         for d in range(channel):
-                            if layout == 'NHWC':
+                            if layout == "NHWC":
                                 scaled_image[n][y][x][d] = extrapolation_value
                             else:
                                 scaled_image[n][d][y][x] = extrapolation_value
@@ -107,8 +109,12 @@ def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
                     closest_y_index = np.round(in_y).astype("int32")
                     for d in range(channel):
                         if layout == "NHWC":
-                            scaled_image[n][y][x][d] = image[b_in][closest_y_index][closest_x_index][d]
+                            scaled_image[n][y][x][d] = image[b_in][closest_y_index][
+                                closest_x_index
+                            ][d]
                         else:
-                            scaled_image[n][d][y][x] = image[b_in][d][closest_y_index][closest_x_index]
+                            scaled_image[n][d][y][x] = image[b_in][d][closest_y_index][
+                                closest_x_index
+                            ]
 
     return scaled_image
index fe48ea5..cc66c5f 100644 (file)
@@ -20,8 +20,10 @@ import itertools
 import numpy as np
 from tvm.topi.nn.util import get_pad_tuple
 
-def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation,
-                                  deformable_groups, groups):
+
+def deformable_conv2d_nchw_python(
+    a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+):
     """Deformable convolution operator in NCHW layout.
 
     Parameters
@@ -77,7 +79,6 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
     else:
         dilation_h, dilation_w = dilation
 
-
     def _bilinear(n, c, h, w):
         low_h, low_w = int(h), int(w)
         high_h = min(low_h + 1, in_height - 1)
@@ -89,7 +90,6 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
         top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
         return (1 - y_lerp) * bottom + y_lerp * top
 
-
     a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
     for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):
         offset = offset_np[n, :, h, w].reshape(deformable_groups, kernel_h, kernel_w, 2)
@@ -99,7 +99,8 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
         index_h_base, index_w_base = np.meshgrid(
             np.arange(in_h, in_h + kernel_h * dilation_h, dilation_h, dtype=offset_np.dtype),
             np.arange(in_w, in_w + kernel_w * dilation_w, dilation_w, dtype=offset_np.dtype),
-            indexing='ij')
+            indexing="ij",
+        )
 
         for c, kh, kw in itertools.product(range(in_channel), range(kernel_h), range(kernel_w)):
             dg = c // ic_per_dgroup
@@ -112,8 +113,9 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
             a_deform[n, c, h, w, kh, kw] = _bilinear(n, c, y, x)
 
     b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=dtype)
-    for n, c, f, h, w in itertools.product(range(batch), range(in_channel), range(out_channel),
-                                           range(out_height), range(out_width)):
+    for n, c, f, h, w in itertools.product(
+        range(batch), range(in_channel), range(out_channel), range(out_height), range(out_width)
+    ):
         b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])
 
     return b_np
index f4a60bc..b5170d5 100644 (file)
@@ -19,7 +19,7 @@
 import numpy as np
 
 
-def depth_to_space_python(data, block_size, mode='DCR'):
+def depth_to_space_python(data, block_size, mode="DCR"):
     """Depth to Space operator in python for NCHW layout.
 
     Parameters
@@ -41,13 +41,11 @@ def depth_to_space_python(data, block_size, mode='DCR'):
     new_w = int(in_h * block_size)
     new_c = int(in_c / (block_size * block_size))
 
-    if mode == 'DCR':
-        expanded = np.reshape(
-            data, newshape=[in_n, block_size, block_size, new_c, in_h, in_w])
+    if mode == "DCR":
+        expanded = np.reshape(data, newshape=[in_n, block_size, block_size, new_c, in_h, in_w])
         transposed = np.transpose(expanded, axes=[0, 3, 4, 1, 5, 2])
     else:
-        expanded = np.reshape(
-            data, newshape=(in_n, new_c, block_size, block_size, in_h, in_w))
+        expanded = np.reshape(data, newshape=(in_n, new_c, block_size, block_size, in_h, in_w))
         transposed = np.transpose(expanded, axes=(0, 1, 4, 2, 5, 3))
     newshape = [in_n, new_c, new_h, new_w]
     d2s_out = np.reshape(transposed, newshape=newshape)
index a541bea..06f26ab 100644 (file)
@@ -19,6 +19,7 @@
 import numpy as np
 from scipy import signal
 
+
 def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
     """Depthwise convolution operator in NCHW layout.
 
@@ -49,22 +50,29 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
         stride_h, stride_w = stride
 
     # calculate output shape
-    if padding == 'VALID':
+    if padding == "VALID":
         out_channel = in_channel * channel_multiplier
         out_height = (in_height - filter_height) // stride_h + 1
         out_width = (in_width - filter_width) // stride_w + 1
         output_np = np.zeros((batch, out_channel, out_height, out_width))
         for i in range(batch):
             for j in range(out_channel):
-                output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
-                                                          np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
-                                                          mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_width + 1):stride_w]
-    if padding == 'SAME':
+                output_np[i, j, :, :] = signal.convolve2d(
+                    input_np[i, j // channel_multiplier, :, :],
+                    np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
+                    mode="valid",
+                )[
+                    0 : (in_height - filter_height + 1) : stride_h,
+                    0 : (in_width - filter_width + 1) : stride_w,
+                ]
+    if padding == "SAME":
         out_channel = in_channel * channel_multiplier
         out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
         out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
         output_np = np.zeros((batch, out_channel, out_height, out_width))
-        pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
+        pad_along_height = np.int(
+            np.max((out_height - 1) * stride_h + filter_height - in_height, 0)
+        )
         pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
         pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
         pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
@@ -74,12 +82,15 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
         index_w = pad_left_scipy - pad_left_tvm
         for i in range(batch):
             for j in range(out_channel):
-                output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
-                                                          np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
-                                                          mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
+                output_np[i, j, :, :] = signal.convolve2d(
+                    input_np[i, j // channel_multiplier, :, :],
+                    np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
+                    mode="same",
+                )[index_h:in_height:stride_h, index_w:in_width:stride_w]
 
     return output_np
 
+
 def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
     """Depthwise convolution operator in nchw layout.
 
@@ -110,22 +121,29 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
         stride_h, stride_w = stride
 
     # calculate output shape
-    if padding == 'VALID':
+    if padding == "VALID":
         out_channel = in_channel * channel_multiplier
         out_height = (in_height - filter_height) // stride_h + 1
         out_width = (in_width - filter_width) // stride_w + 1
         output_np = np.zeros((batch, out_height, out_width, out_channel))
         for i in range(batch):
             for j in range(out_channel):
-                output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
-                                                          np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
-                                                          mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_width + 1):stride_w]
-    if padding == 'SAME':
+                output_np[i, :, :, j] = signal.convolve2d(
+                    input_np[i, :, :, j // channel_multiplier],
+                    np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
+                    mode="valid",
+                )[
+                    0 : (in_height - filter_height + 1) : stride_h,
+                    0 : (in_width - filter_width + 1) : stride_w,
+                ]
+    if padding == "SAME":
         out_channel = in_channel * channel_multiplier
         out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
         out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
         output_np = np.zeros((batch, out_height, out_width, out_channel))
-        pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
+        pad_along_height = np.int(
+            np.max((out_height - 1) * stride_h + filter_height - in_height, 0)
+        )
         pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
         pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
         pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
@@ -135,8 +153,10 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
         index_w = pad_left_scipy - pad_left_tvm
         for i in range(batch):
             for j in range(out_channel):
-                output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
-                                                          np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
-                                                          mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
+                output_np[i, :, :, j] = signal.convolve2d(
+                    input_np[i, :, :, j // channel_multiplier],
+                    np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
+                    mode="same",
+                )[index_h:in_height:stride_h, index_w:in_width:stride_w]
 
     return output_np
index 8eaef92..b4fff24 100644 (file)
@@ -36,12 +36,14 @@ def dilate_python(input_np, strides):
         n-D, the same layout as Input.
     """
     n = len(input_np.shape)
-    assert len(strides) == n, \
-        "Input dimension and strides size dismatch : %d vs %d" %(n, len(strides))
+    assert len(strides) == n, "Input dimension and strides size dismatch : %d vs %d" % (
+        n,
+        len(strides),
+    )
     output_size = ()
     no_zero = ()
     for i in range(n):
-        output_size += ((input_np.shape[i]-1)*strides[i]+1,)
+        output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
         no_zero += ((range(0, output_size[i], strides[i])),)
     output_np = np.zeros(shape=output_size)
     output_np[np.ix_(*no_zero)] = input_np
index 25c8a9f..89e24c1 100644 (file)
@@ -18,8 +18,9 @@
 """gather_nd in python"""
 import numpy as np
 
+
 def gather_nd_python(a_np, indices_np):
-    """ Python version of GatherND operator
+    """Python version of GatherND operator
 
     Parameters
     ----------
@@ -35,7 +36,7 @@ def gather_nd_python(a_np, indices_np):
         Numpy array
     """
     a_shape = a_np.shape
-    indices_np = indices_np.astype('int32')
+    indices_np = indices_np.astype("int32")
     indices_shape = indices_np.shape
     assert len(indices_shape) > 1
     assert indices_shape[0] <= len(a_shape)
index 0f3573c..33479e5 100644 (file)
@@ -18,8 +18,9 @@
 """gather in python"""
 import numpy as np
 
+
 def gather_python(data, axis, indices):
-    """ Python version of Gather operator
+    """Python version of Gather operator
 
     Parameters
     ----------
index 964d8a2..a8c304f 100644 (file)
@@ -21,8 +21,7 @@ import numpy as np
 
 
 def affine_grid_python(data, target_shape):
-    yv, xv = np.meshgrid(
-        np.arange(target_shape[0]), np.arange(target_shape[1]))
+    yv, xv = np.meshgrid(np.arange(target_shape[0]), np.arange(target_shape[1]))
     yv = yv.T * 2 / (target_shape[0] - 1) - 1
     xv = xv.T * 2 / (target_shape[1] - 1) - 1
     ones = np.ones_like(xv)
@@ -59,7 +58,7 @@ def _bilinear_sample_nchw_python(data, grid):
     return out
 
 
-def grid_sample_nchw_python(data, grid, method='bilinear'):
-    if method == 'bilinear':
+def grid_sample_nchw_python(data, grid, method="bilinear"):
+    if method == "bilinear":
         return _bilinear_sample_nchw_python(data, grid)
     raise ValueError("invalid method")
index c333fa5..42b1266 100644 (file)
@@ -18,6 +18,7 @@
 """L2 normalize in python"""
 import numpy as np
 
+
 def l2_normalize_python(a_np, eps, axis=None):
     """L2 normalize operator in NCHW layout.
 
index 9af662f..6bc6b44 100644 (file)
@@ -19,6 +19,7 @@
 from itertools import product
 import numpy as np
 
+
 def lrn_python(a_np, size, axis, bias, alpha, beta):
     """Local response normalization operator in NCHW layout.
 
@@ -52,18 +53,20 @@ def lrn_python(a_np, size, axis, bias, alpha, beta):
     for i, j, k, l in product(*[range(_axis) for _axis in a_np.shape]):
         axis_size = a_np.shape[axis]
         if axis == 1:
-            #NCHW layout
-            sum_start = j-radius if j-radius >= 0 else 0
-            sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
-            sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
-                                      a_np[i, sum_start:sum_end, k, l])
+            # NCHW layout
+            sum_start = j - radius if j - radius >= 0 else 0
+            sum_end = j + radius + 1 if j + radius + 1 < axis_size else axis_size
+            sqr_sum[i, j, k, l] = sum(
+                a_np[i, sum_start:sum_end, k, l] * a_np[i, sum_start:sum_end, k, l]
+            )
         elif axis == 3:
-            #NHWC layout
-            sum_start = l-radius if l-radius >= 0 else 0
-            sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
-            sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
-                                      a_np[i, j, k, sum_start:sum_end])
+            # NHWC layout
+            sum_start = l - radius if l - radius >= 0 else 0
+            sum_end = l + radius + 1 if l + radius + 1 < axis_size else axis_size
+            sqr_sum[i, j, k, l] = sum(
+                a_np[i, j, k, sum_start:sum_end] * a_np[i, j, k, sum_start:sum_end]
+            )
 
-    sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
+    sqr_sum_up = np.power((bias + (alpha * sqr_sum / size)), beta)
     lrn_out = np.divide(a_np, sqr_sum_up)
     return lrn_out
index e0a8914..63edd0a 100644 (file)
@@ -18,6 +18,7 @@
 """MatrixSetDiag in Python"""
 import numpy as np
 
+
 def matrix_set_diag(input_np, diagonal):
     """matrix_set_diag operator implemented in numpy.
 
index 05834e3..0c4b060 100644 (file)
@@ -18,6 +18,7 @@
 """OneHot in python"""
 import numpy as np
 
+
 def one_hot(indices, on_value, off_value, depth, axis, dtype):
     """one_hot operator implemented in numpy.
 
index 90f2a07..d83b722 100644 (file)
@@ -20,11 +20,17 @@ import math
 import numpy as np
 
 
-def pool1d_ncw_python(np_data, kernel,
-                      strides, padding,
-                      out_shape, pool_type,
-                      count_include_pad=True,
-                      ceil_mode=False, dtype="float32"):
+def pool1d_ncw_python(
+    np_data,
+    kernel,
+    strides,
+    padding,
+    out_shape,
+    pool_type,
+    count_include_pad=True,
+    ceil_mode=False,
+    dtype="float32",
+):
     """Baseline for max_pool1d and avg_pool1d, default layout is NCW"""
     in_n, in_c, in_w = in_shape = np_data.shape
     k_w = kernel[0]
@@ -32,11 +38,9 @@ def pool1d_ncw_python(np_data, kernel,
     pl, pr = padding
 
     if ceil_mode:
-        assert out_shape[2] == int(
-            math.ceil(float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
+        assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
     else:
-        assert out_shape[2] == int(math.floor(
-            float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
+        assert out_shape[2] == int(math.floor(float(in_shape[2] - k_w + pl + pr) / s_w) + 1)
 
     pad_np = np.zeros(shape=(in_n, in_c, in_w + pl + pr)).astype(dtype)
 
@@ -44,20 +48,19 @@ def pool1d_ncw_python(np_data, kernel,
     pad_np[np.ix_(*no_zero)] = np_data
     ret_np = np.zeros(shape=out_shape).astype(dtype)
 
-    if pool_type == 'avg':
+    if pool_type == "avg":
         for k in range(out_shape[2]):
             if count_include_pad:
-                ret_np[:, :, k] = np.mean(
-                    pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,))
+                ret_np[:, :, k] = np.mean(pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,))
             else:
-                pad_count = np.sum(
-                    pad_np[:, :, k * s_w: k * s_w + k_w] > 0, axis=(2,))
+                pad_count = np.sum(pad_np[:, :, k * s_w : k * s_w + k_w] > 0, axis=(2,))
                 ret_np[:, :, k] = np.sum(
-                    pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,)) / np.maximum(pad_count, 1)
+                    pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,)
+                ) / np.maximum(pad_count, 1)
 
-    elif pool_type == 'max':
+    elif pool_type == "max":
         for k in range(out_shape[2]):
-            ret_np[:, :, k] = np.max(pad_np[:, :, k * s_w: k * s_w + k_w], axis=(2,))
+            ret_np[:, :, k] = np.max(pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,))
 
     else:
         raise ValueError("Pool type {} is not supported".format(pool_type))
index ee671c2..bf281e5 100644 (file)
 import numpy as np
 
 
-def pool_grad_nchw(a_np, out_grad_np,
-                   pool_size,
-                   strides,
-                   padding,
-                   pool_type,
-                   ceil_mode,
-                   count_include_pad=True):
+def pool_grad_nchw(
+    a_np, out_grad_np, pool_size, strides, padding, pool_type, ceil_mode, count_include_pad=True
+):
     """pool_grad for NCHW layout in python"""
     dtype = a_np.dtype
     n, ic, ih, iw = a_np.shape
@@ -33,37 +29,40 @@ def pool_grad_nchw(a_np, out_grad_np,
     sh, sw = strides
     pt, pl, pb, pr = padding
 
-    pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype)
-    no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl)))
+    pad_np = np.zeros(shape=(n, ic, ih + pt + pb, iw + pl + pr)).astype(dtype)
+    no_zero = (range(n), range(ic), (range(pt, ih + pt)), (range(pl, iw + pl)))
     pad_np[np.ix_(*no_zero)] = a_np
     _, _, oh, ow = out_grad_np.shape
     pool_grad_np = np.zeros(shape=a_np.shape)
     pad_pool_grad_np = np.zeros(shape=pad_np.shape)
 
-    if pool_type == 'avg':
+    if pool_type == "avg":
         for i in range(oh):
             for j in range(ow):
                 if count_include_pad:
-                    shape = pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw].shape
+                    shape = pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw].shape
                     # this can be different from kh*kw if input size cannot divide stride
                     pad_count = shape[2] * shape[3]
                 else:
                     pad_count = np.sum(
-                        pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3))
+                        pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw] > 0, axis=(2, 3)
+                    )
                     # take the first element, as they are the same across batch and channel
                     pad_count = pad_count.ravel()[0]
-                pad_pool_grad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] += \
-                    out_grad_np[:, :, i, j].reshape(n, ic, 1, 1) / np.maximum(pad_count, 1)
-    elif pool_type == 'max':
+                pad_pool_grad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw] += out_grad_np[
+                    :, :, i, j
+                ].reshape(n, ic, 1, 1) / np.maximum(pad_count, 1)
+    elif pool_type == "max":
         for i in range(oh):
             for j in range(ow):
-                a_patch = pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw]
+                a_patch = pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw]
                 a_patch = np.reshape(a_patch, (n, ic, -1))
                 max_indices = np.argmax(a_patch, axis=2)
                 c_idx, n_idx = np.meshgrid(range(ic), range(n), sparse=True)
                 h_idx, w_idx = np.unravel_index(max_indices, (kh, kw))
-                pad_pool_grad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw][n_idx, c_idx, h_idx, w_idx] += \
-                    out_grad_np[n_idx, c_idx, i, j]
+                pad_pool_grad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw][
+                    n_idx, c_idx, h_idx, w_idx
+                ] += out_grad_np[n_idx, c_idx, i, j]
     for i in range(pool_grad_np.shape[2]):
         for j in range(pool_grad_np.shape[3]):
             pool_grad_np[:, :, i, j] = pad_pool_grad_np[:, :, i + pt, j + pl]
index 00b1dee..9ef35c4 100644 (file)
@@ -18,6 +18,7 @@
 """Reorg in python"""
 import numpy as np
 
+
 def reorg_python(a_np, stride):
     """Reorg operator
 
@@ -36,12 +37,12 @@ def reorg_python(a_np, stride):
     """
 
     batch, in_channel, in_height, in_width = a_np.shape
-    a_np = np.reshape(a_np, batch*in_channel*in_height*in_width)
-    out_c = int(in_channel/(stride*stride))
-    out_channel = in_channel*stride*stride
-    out_height = int(in_height/stride)
-    out_width = int(in_width/stride)
-    b_np = np.zeros(batch*out_channel*out_height*out_width)
+    a_np = np.reshape(a_np, batch * in_channel * in_height * in_width)
+    out_c = int(in_channel / (stride * stride))
+    out_channel = in_channel * stride * stride
+    out_height = int(in_height / stride)
+    out_width = int(in_width / stride)
+    b_np = np.zeros(batch * out_channel * out_height * out_width)
     cnt = 0
     for b in range(batch):
         for k in range(in_channel):
@@ -49,10 +50,12 @@ def reorg_python(a_np, stride):
                 for i in range(in_width):
                     c2 = k % out_c
                     offset = int(k / out_c)
-                    w2 = int(i*stride + offset % stride)
-                    h2 = int(j*stride + offset / stride)
-                    out_index = int(w2 + in_width*stride*(h2 + in_height*stride*(c2 + out_c*b)))
+                    w2 = int(i * stride + offset % stride)
+                    h2 = int(j * stride + offset / stride)
+                    out_index = int(
+                        w2 + in_width * stride * (h2 + in_height * stride * (c2 + out_c * b))
+                    )
                     b_np[cnt] = a_np[int(out_index)]
-                    cnt = cnt+1
+                    cnt = cnt + 1
     b_np = np.reshape(b_np, (batch, out_channel, out_height, out_width))
     return b_np
index d328549..5bb292c 100644 (file)
@@ -19,6 +19,7 @@
 import math
 import numpy as np
 
+
 def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio):
     """Roi align in python"""
     _, channel, height, width = a_np.shape
@@ -43,10 +44,12 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
 
         ly = y - y_low
         lx = x - x_low
-        return (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low] + \
-               (1 - ly) * lx * a_np[b, c, y_low, x_high] + \
-            ly * (1 - lx) * a_np[b, c, y_high, x_low] + \
-            ly * lx * a_np[b, c, y_high, x_high]
+        return (
+            (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low]
+            + (1 - ly) * lx * a_np[b, c, y_low, x_high]
+            + ly * (1 - lx) * a_np[b, c, y_high, x_low]
+            + ly * lx * a_np[b, c, y_high, x_high]
+        )
 
     for i in range(num_roi):
         roi = rois_np[i]
@@ -69,7 +72,7 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_rati
         for c in range(channel):
             for ph in range(pooled_size_h):
                 for pw in range(pooled_size_w):
-                    total = 0.
+                    total = 0.0
                     for iy in range(roi_bin_grid_h):
                         for ix in range(roi_bin_grid_w):
                             y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
index 075d7a9..08d9f16 100644 (file)
@@ -19,6 +19,7 @@
 import math
 import numpy as np
 
+
 def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale):
     """Roi pool in python"""
     _, channel, height, width = a_np.shape
@@ -57,7 +58,7 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale):
 
                 for c in range(channel):
                     if is_empty:
-                        b_np[i, c, ph, pw] = 0.
+                        b_np[i, c, ph, pw] = 0.0
                     else:
                         b_np[i, c, ph, pw] = np.max(a_np[batch_index, c, hstart:hend, wstart:wend])
     return b_np
index d77eb6f..9b67fb4 100644 (file)
@@ -18,6 +18,7 @@
 """Sequence mask in python"""
 import numpy as np
 
+
 def sequence_mask(data, valid_length, mask_value, axis):
     """batch_matmul operator implemented in numpy.
 
@@ -46,7 +47,8 @@ def sequence_mask(data, valid_length, mask_value, axis):
     val_len_expand_shape[1 - axis] = in_shape[1 - axis]
     seq_len_expand_shape = [1 for _ in range(len(in_shape))]
     seq_len_expand_shape[axis] = in_shape[axis]
-    mask = np.broadcast_to(np.arange(max_length).reshape(seq_len_expand_shape),
-                           in_shape) >= valid_length.reshape(val_len_expand_shape)
+    mask = np.broadcast_to(
+        np.arange(max_length).reshape(seq_len_expand_shape), in_shape
+    ) >= valid_length.reshape(val_len_expand_shape)
     out = data * (1 - mask) + mask_value * mask
     return out
index 4a85988..c01a7b5 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 """Slice axis in python"""
 
+
 def slice_axis_python(data, axis, begin, end=None):
     """Slice input array along specific axis.
 
index 119625c..da2893d 100644 (file)
@@ -18,6 +18,7 @@
 """Softmax and log_softmax operation in python"""
 import numpy as np
 
+
 def softmax_python(a_np):
     """Softmax operator.
     Parameters
@@ -33,11 +34,12 @@ def softmax_python(a_np):
     assert len(a_np.shape) == 2, "only support 2-dim softmax"
     max_elem = np.amax(a_np, axis=1)
     max_elem = max_elem.reshape(max_elem.shape[0], 1)
-    e = np.exp(a_np-max_elem)
+    e = np.exp(a_np - max_elem)
     expsum = np.sum(e, axis=1)
     out_np = e / expsum[:, None]
     return out_np
 
+
 def log_softmax_python(a_np):
     """Log_softmax operator.
     Parameters
@@ -53,7 +55,7 @@ def log_softmax_python(a_np):
     assert len(a_np.shape) == 2, "only support 2-dim log_softmax"
     max_elem = np.amax(a_np, axis=1)
     max_elem = max_elem.reshape(max_elem.shape[0], 1)
-    e = np.exp(a_np-max_elem)
+    e = np.exp(a_np - max_elem)
     expsum = np.sum(e, axis=1)
     out_np = a_np - max_elem - np.log(expsum[:, None])
     return out_np
index 3a3b941..e56a12d 100644 (file)
@@ -41,8 +41,7 @@ def space_to_depth_python(data, block_size):
     new_w = int(in_h / block_size)
     new_c = int(in_c * (block_size * block_size))
 
-    expanded = np.reshape(
-        data, newshape=[in_n, in_c, new_h, block_size, new_w, block_size])
+    expanded = np.reshape(data, newshape=[in_n, in_c, new_h, block_size, new_w, block_size])
     transposed = np.transpose(expanded, axes=[0, 3, 5, 1, 2, 4])
     newshape = [in_n, new_c, new_h, new_w]
     d2s_out = np.reshape(transposed, newshape=newshape)
index 970e1de..c5eb723 100644 (file)
@@ -65,9 +65,7 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"):
         else:
             new_end = end[i]
 
-        slices.append(slice(new_begin,
-                            new_end,
-                            new_stride))
+        slices.append(slice(new_begin, new_end, new_stride))
     return data[tuple(slices)]
 
 
@@ -100,9 +98,12 @@ def strided_set_python(data, v, begin, end, strides):
     slices = []
     res = data.copy()
     for i in range(len(data.shape)):
-        slices.append(slice(
-            begin[i] if i < len(begin) else None,
-            end[i] if i < len(end) else None,
-            strides[i] if i < len(strides) else None))
+        slices.append(
+            slice(
+                begin[i] if i < len(begin) else None,
+                end[i] if i < len(end) else None,
+                strides[i] if i < len(strides) else None,
+            )
+        )
     res[tuple(slices)] = v
     return res
index cc8fdd6..de1e230 100644 (file)
 import math
 import numpy as np
 
-def trilinear_resize3d_python(data_in, out_size, layout,
-                              coordinate_transformation_mode="align_corners"):
+
+def trilinear_resize3d_python(
+    data_in, out_size, layout, coordinate_transformation_mode="align_corners"
+):
     """ Trilinear 3d scaling using python"""
     (new_d, new_h, new_w) = out_size
 
-    if layout == 'NDHWC':
+    if layout == "NDHWC":
         (batch, d, h, w, channel) = data_in.shape
         data_out = np.ones((batch, new_d, new_h, new_w, channel))
     else:
@@ -32,16 +34,17 @@ def trilinear_resize3d_python(data_in, out_size, layout,
         data_out = np.ones((batch, channel, new_d, new_h, new_w))
 
     if coordinate_transformation_mode == "align_corners":
-        depth_scale = np.float32(d-1) / np.float32(out_size[0]-1)
-        height_scale = np.float32(h-1) / np.float32(out_size[1]-1)
-        width_scale = np.float32(w-1) / np.float32(out_size[2]-1)
+        depth_scale = np.float32(d - 1) / np.float32(out_size[0] - 1)
+        height_scale = np.float32(h - 1) / np.float32(out_size[1] - 1)
+        width_scale = np.float32(w - 1) / np.float32(out_size[2] - 1)
     elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
         depth_scale = np.float32(d) / np.float32(out_size[0])
         height_scale = np.float32(h) / np.float32(out_size[1])
         width_scale = np.float32(w) / np.float32(out_size[2])
     else:
-        raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
-            coordinate_transformation_mode))
+        raise ValueError(
+            "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode)
+        )
 
     def _lerp(A, B, t):
         return A * (1.0 - t) + B * t
@@ -62,14 +65,17 @@ def trilinear_resize3d_python(data_in, out_size, layout,
             for m in range(new_d):
                 for j in range(new_h):
                     for k in range(new_w):
-                        z0, z1, z_lerp = _in_coord(m, depth_scale, d,\
-                                                   coordinate_transformation_mode)
-                        y0, y1, y_lerp = _in_coord(j, height_scale, h,\
-                                                   coordinate_transformation_mode)
-                        x0, x1, x_lerp = _in_coord(k, width_scale, w,\
-                                                   coordinate_transformation_mode)
+                        z0, z1, z_lerp = _in_coord(
+                            m, depth_scale, d, coordinate_transformation_mode
+                        )
+                        y0, y1, y_lerp = _in_coord(
+                            j, height_scale, h, coordinate_transformation_mode
+                        )
+                        x0, x1, x_lerp = _in_coord(
+                            k, width_scale, w, coordinate_transformation_mode
+                        )
 
-                        if layout == 'NDHWC':
+                        if layout == "NDHWC":
                             A0 = data_in[b][z0][y0][x0][i]
                             B0 = data_in[b][z0][y0][x1][i]
                             C0 = data_in[b][z0][y1][x0][i]
@@ -97,7 +103,7 @@ def trilinear_resize3d_python(data_in, out_size, layout,
 
                         pixel = np.float32(_lerp(top, bottom, y_lerp))
 
-                        if layout == 'NDHWC':
+                        if layout == "NDHWC":
                             data_out[b][m][j][k][i] = pixel
                         else:
                             data_out[b][i][m][j][k] = pixel
index 8cc00ad..203e804 100644 (file)
@@ -34,13 +34,18 @@ def upsample_nearest(arr, scale):
             out[y, x] = arr[in_y, in_x]
     return out
 
-def upsampling_python(data, scale, layout='NCHW'):
+
+def upsampling_python(data, scale, layout="NCHW"):
     """ Python version of scaling using nearest neighbour """
 
     ishape = data.shape
-    if layout == 'NCHW':
-        oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])),
-                  int(round(ishape[3]*scale[1])))
+    if layout == "NCHW":
+        oshape = (
+            ishape[0],
+            ishape[1],
+            int(round(ishape[2] * scale[0])),
+            int(round(ishape[3] * scale[1])),
+        )
         output_np = np.zeros(oshape, dtype=data.dtype)
         for b in range(oshape[0]):
             for c in range(oshape[1]):
@@ -48,19 +53,31 @@ def upsampling_python(data, scale, layout='NCHW'):
         return output_np
     # NCHWinic
     if nchw_pack_layout(layout):
-        oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])),
-                  int(round(ishape[3]*scale[1])), ishape[4], ishape[5])
+        oshape = (
+            ishape[0],
+            ishape[1],
+            int(round(ishape[2] * scale[0])),
+            int(round(ishape[3] * scale[1])),
+            ishape[4],
+            ishape[5],
+        )
         output_np = np.zeros(oshape, dtype=data.dtype)
         for b in range(oshape[0]):
             for ib in range(oshape[4]):
                 for c in range(oshape[1]):
                     for ic in range(oshape[5]):
-                        output_np[b, c, :, :, ib, ic] = upsample_nearest(data[b, c, :, :, ib, ic], scale)
+                        output_np[b, c, :, :, ib, ic] = upsample_nearest(
+                            data[b, c, :, :, ib, ic], scale
+                        )
         return output_np
 
-    if layout == 'NHWC':
-        oshape = (ishape[0], int(round(ishape[1]*scale[0])),
-                  int(round(ishape[2]*scale[1])), ishape[3])
+    if layout == "NHWC":
+        oshape = (
+            ishape[0],
+            int(round(ishape[1] * scale[0])),
+            int(round(ishape[2] * scale[1])),
+            ishape[3],
+        )
         output_np = np.zeros(oshape, dtype=data.dtype)
         for b in range(oshape[0]):
             for c in range(oshape[3]):
@@ -68,6 +85,7 @@ def upsampling_python(data, scale, layout='NCHW'):
         return output_np
     raise ValueError("not support this layout {} yet".format(layout))
 
+
 def upsample3d_nearest(arr, scale):
     """ Populate the array by scale factor"""
     d, h, w = arr.shape
@@ -84,25 +102,32 @@ def upsample3d_nearest(arr, scale):
                 out[z, y, x] = arr[in_z, in_y, in_x]
     return out
 
-def upsampling3d_python(data, scale, layout='NCDHW'):
+
+def upsampling3d_python(data, scale, layout="NCDHW"):
     """ Python version of 3D scaling using nearest neighbour """
 
     ishape = data.shape
-    if layout == 'NCDHW':
-        oshape = (ishape[0], ishape[1],
-                  int(round(ishape[2]*scale[0])),
-                  int(round(ishape[3]*scale[1])),
-                  int(round(ishape[4]*scale[2])))
+    if layout == "NCDHW":
+        oshape = (
+            ishape[0],
+            ishape[1],
+            int(round(ishape[2] * scale[0])),
+            int(round(ishape[3] * scale[1])),
+            int(round(ishape[4] * scale[2])),
+        )
         output_np = np.zeros(oshape, dtype=data.dtype)
         for b in range(oshape[0]):
             for c in range(oshape[1]):
                 output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale)
         return output_np
-    if layout == 'NDHWC':
-        oshape = (ishape[0],
-                  int(round(ishape[1]*scale[0])),
-                  int(round(ishape[2]*scale[1])),
-                  int(round(ishape[3]*scale[2])), ishape[4])
+    if layout == "NDHWC":
+        oshape = (
+            ishape[0],
+            int(round(ishape[1] * scale[0])),
+            int(round(ishape[2] * scale[1])),
+            int(round(ishape[3] * scale[2])),
+            ishape[4],
+        )
         output_np = np.zeros(oshape, dtype=data.dtype)
         for b in range(oshape[0]):
             for c in range(oshape[4]):
index 1681d87..6af0828 100644 (file)
@@ -83,8 +83,11 @@ def expand_like(a, shape_like, axis):
         if len(a.shape) == 1 and len(axis) == len(shape_like.shape):
             # A special case: `a` is a scalar represented as a 1-dim tensor
             return te.compute(shape_like.shape, lambda *idxs: a(0))
-        raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format(
-            len(axis), len(a.shape), len(shape_like.shape)))
+        raise ValueError(
+            "shape inconsistent when expand_like ({}, {}, {})".format(
+                len(axis), len(a.shape), len(shape_like.shape)
+            )
+        )
 
     real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
     real_axis = sorted(real_axis)
@@ -97,6 +100,7 @@ def expand_like(a, shape_like, axis):
                 indices.append(idxs[i])
                 axis_index += 1
         return a(*indices)
+
     return te.compute(shape_like.shape, _compute)
 
 
@@ -200,7 +204,8 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"):
         strides = []
     return cpp.strided_slice(a, begin, end, strides, slice_mode)
 
-@tvm.te.tag_scope(tag=tag.INJECTIVE+",strided_set")
+
+@tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set")
 def strided_set(a, v, begin, end, strides=None):
     """Set slice of an array.
 
@@ -231,59 +236,57 @@ def strided_set(a, v, begin, end, strides=None):
 
     if len(begin.shape) != 1:
         raise ValueError("begin should be a vector")
-    if not begin.dtype == 'int32':
+    if not begin.dtype == "int32":
         raise TypeError("begin should be int32")
     if len(end.shape) != 1:
         raise ValueError("end should be a vector")
-    if not end.dtype == 'int32':
+    if not end.dtype == "int32":
         raise TypeError("end should be int32")
     if strides is not None:
         if len(strides.shape) != 1:
             raise ValueError("strides should be a vector")
-        if not strides.dtype == 'int32':
+        if not strides.dtype == "int32":
             raise TypeError("strides should be int32")
 
     def _max(a, b):
         return tvm.tir.Select(a > b, a, b)
 
     if strides is None:
-        strides = [tvm.tir.const(1, 'int32')] * n
+        strides = [tvm.tir.const(1, "int32")] * n
     else:
-        strides = [tvm.tir.if_then_else(strides.shape[0] > i,
-                                        strides[i],
-                                        tvm.tir.const(1, 'int32'))
-                   for i in range(n)]
-
-    begin = [tvm.tir.if_then_else(begin.shape[0] > i,
-                                  begin[i],
-                                  tvm.tir.Select(strides[i] > 0,
-                                                 tvm.tir.const(0, 'int32'),
-                                                 a.shape[i]))
-             for i in range(n)]
-    end = [tvm.tir.if_then_else(end.shape[0] > i,
-                                end[i],
-                                tvm.tir.Select(strides[i] > 0,
-                                               a.shape[i] + 1,
-                                               -(a.shape[i] + 1)))
-           for i in range(n)]
-
+        strides = [
+            tvm.tir.if_then_else(strides.shape[0] > i, strides[i], tvm.tir.const(1, "int32"))
+            for i in range(n)
+        ]
+
+    begin = [
+        tvm.tir.if_then_else(
+            begin.shape[0] > i,
+            begin[i],
+            tvm.tir.Select(strides[i] > 0, tvm.tir.const(0, "int32"), a.shape[i]),
+        )
+        for i in range(n)
+    ]
+    end = [
+        tvm.tir.if_then_else(
+            end.shape[0] > i,
+            end[i],
+            tvm.tir.Select(strides[i] > 0, a.shape[i] + 1, -(a.shape[i] + 1)),
+        )
+        for i in range(n)
+    ]
 
     # Convert negative indexes
     for i in range(n):
-        begin[i] = tvm.tir.if_then_else(begin[i] < 0,
-                                        begin[i] + a.shape[i],
-                                        begin[i])
-        end[i] = tvm.tir.if_then_else(end[i] < 0,
-                                      end[i] + a.shape[i],
-                                      end[i])
+        begin[i] = tvm.tir.if_then_else(begin[i] < 0, begin[i] + a.shape[i], begin[i])
+        end[i] = tvm.tir.if_then_else(end[i] < 0, end[i] + a.shape[i], end[i])
 
     def _select(*indices):
         from_val = []
         index_tuple = []
         for i in range(n):
             from_val.append(within_index(begin[i], end[i], strides[i], indices[i]))
-            index_tuple.append(
-                make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
+            index_tuple.append(make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
         return tvm.tir.if_then_else(tvm.tir.all(*from_val), v(*index_tuple), a(*indices))
 
     return te.compute(a.shape, _select, name="strided_set")
@@ -657,8 +660,9 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
         depending on the value of `axis`.
     """
 
-    assert len(data.shape) >= 2,\
-        "only support data.ndim >= 2, received data.shape = {}".format(data.shape)
+    assert len(data.shape) >= 2, "only support data.ndim >= 2, received data.shape = {}".format(
+        data.shape
+    )
     assert axis in (0, 1), "only support axis = 0, 1, received axis = {}".format(axis)
     return cpp.sequence_mask(data, valid_length, mask_value, axis)
 
@@ -703,6 +707,7 @@ def where(condition, x, y):
     """
     return cpp.where(condition, x, y)
 
+
 def one_hot(indices, on_value, off_value, depth, axis, dtype):
     """
     Returns a one-hot tensor where the locations repsented by indices take value on_value,
@@ -751,25 +756,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
 def unravel_index(indices, shape):
     """Convert a flat index or array of flat indices into a tuple of coordinate arrays.
 
-       Example::
-       -   unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]]
+    Example::
+    -   unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]]
 
-       Parameters
-       ----------
-       indices : relay.Expr
-           An integer array containing indices.
+    Parameters
+    ----------
+    indices : relay.Expr
+        An integer array containing indices.
 
-       shape : relay.Expr
-           The shape of the array.
+    shape : relay.Expr
+        The shape of the array.
 
-       Returns
-       -------
-       result : relay.Expr
-           The tuple of coordinate arrays.
+    Returns
+    -------
+    result : relay.Expr
+        The tuple of coordinate arrays.
     """
 
     return cpp.unravel_index(indices, shape)
 
+
 def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0):
     """Converts a sparse representation into a dense tensor.
 
@@ -799,6 +805,7 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0
 
     return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)
 
+
 def matrix_set_diag(data, diagonal):
     """
     Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values.
@@ -839,6 +846,7 @@ def matrix_set_diag(data, diagonal):
     """
     return cpp.matrix_set_diag(data, diagonal)
 
+
 def adv_index(data, indices):
     """Numpy style indexing with tensors.
 
index 5bde1cb..0a5c93c 100644 (file)
@@ -24,16 +24,20 @@ from tvm import te
 from tvm.tir import layout, bijective_layout
 from . import tag, cpp
 
+
 class InvalidShapeError(ValueError):
     """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
 
+
 def nchw_pack_layout(layout_info):
     """Check whether the layout type is NCHWinic"""
-    return layout_info[:4] == 'NCHW' and 'c' in layout_info and 'n' in layout_info
+    return layout_info[:4] == "NCHW" and "c" in layout_info and "n" in layout_info
+
 
 def nchw_xc_layout(layout_info):
     """Check whether the layout type is NCHWxc"""
-    return layout_info[:4] == 'NCHW' and 'c' in layout_info and layout_info[4:-1].isnumeric()
+    return layout_info[:4] == "NCHW" and "c" in layout_info and layout_info[4:-1].isnumeric()
+
 
 def traverse_inline(s, final_op, callback):
     """Traverse computation graph and do auto inline
@@ -290,9 +294,11 @@ def const_matrix(matrix, name="const_matrix"):
         now = tvm.tir.const(0.0, dtype)
         for ii in range(row):
             for jj in range(col):
-                now = tvm.tir.Select(tvm.tir.all(idxm(i, row) == ii, idxm(j, col) == jj),
-                                     tvm.tir.const(matrix[ii][jj], dtype),
-                                     now)
+                now = tvm.tir.Select(
+                    tvm.tir.all(idxm(i, row) == ii, idxm(j, col) == jj),
+                    tvm.tir.const(matrix[ii][jj], dtype),
+                    now,
+                )
         return now
 
     return te.compute(matrix.shape, select_array, name=name)
@@ -352,12 +358,13 @@ def get_shape(src_shape, src_layout, dst_layout):
     if isinstance(dst_layout, str):
         dst_layout = layout(dst_layout)
 
-    assert len(src_layout) == len(dst_layout), \
-        "Incompatible layout %s vs %s" % (src_layout, dst_layout)
+    assert len(src_layout) == len(dst_layout), "Incompatible layout %s vs %s" % (
+        src_layout,
+        dst_layout,
+    )
 
     layout_mapping = bijective_layout(src_layout, dst_layout)
-    dst_indices = layout_mapping.forward_index(
-        tvm.runtime.convert(list(range(len(src_layout)))))
+    dst_indices = layout_mapping.forward_index(tvm.runtime.convert(list(range(len(src_layout)))))
 
     return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
 
@@ -387,9 +394,7 @@ def within_index(b, e, s, i):
     """
     bc = tvm.tir.Select(s < 0, i <= e, i < b)
     ec = tvm.tir.Select(s < 0, i > b, i >= e)
-    ss = te.if_then_else(s < 0,
-                         ((i - e) + (e % te.abs(s)) + 1) % te.abs(s),
-                         (i - b) % s)
+    ss = te.if_then_else(s < 0, ((i - e) + (e % te.abs(s)) + 1) % te.abs(s), (i - b) % s)
     return tvm.tir.Select(tvm.tir.Or(bc, ec), tvm.tir.const(False), ss.equal(0))
 
 
@@ -428,9 +433,7 @@ def make_idx(b, e, s, z, i):
     # Clamp to array size
     b = tvm.tir.Select(z < b, z - 1, b)
 
-    ss = tvm.tir.if_then_else(s < 0,
-                              (b - i) // te.abs(s),
-                              (i - b) // s)
+    ss = tvm.tir.if_then_else(s < 0, (b - i) // te.abs(s), (i - b) // s)
     return tvm.tir.if_then_else(tvm.tir.Or(bc, ec), 88, ss)
 
 
index 1ee9e83..76e1808 100644 (file)
@@ -22,6 +22,7 @@ from tvm import te
 from tvm.te import hybrid
 from ..sort import argsort
 
+
 @hybrid.script
 def hybrid_rearrange_box_out(data, one, batch_size, num_anchors):
     """Hybrid routine to rearrange nms output to
@@ -50,10 +51,7 @@ def hybrid_rearrange_box_out(data, one, batch_size, num_anchors):
         [batch_size, num_anchors, 6].
     """
     elem_length = data.shape[2]
-    output = output_tensor((batch_size,
-                            num_anchors,
-                            elem_length),
-                           data.dtype)
+    output = output_tensor((batch_size, num_anchors, elem_length), data.dtype)
 
     for i in parallel(batch_size):
         valid_idx = 0
@@ -120,8 +118,9 @@ def hybrid_rearrange_indices_out(data, one, batch_size, num_anchors):
 
 
 @hybrid.script
-def hybrid_get_valid_counts(data, score_threshold, id_index, score_index,
-                            one, batch_size, num_anchors):
+def hybrid_get_valid_counts(
+    data, score_threshold, id_index, score_index, one, batch_size, num_anchors
+):
     """Hybrid routine to get valid count of bounding boxes
     given a score threshold. Also moves valid boxes to the
     top of input data.
@@ -164,17 +163,13 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index,
     """
     box_data_length = data.shape[2]
     valid_count = output_tensor((batch_size,), "int32")
-    out_tensor = output_tensor((batch_size,
-                                num_anchors,
-                                box_data_length),
-                               data.dtype)
+    out_tensor = output_tensor((batch_size, num_anchors, box_data_length), data.dtype)
     out_indices = output_tensor((batch_size, num_anchors), "int32")
     for i in parallel(batch_size):
         valid_count[i] = 0
         for j in range(num_anchors):
             score = data[i, j, score_index]
-            if score > score_threshold and \
-                    (id_index < 0 or data[i, j, id_index] >= 0):
+            if score > score_threshold and (id_index < 0 or data[i, j, id_index] >= 0):
                 for k in range(box_data_length):
                     out_tensor[i, valid_count[i], k] = data[i, j, k]
                 out_indices[i, valid_count[i]] = j
@@ -219,16 +214,36 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
     score_threshold_const = tvm.tir.const(score_threshold, data.dtype)
     id_index_const = tvm.tir.const(id_index, "int32")
     score_index_const = tvm.tir.const(score_index, "int32")
-    return hybrid_get_valid_counts(data, score_threshold_const,
-                                   id_index_const, score_index_const,
-                                   tvm.tir.const(1, data.dtype),
-                                   data.shape[0], data.shape[1])
+    return hybrid_get_valid_counts(
+        data,
+        score_threshold_const,
+        id_index_const,
+        score_index_const,
+        tvm.tir.const(1, data.dtype),
+        data.shape[0],
+        data.shape[1],
+    )
 
 
 @hybrid.script
-def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors,
-               max_output_size, iou_threshold, force_suppress, top_k, coord_start,
-               score_index, id_index, return_indices, zero, one):
+def hybrid_nms(
+    data,
+    sorted_index,
+    valid_count,
+    indices,
+    batch_size,
+    num_anchors,
+    max_output_size,
+    iou_threshold,
+    force_suppress,
+    top_k,
+    coord_start,
+    score_index,
+    id_index,
+    return_indices,
+    zero,
+    one,
+):
     """Hybrid routing for non-maximum suppression.
 
     Parameters
@@ -305,9 +320,14 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors
 
     # box_indices is the expected indices of boxes
     box_indices = output_tensor((batch_size, num_anchors), sorted_index.dtype)
-    output = output_tensor((batch_size,
-                            num_anchors,
-                            box_data_length,), data.dtype)
+    output = output_tensor(
+        (
+            batch_size,
+            num_anchors,
+            box_data_length,
+        ),
+        data.dtype,
+    )
 
     for i in range(batch_size):
         if iou_threshold > 0:
@@ -342,21 +362,33 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors
                     is_valid_box = 1
 
                     # a_l: left, a_t: top, a_r: right, a_b: bottom
-                    a_l = min(output[batch_idx, box_a_idx, box_start_idx],
-                              output[batch_idx, box_a_idx, box_start_idx + 2])
-                    a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1],
-                              output[batch_idx, box_a_idx, box_start_idx + 3])
-                    a_r = max(output[batch_idx, box_a_idx, box_start_idx],
-                              output[batch_idx, box_a_idx, box_start_idx + 2])
-                    a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1],
-                              output[batch_idx, box_a_idx, box_start_idx + 3])
+                    a_l = min(
+                        output[batch_idx, box_a_idx, box_start_idx],
+                        output[batch_idx, box_a_idx, box_start_idx + 2],
+                    )
+                    a_t = min(
+                        output[batch_idx, box_a_idx, box_start_idx + 1],
+                        output[batch_idx, box_a_idx, box_start_idx + 3],
+                    )
+                    a_r = max(
+                        output[batch_idx, box_a_idx, box_start_idx],
+                        output[batch_idx, box_a_idx, box_start_idx + 2],
+                    )
+                    a_b = max(
+                        output[batch_idx, box_a_idx, box_start_idx + 1],
+                        output[batch_idx, box_a_idx, box_start_idx + 3],
+                    )
 
                     # check if current box j is valid by calculating iou with
                     # all existing valid boxes
                     for k in range(j):
                         check_iou = 0
-                        if is_valid_box == 1 and k < j and output[i, k, score_index] > 0 \
-                                and (id_index < 0 or output[i, k, id_index] >= 0):
+                        if (
+                            is_valid_box == 1
+                            and k < j
+                            and output[i, k, score_index] > 0
+                            and (id_index < 0 or output[i, k, id_index] >= 0)
+                        ):
                             if force_suppress:
                                 check_iou = 1
                             elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]:
@@ -366,14 +398,22 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors
                             box_b_idx = k
 
                             # b_l: left, b_t: top, b_r: right, b_b: bottom
-                            b_l = min(output[batch_idx, box_b_idx, box_start_idx],
-                                      output[batch_idx, box_b_idx, box_start_idx + 2])
-                            b_t = min(output[batch_idx, box_b_idx, box_start_idx + 1],
-                                      output[batch_idx, box_b_idx, box_start_idx + 3])
-                            b_r = max(output[batch_idx, box_b_idx, box_start_idx],
-                                      output[batch_idx, box_b_idx, box_start_idx + 2])
-                            b_b = max(output[batch_idx, box_b_idx, box_start_idx + 1],
-                                      output[batch_idx, box_b_idx, box_start_idx + 3])
+                            b_l = min(
+                                output[batch_idx, box_b_idx, box_start_idx],
+                                output[batch_idx, box_b_idx, box_start_idx + 2],
+                            )
+                            b_t = min(
+                                output[batch_idx, box_b_idx, box_start_idx + 1],
+                                output[batch_idx, box_b_idx, box_start_idx + 3],
+                            )
+                            b_r = max(
+                                output[batch_idx, box_b_idx, box_start_idx],
+                                output[batch_idx, box_b_idx, box_start_idx + 2],
+                            )
+                            b_b = max(
+                                output[batch_idx, box_b_idx, box_start_idx + 1],
+                                output[batch_idx, box_b_idx, box_start_idx + 3],
+                            )
 
                             # Overlapping width and height
                             w = max(zero, min(a_r, b_r) - max(a_l, b_l))
@@ -419,11 +459,22 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors
 
     return output, box_indices
 
+
 @tvm.target.generic_func
-def non_max_suppression(data, valid_count, indices, max_output_size=-1,
-                        iou_threshold=0.5, force_suppress=False, top_k=-1,
-                        coord_start=2, score_index=1, id_index=0,
-                        return_indices=True, invalid_to_bottom=False):
+def non_max_suppression(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
     """Non-maximum suppression operator for object detection.
 
     Parameters
@@ -506,27 +557,37 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
     score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis])
     sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
 
-    out, box_indices = hybrid_nms(data,
-                                  sort_tensor,
-                                  valid_count,
-                                  indices,
-                                  batch_size,
-                                  num_anchors,
-                                  max_output_size,
-                                  tvm.tir.const(iou_threshold, dtype=data.dtype),
-                                  tvm.tir.const(force_suppress, dtype="bool"),
-                                  tvm.tir.const(top_k, dtype="int32"),
-                                  tvm.tir.const(coord_start, dtype="int32"),
-                                  tvm.tir.const(score_index, dtype="int32"),
-                                  tvm.tir.const(id_index, dtype="int32"),
-                                  tvm.tir.const(return_indices, dtype="bool"),
-                                  zero=tvm.tir.const(0, dtype=data.dtype),
-                                  one=tvm.tir.const(1, dtype=data.dtype))
+    out, box_indices = hybrid_nms(
+        data,
+        sort_tensor,
+        valid_count,
+        indices,
+        batch_size,
+        num_anchors,
+        max_output_size,
+        tvm.tir.const(iou_threshold, dtype=data.dtype),
+        tvm.tir.const(force_suppress, dtype="bool"),
+        tvm.tir.const(top_k, dtype="int32"),
+        tvm.tir.const(coord_start, dtype="int32"),
+        tvm.tir.const(score_index, dtype="int32"),
+        tvm.tir.const(id_index, dtype="int32"),
+        tvm.tir.const(return_indices, dtype="bool"),
+        zero=tvm.tir.const(0, dtype=data.dtype),
+        one=tvm.tir.const(1, dtype=data.dtype),
+    )
     if return_indices:
-        return hybrid_rearrange_indices_out(box_indices, one=tvm.tir.const(1, dtype="int32"),
-                                            batch_size=batch_size, num_anchors=num_anchors)
+        return hybrid_rearrange_indices_out(
+            box_indices,
+            one=tvm.tir.const(1, dtype="int32"),
+            batch_size=batch_size,
+            num_anchors=num_anchors,
+        )
 
     if invalid_to_bottom:
-        out = hybrid_rearrange_box_out(out, one=tvm.tir.const(1, dtype=data.dtype),
-                                       batch_size=batch_size, num_anchors=num_anchors)
+        out = hybrid_rearrange_box_out(
+            out,
+            one=tvm.tir.const(1, dtype=data.dtype),
+            batch_size=batch_size,
+            num_anchors=num_anchors,
+        )
     return out
index e99ebe0..cda7522 100644 (file)
@@ -22,17 +22,22 @@ from tvm import te
 from ...util import get_const_tuple, get_const_int
 from ...sort import argsort
 
+
 def generate_anchor(ratio, scale, base_size):
     """Generate anchor"""
     w = h = float(base_size)
-    x_ctr = 0.5 * (w - 1.)
-    y_ctr = 0.5 * (h - 1.)
+    x_ctr = 0.5 * (w - 1.0)
+    y_ctr = 0.5 * (h - 1.0)
     size = w * h
     size_ratios = math.floor(size / ratio)
     new_w = math.floor(math.sqrt(size_ratios) + 0.5) * scale
     new_h = math.floor((new_w / scale * ratio) + 0.5) * scale
-    return (x_ctr - 0.5 * (new_w - 1.0), y_ctr - 0.5 * (new_h - 1.0),
-            x_ctr + 0.5 * (new_w - 1.0), y_ctr + 0.5 * (new_h - 1.0))
+    return (
+        x_ctr - 0.5 * (new_w - 1.0),
+        y_ctr - 0.5 * (new_h - 1.0),
+        x_ctr + 0.5 * (new_w - 1.0),
+        y_ctr + 0.5 * (new_h - 1.0),
+    )
 
 
 def reg_bbox(x1, y1, x2, y2, dx, dy, dw, dh):
@@ -62,8 +67,18 @@ def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
     pred_y2 = y2 + dy2
     return pred_x1, pred_y1, pred_x2, pred_y2
 
-def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
-                    feature_stride, rpn_min_size, iou_loss):
+
+def predict_bbox_ir(
+    cls_prob_buf,
+    bbox_pred_buf,
+    im_info_buf,
+    out_buf,
+    scales,
+    ratios,
+    feature_stride,
+    rpn_min_size,
+    iou_loss,
+):
     """Predict bounding boxes based on anchors, scores and deltas.
 
     Parameters
@@ -131,8 +146,10 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
             x2 = anchor[2] + w * feature_stride
             y2 = anchor[3] + h * feature_stride
 
-            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
-                     for i in range(4)]
+            delta = [
+                p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
+                for i in range(4)
+            ]
             regression_func = reg_iou if iou_loss else reg_bbox
             pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
 
@@ -141,16 +158,17 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
             pred_x2 = tvm.te.max(tvm.te.min(pred_x2, im_width - 1.0), 0.0)
             pred_y2 = tvm.te.max(tvm.te.min(pred_y2, im_height - 1.0), 0.0)
 
-            real_height = (im_height / feature_stride).astype('int32')
-            real_width = (im_width / feature_stride).astype('int32')
+            real_height = (im_height / feature_stride).astype("int32")
+            real_width = (im_width / feature_stride).astype("int32")
 
             bbox_w = pred_x2 - pred_x1 + 1.0
             bbox_h = pred_y2 - pred_y1 + 1.0
             min_size = p_im_info[b * 3 + 2] * rpn_min_size
 
             pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
-            pred_score = tvm.tir.Select(tvm.tir.any(h >= real_height, w >= real_width),
-                                        -1.0, pred_score)
+            pred_score = tvm.tir.Select(
+                tvm.tir.any(h >= real_height, w >= real_width), -1.0, pred_score
+            )
             p_out[out_index * 5 + 0] = pred_x1
             p_out[out_index * 5 + 1] = pred_y1
             p_out[out_index * 5 + 2] = pred_x2
@@ -200,8 +218,9 @@ def argsort_ir(data_buf, out_index_buf):
         with ib.for_range(0, num_bbox) as k:
             with ib.for_range(0, (num_bbox + 1) // 2) as tid:
                 offset = start + 2 * tid + idxm(k, 2)
-                with ib.if_scope(tvm.tir.all(offset + 1 < num_bbox,
-                                             p_data[offset] < p_data[offset + 1])):
+                with ib.if_scope(
+                    tvm.tir.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])
+                ):
                     temp_data[0] = p_data[offset]
                     p_data[offset] = p_data[offset + 1]
                     p_data[offset + 1] = temp_data[0]
@@ -231,18 +250,29 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
     stmt : Stmt
         The result IR statement.
     """
+
     def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
-        """Calculate overlap of two boxes.
-        """
-        w = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
-                       - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
-        h = tvm.te.max(0.0, tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
-                       - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
+        """Calculate overlap of two boxes."""
+        w = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
+            - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx])
+            + 1.0,
+        )
+        h = tvm.te.max(
+            0.0,
+            tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
+            - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])
+            + 1.0,
+        )
         i = w * h
-        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
-            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
-            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
-            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
+        u = (
+            (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0)
+            * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0)
+            + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0)
+            * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0)
+            - i
+        )
         return i / u
 
     batch, num_bbox = get_const_tuple(out_buf.shape)
@@ -286,12 +316,12 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
     rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
     ib = tvm.tir.ir_builder.create()
-    i = ib.allocate('int32', (batch,), 'i', scope='local')
+    i = ib.allocate("int32", (batch,), "i", scope="local")
     p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
     p_remove = ib.buffer_ptr(remove_mask_buf)
     p_out = ib.buffer_ptr(out_buf)
 
-    nkeep = ib.allocate('int32', (batch,), 'nkeep', scope='local')
+    nkeep = ib.allocate("int32", (batch,), "nkeep", scope="local")
 
     with ib.for_range(0, batch) as b:
         nkeep[b] = 0
@@ -303,15 +333,19 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
                 nkeep[b] += 1
     with ib.for_range(0, batch) as b:
         with ib.if_scope(nkeep[b] > 0):
-            with ib.for_range(0, te.ceil(
-                    tvm.tir.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')):
+            with ib.for_range(
+                0, te.ceil(tvm.tir.const(rpn_post_nms_top_n, "float32") / nkeep[b]).astype("int32")
+            ):
                 with ib.for_range(0, num_bbox) as j:
                     offset_j = (b * num_bbox + j) * 5
                     offset_i = (b * rpn_post_nms_top_n + i[b]) * 5
-                    with ib.if_scope(tvm.tir.all(i[b] < rpn_post_nms_top_n,
-                                                 p_remove[(b*num_bbox+j)] == False)):
-                        p_out[offset_i] = tvm.tir.Cast('float32', b)
-                        with ib.for_range(0, 4, for_type='unroll') as k:
+                    with ib.if_scope(
+                        tvm.tir.all(
+                            i[b] < rpn_post_nms_top_n, p_remove[(b * num_bbox + j)] == False
+                        )
+                    ):
+                        p_out[offset_i] = tvm.tir.Cast("float32", b)
+                        with ib.for_range(0, 4, for_type="unroll") as k:
                             p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
                         i[b] = i[b] + 1
 
@@ -319,8 +353,19 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     return body
 
 
-def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
-             rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
+def proposal(
+    cls_prob,
+    bbox_pred,
+    im_info,
+    scales,
+    ratios,
+    feature_stride,
+    threshold,
+    rpn_pre_nms_top_n,
+    rpn_post_nms_top_n,
+    rpn_min_size,
+    iou_loss,
+):
     """Proposal operator.
 
     Parameters
@@ -371,20 +416,33 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres
     num_bbox = height * width * num_anchors
     rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
 
-    bbox = te.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
-                     predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
-                                     feature_stride, rpn_min_size, iou_loss),
-                     dtype=bbox_pred.dtype)
-    score = te.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
+    bbox = te.extern(
+        (batch, num_bbox, 5),
+        [cls_prob, bbox_pred, im_info],
+        lambda ins, outs: predict_bbox_ir(
+            ins[0], ins[1], ins[2], outs[0], scales, ratios, feature_stride, rpn_min_size, iou_loss
+        ),
+        dtype=bbox_pred.dtype,
+    )
+    score = te.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag="bbox_score")
     valid_count_shape = (1,)
     valid_count = te.compute(valid_count_shape, lambda i: num_bbox)
     sorted_index = argsort(score, valid_count=valid_count, axis=1, is_ascend=False)
-    sorted_bbox = te.compute((batch, rpn_pre_nms_top_n, 5),
-                             lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
-    nms_remove_mask = te.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
-                                lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
-                                dtype='bool')
-    nms_out = te.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
-                        lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
-                        dtype=sorted_bbox.dtype)
+    sorted_bbox = te.compute(
+        (batch, rpn_pre_nms_top_n, 5),
+        lambda b, i, j: bbox[b, sorted_index[b, i], j],
+        tag="sorted_bbox",
+    )
+    nms_remove_mask = te.extern(
+        (batch, rpn_pre_nms_top_n),
+        [sorted_bbox],
+        lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
+        dtype="bool",
+    )
+    nms_out = te.extern(
+        (batch * rpn_post_nms_top_n, 5),
+        [sorted_bbox, nms_remove_mask],
+        lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
+        dtype=sorted_bbox.dtype,
+    )
     return nms_out
index 9aa1ef9..eafdc21 100644 (file)
@@ -67,7 +67,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
 
     def _sample(i, c, ph, pw):
         roi = rois[i]
-        batch_index = roi[0].astype('int32')
+        batch_index = roi[0].astype("int32")
         roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[3], roi[4]
         roi_start_h *= spatial_scale
         roi_end_h *= spatial_scale
@@ -82,20 +82,27 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
         bin_w = roi_w / pooled_size_w
 
         if sample_ratio > 0:
-            roi_bin_grid_h = roi_bin_grid_w = tvm.tir.const(sample_ratio, 'int32')
+            roi_bin_grid_h = roi_bin_grid_w = tvm.tir.const(sample_ratio, "int32")
         else:
-            roi_bin_grid_h = te.ceil(roi_h / pooled_size_h).astype('int32')
-            roi_bin_grid_w = te.ceil(roi_w / pooled_size_w).astype('int32')
+            roi_bin_grid_h = te.ceil(roi_h / pooled_size_h).astype("int32")
+            roi_bin_grid_w = te.ceil(roi_w / pooled_size_w).astype("int32")
 
         count = roi_bin_grid_h * roi_bin_grid_w
         rh = te.reduce_axis((0, roi_bin_grid_h))
         rw = te.reduce_axis((0, roi_bin_grid_w))
         roi_start_h += ph * bin_h
         roi_start_w += pw * bin_w
-        return te.sum(_bilinear(batch_index, c,
-                                roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h,
-                                roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w) / count,
-                      axis=[rh, rw])
-
-    return te.compute((num_roi, channel, pooled_size_h, pooled_size_w), _sample,
-                      tag='pool,roi_align_nchw')
+        return te.sum(
+            _bilinear(
+                batch_index,
+                c,
+                roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h,
+                roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w,
+            )
+            / count,
+            axis=[rh, rw],
+        )
+
+    return te.compute(
+        (num_roi, channel, pooled_size_h, pooled_size_w), _sample, tag="pool,roi_align_nchw"
+    )
index a206f34..2254b74 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from ...util import get_const_tuple
 
+
 def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
     """ROI pool operator in NCHW layout.
 
@@ -55,27 +56,27 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
 
     def _pool(i, c, ph, pw):
         roi = rois[i]
-        batch_index = roi[0].astype('int32')
+        batch_index = roi[0].astype("int32")
         roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[3], roi[4]
 
-        roi_start_h = te.round(roi_start_h * spatial_scale).astype('int32')
-        roi_start_w = te.round(roi_start_w * spatial_scale).astype('int32')
-        roi_end_h = te.round(roi_end_h * spatial_scale).astype('int32')
-        roi_end_w = te.round(roi_end_w * spatial_scale).astype('int32')
+        roi_start_h = te.round(roi_start_h * spatial_scale).astype("int32")
+        roi_start_w = te.round(roi_start_w * spatial_scale).astype("int32")
+        roi_end_h = te.round(roi_end_h * spatial_scale).astype("int32")
+        roi_end_w = te.round(roi_end_w * spatial_scale).astype("int32")
 
         # force malformed ROIs to be 1x1
-        roi_h = tvm.te.max(roi_end_h - roi_start_h + 1, tvm.tir.const(1, 'int32'))
-        roi_w = tvm.te.max(roi_end_w - roi_start_w + 1, tvm.tir.const(1, 'int32'))
+        roi_h = tvm.te.max(roi_end_h - roi_start_h + 1, tvm.tir.const(1, "int32"))
+        roi_w = tvm.te.max(roi_end_w - roi_start_w + 1, tvm.tir.const(1, "int32"))
 
         bin_h = roi_h.astype(dtype) / pooled_size_h
         bin_w = roi_w.astype(dtype) / pooled_size_w
 
         # use epsilon to prevent floating point precision loss in floor/ceil
         epsilon = tvm.tir.const(0.00001, dtype)
-        hstart = te.floor(ph * bin_h + epsilon).astype('int32')
-        wstart = te.floor(pw * bin_w + epsilon).astype('int32')
-        hend = te.ceil((ph + 1) * bin_h - epsilon).astype('int32')
-        wend = te.ceil((pw + 1) * bin_w - epsilon).astype('int32')
+        hstart = te.floor(ph * bin_h + epsilon).astype("int32")
+        wstart = te.floor(pw * bin_w + epsilon).astype("int32")
+        hend = te.ceil((ph + 1) * bin_h - epsilon).astype("int32")
+        wend = te.ceil((pw + 1) * bin_w - epsilon).astype("int32")
         hstart = tvm.te.min(tvm.te.max(hstart + roi_start_h, 0), height)
         wstart = tvm.te.min(tvm.te.max(wstart + roi_start_w, 0), width)
         hend = tvm.te.min(tvm.te.max(hend + roi_start_h, 0), height)
@@ -83,11 +84,12 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
 
         non_empty = tvm.tir.all(hstart < hend, wstart < wend)
         min_value = lambda dtype: tvm.tir.if_then_else(
-            non_empty, tvm.te.min_value(dtype), tvm.tir.const(0.0, dtype))
+            non_empty, tvm.te.min_value(dtype), tvm.tir.const(0.0, dtype)
+        )
         # pylint: disable=unnecessary-lambda
-        _max = te.comm_reducer(lambda x, y: tvm.te.max(x, y), min_value, name='max')
-        rh = te.reduce_axis((0, hend - hstart), 'rh')
-        rw = te.reduce_axis((0, wend - wstart), 'rw')
-        return _max(data[batch_index, c, hstart+rh, wstart+rw], axis=[rh, rw])
+        _max = te.comm_reducer(lambda x, y: tvm.te.max(x, y), min_value, name="max")
+        rh = te.reduce_axis((0, hend - hstart), "rh")
+        rw = te.reduce_axis((0, wend - wstart), "rw")
+        return _max(data[batch_index, c, hstart + rh, wstart + rw], axis=[rh, rw])
 
     return te.compute((num_roi, channel, pooled_size_h, pooled_size_w), _pool, tag="pool,roi_pool")
index ec790fa..9883085 100644 (file)
@@ -22,6 +22,7 @@ Reorg operator, used in darknet.
 from __future__ import absolute_import as _abs
 from .. import cpp
 
+
 def reorg(data, stride):
     """Reorg forward operators.
 
index 6534503..cbb2c1b 100644 (file)
@@ -25,6 +25,7 @@ from tvm import topi
 
 from ..nms import non_max_suppression
 
+
 @hybrid.script
 def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
     """Hybrid routing for multibox_prior operator.
@@ -75,11 +76,18 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
                     w = float32(sizes[k] * in_height) / in_width / 2.0
                     h = sizes[k] / 2.0
                 else:
-                    w = float32(sizes[0] * in_height) / in_width \
-                        * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
+                    w = (
+                        float32(sizes[0] * in_height)
+                        / in_width
+                        * sqrt(ratios[k - num_sizes + 1] * 1.0)
+                        / 2.0
+                    )
                     h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
-                count = i * in_width * (num_sizes + num_ratios - 1) \
-                    + j * (num_sizes + num_ratios - 1) + k
+                count = (
+                    i * in_width * (num_sizes + num_ratios - 1)
+                    + j * (num_sizes + num_ratios - 1)
+                    + k
+                )
                 output[0, count, 0] = center_w - w
                 output[0, count, 1] = center_h - h
                 output[0, count, 2] = center_w + w
@@ -116,8 +124,13 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
     out : tvm.te.Tensor
         3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
     """
-    out = hybrid_multibox_prior(data, tvm.runtime.convert(sizes), tvm.runtime.convert(ratios),
-                                tvm.runtime.convert(steps), tvm.runtime.convert(offsets))
+    out = hybrid_multibox_prior(
+        data,
+        tvm.runtime.convert(sizes),
+        tvm.runtime.convert(ratios),
+        tvm.runtime.convert(steps),
+        tvm.runtime.convert(offsets),
+    )
     if clip:
         out = topi.clip(out, 0, 1)
     return out
@@ -125,8 +138,7 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
 
 @hybrid.script
 def _hybridy_transform_loc(box, pred_loc, variance, clip):
-    """Transform prior anchor box to output box through location predictions.
-    """
+    """Transform prior anchor box to output box through location predictions."""
     al = box[0]
     at = box[1]
     ar = box[2]
@@ -158,9 +170,9 @@ def _hybridy_transform_loc(box, pred_loc, variance, clip):
     output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh
     return output
 
+
 @hybrid.script
-def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
-                                  clip, threshold, variances):
+def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances):
     """Hybrid routing for transform location in multibox_detection operator.
 
     Parameters
@@ -196,8 +208,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
     num_anchors = cls_prob.shape[2]
     box_coord = allocate((4,), loc_pred.dtype)
     pred_coord = allocate((4,), loc_pred.dtype)
-    out_loc = output_tensor((batch_size, num_anchors, 6),
-                            loc_pred.dtype)
+    out_loc = output_tensor((batch_size, num_anchors, 6), loc_pred.dtype)
     valid_count = output_tensor((batch_size,), "int32")
 
     for i in parallel(batch_size):
@@ -221,8 +232,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
                 for l in range(4):
                     box_coord[l] = anchor[0, j, l]
                     pred_coord[l] = loc_pred[i, j * 4 + l]
-                out_coord = _hybridy_transform_loc(box_coord, pred_coord,
-                                                   variances, clip)
+                out_coord = _hybridy_transform_loc(box_coord, pred_coord, variances, clip)
                 out_loc[i, valid_count[i], 2] = out_coord[0]
                 out_loc[i, valid_count[i], 3] = out_coord[1]
                 out_loc[i, valid_count[i], 4] = out_coord[2]
@@ -231,8 +241,10 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
 
     return out_loc, valid_count
 
-def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
-                           variances=(0.1, 0.1, 0.2, 0.2)):
+
+def multibox_transform_loc(
+    cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
+):
     """Location transformation for multibox detection
 
     Parameters
@@ -259,13 +271,27 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
     -------
     ret : tuple of tvm.te.Tensor
     """
-    return hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
-                                         tvm.tir.const(clip, "bool"),
-                                         tvm.tir.const(threshold, "float32"),
-                                         tvm.runtime.convert(variances))
-
-def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
-                       force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
+    return hybrid_multibox_transform_loc(
+        cls_prob,
+        loc_pred,
+        anchor,
+        tvm.tir.const(clip, "bool"),
+        tvm.tir.const(threshold, "float32"),
+        tvm.runtime.convert(variances),
+    )
+
+
+def multibox_detection(
+    cls_prob,
+    loc_pred,
+    anchor,
+    clip=True,
+    threshold=0.01,
+    nms_threshold=0.5,
+    force_suppress=False,
+    variances=(0.1, 0.1, 0.2, 0.2),
+    nms_topk=-1,
+):
     """Convert multibox detection predictions.
 
     Parameters
@@ -302,9 +328,15 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
     out : tvm.te.Tensor
         3-D tensor with shape (batch_size, num_anchors, 6)
     """
-    inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
-                                       clip, threshold, variances)
-    out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1,
-                              iou_threshold=nms_threshold, force_suppress=force_suppress,
-                              top_k=nms_topk, return_indices=False)
+    inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances)
+    out = non_max_suppression(
+        inter_out[0],
+        inter_out[1],
+        inter_out[1],
+        max_output_size=-1,
+        iou_threshold=nms_threshold,
+        force_suppress=force_suppress,
+        top_k=nms_topk,
+        return_indices=False,
+    )
     return out
index 539a918..333d3be 100644 (file)
@@ -42,8 +42,7 @@ def batch_matmul(cfg, x, y):
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
-    assert len(x.shape) == 3 and len(
-        y.shape) == 3, "only support 3-dim batch_matmul"
+    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
     XB, M, XK = get_const_tuple(x.shape)
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
@@ -53,11 +52,10 @@ def batch_matmul(cfg, x, y):
     if cfg.is_fallback:
         _default_batch_matmul_config(cfg, M, N, K)
 
-    k = te.reduce_axis((0, K), name='k')
+    k = te.reduce_axis((0, K), name="k")
     C = te.compute(
-        (B, M, N),
-        lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k),
-        tag='batch_matmul')
+        (B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
+    )
     return C
 
 
@@ -108,7 +106,7 @@ def schedule_batch_matmul(cfg, outs):
             s[O].parallel(bxyo)
 
             s[CC].compute_at(s[O], bxyo)
-            k, = s[CC].op.reduce_axis
+            (k,) = s[CC].op.reduce_axis
             ko, ki = cfg["tile_k"].apply(s, CC, k)
 
             Crf = s.rfactor(CC, ki)
@@ -116,7 +114,7 @@ def schedule_batch_matmul(cfg, outs):
             _, _, y, x = s[Crf].op.axis
             s[Crf].fuse(y, x)
             s[Crf].vectorize(s[Crf].op.axis[0])
-            s[O].pragma(bxyo, 'auto_unroll_max_step', 16)
+            s[O].pragma(bxyo, "auto_unroll_max_step", 16)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
@@ -148,8 +146,7 @@ def batch_matmul_cblas(cfg, x, y):
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
-    assert len(x.shape) == 3 and len(
-        y.shape) == 3, "only support 3-dim batch_matmul"
+    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
     XB, M, XK = get_const_tuple(x.shape)
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
index b4a01a5..34fcbfb 100644 (file)
@@ -41,7 +41,7 @@ def schedule_binarize_pack(outs):
 
     def traverse(OP):
         # schedule binarize_pack
-        if OP.tag == 'binarize_pack':
+        if OP.tag == "binarize_pack":
             Out = OP.output(0)
             _schedule(Out)
         else:
index d90694e..be02cb9 100644 (file)
@@ -58,7 +58,7 @@ def schedule_binary_dense(outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule binary_dense
-        elif OP.tag == 'binary_dense':
+        elif OP.tag == "binary_dense":
             output = OP.output(0)
             data = OP.input_tensors[0]
             weight = OP.input_tensors[1]
index 37fe352..5fcc9e1 100644 (file)
@@ -25,9 +25,20 @@ from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 
+
 @autotvm.register_topi_compute("bitserial_conv2d_nchw.x86")
-def bitserial_conv2d_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
-                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+def bitserial_conv2d_nchw(
+    cfg,
+    data,
+    kernel,
+    stride,
+    padding,
+    in_bits,
+    weight_bits,
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     """ Compute convolution with pack on spatial axes. """
     assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
     data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
@@ -54,7 +65,7 @@ def bitserial_conv2d_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bi
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    HCAT, WCAT = KH-1, KW-1
+    HCAT, WCAT = KH - 1, KW - 1
 
     TH = H + TPAD + DPAD
     TW = W + LPAD + RPAD
@@ -66,17 +77,17 @@ def bitserial_conv2d_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bi
     ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
     ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
 
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
-
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
-                       policy='interval_all', interval=(6, 11))
+    co, vc = cfg.define_split("tile_co", co, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    cfg.define_annotate("ann_reduce", [ib, kb, kh, kw], policy="try_unroll")
+
+    cfg.define_reorder(
+        "reorder_0",
+        [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
+        policy="interval_all",
+        interval=(6, 11),
+    )
     # binary ops
     cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
     # ====================
@@ -85,59 +96,91 @@ def bitserial_conv2d_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bi
     VH = cfg["tile_oh"].size[-1]
     VW = cfg["tile_ow"].size[-1]
 
-    dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
-    kvshape = (CO//VC, CI, KH, KW, KB, VC)
-    ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
+    dvshape = (1, TH // (VH * HSTR), TW // (VW * WSTR), CI, VH * HSTR + HCAT, VW * WSTR + WCAT, IB)
+    kvshape = (CO // VC, CI, KH, KW, KB, VC)
+    ovshape = (1, CO // VC, OH // VH, OW // VW, VH, VW, VC)
     oshape = (1, CO, OH, OW)
 
-    if (TPAD != 0 and RPAD != 0):
+    if TPAD != 0 and RPAD != 0:
         data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
     else:
         data_pad = data_q
 
-    data_vec = te.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
-                          data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
+    data_vec = te.compute(
+        dvshape,
+        lambda n, h, w, ci, vh, vw, b: data_pad[b][n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw],
+        name="data_vec",
+    )
 
     if len(kernel.shape) == 4:
-        kernel_vec = te.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
-                                kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
-
-    ci = te.reduce_axis((0, CI), name='ci')
-    dh = te.reduce_axis((0, KH), name='dh')
-    dw = te.reduce_axis((0, KW), name='dw')
-    b1 = te.reduce_axis((0, IB), name='ib')
-    b2 = te.reduce_axis((0, KB), name='kb')
+        kernel_vec = te.compute(
+            kvshape,
+            lambda co, ci, dh, dw, b, vc: kernel_q[b][co * VC + vc][ci][dh][dw],
+            name="kernel_vec",
+        )
+
+    ci = te.reduce_axis((0, CI), name="ci")
+    dh = te.reduce_axis((0, KH), name="dh")
+    dw = te.reduce_axis((0, KW), name="dw")
+    b1 = te.reduce_axis((0, IB), name="ib")
+    b2 = te.reduce_axis((0, KB), name="kb")
 
     def _conv(n, co, h, w, vh, vw, vc):
-        b1b2 = (b1+b2).astype(out_dtype)
+        b1b2 = (b1 + b2).astype(out_dtype)
         if unipolar:
-            return te.sum((tvm.tir.popcount(
-                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
-                kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype))  -
-                           tvm.tir.popcount(
-                               data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
-                               & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
-                          axis=[ci, dh, dw, b1, b2])
-
-        return te.sum((tvm.tir.popcount(
-            data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
-            kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
-                      axis=[ci, dh, dw, b1, b2])
-
-    conv = te.compute(ovshape, _conv, name='conv_out')
+            return te.sum(
+                (
+                    tvm.tir.popcount(
+                        data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1].astype(out_dtype)
+                        & kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)
+                    )
+                    - tvm.tir.popcount(
+                        data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1].astype(out_dtype)
+                        & ~kernel_vec[co, ci, dh, dw, b2, vc]
+                    ).astype(out_dtype)
+                )
+                << b1b2,
+                axis=[ci, dh, dw, b1, b2],
+            )
+
+        return te.sum(
+            (
+                tvm.tir.popcount(
+                    data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1]
+                    & kernel_vec[co, ci, dh, dw, b2, vc]
+                )
+            ).astype(out_dtype)
+            << b1b2,
+            axis=[ci, dh, dw, b1, b2],
+        )
+
+    conv = te.compute(ovshape, _conv, name="conv_out")
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
 
     return te.compute(
-        oshape, lambda n, co, h, w:
-        conv[n,
-             idxd(co, VC), idxd(h, VH), idxd(w, VW),
-             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
-        name='conv_vec', tag='spatial_bitserial_conv_nchw')
+        oshape,
+        lambda n, co, h, w: conv[
+            n, idxd(co, VC), idxd(h, VH), idxd(w, VW), idxm(h, VH), idxm(w, VW), idxm(co, VC)
+        ],
+        name="conv_vec",
+        tag="spatial_bitserial_conv_nchw",
+    )
+
 
 @autotvm.register_topi_compute("bitserial_conv2d_nhwc.x86")
-def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
-                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+def bitserial_conv2d_nhwc(
+    cfg,
+    data,
+    kernel,
+    stride,
+    padding,
+    in_bits,
+    weight_bits,
+    pack_dtype="uint32",
+    out_dtype="int16",
+    unipolar=True,
+):
     """ Compute convolution with pack on spatial axes. """
     assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
     data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
@@ -162,7 +205,7 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bi
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    HCAT, WCAT = KH-1, KW-1
+    HCAT, WCAT = KH - 1, KW - 1
 
     PAD_H = H + (TPAD + DPAD)
     PAD_W = W + (LPAD + RPAD)
@@ -175,16 +218,16 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bi
     ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
     ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
 
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
-    cfg.define_reorder("reorder_0",
-                       [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
-                       policy='interval_all', interval=(3, 7))
+    co, vc = cfg.define_split("tile_co", co, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16)
+    cfg.define_annotate("ann_reduce", [ib, kb, kh, kw], policy="try_unroll")
+    cfg.define_reorder(
+        "reorder_0",
+        [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
+        policy="interval_all",
+        interval=(3, 7),
+    )
     # binary ops
     cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
     # ====================
@@ -193,62 +236,95 @@ def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bi
     VH = cfg["tile_oh"].size[-1]
     VW = cfg["tile_ow"].size[-1]
 
-    dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
+    dvshape = (
+        1,
+        PAD_H // (VH * HSTR),
+        PAD_W // (VW * WSTR),
+        VH * HSTR + HCAT,
+        VW * WSTR + WCAT,
+        CI,
+        IB,
+    )
     kvshape = (CO, KH, KW, CI, VC, KB)
     ovshape = (1, OH, OW, CO, VH, VW, VC)
     oshape = (1, OH, OW, CO)
 
-    if (DPAD != 0 and RPAD != 0):
+    if DPAD != 0 and RPAD != 0:
         data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
     else:
         data_pad = data_q
 
-    data_vec = te.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \
-                          data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec')
+    data_vec = te.compute(
+        dvshape,
+        lambda n, h, w, vh, vw, ci, b: data_pad[n][h * VH * HSTR + vh][w * VW * WSTR + vw][ci][b],
+        name="data_vec",
+    )
 
-    kernel_vec = te.compute(kvshape, lambda co, dh, dw, ci, vc, b: \
-                            kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec')
+    kernel_vec = te.compute(
+        kvshape,
+        lambda co, dh, dw, ci, vc, b: kernel_q[dh][dw][ci][co * VC + vc][b],
+        name="kernel_vec",
+    )
 
-    ci = te.reduce_axis((0, CI), name='ci')
-    dh = te.reduce_axis((0, KH), name='dh')
-    dw = te.reduce_axis((0, KW), name='dw')
-    b1 = te.reduce_axis((0, IB), name='ib')
-    b2 = te.reduce_axis((0, KB), name='kb')
+    ci = te.reduce_axis((0, CI), name="ci")
+    dh = te.reduce_axis((0, KH), name="dh")
+    dw = te.reduce_axis((0, KW), name="dw")
+    b1 = te.reduce_axis((0, IB), name="ib")
+    b2 = te.reduce_axis((0, KB), name="kb")
 
     def _conv(n, h, w, co, vh, vw, vc):
-        b1b2 = (b1+b2).astype(out_dtype)
+        b1b2 = (b1 + b2).astype(out_dtype)
         if unipolar:
             return te.sum(
-                ((tvm.tir.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
-                                   kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
-                  tvm.tir.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
-                                   ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
-                axis=[dh, dw, ci, b1, b2])
-
-        return te.sum(tvm.tir.popcount(
-            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
-            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
-                      axis=[dh, dw, ci, b1, b2])
-
-    conv = te.compute(ovshape, _conv, name='conv')
+                (
+                    (
+                        tvm.tir.popcount(
+                            data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ci, b1]
+                            & kernel_vec[co, dh, dw, ci, vc, b2]
+                        ).astype(out_dtype)
+                        - tvm.tir.popcount(
+                            data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ci, b1]
+                            & ~kernel_vec[co, dh, dw, ci, vc, b2]
+                        ).astype(out_dtype)
+                    )
+                    << b1b2
+                ),
+                axis=[dh, dw, ci, b1, b2],
+            )
+
+        return te.sum(
+            tvm.tir.popcount(
+                data_vec[n, h, w, vh * HSTR + dh, vw * WSTR + dw, ci, b1]
+                & kernel_vec[co, dh, dw, ci, vc, b2]
+            ).astype(out_dtype)
+            << b1b2,
+            axis=[dh, dw, ci, b1, b2],
+        )
+
+    conv = te.compute(ovshape, _conv, name="conv")
 
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
     return te.compute(
-        oshape, lambda n, h, w, co:
-        conv[n,
-             idxd(h, VH), idxd(w, VW), idxd(co, VC),
-             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
-        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
+        oshape,
+        lambda n, h, w, co: conv[
+            n, idxd(h, VH), idxd(w, VW), idxd(co, VC), idxm(h, VH), idxm(w, VW), idxm(co, VC)
+        ],
+        name="output_unpack",
+        tag="spatial_bitserial_conv_nhwc",
+    )
+
 
 @autotvm.register_topi_schedule("bitserial_conv2d_nchw.x86")
 def schedule_bitserial_conv2d_nchw(cfg, outs):
     return _schedule_bitserial_conv2d(cfg, outs)
 
+
 @autotvm.register_topi_schedule("bitserial_conv2d_nhwc.x86")
 def schedule_bitserial_conv2d_nhwc(cfg, outs):
     return _schedule_bitserial_conv2d(cfg, outs)
 
+
 def _schedule_bitserial_conv2d(cfg, outs):
     """CPU schedule for bitserial convolutions NCHW and NHWC"""
     s = te.create_schedule([x.op for x in outs])
@@ -258,7 +334,7 @@ def _schedule_bitserial_conv2d(cfg, outs):
         """Traverse operators from computation graph"""
         output = op.output(0)
         # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
+        if tag.is_broadcast(op.tag) or "elemwise" in op.tag:
             if op not in s.outputs:
                 s[op].compute_inline()
             for tensor in op.input_tensors:
@@ -266,7 +342,7 @@ def _schedule_bitserial_conv2d(cfg, outs):
                     if isinstance(tensor.op, tvm.te.ComputeOp):
                         traverse(tensor.op)
 
-        elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag:
+        elif "spatial_bitserial_conv_nchw" in op.tag or "spatial_bitserial_conv_nhwc" in op.tag:
             conv_out = op.input_tensors[0]
             kernel_vec = conv_out.op.input_tensors[1]
             kernel_q = kernel_vec.op.input_tensors[0]
@@ -283,22 +359,41 @@ def _schedule_bitserial_conv2d(cfg, outs):
                 # Need to go up 1 further, from the combine in bitpack
                 data = data.op.input_tensors[0]
 
-            if 'spatial_bitserial_conv_nchw' in op.tag:
-                _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
-                                                kernel_q, kernel_vec,
-                                                conv_out, output, outs[0])
-            elif 'spatial_bitserial_conv_nhwc' in op.tag:
-                _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
-                                                kernel_q, kernel_vec,
-                                                conv_out, output, outs[0])
+            if "spatial_bitserial_conv_nchw" in op.tag:
+                _schedule_bitserial_conv2d_nchw(
+                    cfg,
+                    s,
+                    data_q,
+                    data_pad,
+                    data_vec,
+                    kernel_q,
+                    kernel_vec,
+                    conv_out,
+                    output,
+                    outs[0],
+                )
+            elif "spatial_bitserial_conv_nhwc" in op.tag:
+                _schedule_bitserial_conv2d_nhwc(
+                    cfg,
+                    s,
+                    data_q,
+                    data_pad,
+                    data_vec,
+                    kernel_q,
+                    kernel_vec,
+                    conv_out,
+                    output,
+                    outs[0],
+                )
         scheduled_ops.append(op)
 
     traverse(outs[0].op)
     return s
 
-def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
-                                    kernel_q, kernel_vec,
-                                    conv_out, output, last):
+
+def _schedule_bitserial_conv2d_nchw(
+    cfg, s, data_q, data_pad, data_vec, kernel_q, kernel_vec, conv_out, output, last
+):
     IB, _, CI, IH, IW = data_q.shape
     KB, CO, _, KH, KW = kernel_q.shape
     _, _, OH, OW = output.shape
@@ -340,7 +435,6 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
     s[data_vec].pragma(paxis, "parallel_stride_pattern")
     s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
 
-
     ##### Schedule Kenerl bitpacking
     co, _, _, _, _, _ = s[kernel_vec].op.axis
     cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
@@ -357,20 +451,25 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
     s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
     s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
 
-
-   ##### Schedule Convolution
+    ##### Schedule Convolution
     n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
     ci, dh, dw, ib, kb = s[conv_out].op.reduce_axis
 
     # s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
     cfg["reorder_0"].apply(s, conv_out, [n, co, oh, ow, vc, vh, vw, dh, dw, kb, ib, ci])
-    cfg["ann_reduce"].apply(s, conv_out, [kb, ib, dh, dw],
-                            axis_lens=[get_const_int(kb.dom.extent),
-                                       get_const_int(ib.dom.extent),
-                                       get_const_int(dh.dom.extent),
-                                       get_const_int(dw.dom.extent)],
-                            max_unroll=16,
-                            cfg=cfg)
+    cfg["ann_reduce"].apply(
+        s,
+        conv_out,
+        [kb, ib, dh, dw],
+        axis_lens=[
+            get_const_int(kb.dom.extent),
+            get_const_int(ib.dom.extent),
+            get_const_int(dh.dom.extent),
+            get_const_int(dw.dom.extent),
+        ],
+        max_unroll=16,
+        cfg=cfg,
+    )
 
     s[conv_out].vectorize(vc)
 
@@ -395,9 +494,10 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
     s[last].parallel(oco)
     return s
 
-def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
-                                    kernel_q, kernel_vec,
-                                    conv_out, output, last):
+
+def _schedule_bitserial_conv2d_nhwc(
+    cfg, s, data_q, data_pad, data_vec, kernel_q, kernel_vec, conv_out, output, last
+):
     # no stride and padding info here
     _, IH, IW, CI, IB = data_q.shape
     KH, KW, _, CO, KB = kernel_q.shape
@@ -428,13 +528,19 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
 
     # s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
     cfg["reorder_0"].apply(s, conv_out, [n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2])
-    cfg["ann_reduce"].apply(s, conv_out, [b1, b2, dh, dw],
-                            axis_lens=[get_const_int(b1.dom.extent),
-                                       get_const_int(b2.dom.extent),
-                                       get_const_int(dh.dom.extent),
-                                       get_const_int(dw.dom.extent)],
-                            max_unroll=16,
-                            cfg=cfg)
+    cfg["ann_reduce"].apply(
+        s,
+        conv_out,
+        [b1, b2, dh, dw],
+        axis_lens=[
+            get_const_int(b1.dom.extent),
+            get_const_int(b2.dom.extent),
+            get_const_int(dh.dom.extent),
+            get_const_int(dw.dom.extent),
+        ],
+        max_unroll=16,
+        cfg=cfg,
+    )
 
     s[conv_out].unroll(b1)
     s[conv_out].unroll(b2)
index 8d5736b..e9546ac 100644 (file)
@@ -24,9 +24,11 @@ from tvm.topi.util import get_const_int, get_const_tuple
 from .. import tag
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 
-@autotvm.register_topi_compute('bitserial_dense.x86')
-def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
-                    out_dtype='int16', unipolar=True):
+
+@autotvm.register_topi_compute("bitserial_dense.x86")
+def bitserial_dense(
+    cfg, data, weight, data_bits, weight_bits, pack_dtype="uint32", out_dtype="int16", unipolar=True
+):
     """Bitserial dense implementation. TODO: Why are these separate
 
     Parameters
@@ -51,45 +53,66 @@ def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint3
     ######## Search space
     x, y = cfg.axis(X), cfg.axis(Y)
     db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
-    ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
-    yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
-    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
+    ko, ki = cfg.define_split("tile_k", k, num_outputs=2)
+    yo, yi = cfg.define_split("tile_y", y, num_outputs=2)
+    xo, xi = cfg.define_split("tile_x", x, num_outputs=2)
 
-    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
-                       policy='candidate', candidate=[
-                           [yo, xo, ko, yi, wb, db, ki, xi],
-                           [yo, xo, yi, ko, wb, db, ki, xi]])
+    cfg.define_reorder(
+        "reorder_0",
+        [yo, xo, ko, yi, wb, db, ki, xi],
+        policy="candidate",
+        candidate=[[yo, xo, ko, yi, wb, db, ki, xi], [yo, xo, yi, ko, wb, db, ki, xi]],
+    )
 
-    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
-    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')
+    cfg.define_annotate("ann_reduce", [db, wb], policy="try_unroll")
+    cfg.define_annotate("ann_spatial", [yi, xi], policy="try_unroll_vec")
 
     ###### Compute rule
-    VX = cfg['tile_x'].size[-1]
+    VX = cfg["tile_x"].size[-1]
 
-    wvshape = (X//VX, WB, VX, K)
+    wvshape = (X // VX, WB, VX, K)
     oshape = (Y, X)
 
-    k = te.reduce_axis((0, K), name='k')
-    db = te.reduce_axis((0, DB), name='db')
-    wb = te.reduce_axis((0, WB), name='wb')
+    k = te.reduce_axis((0, K), name="k")
+    db = te.reduce_axis((0, DB), name="db")
+    wb = te.reduce_axis((0, WB), name="wb")
 
     # Tile data and weights
-    weight_vec = te.compute(wvshape, lambda xo, wb, vx, k:
-                            weight_packed[xo*VX+vx][wb][k], name='weight_vec')
+    weight_vec = te.compute(
+        wvshape, lambda xo, wb, vx, k: weight_packed[xo * VX + vx][wb][k], name="weight_vec"
+    )
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    matmul_unipolar = te.compute(oshape, lambda i, j: te.sum(
-        (tvm.tir.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
-         tvm.tir.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
-         ).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
-
-    matmul = te.compute(oshape, lambda i, j: te.sum(
-        tvm.tir.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
-                         ).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
+    matmul_unipolar = te.compute(
+        oshape,
+        lambda i, j: te.sum(
+            (
+                tvm.tir.popcount(
+                    weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
+                )
+                - tvm.tir.popcount(
+                    ~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
+                )
+            ).astype(out_dtype)
+            << (db + wb).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense_unipolar",
+    )
+
+    matmul = te.compute(
+        oshape,
+        lambda i, j: te.sum(
+            tvm.tir.popcount(
+                weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
+            ).astype(out_dtype)
+            << (db + wb).astype(out_dtype),
+            axis=[wb, db, k],
+        ),
+        tag="bitserial_dense",
+    )
 
     # binary ops
     cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
@@ -98,7 +121,8 @@ def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint3
         return matmul_unipolar
     return matmul
 
-@autotvm.register_topi_schedule('biserial_dense.x86')
+
+@autotvm.register_topi_schedule("biserial_dense.x86")
 def schedule_bitserial_dense(cfg, outs):
     """Schedule for bitserial_dense.
 
@@ -127,18 +151,23 @@ def schedule_bitserial_dense(cfg, outs):
         xo, xi = cfg["tile_x"].apply(s, output, x)
         ko, ki = cfg["tile_k"].apply(s, output, k)
 
-
         cfg["reorder_0"].apply(s, output, [yo, xo, ko, yi, wb, db, ki, xi])
-        cfg["ann_reduce"].apply(s, output, [db, wb],
-                                axis_lens=[get_const_int(db.dom.extent),
-                                           get_const_int(wb.dom.extent)],
-                                max_unroll=8,
-                                cfg=cfg)
-        cfg["ann_spatial"].apply(s, output, [yi, xi],
-                                 axis_lens=[cfg['tile_y'].size[-1],
-                                            cfg['tile_x'].size[-1]],
-                                 max_unroll=8,
-                                 cfg=cfg)
+        cfg["ann_reduce"].apply(
+            s,
+            output,
+            [db, wb],
+            axis_lens=[get_const_int(db.dom.extent), get_const_int(wb.dom.extent)],
+            max_unroll=8,
+            cfg=cfg,
+        )
+        cfg["ann_spatial"].apply(
+            s,
+            output,
+            [yi, xi],
+            axis_lens=[cfg["tile_y"].size[-1], cfg["tile_x"].size[-1]],
+            max_unroll=8,
+            cfg=cfg,
+        )
         s[output].vectorize(xi)
         s[output].parallel(yo)
         return s
@@ -146,14 +175,14 @@ def schedule_bitserial_dense(cfg, outs):
     def traverse(op):
         """Internal traverse function"""
         # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
+        if tag.is_broadcast(op.tag) or "elemwise" in op.tag:
             if op not in s.outputs:
                 s[op].compute_inline()
             for tensor in op.input_tensors:
                 if isinstance(tensor.op, tvm.te.ComputeOp):
                     traverse(tensor.op)
 
-        elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
+        elif op.tag == "bitserial_dense" or "bitserial_dense_unipolar":
             output = op.output(0)
             weight_vec = op.input_tensors[0]
 
index 1e30c9f..121c1c2 100644 (file)
@@ -32,8 +32,8 @@ def schedule_conv1d_ncw(outs):
         if tag.is_broadcast(op.tag):
             if op not in s.outputs:
                 s[op].compute_inline()
-            else: # inject custom schedule
-                if len(op.axis) == 3: # schedule bias + bn + relu
+            else:  # inject custom schedule
+                if len(op.axis) == 3:  # schedule bias + bn + relu
                     n, c, w = op.axis
                     fused = s[op].fuse(n, c)
                     s[op].parallel(fused)
@@ -42,7 +42,7 @@ def schedule_conv1d_ncw(outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
 
-        if 'conv1d_ncw' in op.tag:
+        if "conv1d_ncw" in op.tag:
             conv = op.output(0)
             kernel = op.input_tensors[1]
             if isinstance(kernel.op, te.tensor.ComputeOp) and "dilate" in kernel.op.tag:
@@ -62,7 +62,7 @@ def schedule_conv1d_ncw(outs):
             rc, rw = C.op.reduce_axis
             n_out, c_out, w_out = output_op.axis
             s[C].vectorize(w)
-            if op != output_op: # fuse bias + bn + relu into conv
+            if op != output_op:  # fuse bias + bn + relu into conv
                 s[C].compute_at(s[output_op], w_out)
             else:
                 fused = s[C].fuse(n, c)
@@ -86,8 +86,8 @@ def schedule_conv1d_nwc(outs):
         if tag.is_broadcast(op.tag):
             if op not in s.outputs:
                 s[op].compute_inline()
-            else: # inject custom schedule
-                if len(op.axis) == 3: # schedule bias + bn + relu
+            else:  # inject custom schedule
+                if len(op.axis) == 3:  # schedule bias + bn + relu
                     n, w, c = op.axis
                     fused = s[op].fuse(n, w)
                     s[op].parallel(fused)
@@ -96,7 +96,7 @@ def schedule_conv1d_nwc(outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
 
-        if 'conv1d_nwc' in op.tag:
+        if "conv1d_nwc" in op.tag:
             conv = op.output(0)
             kernel = op.input_tensors[1]
             if isinstance(kernel.op, te.tensor.ComputeOp) and "dilate" in kernel.op.tag:
@@ -116,7 +116,7 @@ def schedule_conv1d_nwc(outs):
             rc, rw = C.op.reduce_axis
             n_out, w_out, c_out = output_op.axis
             s[C].vectorize(c)
-            if op != output_op: # fuse bias + bn + relu into conv
+            if op != output_op:  # fuse bias + bn + relu into conv
                 s[C].compute_at(s[output_op], c_out)
             else:
                 fused = s[C].fuse(n, w)
index 610369d..47fb48e 100644 (file)
@@ -31,10 +31,12 @@ from ..nn.util import get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
 from . import conv2d_avx_1x1, conv2d_avx_common
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
 
-def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
-                        layout='NCHW'):
+
+def _get_default_config(
+    cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
+):
     """
     Get default schedule config for the workload
     """
@@ -48,6 +50,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
     if is_depthwise:
         wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
         from .depthwise_conv2d import _fallback_schedule
+
         _fallback_schedule(cfg, wkl)
     else:
         wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
@@ -57,6 +60,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
         else:
             conv2d_avx_common._fallback_schedule(cfg, wkl)
 
+
 @conv2d_infer_layout.register("cpu")
 def _conv2d_infer_layout(workload, cfg):
     _, data, kernel, strides, padding, dilation, layout, _, dtype = workload
@@ -74,6 +78,7 @@ def _conv2d_infer_layout(workload, cfg):
     out_layout = "NCHW%dc" % tile_oc
     return ((in_shape, in_layout),), ((out_shape, out_layout),)
 
+
 def schedule_conv2d_nhwc(outs):
     """Create schedule for conv2d_nhwc"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
@@ -81,7 +86,7 @@ def schedule_conv2d_nhwc(outs):
     output_op = outs[0].op
 
     def _callback(op):
-        if 'conv2d_nhwc' in op.tag:
+        if "conv2d_nhwc" in op.tag:
             conv = op.output(0)
             kernel = op.input_tensors[1]
             if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
@@ -101,7 +106,7 @@ def schedule_conv2d_nhwc(outs):
             s[C].vectorize(c)
 
             O = output_op.output(0)
-            if len(O.op.axis) == 4: # schedule bias + bn + relu
+            if len(O.op.axis) == 4:  # schedule bias + bn + relu
                 n, h, w, c = O.op.axis
                 fused = s[O].fuse(n, h, w)
                 s[O].parallel(fused)
@@ -115,16 +120,18 @@ def schedule_conv2d_nhwc(outs):
     traverse_inline(s, output_op, _callback)
     return s
 
+
 def conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
     layout = "NCHW"
-    packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation,
-                              layout, layout, out_dtype)
+    packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation, layout, layout, out_dtype)
     return unpack_NCHWc_to_nchw(packed_out, out_dtype)
 
+
 def schedule_conv2d_nchw(outs):
     """Create schedule for tensors"""
     return schedule_conv2d_NCHWc(outs)
 
+
 def _pack_data(cfg, data, kernel):
     n, _, ih, iw = get_const_tuple(data.shape)
     oc, ic, kh, kw = get_const_tuple(kernel.shape)
@@ -133,18 +140,21 @@ def _pack_data(cfg, data, kernel):
     ic_chunk = ic // ic_bn
     oc_chunk = oc // oc_bn
 
-    data = te.compute((n, ic_chunk, ih, iw, ic_bn),
-                      lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
-                      name="data_vec")
+    data = te.compute(
+        (n, ic_chunk, ih, iw, ic_bn),
+        lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
+        name="data_vec",
+    )
 
     kernel = te.compute(
         (oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn),
-        lambda occ, icc, k_h, k_w, icb, ocb:
-        kernel[occ * oc_bn + ocb, icc * ic_bn + icb, k_h, k_w],
-        name="kernel_vec")
+        lambda occ, icc, k_h, k_w, icb, ocb: kernel[occ * oc_bn + ocb, icc * ic_bn + icb, k_h, k_w],
+        name="kernel_vec",
+    )
 
     return data, kernel
 
+
 @autotvm.register_topi_compute("conv2d_NCHWc.x86")
 def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
     """Compute conv2d with NCHWc layout."""
@@ -152,8 +162,9 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
     # we keep them for debug convenience when dumping autotvm workload
     if len(data.shape) == 5:
         n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
-        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
-            get_const_tuple(kernel.shape)
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(
+            kernel.shape
+        )
         in_channel = ic_chunk * ic_bn
         num_filter = oc_chunk * oc_bn
     else:
@@ -169,8 +180,9 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
 
     cfg.define_split("tile_ic", in_channel, num_outputs=2)
     cfg.define_split("tile_oc", num_filter, num_outputs=2)
-    cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64,
-                     policy="verbose")
+    cfg.define_split(
+        "tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64, policy="verbose"
+    )
     if is_kernel_1x1:
         cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
     else:
@@ -178,36 +190,38 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
 
     # If no config was set, we can fallback to default config.
     if cfg.is_fallback:
-        _get_default_config(cfg, te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
-                            te.placeholder((num_filter, in_channel, kernel_height, kernel_width),
-                                           dtype=kernel.dtype),
-                            strides, padding, out_dtype)
+        _get_default_config(
+            cfg,
+            te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+            te.placeholder(
+                (num_filter, in_channel, kernel_height, kernel_width), dtype=kernel.dtype
+            ),
+            strides,
+            padding,
+            out_dtype,
+        )
 
     # Pack data if raw 4-D data is provided.
     # This can only happen when autotuning.
     if len(data.shape) == 4:
         if autotvm.GLOBAL_SCOPE.in_tuning:
             # Directly use modified data layout placeholder.
-            dshape = (n, in_channel // cfg["tile_ic"].size[-1],
-                      ih, iw, cfg["tile_ic"].size[-1])
+            dshape = (n, in_channel // cfg["tile_ic"].size[-1], ih, iw, cfg["tile_ic"].size[-1])
             data = tvm.te.placeholder(dshape, data.dtype, name="data")
-            kshape = (num_filter // cfg["tile_oc"].size[-1],
-                      in_channel // cfg["tile_ic"].size[-1],
-                      kernel_height, kernel_width,
-                      cfg["tile_ic"].size[-1],
-                      cfg["tile_oc"].size[-1])
+            kshape = (
+                num_filter // cfg["tile_oc"].size[-1],
+                in_channel // cfg["tile_ic"].size[-1],
+                kernel_height,
+                kernel_width,
+                cfg["tile_ic"].size[-1],
+                cfg["tile_oc"].size[-1],
+            )
             kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel")
         else:
             data, kernel = _pack_data(cfg, data, kernel)
 
-    return nn.conv2d_NCHWc(data,
-                           kernel,
-                           strides,
-                           padding,
-                           dilation,
-                           layout,
-                           out_layout,
-                           out_dtype)
+    return nn.conv2d_NCHWc(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype)
+
 
 @autotvm.register_topi_schedule("conv2d_NCHWc.x86")
 def schedule_conv2d_NCHWc(cfg, outs):
@@ -216,13 +230,20 @@ def schedule_conv2d_NCHWc(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_NCHWc' in op.tag:
+        if "conv2d_NCHWc" in op.tag:
             conv_out = op.output(0)
             kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
 
             args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
-            _, _, kh, kw, _, _, = get_const_tuple(kernel_vec.shape)
+            (
+                _,
+                _,
+                kh,
+                kw,
+                _,
+                _,
+            ) = get_const_tuple(kernel_vec.shape)
             if kh == 1 and kw == 1:
                 conv2d_avx_1x1._schedule_conv_NCHWc(*args)
             else:
index 992353e..1c90841 100644 (file)
@@ -30,11 +30,12 @@ from ..util import get_const_tuple
 from ..nn import conv2d_legalize, conv2d_alter_layout
 from ..nn.util import get_pad_tuple
 
-logger = logging.getLogger('topi')
+logger = logging.getLogger("topi")
 
 _NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
 _OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")
 
+
 @conv2d_alter_layout.register("cpu")
 def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     target = tvm.target.Target.current(allow_none=False)
@@ -44,7 +45,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         workload = cfg.workload
     else:
         _, outs = relay.backend.compile_engine.select_implementation(
-            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+        )
         workload = autotvm.task.get_workload(outs)
         if workload is None:
             # The best implementation is not an AutoTVM template,
@@ -53,7 +55,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         cfg = dispatch_ctx.query(target, workload)
 
     topi_tmpl = workload[0]
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
 
     # Parse the attributes.
     padding = attrs.get_int_tuple("padding")
@@ -70,27 +72,41 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # we only convert conv2d_NCHW to conv2d_NCHWc for x86
         if data_layout == "NCHW" and kernel_layout == "OIHW":
             if cfg.is_fallback:
-                _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                    out_dtype, False, data_layout)
+                _get_default_config(
+                    cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
+                )
             batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
             out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
             ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
 
             # update new attrs
-            new_attrs['channels'] = out_channel
-            new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+            new_attrs["channels"] = out_channel
+            new_attrs["data_layout"] = "NCHW%dc" % ic_bn
             # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-            new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
-            new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+            new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
+            new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
             # Store altered operator's config
-            new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                      dtype=data_dtype)
-            new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
-                                         kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
+            new_data = te.placeholder(
+                (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+            )
+            new_kernel = te.placeholder(
+                (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn, oc_bn),
+                dtype=kernel_tensor.dtype,
+            )
             new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
-                 new_attrs["out_layout"], out_dtype], topi_tmpl)
+                [
+                    new_data,
+                    new_kernel,
+                    strides,
+                    padding,
+                    dilation,
+                    new_attrs["data_layout"],
+                    new_attrs["out_layout"],
+                    out_dtype,
+                ],
+                topi_tmpl,
+            )
             dispatch_ctx.update(target, new_workload, cfg)
         else:
             assert _NCHWc_matcher.match(data_layout)
@@ -101,8 +117,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # TODO(@icemelon9, @anijain2305): Need to support data layout NHWC with kernel layout HWIO
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         if cfg.is_fallback:
-            _get_default_config_int8(cfg, data_tensor, kernel_tensor, strides, padding,
-                                     out_dtype, False, data_layout)
+            _get_default_config_int8(
+                cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
+            )
 
         batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
         out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
@@ -112,32 +129,43 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         # convert kernel data layout from 4D to 7D
         data_expr, kernel_expr = inputs
         kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
-        kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
+        kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn))
         kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
-        kernel_OHWoIi = relay.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
-                                                     in_channel//ic_bn, ic_bn))
-        kernel_OHWoIie = relay.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
-                                                       in_channel//ic_bn, ic_bn//n_elems, n_elems))
+        kernel_OHWoIi = relay.reshape(
+            kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn)
+        )
+        kernel_OHWoIie = relay.reshape(
+            kernel_OHWoIi,
+            (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems),
+        )
         kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
 
         # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+        new_attrs["channels"] = out_channel
+        new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+        new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
         # Store altered operator's config.
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel // oc_bn,
-                                     in_channel // ic_bn,
-                                     kh,
-                                     kw,
-                                     ic_bn // n_elems,
-                                     oc_bn,
-                                     n_elems), dtype=kernel_dtype)
+        new_data = te.placeholder(
+            (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+        )
+        new_kernel = te.placeholder(
+            (out_channel // oc_bn, in_channel // ic_bn, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
+            dtype=kernel_dtype,
+        )
         new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
-             new_attrs['out_layout'], out_dtype], topi_tmpl)
+            [
+                new_data,
+                new_kernel,
+                strides,
+                padding,
+                dilation,
+                new_attrs["data_layout"],
+                new_attrs["out_layout"],
+                out_dtype,
+            ],
+            topi_tmpl,
+        )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)
@@ -145,8 +173,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
         if data_layout == "NCHW" and kernel_layout == "OIHW":
             if cfg.is_fallback:
-                _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                    out_dtype, True, data_layout)
+                _get_default_config(
+                    cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, True, data_layout
+                )
 
             batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
             out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
@@ -154,19 +183,31 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
             assert channel_multiplier == 1
 
             # update new attrs
-            new_attrs['channels'] = out_channel
-            new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
-            new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
-            new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+            new_attrs["channels"] = out_channel
+            new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+            new_attrs["kernel_layout"] = "OIHW1i%do" % oc_bn
+            new_attrs["out_layout"] = "NCHW%dc" % oc_bn
 
             # Store altered operator's config.
-            new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                      dtype=data_dtype)
-            new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn),
-                                        dtype=kernel_dtype)
+            new_data = te.placeholder(
+                (batch_size, in_channel // ic_bn, height, width, ic_bn), dtype=data_dtype
+            )
+            new_kernel = te.placeholder(
+                (out_channel // oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype
+            )
             new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
-                 new_attrs['out_layout'], out_dtype], topi_tmpl)
+                [
+                    new_data,
+                    new_kernel,
+                    strides,
+                    padding,
+                    dilation,
+                    new_attrs["data_layout"],
+                    new_attrs["out_layout"],
+                    out_dtype,
+                ],
+                topi_tmpl,
+            )
             dispatch_ctx.update(target, new_workload, cfg)
         else:
             assert _NCHWc_matcher.match(data_layout)
@@ -228,36 +269,36 @@ def _conv2d_legalize(attrs, inputs, arg_types):
     #   C = (A' conv B) - 128 (conv) B
     # where A' = A + 128
     # and 128 (conv) B is basically a reduce on CRS axis for weights.
-    if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
+    if data_tensor.dtype == "int8" and kernel_tensor.dtype == "int8":
         is_int8_inputs = True
         padding = attrs.get_int_tuple("padding")
         kh, kw = attrs.get_int_tuple("kernel_size")
         pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))
 
-        if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
-            adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
+        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
+            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"), axis=(0, 1, 2))
             pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0))
-        elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
+        elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
             pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr))
-            adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
+            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"), axis=(1, 2, 3))
             adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
         else:
             return None
 
-        data = relay.cast(data, 'int32')
-        data = relay.add(data, relay.const(128, 'int32'))
-        data = relay.cast(data, 'uint8')
+        data = relay.cast(data, "int32")
+        data = relay.add(data, relay.const(128, "int32"))
+        data = relay.cast(data, "uint8")
 
         # Do external padding as pad value has to be 128.
         if not (padding[0] == 0 and padding[1] == 0):
             data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
-        new_attrs['padding'] = (0, 0)
+        new_attrs["padding"] = (0, 0)
 
         # The data type is now shifted to uint8
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
 
         # Multiply 128 to adjust shift.
-        adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))
+        adjust_shift = relay.multiply(adjust_shift, relay.const(128, "int32"))
 
     # Legalize if the datatypes are suitable for fast Int8 instructions.  Int8 instructions require
     # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
@@ -271,10 +312,10 @@ def _conv2d_legalize(attrs, inputs, arg_types):
         # Find the value of input and output channel.
         in_channel = -1
         out_channel = -1
-        if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
+        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
             in_channel = data_tensor.shape[3].value
             out_channel = kernel_tensor.shape[3].value
-        elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
+        elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
             in_channel = data_tensor.shape[1].value
             out_channel = kernel_tensor.shape[0].value
         else:
@@ -283,11 +324,11 @@ def _conv2d_legalize(attrs, inputs, arg_types):
         if in_channel % 4 != 0:
             new_in_channel = ((in_channel + 4) // 4) * 4
             diff = new_in_channel - in_channel
-            if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
+            if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
                 data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, diff)))
                 kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, diff), (0, 0)))
                 ic_modified = True
-            elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
+            elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
                 pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
                 data = relay.nn.pad(data, pad_width=pad_width)
                 kernel = relay.nn.pad(kernel, pad_width=pad_width)
@@ -299,22 +340,20 @@ def _conv2d_legalize(attrs, inputs, arg_types):
         if out_channel % 16 != 0:
             new_out_channel = ((out_channel + 16) // 16) * 16
             diff = new_out_channel - out_channel
-            if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
+            if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
                 kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, diff)))
                 oc_modified = True
-            elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
+            elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
                 kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
                 oc_modified = True
             else:
                 return None
 
         if oc_modified:
-            new_attrs['channels'] = new_out_channel
+            new_attrs["channels"] = new_out_channel
             out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
             original_out_shape = [x.value for x in output_tensor.shape]
-            out = relay.strided_slice(out,
-                                      begin=[0, 0, 0, 0],
-                                      end=original_out_shape)
+            out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
         else:
             out = relay.nn.conv2d(data, kernel, **new_attrs)
 
index c6ed832..8ca20be 100644 (file)
@@ -28,6 +28,7 @@ from ..util import get_const_tuple, simplify
 from .tensor_intrin import dot_16x1x16_uint8_int8_int32
 from .util import get_fp32_len
 
+
 def _fallback_schedule(cfg, wkl):
     simd_width = get_fp32_len()
     HPAD, WPAD = wkl.hpad, wkl.wpad
@@ -65,8 +66,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
     _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
 
     # schedule pad
-    if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
-            and "pad" in data_vec.op.tag:
+    if isinstance(s[data_vec].op, tvm.te.ComputeOp) and "pad" in data_vec.op.tag:
         batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
         s[data_vec].vectorize(ic_block)
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
@@ -74,8 +74,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
         data_vec = data_vec.op.input_tensors[0]
 
     oc_bn = cfg["tile_oc"].size[-1]
-    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
-            kernel_vec.name == 'kernel_vec':
+    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec":
         # data and kernel are not pre-computed, schedule layout transform here.
         # this should only be used by x86 conv2d_nchw, which is for
         # testing purpose.
@@ -91,7 +90,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
         s[kernel_vec].parallel(parallel_axis)
 
     C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
@@ -148,9 +147,16 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
 
 
 def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
-    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec,
-                                                           conv_out, last, int32_lanes=16,
-                                                           intrin=dot_16x1x16_uint8_int8_int32())
+    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
+        s,
+        cfg,
+        data_vec,
+        kernel_vec,
+        conv_out,
+        last,
+        int32_lanes=16,
+        intrin=dot_16x1x16_uint8_int8_int32(),
+    )
 
 
 def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
@@ -174,7 +180,8 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     out_channel = num_filter
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
@@ -189,25 +196,29 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
     idxm = tvm.tir.indexmod
 
     packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4)
-    PackW = te.compute(packw_shape,
-                       lambda a, b, c, d, e:
-                       Filter[a, b,
-                              c*16 + idxm(d, 16),
-                              idxd(d, 16) * 4 + e],
-                       name="packed_filter")
-
-    rc = te.reduce_axis((0, in_channel), name='rc')
-    ry = te.reduce_axis((0, kernel_h), name='ry')
-    rx = te.reduce_axis((0, kernel_w), name='rx')
+    PackW = te.compute(
+        packw_shape,
+        lambda a, b, c, d, e: Filter[a, b, c * 16 + idxm(d, 16), idxd(d, 16) * 4 + e],
+        name="packed_filter",
+    )
+
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
     Output = te.compute(
         (batch, out_height, out_width, out_channel),
         lambda nn, yy, xx, ff: te.sum(
-            PaddedInput[nn, yy * stride_h + ry * dilation_h,
-                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            PackW[ry, rx, idxd(ff, 16),
-                  idxd(rc, 4) * 16 + idxm(ff, 16),
-                  idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]),
-        name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
+            PaddedInput[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc
+            ].astype(out_dtype)
+            * PackW[ry, rx, idxd(ff, 16), idxd(rc, 4) * 16 + idxm(ff, 16), idxm(rc, 4)].astype(
+                out_dtype
+            ),
+            axis=[ry, rx, rc],
+        ),
+        name="Conv2d_1x1_Output_int8",
+        tag="conv2d_nhwc_pack_int8",
+    )
     return Output
 
 
index aea954f..28a698c 100644 (file)
@@ -24,6 +24,7 @@ from ..util import get_const_tuple
 from .tensor_intrin import dot_16x1x16_uint8_int8_int32
 from .util import get_fp32_len
 
+
 def _fallback_schedule(cfg, wkl):
     simd_width = get_fp32_len()
     HPAD, WPAD = wkl.hpad, wkl.wpad
@@ -87,8 +88,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
     _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
 
     # schedule pad
-    if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
-            and "pad" in data_vec.op.tag:
+    if isinstance(s[data_vec].op, tvm.te.ComputeOp) and "pad" in data_vec.op.tag:
         batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
         s[data_vec].vectorize(ic_block)
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
@@ -96,8 +96,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
         data_vec = data_vec.op.input_tensors[0]
 
     oc_bn = cfg["tile_oc"].size[-1]
-    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
-            kernel_vec.name == 'kernel_vec':
+    if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec":
         # data and kernel are not pre-computed, schedule layout transform here.
         # this should only be used by x86 conv2d_nchw, which is for
         # testing purpose.
@@ -112,10 +111,9 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
         parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
         s[kernel_vec].parallel(parallel_axis)
 
-
     # schedule 5-D NCHW[x]c conv
     C, O = conv_out, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
@@ -167,6 +165,13 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
 
 
 def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
-    return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec,
-                                                              conv_out, last, int32_lanes=16,
-                                                              intrin=dot_16x1x16_uint8_int8_int32())
+    return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
+        s,
+        cfg,
+        data_vec,
+        kernel_vec,
+        conv_out,
+        last,
+        int32_lanes=16,
+        intrin=dot_16x1x16_uint8_int8_int32(),
+    )
index 4b11143..e2862ec 100644 (file)
@@ -31,8 +31,10 @@ from ..util import get_const_tuple, traverse_inline
 from .. import nn
 from . import conv2d_avx_1x1, conv2d_avx_common
 
-def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
-                             layout='NCHW'):
+
+def _get_default_config_int8(
+    cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
+):
     """
     Get default schedule config for the workload
     """
@@ -40,16 +42,19 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
         # Fallback to FP32 default config until a VNNI schedule is defined.
         wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
         from .depthwise_conv2d import _fallback_schedule
+
         _fallback_schedule(cfg, wkl)
     else:
         wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
         is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
         if is_kernel_1x1:
             conv2d_generic.fallback_schedule_cpu_1x1_int8(
-                cfg, wkl, int32_lanes=16, num_int8_elements=4)
+                cfg, wkl, int32_lanes=16, num_int8_elements=4
+            )
         else:
             conv2d_generic.fallback_schedule_cpu_common_int8(
-                cfg, wkl, int32_lanes=16, num_int8_elements=4)
+                cfg, wkl, int32_lanes=16, num_int8_elements=4
+            )
 
 
 def is_int8_hw_support(data_dtype, kernel_dtype):
@@ -60,7 +65,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
     3) Target is skylake and above.
     """
     # 1) Check datatypes
-    is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
+    is_dtype_support = data_dtype == "uint8" and kernel_dtype == "int8"
 
     # 2) Check LLVM support
     llvm_version = tvm.target.codegen.llvm_version_major()
@@ -69,7 +74,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
     # 3) Check target
     mcpu = tvm.target.Target.current().mcpu
     is_target_support = False
-    if mcpu in ('skylake-avx512', 'cascadelake'):
+    if mcpu in ("skylake-avx512", "cascadelake"):
         is_target_support = True
 
     return is_dtype_support and is_llvm_support and is_target_support
@@ -78,8 +83,9 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
 def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with NCHW layout and int8 dtype"""
     layout = "NCHW"
-    packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
-                                   layout, layout, out_dtype)
+    packed_out = conv2d_NCHWc_int8(
+        data, kernel, strides, padding, dilation, layout, layout, out_dtype
+    )
     return unpack_NCHWc_to_nchw(packed_out, out_dtype)
 
 
@@ -97,34 +103,36 @@ def _pack_data(cfg, data, kernel):
     ic_chunk = ic // ic_bn
     oc_chunk = oc // oc_bn
 
-    data = te.compute((n, ic_chunk, ih, iw, ic_bn),
-                      lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
-                      name="data_vec")
+    data = te.compute(
+        (n, ic_chunk, ih, iw, ic_bn),
+        lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
+        name="data_vec",
+    )
 
     kernel = te.compute(
-        (oc_chunk, ic_chunk, kh, kw, ic_bn//n_elems, oc_bn, n_elems),
-        lambda occ, icc, k_h, k_w, icbc, ocb, icbb:
-        kernel[occ * oc_bn + ocb,
-               icc * ic_bn + icbc * ic_bn//n_elems + icbb, k_h, k_w],
-        name="kernel_vec")
+        (oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
+        lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[
+            occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w
+        ],
+        name="kernel_vec",
+    )
 
     return data, kernel
 
 
 @autotvm.register_topi_compute("conv2d_NCHWc_int8.x86")
-def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding,
-                      dilation, layout, out_layout, out_dtype):
+def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
     """Compute conv2d with NCHWc layout and int8 dtype"""
     if len(data.shape) == 5:
         n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
         in_channel = ic_chunk * ic_bn
-        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ \
-            = get_const_tuple(kernel.shape)
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(
+            kernel.shape
+        )
         num_filter = oc_chunk * oc_bn
     else:
         n, in_channel, ih, iw = get_const_tuple(data.shape)
-        num_filter, _, kernel_height, kernel_width = \
-            get_const_tuple(kernel.shape)
+        num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
 
     # Define autotvm tuning space
     is_kernel_1x1 = kernel_height == 1 and kernel_width == 1
@@ -133,10 +141,8 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding,
     oh = (ih - kernel_height + pt + pb) // sh + 1
     ow = (iw - kernel_width + pl + pr) // sw + 1
 
-    cfg.define_split('tile_ic', in_channel, num_outputs=2,
-                     filter=lambda y: y.size[-1] % 4 == 0)
-    cfg.define_split('tile_oc', num_filter, num_outputs=2,
-                     filter=lambda y: y.size[-1] % 16 == 0)
+    cfg.define_split("tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
+    cfg.define_split("tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
     cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
     if is_kernel_1x1:
         cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
@@ -146,24 +152,24 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding,
     # If no config was set, we can fallback to default config.
     if cfg.is_fallback:
         _get_default_config_int8(
-            cfg, te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
-            te.placeholder((num_filter, in_channel, kernel_height, kernel_width),
-                           dtype=kernel.dtype),
-            strides, padding, out_dtype)
+            cfg,
+            te.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+            te.placeholder(
+                (num_filter, in_channel, kernel_height, kernel_width), dtype=kernel.dtype
+            ),
+            strides,
+            padding,
+            out_dtype,
+        )
 
     # Pack data if raw 4-D data is provided.
     # This can only happen when autotuning.
     if len(data.shape) == 4:
         data, kernel = _pack_data(cfg, data, kernel)
 
-    return nn.conv2d_NCHWc_int8(data,
-                                kernel,
-                                strides,
-                                padding,
-                                dilation,
-                                layout,
-                                out_layout,
-                                out_dtype)
+    return nn.conv2d_NCHWc_int8(
+        data, kernel, strides, padding, dilation, layout, out_layout, out_dtype
+    )
 
 
 @autotvm.register_topi_schedule("conv2d_NCHWc_int8.x86")
@@ -173,7 +179,7 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
 
     def _callback(op):
         """Traverse operators from computation graph"""
-        if 'conv2d_NCHWc_int8' in op.tag:
+        if "conv2d_NCHWc_int8" in op.tag:
             conv_out = op.output(0)
             kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
@@ -203,8 +209,8 @@ def schedule_conv2d_nhwc_pack_int8(cfg, outs):
         if tag.is_broadcast(op.tag):
             if op not in s.outputs:
                 s[op].compute_inline()
-            else: # inject custom schedule
-                if len(op.axis) == 4: # schedule bias + bn + relu
+            else:  # inject custom schedule
+                if len(op.axis) == 4:  # schedule bias + bn + relu
                     n, h, w, c = op.axis
                     fused = s[op].fuse(n, h, w)
                     s[op].parallel(fused)
@@ -213,29 +219,33 @@ def schedule_conv2d_nhwc_pack_int8(cfg, outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
 
-        if 'conv2d_nhwc_pack_int8' in op.tag:
+        if "conv2d_nhwc_pack_int8" in op.tag:
             conv_out = op.output(0)
             kernel = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
-            data = data_vec.op.input_tensors[0] \
-                if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
+            data = (
+                data_vec.op.input_tensors[0]
+                if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag
                 else data_vec
+            )
             if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag:
                 data_pad = data
                 data = data_pad.op.input_tensors[0]
 
             args = [s, cfg, data_vec, conv_out, outs[0]]
-            if data.dtype == 'uint8':
+            if data.dtype == "uint8":
                 kh, kw, _, _, _ = get_const_tuple(kernel.shape)
                 if kh == 1 and kw == 1:
                     conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
                 else:
-                    raise ValueError("Only support 1x1 kernel with "
-                                     "schedule_conv2d_nhwc_pack.")
+                    raise ValueError("Only support 1x1 kernel with " "schedule_conv2d_nhwc_pack.")
             else:
-                raise ValueError("Not support this data type {} with "
-                                 "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype))
+                raise ValueError(
+                    "Not support this data type {} with "
+                    "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype)
+                )
 
         scheduled_ops.append(op)
+
     traverse(output_op)
     return s
index 7ec2817..105c455 100644 (file)
@@ -23,20 +23,27 @@ from .conv2d import conv2d_nchw, schedule_conv2d_nchw
 
 
 def conv2d_transpose_nchw(data, kernel, strides, padding, out_dtype, output_padding):
-    data_pad, kernel_transform = \
-        nn.conv2d_transpose_nchw_preprocess(data, kernel, strides, padding,
-                                            out_dtype, output_padding)
+    data_pad, kernel_transform = nn.conv2d_transpose_nchw_preprocess(
+        data, kernel, strides, padding, out_dtype, output_padding
+    )
     # reuse conv2d_nchw implementation
-    return conv2d_nchw(data_pad, kernel_transform, strides=(1, 1),
-                       padding=(0, 0), dilation=(1, 1), out_dtype=out_dtype)
+    return conv2d_nchw(
+        data_pad,
+        kernel_transform,
+        strides=(1, 1),
+        padding=(0, 0),
+        dilation=(1, 1),
+        out_dtype=out_dtype,
+    )
 
 
 def schedule_conv2d_transpose_nchw(outs):
     """Create schedule for tensors"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = schedule_conv2d_nchw(outs)
+
     def _callback(op):
-        if 'unpack_nchwc' in op.tag:
+        if "unpack_nchwc" in op.tag:
             conv_out = op.input_tensors[0]
             # retrieve data
             data_vec = conv_out.op.input_tensors[0]
index f0dee31..479a27b 100644 (file)
@@ -28,11 +28,29 @@ from ..nn.pad import pad
 from ..util import get_const_tuple, simplify, get_const_int
 from .util import get_fp32_len
 
-Workload3D = namedtuple('Workload',
-                        ['in_dtype', 'out_dtype', 'depth', 'height', 'width',
-                         'in_filter', 'groups', 'out_filter', 'dkernel',
-                         'hkernel', 'wkernel', 'dpad', 'hpad', 'wpad',
-                         'dstride', 'hstride', 'wstride'])
+Workload3D = namedtuple(
+    "Workload",
+    [
+        "in_dtype",
+        "out_dtype",
+        "depth",
+        "height",
+        "width",
+        "in_filter",
+        "groups",
+        "out_filter",
+        "dkernel",
+        "hkernel",
+        "wkernel",
+        "dpad",
+        "hpad",
+        "wpad",
+        "dstride",
+        "hstride",
+        "wstride",
+    ],
+)
+
 
 @autotvm.register_topi_compute("conv3d_ndhwc.x86")
 def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
@@ -109,6 +127,7 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype):
         _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout)
     return _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
 
+
 @autotvm.register_topi_schedule("conv3d_ndhwc.x86")
 def schedule_conv3d_ndhwc(cfg, outs):
     """TOPI schedule callback for conv3d
@@ -127,7 +146,7 @@ def schedule_conv3d_ndhwc(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _traverse(op):
-        if 'conv3d_ndhwc' in op.tag:
+        if "conv3d_ndhwc" in op.tag:
             output = op.output(0)
             conv_out = op.input_tensors[0]
             kernel_vec = conv_out.op.input_tensors[1]
@@ -148,6 +167,7 @@ def schedule_conv3d_ndhwc(cfg, outs):
     traverse_inline(s, outs[0].op, _traverse)
     return s
 
+
 @autotvm.register_topi_schedule("conv3d_ncdhw.x86")
 def schedule_conv3d_ncdhw(cfg, outs):
     """TOPI schedule callback for conv3d
@@ -166,7 +186,7 @@ def schedule_conv3d_ncdhw(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _traverse(op):
-        if 'conv3d_ncdhw' in op.tag:
+        if "conv3d_ncdhw" in op.tag:
             output = op.output(0)
             conv_out = op.input_tensors[0]
             kernel_vec = conv_out.op.input_tensors[1]
@@ -206,7 +226,8 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
     dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
 
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
+    )
 
     pad_d = pad_front + pad_back
     pad_h = pad_top + pad_down
@@ -221,59 +242,80 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
     out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
 
     # pack data
-    DOPAD = (pad_d != 0 or pad_h != 0 or pad_w != 0)
+    DOPAD = pad_d != 0 or pad_h != 0 or pad_w != 0
     if DOPAD:
-        data_pad = pad(data, (0, pad_front, pad_top, pad_left, 0),
-                       (0, pad_back, pad_down, pad_right, 0), name="data_pad")
+        data_pad = pad(
+            data,
+            (0, pad_front, pad_top, pad_left, 0),
+            (0, pad_back, pad_down, pad_right, 0),
+            name="data_pad",
+        )
     else:
         data_pad = data
 
     # fetch schedule
     ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
     shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
-    data_vec = te.compute(shape,
-                          lambda n, C, d, h, c, w: data_pad[n, d, h, w, C * ic_bn + c],
-                          name='data_vec')
+    data_vec = te.compute(
+        shape, lambda n, C, d, h, c, w: data_pad[n, d, h, w, C * ic_bn + c], name="data_vec"
+    )
 
     # pack kernel
-    shape = (num_filter//oc_bn, in_channel//ic_bn,
-             kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
-    kernel_vec = te.compute(shape,
-                            lambda CO, CI, d, h, w, ci, co:
-                            kernel[d, h, w, CI * ic_bn + ci, CO * oc_bn + co],
-                            name='kernel_vec')
+    shape = (
+        num_filter // oc_bn,
+        in_channel // ic_bn,
+        kernel_depth,
+        kernel_height,
+        kernel_width,
+        ic_bn,
+        oc_bn,
+    )
+    kernel_vec = te.compute(
+        shape,
+        lambda CO, CI, d, h, w, ci, co: kernel[d, h, w, CI * ic_bn + ci, CO * oc_bn + co],
+        name="kernel_vec",
+    )
 
     # convolution
-    oshape = (batch_size, num_filter//oc_bn, out_depth, out_height, out_width, oc_bn)
+    oshape = (batch_size, num_filter // oc_bn, out_depth, out_height, out_width, oc_bn)
     unpack_shape = (batch_size, out_depth, out_height, out_width, num_filter)
 
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
-    kd = te.reduce_axis((0, kernel_depth), name='kd')
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
+    kd = te.reduce_axis((0, kernel_depth), name="kd")
     idxmod = tvm.tir.indexmod
     idxdiv = tvm.tir.indexdiv
 
-    conv = te.compute(oshape, lambda n, oc_chunk, od, oh, ow, oc_block:
-                      te.sum(data_vec[n,
-                                      idxdiv(ic, ic_bn),
-                                      od*DSTR+kd*dilation_d,
-                                      oh*HSTR+kh*dilation_h,
-                                      idxmod(ic, ic_bn),
-                                      ow*WSTR+kw*dilation_w].astype(out_dtype) *
-                             kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw,
-                                        idxmod(ic, ic_bn),
-                                        oc_block].astype(out_dtype),
-                             axis=[kd, kh, kw, ic]), name='conv')
-    conv_unpacked = te.compute(unpack_shape,
-                               lambda n, d, h, w, c: conv[n, idxdiv(c, oc_bn),
-                                                          d, h, w,
-                                                          idxmod(c, oc_bn)]
-                               .astype(out_dtype),
-                               name='output_unpack',
-                               tag='conv3d_ndhwc')
+    conv = te.compute(
+        oshape,
+        lambda n, oc_chunk, od, oh, ow, oc_block: te.sum(
+            data_vec[
+                n,
+                idxdiv(ic, ic_bn),
+                od * DSTR + kd * dilation_d,
+                oh * HSTR + kh * dilation_h,
+                idxmod(ic, ic_bn),
+                ow * WSTR + kw * dilation_w,
+            ].astype(out_dtype)
+            * kernel_vec[
+                oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw, idxmod(ic, ic_bn), oc_block
+            ].astype(out_dtype),
+            axis=[kd, kh, kw, ic],
+        ),
+        name="conv",
+    )
+    conv_unpacked = te.compute(
+        unpack_shape,
+        lambda n, d, h, w, c: conv[n, idxdiv(c, oc_bn), d, h, w, idxmod(c, oc_bn)].astype(
+            out_dtype
+        ),
+        name="output_unpack",
+        tag="conv3d_ndhwc",
+    )
     return conv_unpacked
 
+
 def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     out_dtype = data.dtype if out_dtype is None else out_dtype
 
@@ -292,7 +334,8 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dty
     dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
 
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
+    )
 
     pad_d = pad_front + pad_back
     pad_h = pad_top + pad_down
@@ -307,10 +350,14 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dty
     out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
 
     # pack data
-    DOPAD = (pad_d != 0 or pad_h != 0 or pad_w != 0)
+    DOPAD = pad_d != 0 or pad_h != 0 or pad_w != 0
     if DOPAD:
-        data_pad = pad(data, (0, 0, pad_front, pad_top, pad_left),
-                       (0, 0, pad_back, pad_down, pad_right), name="data_pad")
+        data_pad = pad(
+            data,
+            (0, 0, pad_front, pad_top, pad_left),
+            (0, 0, pad_back, pad_down, pad_right),
+            name="data_pad",
+        )
     else:
         data_pad = data
 
@@ -318,63 +365,78 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dty
     ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
 
     shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
-    data_vec = te.compute(shape,
-                          lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w],
-                          name='data_vec')
+    data_vec = te.compute(
+        shape, lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w], name="data_vec"
+    )
 
     # pack kernel
-    shape = (num_filter//oc_bn, in_channel//ic_bn,
-             kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
-    kernel_vec = te.compute(shape,
-                            lambda CO, CI, d, h, w, ci, co:
-                            kernel[CO * oc_bn + co, CI * ic_bn + ci, d, h, w],
-                            name='kernel_vec')
+    shape = (
+        num_filter // oc_bn,
+        in_channel // ic_bn,
+        kernel_depth,
+        kernel_height,
+        kernel_width,
+        ic_bn,
+        oc_bn,
+    )
+    kernel_vec = te.compute(
+        shape,
+        lambda CO, CI, d, h, w, ci, co: kernel[CO * oc_bn + co, CI * ic_bn + ci, d, h, w],
+        name="kernel_vec",
+    )
 
     # convolution
-    oshape = (batch_size, num_filter//oc_bn,
-              out_depth, out_height, out_width, oc_bn)
+    oshape = (batch_size, num_filter // oc_bn, out_depth, out_height, out_width, oc_bn)
     unpack_shape = (batch_size, num_filter, out_depth, out_height, out_width)
 
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
-    kd = te.reduce_axis((0, kernel_depth), name='kd')
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
+    kd = te.reduce_axis((0, kernel_depth), name="kd")
     idxmod = tvm.tir.indexmod
     idxdiv = tvm.tir.indexdiv
 
-    conv = te.compute(oshape, lambda n, oc_chunk, od, oh, ow, oc_block:
-                      te.sum(data_vec[n,
-                                      idxdiv(ic, ic_bn),
-                                      od*DSTR+kd*dilation_d,
-                                      oh*HSTR+kh*dilation_h,
-                                      idxmod(ic, ic_bn),
-                                      ow*WSTR+kw*dilation_w].astype(out_dtype) *
-                             kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw,
-                                        idxmod(ic, ic_bn),
-                                        oc_block].astype(out_dtype),
-                             axis=[ic, kd, kh, kw]), name='conv')
-    conv_unpacked = te.compute(unpack_shape,
-                               lambda n, c, d, h, w: conv[n, idxdiv(c, oc_bn),
-                                                          d, h, w,
-                                                          idxmod(c, oc_bn)]
-                               .astype(out_dtype),
-                               name='output_unpack',
-                               tag='conv3d_ncdhw')
+    conv = te.compute(
+        oshape,
+        lambda n, oc_chunk, od, oh, ow, oc_block: te.sum(
+            data_vec[
+                n,
+                idxdiv(ic, ic_bn),
+                od * DSTR + kd * dilation_d,
+                oh * HSTR + kh * dilation_h,
+                idxmod(ic, ic_bn),
+                ow * WSTR + kw * dilation_w,
+            ].astype(out_dtype)
+            * kernel_vec[
+                oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw, idxmod(ic, ic_bn), oc_block
+            ].astype(out_dtype),
+            axis=[ic, kd, kh, kw],
+        ),
+        name="conv",
+    )
+    conv_unpacked = te.compute(
+        unpack_shape,
+        lambda n, c, d, h, w: conv[n, idxdiv(c, oc_bn), d, h, w, idxmod(c, oc_bn)].astype(
+            out_dtype
+        ),
+        name="output_unpack",
+        tag="conv3d_ncdhw",
+    )
     return conv_unpacked
 
+
 def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
     """Create schedule configuration from input arguments"""
     dshape = get_const_tuple(data.shape)
     kshape = get_const_tuple(kernel.shape)
-    if layout == 'NDHWC':
+    if layout == "NDHWC":
         n, d, h, w, ic = dshape
         kd, kh, kw, _, oc = kshape
-    elif layout == 'NCDHW':
+    elif layout == "NCDHW":
         n, ic, d, h, w = dshape
         oc, _, kd, kh, kw = kshape
     else:
-        raise ValueError("Not support this layout {} with "
-                         "schedule template.".format(layout))
+        raise ValueError("Not support this layout {} with " "schedule template.".format(layout))
 
     # pad_front, pad_top, pad_left, pad_back, pad_down(bottom), pad_right
     pf, pt, pl, pb, pd, pr = get_pad_tuple3d(padding, (kd, kh, kw))
@@ -389,11 +451,12 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
     cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 8)
     cfg.define_knob("unroll_kw", [True, False])
 
+
 def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout):
     """
     Get default schedule config for the workload
     """
-    if layout not in ['NDHWC', 'NCDHW']:
+    if layout not in ["NDHWC", "NCDHW"]:
         raise ValueError("Layout {} is not supported".format(layout))
 
     static_data_shape = []
@@ -406,19 +469,21 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout):
     wkl = _get_conv3d_workload(data, kernel, strides, padding, out_dtype, layout)
     _fallback_schedule(cfg, wkl)
 
-def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
+
+def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
     """ Get the workload structure. """
-    if data_layout == 'NCDHW':
+    if data_layout == "NCDHW":
         _, CI, ID, IH, IW = get_const_tuple(data.shape)
         CO, CIG, KD, KH, KW = get_const_tuple(kernel.shape)
-    elif data_layout == 'NDHWC':
+    elif data_layout == "NDHWC":
         _, ID, IH, IW, CI = get_const_tuple(data.shape)
         KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape)
     else:
         raise ValueError("not support this layout {} yet".format(data_layout))
 
     pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
-        padding, (get_const_int(KD), get_const_int(KH), get_const_int(KW)))
+        padding, (get_const_int(KD), get_const_int(KH), get_const_int(KW))
+    )
     DPAD = pad_front + pad_back
     HPAD = pad_top + pad_down
     WPAD = pad_left + pad_right
@@ -427,11 +492,31 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout='
         DSTR, HSTR, WSTR = stride
     else:
         DSTR, HSTR, WSTR = stride, stride, stride
-    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
-        "Do not support inputs with different data types now. ' \
-        '{} vs. {}".format(data.dtype, kernel.dtype)
-    return Workload3D(data.dtype, out_dtype, ID, IH, IW, CI, GRPS, CO, KD, KH, KW,
-                      DPAD, HPAD, WPAD, DSTR, HSTR, WSTR)
+    assert (data.dtype == kernel.dtype) or (
+        data.dtype == "uint8" and kernel.dtype == "int8"
+    ), "Do not support inputs with different data types now. ' \
+        '{} vs. {}".format(
+        data.dtype, kernel.dtype
+    )
+    return Workload3D(
+        data.dtype,
+        out_dtype,
+        ID,
+        IH,
+        IW,
+        CI,
+        GRPS,
+        CO,
+        KD,
+        KH,
+        KW,
+        DPAD,
+        HPAD,
+        WPAD,
+        DSTR,
+        HSTR,
+        WSTR,
+    )
 
 
 def _fallback_schedule(cfg, wkl):
@@ -465,13 +550,17 @@ def _fallback_schedule(cfg, wkl):
 
 def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
     # fetch schedule
-    ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                                      cfg["tile_ow"].size[-1], cfg["unroll_kw"].val)
+    ic_bn, oc_bn, reg_n, unroll_kw = (
+        cfg["tile_ic"].size[-1],
+        cfg["tile_oc"].size[-1],
+        cfg["tile_ow"].size[-1],
+        cfg["unroll_kw"].val,
+    )
 
     # get padding size
     padding = infer_pad3d(data, data_pad, "NDHWC")
     DPAD, HPAD, WPAD = padding
-    DOPAD = (DPAD != 0 or HPAD != 0 or WPAD != 0)
+    DOPAD = DPAD != 0 or HPAD != 0 or WPAD != 0
 
     A, W = data, kernel_vec
     A0, A1 = data_pad, data_vec
@@ -493,7 +582,7 @@ def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou
 
     # schedule conv
     C, O0, O = conv_out, output, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     _, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
@@ -532,15 +621,20 @@ def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou
     s[O].parallel(parallel_axis)
     return s
 
+
 def _schedule_conv3d_ncdhw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
     # fetch schedule
-    ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                                      cfg["tile_ow"].size[-1], cfg["unroll_kw"].val)
+    ic_bn, oc_bn, reg_n, unroll_kw = (
+        cfg["tile_ic"].size[-1],
+        cfg["tile_oc"].size[-1],
+        cfg["tile_ow"].size[-1],
+        cfg["unroll_kw"].val,
+    )
 
     # get padding size
     padding = infer_pad3d(data, data_pad, "NCDHW")
     DPAD, HPAD, WPAD = padding
-    DOPAD = (DPAD != 0 or HPAD != 0 or WPAD != 0)
+    DOPAD = DPAD != 0 or HPAD != 0 or WPAD != 0
 
     A, W = data, kernel_vec
     A0, A1 = data_pad, data_vec
@@ -562,7 +656,7 @@ def _schedule_conv3d_ncdhw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou
 
     # schedule conv
     C, O0, O = conv_out, output, last
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     _, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
index 698702a..f986ccf 100644 (file)
@@ -23,21 +23,23 @@ from ..util import traverse_inline
 from .. import nn
 from .conv3d import conv3d_ncdhw, schedule_conv3d_ncdhw
 
+
 def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding):
-    data_pad, kernel_transform = \
-        nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding,
-                                             out_dtype, output_padding)
+    data_pad, kernel_transform = nn.conv3d_transpose_ncdhw_preprocess(
+        data, kernel, strides, padding, out_dtype, output_padding
+    )
 
     # reuse conv3d_ncdhw implementation
-    return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1),
-                        (0, 0, 0), (1, 1, 1), out_dtype)
+    return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1), (0, 0, 0), (1, 1, 1), out_dtype)
+
 
 def schedule_conv3d_transpose_ncdhw(outs):
     """Create schedule for tensors"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = schedule_conv3d_ncdhw(outs)
+
     def _callback(op):
-        if 'unpack_ncdhwc' in op.tag:
+        if "unpack_ncdhwc" in op.tag:
             conv_out = op.input_tensors[0]
             # retrieve data
             data_vec = conv_out.op.input_tensors[0]
index c2e5b55..e318493 100644 (file)
@@ -29,12 +29,13 @@ from .util import get_fp32_len
 from .. import generic, tag
 from ..util import traverse_inline, get_const_tuple
 
+
 def _schedule_dense_pack_template(cfg, s, C):
     A, packedB = s[C].op.input_tensors
 
     CC = s.cache_write(C, "global")
     y, x = s[C].op.axis
-    k, = s[CC].op.reduce_axis
+    (k,) = s[CC].op.reduce_axis
 
     yt, yo, yi = cfg["tile_y"].apply(s, C, y)
     xt, xo, xi = cfg["tile_x"].apply(s, C, x)
@@ -61,7 +62,7 @@ def _schedule_dense_pack_template(cfg, s, C):
 
 def _schedule_dense_nopack_template(cfg, s, C):
     y, x = s[C].op.axis
-    kk, = s[C].op.reduce_axis
+    (kk,) = s[C].op.reduce_axis
     yo, yi = cfg["tile_y"].apply(s, C, y)
     xo, xi = cfg["tile_x"].apply(s, C, x)
     s[C].reorder(yo, xo, yi, xi)
@@ -69,10 +70,10 @@ def _schedule_dense_nopack_template(cfg, s, C):
     s[C].parallel(xyo)
     s[C].unroll(kk)
 
-    CC, = s[C].op.input_tensors
+    (CC,) = s[C].op.input_tensors
     s[CC].compute_at(s[C], xyo)
     z, y, x = s[CC].op.axis
-    k, = s[CC].op.reduce_axis
+    (k,) = s[CC].op.reduce_axis
     yz = s[CC].fuse(z, y)
     s[CC].reorder(k, yz, x)
     s[CC].unroll(yz)
@@ -91,7 +92,7 @@ def _default_dense_pack_config(cfg, M, N, K):
 
     vec_width = get_fp32_len()
     tilex_ii = 1
-    for bn in range(vec_width*2, 0, -1):
+    for bn in range(vec_width * 2, 0, -1):
         if N % bn == 0:
             tilex_ii = bn
             break
@@ -128,7 +129,7 @@ def _default_dense_nopack_config(cfg, M, N, K):
 
     vec_width = get_fp32_len()
     tilek_bn = 1
-    for bn in range(vec_width*2, 0, -1):
+    for bn in range(vec_width * 2, 0, -1):
         if K % bn == 0:
             tilek_bn = bn
             break
@@ -136,6 +137,7 @@ def _default_dense_nopack_config(cfg, M, N, K):
     cfg["tile_x"] = SplitEntity([N, 1])
     cfg["tile_y"] = SplitEntity([1, M])
 
+
 @autotvm.register_topi_compute("dense_nopack.x86")
 def dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
     """Compute dense without packing"""
@@ -152,18 +154,18 @@ def dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
 
     vec = cfg["tile_k"].size[-1]
     k = te.reduce_axis((0, K // vec), "k")
-    CC = te.compute((M, N, vec),
-                    lambda z, y, x: te.sum(
-                        data[z, k * vec + x].astype(out_dtype) *
-                        weight[y, k * vec + x].astype(out_dtype), axis=k))
+    CC = te.compute(
+        (M, N, vec),
+        lambda z, y, x: te.sum(
+            data[z, k * vec + x].astype(out_dtype) * weight[y, k * vec + x].astype(out_dtype),
+            axis=k,
+        ),
+    )
 
     kk = te.reduce_axis((0, vec), "kk")
-    C = te.compute((M, N),
-                   lambda y, x: te.sum(CC[y, x, kk], axis=kk),
-                   tag="dense_nopack")
+    C = te.compute((M, N), lambda y, x: te.sum(CC[y, x, kk], axis=kk), tag="dense_nopack")
     if bias is not None:
-        C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
-                       tag=tag.BROADCAST)
+        C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
     return C
 
 
@@ -173,18 +175,20 @@ def schedule_dense_nopack(cfg, outs):
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'dense_nopack' in op.tag:
+        if "dense_nopack" in op.tag:
             _schedule_dense_nopack_template(cfg, s, op.output(0))
+
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 @autotvm.register_topi_compute("dense_pack.x86")
 def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
     """Compute dense with packing"""
     if out_dtype is None:
         out_dtype = data.dtype
-    M, K = get_const_tuple(data.shape) # batch, in_dim
-    N, _ = get_const_tuple(weight.shape) # out_dim
+    M, K = get_const_tuple(data.shape)  # batch, in_dim
+    N, _ = get_const_tuple(weight.shape)  # out_dim
     # create tuning space
     cfg.define_split("tile_y", M, num_outputs=3)
     cfg.define_split("tile_x", N, num_outputs=3)
@@ -194,23 +198,27 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
 
     packw_bn = cfg["tile_x"].size[-1]
     packw_shape = (N // packw_bn, K, packw_bn)
-    packw = te.compute(packw_shape,
-                       lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
+    packw = te.compute(
+        packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight"
+    )
 
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
     k = te.reduce_axis((0, K), name="k")
-    C = te.compute((M, N),
-                   lambda y, x: te.sum(
-                       data[y, k].astype(out_dtype) *
-                       packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
-                       axis=k),
-                   tag="dense_pack")
+    C = te.compute(
+        (M, N),
+        lambda y, x: te.sum(
+            data[y, k].astype(out_dtype)
+            * packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
+            axis=k,
+        ),
+        tag="dense_pack",
+    )
     if bias is not None:
-        C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
-                       tag=tag.BROADCAST)
+        C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
     return C
 
+
 @autotvm.register_topi_schedule("dense_pack.x86")
 def schedule_dense_pack(cfg, outs):
     """Create the schedule for dense_pack"""
@@ -219,9 +227,11 @@ def schedule_dense_pack(cfg, outs):
     def _callback(op):
         if "dense_pack" in op.tag:
             _schedule_dense_pack_template(cfg, s, op.output(0))
+
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
     """Compute dense using a BLAS library"""
     M, K = get_const_tuple(data.shape)
@@ -234,43 +244,46 @@ def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
                 "(matmulu8s8s32 not imlemented)"
             )
         C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
-    elif data.dtype == 'float32' or data.dtype == 'float64':
+    elif data.dtype == "float32" or data.dtype == "float64":
         C = lib.matmul(data, weight, False, True)
     else:
-        raise NotImplementedError(
-            f"Dense with {lib.__name__} for {data.dtype} is not supported"
-        )
+        raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype} is not supported")
 
     if bias is not None:
-        C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
-                       tag=tag.BROADCAST)
+        C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
     return C
 
+
 @autotvm.register_topi_compute("dense_cblas.x86")
 def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
     """Compute dense using a cblas"""
     return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
 
+
 @autotvm.register_topi_schedule("dense_cblas.x86")
 def schedule_dense_cblas(_, outs):
     """Create schedule for dense_cblas"""
     return generic.schedule_extern(outs)
 
+
 @autotvm.register_topi_compute("dense_mkl.x86")
 def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
     """Compute dense using mkl"""
     return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
 
+
 @autotvm.register_topi_schedule("dense_mkl.x86")
 def schedule_dense_mkl(_, outs):
     """Create schedule for dense_mkl"""
     return generic.schedule_extern(outs)
 
+
 @autotvm.register_topi_compute("dense_mkldnn.x86")
 def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
     """Compute dense using mkldnn"""
     return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn)
 
+
 @autotvm.register_topi_schedule("dense_mkldnn.x86")
 def schedule_dense_mkldnn(_, outs):
     """Create schedule for dense_mkldnn"""
index acbe0f7..1921f7f 100644 (file)
@@ -29,6 +29,7 @@ from ..nn.conv2d import unpack_NCHWc_to_nchw
 from ..util import traverse_inline
 from .util import get_fp32_len
 
+
 def _fallback_schedule(cfg, wkl):
     """
     Get default schedule for the workload
@@ -68,17 +69,21 @@ def _fallback_schedule(cfg, wkl):
     cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
     cfg["unroll_kw"] = OtherOptionEntity(False)
 
+
 def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
     """Compute depthwise conv2d with NCHW layout."""
     layout = "NCHW"
-    packed_out = depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
-                                        layout, layout, out_dtype)
+    packed_out = depthwise_conv2d_NCHWc(
+        data, kernel, strides, padding, dilation, layout, layout, out_dtype
+    )
     return unpack_NCHWc_to_nchw(packed_out, out_dtype)
 
+
 def schedule_depthwise_conv2d_nchw(outs):
     """Create schedule for depthwise_conv2d_nchw."""
     return schedule_depthwise_conv2d_NCHWc(outs)
 
+
 def _pack_data(cfg, data, kernel):
     n, ic, ih, iw = get_const_tuple(data.shape)
     filters, cm, kh, kw = get_const_tuple(kernel.shape)
@@ -88,29 +93,40 @@ def _pack_data(cfg, data, kernel):
     ic_chunk = ic // ic_bn
     oc_chunk = oc // oc_bn
 
-    data = te.compute((n, ic_chunk, ih, iw, ic_bn),
-                      lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
-                      name="data_vec")
+    data = te.compute(
+        (n, ic_chunk, ih, iw, ic_bn),
+        lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
+        name="data_vec",
+    )
 
     kernel = te.compute(
         (oc_chunk, 1, kh, kw, 1, oc_bn),
-        lambda occ, icc, k_h, k_w, icb, ocb:
-        kernel[(occ * oc_bn + ocb) // cm,
-               (occ * oc_bn + ocb) % cm, k_h, k_w],
-        name="kernel_vec")
+        lambda occ, icc, k_h, k_w, icb, ocb: kernel[
+            (occ * oc_bn + ocb) // cm, (occ * oc_bn + ocb) % cm, k_h, k_w
+        ],
+        name="kernel_vec",
+    )
 
     return data, kernel
 
+
 @autotvm.register_topi_compute("depthwise_conv2d_NCHWc.x86")
-def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
-                           layout, out_layout, out_dtype=None):
+def depthwise_conv2d_NCHWc(
+    cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype=None
+):
     """Compute depthwise conv2d with NCHWc layout"""
     out_dtype = data.dtype if out_dtype is None else out_dtype
 
     if len(data.shape) == 5:
         batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape)
-        out_channel_chunk, cm_chunk, filter_height, filter_width, cm_block, out_channel_block \
-            = get_const_tuple(kernel.shape)
+        (
+            out_channel_chunk,
+            cm_chunk,
+            filter_height,
+            filter_width,
+            cm_block,
+            out_channel_block,
+        ) = get_const_tuple(kernel.shape)
         in_channel = in_channel_chunk * in_channel_block
         out_channel = out_channel_chunk * out_channel_block
         channel_multiplier = cm_chunk * cm_block
@@ -128,7 +144,8 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
     dilated_kernel_h = (filter_height - 1) * dh + 1
     dilated_kernel_w = (filter_width - 1) * dw + 1
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
     HPAD = pad_top + pad_down
     WPAD = pad_left + pad_right
 
@@ -143,9 +160,13 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
     # get workload and related schedule config
     wkl = _get_workload(
         te.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype),
-        te.placeholder((out_channel, channel_multiplier, filter_height, filter_width),
-                       dtype=kernel.dtype),
-        strides, (pad_top, pad_down), out_dtype)
+        te.placeholder(
+            (out_channel, channel_multiplier, filter_height, filter_width), dtype=kernel.dtype
+        ),
+        strides,
+        (pad_top, pad_down),
+        out_dtype,
+    )
     if cfg.is_fallback:
         _fallback_schedule(cfg, wkl)
 
@@ -165,11 +186,10 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
         else:
             data, kernel = _pack_data(cfg, data, kernel)
             _, _, _, _, in_channel_block = get_const_tuple(data.shape)
-            out_channel_chunk, _, _, _, _, out_channel_block \
-                = get_const_tuple(kernel.shape)
+            out_channel_chunk, _, _, _, _, out_channel_block = get_const_tuple(kernel.shape)
 
     # padding stage
-    DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0)
+    DOPAD = pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0
     if DOPAD:
         pad_before = [0, 0, pad_top, pad_left, 0]
         pad_after = [0, 0, pad_down, pad_right, 0]
@@ -177,27 +197,37 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
     else:
         data_pad = data
 
-
     # depthconv stage
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
 
-    kh = te.reduce_axis((0, filter_height), name='kh')
-    kw = te.reduce_axis((0, filter_width), name='kw')
+    kh = te.reduce_axis((0, filter_height), name="kh")
+    kw = te.reduce_axis((0, filter_width), name="kw")
     Output = te.compute(
         (batch, out_channel_chunk, out_height, out_width, out_channel_block),
         lambda b, oco, oh, ow, oci: te.sum(
-            (data_pad[
-                b,
-                idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block),
-                oh*HSTR+kh*dh, ow*WSTR+kw*dw,
-                idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)]
-             .astype(out_dtype) *
-             kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
-            axis=[kh, kw]),
-        name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc")
+            (
+                data_pad[
+                    b,
+                    idxdiv(
+                        idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block
+                    ),
+                    oh * HSTR + kh * dh,
+                    ow * WSTR + kw * dw,
+                    idxmod(
+                        idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block
+                    ),
+                ].astype(out_dtype)
+                * kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)
+            ),
+            axis=[kh, kw],
+        ),
+        name="DepthwiseConv2d",
+        tag="depthwise_conv2d_NCHWc",
+    )
     return Output
 
+
 @autotvm.register_topi_schedule("depthwise_conv2d_NCHWc.x86")
 def schedule_depthwise_conv2d_NCHWc(cfg, outs):
     """CPU schedule for depthwise conv2d in NCHW[x]c layout"""
@@ -206,7 +236,7 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
 
     def _callback(op):
         """Traverse operators from computation graph"""
-        if 'depthwise_conv2d_NCHWc' in op.tag:
+        if "depthwise_conv2d_NCHWc" in op.tag:
             conv_out = op.output(0)
             data = conv_out.op.input_tensors[0]
             kernel = conv_out.op.input_tensors[1]
@@ -215,20 +245,20 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
     tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
     unroll_kw = cfg["unroll_kw"].val
 
     # schedule pad
-    if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
-            and "pad" in data_vec.op.tag:
+    if isinstance(s[data_vec].op, tvm.te.ComputeOp) and "pad" in data_vec.op.tag:
         batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
         s[data_vec].vectorize(ic_block)
         parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
         s[data_vec].parallel(parallel_axis)
 
     C, O = conv_out, output
-    CC = s.cache_write(C, 'global')
+    CC = s.cache_write(C, "global")
 
     _, ic_chunk, oh, ow, ic_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=tile_ow)
@@ -271,6 +301,7 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
 
     return s
 
+
 @depthwise_conv2d_infer_layout.register("cpu")
 def _depthwise_conv2d_infer_layout(workload, cfg):
     _, data, kernel, strides, padding, dilation, _, _, dtype = workload
index 7c37ac7..eaf39db 100644 (file)
@@ -19,6 +19,7 @@
 from tvm import te
 from ..util import is_empty_shape
 
+
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -50,6 +51,7 @@ def schedule_injective_from_existing(sch, out):
         sch[out].vectorize(li)
     return sch
 
+
 def schedule_injective(outs):
     """X86 schedule for injective op.
 
@@ -73,6 +75,7 @@ def schedule_injective(outs):
         schedule_injective_from_existing(s, x)
     return s
 
+
 def schedule_concatenate(outs):
     """X86 schedule for concatenate op.
 
@@ -87,6 +90,7 @@ def schedule_concatenate(outs):
     sch: Schedule
         The computation schedule for the op.
     """
+
     def vectorize(sch, tensor, vectorize_limit):
         """Internal vectorization function for concatenate."""
         inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1]
@@ -118,5 +122,6 @@ def schedule_concatenate(outs):
         s[x].parallel(s[x].op.axis[0])
     return s
 
+
 schedule_elemwise = schedule_injective
 schedule_broadcast = schedule_injective
index 8f884b8..0994700 100644 (file)
@@ -18,6 +18,7 @@
 """x86 nn operators"""
 from tvm import te
 
+
 def schedule_softmax(outs):
     """Schedule for softmax
 
@@ -37,19 +38,23 @@ def schedule_softmax(outs):
     s = te.create_schedule([x.op for x in outs])
 
     op_tag = softmax.op.tag
-    if op_tag == 'softmax_output':
+    if op_tag == "softmax_output":
         exp = softmax.op.input_tensors[0]
         expsum = softmax.op.input_tensors[1]
         max_elem = s[exp].op.input_tensors[1]
-        axis = int(softmax.op.attrs['axis'])
-    elif op_tag == 'log_softmax_output':
+        axis = int(softmax.op.attrs["axis"])
+    elif op_tag == "log_softmax_output":
         exp = None
         max_elem = softmax.op.input_tensors[1]
         expsum = softmax.op.input_tensors[2]
         axis = 1
     else:
-        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
-                         Got {0}'.format(op_tag))
+        raise ValueError(
+            "Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}".format(
+                op_tag
+            )
+        )
 
     # only parallelize outer dimensions up to axis
     outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
index f7664d9..91108ac 100644 (file)
@@ -19,6 +19,7 @@
 from tvm import te
 from .. import tag
 
+
 def _parallel_sch(sch, oshape, do_vectorize=False):
     def vectorize(fused_axis, num_parallel_axis, vectorize_limit=64):
         """Internal vectorization utility function."""
@@ -95,7 +96,7 @@ def schedule_pool(outs, layout):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule pool
-        elif OP.tag.startswith('pool'):
+        elif OP.tag.startswith("pool"):
             # Average pool accumulation and division happens in different for loops (#3607).
             # To ensure good parallel support, apply multi-threading on the second loop.
             if OP != outs[0].op:
@@ -143,7 +144,7 @@ def schedule_adaptive_pool(outs):
                 if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                     traverse(tensor.op)
         # schedule pool
-        elif OP.tag.startswith('adaptive_pool'):
+        elif OP.tag.startswith("adaptive_pool"):
             if OP != outs[0].op:
                 output = outs[0]
                 output_fused = s[output].fuse(output.op.axis[0], output.op.axis[1])
index 0dfc3f2..69659de 100644 (file)
@@ -22,6 +22,7 @@ from .injective import schedule_injective_from_existing
 from .. import tag
 from ..util import get_const_tuple
 
+
 def _schedule_reduce(sch, op, is_idx_reduce=False):
     if is_idx_reduce:
         real_out = op.output(0)
@@ -97,12 +98,12 @@ def schedule_reduce(outs):
                 schedule_injective_from_existing(sch, operator)
             for tensor in operator.input_tensors:
                 traverse_after_reduce(tensor.op)
-        elif operator.tag == 'comm_reduce':
+        elif operator.tag == "comm_reduce":
             _schedule_reduce(sch, operator, is_idx_reduce=False)
             for tensor in operator.input_tensors:
                 if tensor.op not in scheduled_ops:
                     traverse_before_reduce(tensor.op)
-        elif operator.tag == 'comm_reduce_idx':
+        elif operator.tag == "comm_reduce_idx":
             _schedule_reduce(sch, operator, is_idx_reduce=True)
             input_tensors = operator.input_tensors[0].op.input_tensors
             for tensor in input_tensors:
index e5cfcfe..fd65053 100644 (file)
@@ -160,22 +160,36 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
                     output_val = 0.0
                     for iy in range(roi_bin_grid_h):
                         for ix in range(roi_bin_grid_w):
-                            output_val += w_pc[n, pre_calc_index, 0] \
-                                * data[roi_batch_index, c,
-                                       pos_pc[n, pre_calc_index, 2],
-                                       pos_pc[n, pre_calc_index, 0]] \
-                                + w_pc[n, pre_calc_index, 1] \
-                                * data[roi_batch_index, c,
-                                       pos_pc[n, pre_calc_index, 2],
-                                       pos_pc[n, pre_calc_index, 1]] \
-                                + w_pc[n, pre_calc_index, 2] \
-                                * data[roi_batch_index, c,
-                                       pos_pc[n, pre_calc_index, 3],
-                                       pos_pc[n, pre_calc_index, 0]] \
-                                + w_pc[n, pre_calc_index, 3] \
-                                * data[roi_batch_index, c,
-                                       pos_pc[n, pre_calc_index, 3],
-                                       pos_pc[n, pre_calc_index, 1]]
+                            output_val += (
+                                w_pc[n, pre_calc_index, 0]
+                                * data[
+                                    roi_batch_index,
+                                    c,
+                                    pos_pc[n, pre_calc_index, 2],
+                                    pos_pc[n, pre_calc_index, 0],
+                                ]
+                                + w_pc[n, pre_calc_index, 1]
+                                * data[
+                                    roi_batch_index,
+                                    c,
+                                    pos_pc[n, pre_calc_index, 2],
+                                    pos_pc[n, pre_calc_index, 1],
+                                ]
+                                + w_pc[n, pre_calc_index, 2]
+                                * data[
+                                    roi_batch_index,
+                                    c,
+                                    pos_pc[n, pre_calc_index, 3],
+                                    pos_pc[n, pre_calc_index, 0],
+                                ]
+                                + w_pc[n, pre_calc_index, 3]
+                                * data[
+                                    roi_batch_index,
+                                    c,
+                                    pos_pc[n, pre_calc_index, 3],
+                                    pos_pc[n, pre_calc_index, 1],
+                                ]
+                            )
                             pre_calc_index += 1
 
                     output_val /= count
@@ -221,13 +235,17 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
         _, _, height, width = get_const_tuple(data.shape)
         max_roi_bin_grid_h = math.ceil(height / pooled_size[0])
         max_roi_bin_grid_w = math.ceil(width / pooled_size[1])
-    max_pc_shape = (rois.shape[0], max_roi_bin_grid_h * max_roi_bin_grid_w
-                    * pooled_size[0] * pooled_size[1], 4)
+    max_pc_shape = (
+        rois.shape[0],
+        max_roi_bin_grid_h * max_roi_bin_grid_w * pooled_size[0] * pooled_size[1],
+        4,
+    )
     w_pc_buffer = full(max_pc_shape, data.dtype, 0)
     pos_pc_buffer = full(max_pc_shape, "int32", 0)
 
     pooled_size = tvm.runtime.convert(pooled_size)
     spatial_scale = tvm.tir.const(spatial_scale, "float32")
     sample_ratio = tvm.tir.const(sample_ratio, "int32")
-    return roi_align_nchw_ir(data, rois, w_pc_buffer, pos_pc_buffer,
-                             pooled_size, spatial_scale, sample_ratio)
+    return roi_align_nchw_ir(
+        data, rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
+    )
index 02cbd2d..8c4a387 100644 (file)
@@ -21,16 +21,17 @@ from tvm import te
 from ..util import traverse_inline, get_const_int
 from .util import get_fp32_len
 
+
 def schedule_sparse_dense(outs):
     """Create schedule for sparse dense"""
     s = te.create_schedule([x.op for x in outs])
+
     def _callback(op):
         simd_width = get_fp32_len()
         if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
             (_, v_i) = s[op].op.axis
             s[op].vectorize(v_i)
-            (y_o, y_i) = s[outs[0].op].split(
-                s[outs[0].op].op.axis[1], 2 * simd_width)
+            (y_o, y_i) = s[outs[0].op].split(s[outs[0].op].op.axis[1], 2 * simd_width)
             s[op].compute_at(s[outs[0]], y_o)
             s[outs[0].op].vectorize(y_i)
         if op.tag == "sparse_dense_bsrmm":
@@ -47,8 +48,7 @@ def schedule_sparse_dense(outs):
             s[y_bsrmm].compute_at(s[y_reshape], noi)
             s[y_reshape].vectorize(noi)
             if op != s[outs[0]].op:
-                (y_o, y_i) = s[outs[0].op].split(
-                    s[outs[0].op].op.axis[1], 2 * simd_width)
+                (y_o, y_i) = s[outs[0].op].split(s[outs[0].op].op.axis[1], 2 * simd_width)
                 s[y_reshape].compute_at(s[outs[0]], y_o)
                 s[outs[0].op].parallel(y_o)
                 s[outs[0].op].vectorize(y_i)
index 17c0b36..818765d 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Core kernel of dot product of 4 Int8 operations"""
-#pylint: disable=invalid-name
+# pylint: disable=invalid-name
 import tvm
 from tvm import te
 import tvm.target.codegen
@@ -25,8 +25,10 @@ def dot_16x1x16_uint8_int8_int32():
     """Dispatch the most optimized intrin depending on the target"""
     mcpu = tvm.target.Target.current().mcpu
 
-    assert mcpu in ("skylake-avx512", "cascadelake"), \
-        "An old Intel machine that does not have fast Int8 support."
+    assert mcpu in (
+        "skylake-avx512",
+        "cascadelake",
+    ), "An old Intel machine that does not have fast Int8 support."
     if mcpu == "skylake-avx512":
         return dot_16x1x16_uint8_int8_int32_skylake()
     # cascadelake
@@ -62,60 +64,67 @@ def dot_16x1x16_uint8_int8_int32_skylake():
         The Skylake int8 TensorIntrin that can be used in tensorizing schedule
     """
 
-    int32_lanes = 16 # 16 int32 lanes in AVX512
-    num_int8_elements = 4 # 4 int8 elements in int32
-    data = te.placeholder((num_int8_elements,), dtype='uint8', name='data')
-    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
-    k = te.reduce_axis((0, num_int8_elements), name='k')
-    C = te.compute((int32_lanes,),
-                   lambda i: te.sum(data[k].astype('int32') *
-                                    kernel[i, k].astype('int32'),
-                                    axis=k),
-                   name="C")
-
-    a_buffer = tvm.tir.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
-                                   offset_factor=1,
-                                   strides=[1])
-    b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
-                                   offset_factor=1,
-                                   strides=[te.var('ldw'), 1])
+    int32_lanes = 16  # 16 int32 lanes in AVX512
+    num_int8_elements = 4  # 4 int8 elements in int32
+    data = te.placeholder((num_int8_elements,), dtype="uint8", name="data")
+    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel")
+    k = te.reduce_axis((0, num_int8_elements), name="k")
+    C = te.compute(
+        (int32_lanes,),
+        lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    a_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype="uint8", name="a_buffer", offset_factor=1, strides=[1]
+    )
+    b_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1]
+    )
 
     def _intrin_func(ins, outs):
         def _instr(index):
             ib = tvm.tir.ir_builder.create()
             if index == 1:
-                ib.emit(outs[0].vstore(0, tvm.tir.const(0, 'int32x16')))
+                ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x16")))
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x4")
-            re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
-            vec_ai32 = re_int32.astype('int32x16')
-            vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+            re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
+            vec_ai32 = re_int32.astype("int32x16")
+            vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32)
             vec_b = ins[1].vload([0, 0], "int8x64")
             vec_one = tvm.tir.const(1, "int16x32")
             pair_reduction = tvm.tir.call_llvm_pure_intrin(
-                'int16x32',
-                'llvm.x86.avx512.pmaddubs.w.512',
-                tvm.tir.const(0, 'uint32'),
-                vec_a, vec_b)
+                "int16x32",
+                "llvm.x86.avx512.pmaddubs.w.512",
+                tvm.tir.const(0, "uint32"),
+                vec_a,
+                vec_b,
+            )
             quad_reduction = tvm.tir.call_llvm_pure_intrin(
-                'int32x16',
-                'llvm.x86.avx512.pmaddw.d.512',
-                tvm.tir.const(0, 'uint32'),
-                pair_reduction, vec_one)
+                "int32x16",
+                "llvm.x86.avx512.pmaddw.d.512",
+                tvm.tir.const(0, "uint32"),
+                pair_reduction,
+                vec_one,
+            )
             if index == 0:
                 ib.emit(outs[0].vstore(0, quad_reduction))
             else:
-                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
+                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16")))
             return ib.get()
 
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
 
-    buffer_params = {"offset_factor" : 1}
+    buffer_params = {"offset_factor": 1}
     return te.decl_tensor_intrin(
-        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
-        default_buffer_params=buffer_params)
+        C.op,
+        _intrin_func,
+        binds={data: a_buffer, kernel: b_buffer},
+        default_buffer_params=buffer_params,
+    )
 
 
 def dot_16x1x16_uint8_int8_int16():
@@ -149,22 +158,21 @@ def dot_16x1x16_uint8_int8_int16():
         The Skylake int8 TensorIntrin that can be used in tensorizing schedule
     """
 
-    int16_lanes = 4*32 # 4*32 int32 lanes in 4 AVX512 vector registers
-    num_int8_elements = 2 # 2 int8 elements in int16
-    data = te.placeholder((num_int8_elements,), dtype='uint8', name='data')
-    kernel = te.placeholder((int16_lanes, num_int8_elements), dtype='int8', name='kernel')
-    k = te.reduce_axis((0, num_int8_elements), name='k')
-    C = te.compute((int16_lanes, ),
-                   lambda i: te.sum(data[k].astype('int16') *
-                                    kernel[i, k].astype('int16'),
-                                    axis=k),
-                   name="C")
-
-    a_buffer = tvm.tir.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
-                                   offset_factor=1,
-                                   strides=[1])
-    b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
-                                   offset_factor=1)
+    int16_lanes = 4 * 32  # 4*32 int32 lanes in 4 AVX512 vector registers
+    num_int8_elements = 2  # 2 int8 elements in int16
+    data = te.placeholder((num_int8_elements,), dtype="uint8", name="data")
+    kernel = te.placeholder((int16_lanes, num_int8_elements), dtype="int8", name="kernel")
+    k = te.reduce_axis((0, num_int8_elements), name="k")
+    C = te.compute(
+        (int16_lanes,),
+        lambda i: te.sum(data[k].astype("int16") * kernel[i, k].astype("int16"), axis=k),
+        name="C",
+    )
+
+    a_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype="uint8", name="a_buffer", offset_factor=1, strides=[1]
+    )
+    b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype="int8", name="b_buffer", offset_factor=1)
     # strides=[te.var('ldw'), 1, 1])
 
     def _intrin_func(ins, outs):
@@ -172,34 +180,43 @@ def dot_16x1x16_uint8_int8_int16():
             ib = tvm.tir.ir_builder.create()
             if index == 1:
                 for i in range(4):
-                    ib.emit(outs[0].vstore([i*32], tvm.tir.const(0, 'int16x32')))
+                    ib.emit(outs[0].vstore([i * 32], tvm.tir.const(0, "int16x32")))
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x2")
-            re_int16 = tvm.tir.call_intrin('int16', 'tir.reinterpret', a_int8)
-            vec_ai16 = re_int16.astype('int16x32')
-            vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai16)
+            re_int16 = tvm.tir.call_intrin("int16", "tir.reinterpret", a_int8)
+            vec_ai16 = re_int16.astype("int16x32")
+            vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai16)
 
             for i in range(4):
-                vec_b = ins[1].vload([i*32, 0], "int8x64")
+                vec_b = ins[1].vload([i * 32, 0], "int8x64")
                 pair_reduction = tvm.tir.call_llvm_pure_intrin(
-                    'int16x32',
-                    'llvm.x86.avx512.pmaddubs.w.512',
-                    tvm.tir.const(0, 'uint32'),
-                    vec_a, vec_b)
+                    "int16x32",
+                    "llvm.x86.avx512.pmaddubs.w.512",
+                    tvm.tir.const(0, "uint32"),
+                    vec_a,
+                    vec_b,
+                )
                 if index == 0:
-                    ib.emit(outs[0].vstore([i*32], pair_reduction))
+                    ib.emit(outs[0].vstore([i * 32], pair_reduction))
                 else:
-                    ib.emit(outs[0].vstore([i*32], pair_reduction + outs[0].vload([i*32],
-                                                                                  'int16x32')))
+                    ib.emit(
+                        outs[0].vstore(
+                            [i * 32], pair_reduction + outs[0].vload([i * 32], "int16x32")
+                        )
+                    )
             return ib.get()
 
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
-    buffer_params = {"offset_factor" : 1}
+
+    buffer_params = {"offset_factor": 1}
     return te.decl_tensor_intrin(
-        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
-        default_buffer_params=buffer_params)
+        C.op,
+        _intrin_func,
+        binds={data: a_buffer, kernel: b_buffer},
+        default_buffer_params=buffer_params,
+    )
 
 
 def dot_16x1x16_uint8_int8_int32_cascadelake():
@@ -231,72 +248,81 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
         The Cascade Lake int8 TensorIntrin that can be used in tensorizing schedule
     """
 
-    int32_lanes = 16 # 16 int32 lanes in AVX512
-    num_int8_elements = 4 # 4 int8 elements in int32
-    data = te.placeholder((num_int8_elements,), dtype='uint8', name='data')
-    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
-    k = te.reduce_axis((0, num_int8_elements), name='k')
-    C = te.compute((int32_lanes,),
-                   lambda i: te.sum(data[k].astype('int32') *
-                                    kernel[i, k].astype('int32'),
-                                    axis=k),
-                   name="C")
-
-    a_buffer = tvm.tir.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
-                                   offset_factor=1,
-                                   strides=[1])
-    b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
-                                   offset_factor=1,
-                                   strides=[te.var('ldw'), 1])
+    int32_lanes = 16  # 16 int32 lanes in AVX512
+    num_int8_elements = 4  # 4 int8 elements in int32
+    data = te.placeholder((num_int8_elements,), dtype="uint8", name="data")
+    kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel")
+    k = te.reduce_axis((0, num_int8_elements), name="k")
+    C = te.compute(
+        (int32_lanes,),
+        lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    a_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype="uint8", name="a_buffer", offset_factor=1, strides=[1]
+    )
+    b_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1]
+    )
 
     def _intrin_func(ins, outs):
         def _instr(index):
             ib = tvm.tir.ir_builder.create()
             if index == 1:
-                ib.emit(outs[0].vstore(0, tvm.tir.const(0, 'int32x16')))
+                ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x16")))
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x4")
-            re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
-            vec_ai32 = re_int32.astype('int32x16')
+            re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
+            vec_ai32 = re_int32.astype("int32x16")
             vec_b = ins[1].vload([0, 0], "int8x64")
 
-            vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512'
+            vnni_inst_name = "llvm.x86.avx512.vpdpbusd.512"
             llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
 
-            if llvm_id != 0: # VNNI is available for current LLVM version
-                vec_bi32 = tvm.tir.call_intrin('int32x16', 'tir.reinterpret', vec_b)
+            if llvm_id != 0:  # VNNI is available for current LLVM version
+                vec_bi32 = tvm.tir.call_intrin("int32x16", "tir.reinterpret", vec_b)
                 vec_zero = tvm.tir.const(0, "int32x16")
                 quad_reduction = tvm.tir.call_llvm_pure_intrin(
-                    'int32x16',
-                    'llvm.x86.avx512.vpdpbusd.512',
-                    tvm.tir.const(0, 'uint32'),
+                    "int32x16",
+                    "llvm.x86.avx512.vpdpbusd.512",
+                    tvm.tir.const(0, "uint32"),
                     vec_zero,
-                    vec_ai32, vec_bi32)
-            else: # Fall back to the normal AVX512
-                vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+                    vec_ai32,
+                    vec_bi32,
+                )
+            else:  # Fall back to the normal AVX512
+                vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32)
                 vec_one = tvm.tir.const(1, "int16x32")
                 pair_reduction = tvm.tir.call_llvm_pure_intrin(
-                    'int16x32',
-                    'llvm.x86.avx512.pmaddubs.w.512',
-                    tvm.tir.const(0, 'uint32'),
-                    vec_a, vec_b)
+                    "int16x32",
+                    "llvm.x86.avx512.pmaddubs.w.512",
+                    tvm.tir.const(0, "uint32"),
+                    vec_a,
+                    vec_b,
+                )
                 quad_reduction = tvm.tir.call_llvm_pure_intrin(
-                    'int32x16',
-                    'llvm.x86.avx512.pmaddw.d.512',
-                    tvm.tir.const(0, 'uint32'),
-                    pair_reduction, vec_one)
+                    "int32x16",
+                    "llvm.x86.avx512.pmaddw.d.512",
+                    tvm.tir.const(0, "uint32"),
+                    pair_reduction,
+                    vec_one,
+                )
 
             if index == 0:
                 ib.emit(outs[0].vstore(0, quad_reduction))
             else:
-                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
+                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16")))
             return ib.get()
 
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
 
-    buffer_params = {"offset_factor" : 1}
+    buffer_params = {"offset_factor": 1}
     return te.decl_tensor_intrin(
-        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
-        default_buffer_params=buffer_params)
+        C.op,
+        _intrin_func,
+        binds={data: a_buffer, kernel: b_buffer},
+        default_buffer_params=buffer_params,
+    )
index f2a35d2..92c11a7 100644 (file)
@@ -21,6 +21,6 @@ import tvm
 def get_fp32_len():
     mcpu = tvm.target.Target.current().mcpu
     fp32_vec_len = 8
-    if mcpu in ('skylake-avx512', 'cascadelake'):
+    if mcpu in ("skylake-avx512", "cascadelake"):
         fp32_vec_len = 16
     return fp32_vec_len
index ddfa03b..d34b440 100755 (executable)
@@ -28,9 +28,10 @@ from tvm.relay import testing
 
 CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
 
+
 def _get_model(dshape):
-    data = relay.var('data', shape=dshape)
-    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
+    data = relay.var("data", shape=dshape)
+    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1] * 2)
     fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
     left, right = relay.split(fc, indices_or_sections=2, axis=1)
     one = relay.const(1, dtype="float32")
@@ -41,13 +42,13 @@ def main():
     dshape = (32, 16)
     net = _get_model(dshape)
     mod, params = testing.create_workload(net)
-    graph, lib, params = relay.build(
-        mod, 'llvm', params=params)
+    graph, lib, params = relay.build(mod, "llvm", params=params)
 
-    with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
+    with open(osp.join(CWD, "graph.json"), "w") as f_resnet:
         f_resnet.write(graph)
-    with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
+    with open(osp.join(CWD, "graph.params"), "wb") as f_params:
         f_params.write(relay.save_param_dict(params))
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     main()
index cb7c4f7..e743e48 100755 (executable)
@@ -29,27 +29,28 @@ from tvm.relay import testing
 
 
 def _get_model(dshape):
-    data = relay.var('data', shape=dshape)
-    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
+    data = relay.var("data", shape=dshape)
+    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1] * 2)
     fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
     left, right = relay.split(fc, indices_or_sections=2, axis=1)
     one = relay.const(1, dtype="float32")
     return relay.Tuple([(left + one), (right - one), fc])
 
+
 def main():
     dshape = (4, 8)
     net = _get_model(dshape)
     mod, params = testing.create_workload(net)
-    graph, lib, params = relay.build(
-        mod, 'llvm --system-lib', params=params)
+    graph, lib, params = relay.build(mod, "llvm --system-lib", params=params)
 
     out_dir = sys.argv[1]
-    lib.save(osp.join(sys.argv[1], 'graph.o'))
-    with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
+    lib.save(osp.join(sys.argv[1], "graph.o"))
+    with open(osp.join(out_dir, "graph.json"), "w") as f_resnet:
         f_resnet.write(graph)
 
-    with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
+    with open(osp.join(out_dir, "graph.params"), "wb") as f_params:
         f_params.write(relay.save_param_dict(params))
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     main()
index bf7e60a..2a9ca23 100755 (executable)
@@ -24,15 +24,17 @@ import sys
 import tvm
 from tvm import te
 
+
 def main():
-    n = te.var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = tvm.te.create_schedule(C.op)
     s[C].parallel(s[C].op.axis[0])
     print(tvm.lower(s, [A, B, C], simple_mode=True))
-    tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+    tvm.build(s, [A, B, C], "llvm --system-lib").save(osp.join(sys.argv[1], "test.o"))
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
index cb7353f..4b270fa 100755 (executable)
@@ -25,17 +25,19 @@ import tvm
 from tvm import te
 from tvm.contrib import cc
 
+
 def main():
-    n = te.var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = tvm.te.create_schedule(C.op)
     s[C].parallel(s[C].op.axis[0])
     print(tvm.lower(s, [A, B, C], simple_mode=True))
-    obj_file = osp.join(sys.argv[1], 'test.o')
-    tvm.build(s, [A, B, C], 'llvm').save(obj_file)
-    cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
+    obj_file = osp.join(sys.argv[1], "test.o")
+    tvm.build(s, [A, B, C], "llvm").save(obj_file)
+    cc.create_shared(osp.join(sys.argv[1], "test.so"), [obj_file])
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
index e598bde..42da22d 100755 (executable)
@@ -24,15 +24,19 @@ import sys
 import tvm
 from tvm import te
 
+
 def main():
-    n = te.var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = tvm.te.create_schedule(C.op)
     s[C].parallel(s[C].op.axis[0])
     print(tvm.lower(s, [A, B, C], simple_mode=True))
-    tvm.build(s, [A, B, C], 'llvm -mtriple=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+    tvm.build(s, [A, B, C], "llvm -mtriple=wasm32-unknown-unknown --system-lib").save(
+        osp.join(sys.argv[1], "test.o")
+    )
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
index 4dc1a2c..14e2eee 100644 (file)
@@ -30,19 +30,25 @@ from tvm import relay
 from tvm.relay import testing
 from tvm.contrib import graph_runtime, cc
 
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
 logger = logging.getLogger(__name__)
 
-parser = argparse.ArgumentParser(description='Resnet build example')
+parser = argparse.ArgumentParser(description="Resnet build example")
 aa = parser.add_argument
-aa('--build-dir', type=str, required=True, help='directory to put the build artifacts')
-aa('--pretrained', action='store_true', help='use a pretrained resnet')
-aa('--batch-size', type=int, default=1, help='input image batch size')
-aa('--opt-level', type=int, default=3,
-   help='level of optimization. 0 is unoptimized and 3 is the highest level')
-aa('--target', type=str, default='llvm', help='target context for compilation')
-aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
-aa('--image-name', type=str, default='cat.png', help='name of input image to download')
+aa("--build-dir", type=str, required=True, help="directory to put the build artifacts")
+aa("--pretrained", action="store_true", help="use a pretrained resnet")
+aa("--batch-size", type=int, default=1, help="input image batch size")
+aa(
+    "--opt-level",
+    type=int,
+    default=3,
+    help="level of optimization. 0 is unoptimized and 3 is the highest level",
+)
+aa("--target", type=str, default="llvm", help="target context for compilation")
+aa("--image-shape", type=str, default="3,224,224", help="input image dimensions")
+aa("--image-name", type=str, default="cat.png", help="name of input image to download")
 args = parser.parse_args()
 
 build_dir = args.build_dir
@@ -52,9 +58,10 @@ target = tvm.target.Target(args.target)
 image_shape = tuple(map(int, args.image_shape.split(",")))
 data_shape = (batch_size,) + image_shape
 
+
 def build(target_dir):
     """ Compiles resnet18 with TVM"""
-    deploy_lib = osp.join(target_dir, 'deploy_lib.o')
+    deploy_lib = osp.join(target_dir, "deploy_lib.o")
     if osp.exists(deploy_lib):
         return
 
@@ -64,15 +71,17 @@ def build(target_dir):
 
         # if `--pretrained` is enabled, it downloads a pretrained
         # resnet18 trained on imagenet1k dataset for image classification task
-        block = get_model('resnet18_v1', pretrained=True)
+        block = get_model("resnet18_v1", pretrained=True)
         net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
         # we want a probability so add a softmax operator
-        net = relay.Function(net.params, relay.nn.softmax(net.body),
-            None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
     else:
         # use random weights from relay.testing
         net, params = relay.testing.resnet.get_workload(
-            num_layers=18, batch_size=batch_size, image_shape=image_shape)
+            num_layers=18, batch_size=batch_size, image_shape=image_shape
+        )
 
     # compile the model
     with tvm.transform.PassContext(opt_level=opt_level):
@@ -80,26 +89,30 @@ def build(target_dir):
 
     # save the model artifacts
     lib.save(deploy_lib)
-    cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
-                    [osp.join(target_dir, "deploy_lib.o")])
+    cc.create_shared(osp.join(target_dir, "deploy_lib.so"), [osp.join(target_dir, "deploy_lib.o")])
 
     with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
         fo.write(graph)
 
-    with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
+    with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo:
         fo.write(relay.save_param_dict(params))
 
+
 def download_img_labels():
     """ Download an image and imagenet1k class labels for test"""
     from mxnet.gluon.utils import download
 
-    img_name = 'cat.png'
-    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-    synset_name = 'synset.txt'
-    download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
+    img_name = "cat.png"
+    synset_url = "".join(
+        [
+            "https://gist.githubusercontent.com/zhreshold/",
+            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+            "imagenet1000_clsid_to_human.txt",
+        ]
+    )
+    synset_name = "synset.txt"
+    download("https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true", img_name)
     download(synset_url, synset_name)
 
     with open(synset_name) as fin:
@@ -109,11 +122,12 @@ def download_img_labels():
         w = csv.writer(fout)
         w.writerows(synset.items())
 
+
 def test_build(build_dir):
     """ Sanity check with random input"""
     graph = open(osp.join(build_dir, "deploy_graph.json")).read()
     lib = tvm.runtime.load_module(osp.join(build_dir, "deploy_lib.so"))
-    params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
+    params = bytearray(open(osp.join(build_dir, "deploy_param.params"), "rb").read())
     input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
     ctx = tvm.cpu()
     module = graph_runtime.create(graph, lib, ctx)
@@ -122,7 +136,7 @@ def test_build(build_dir):
     out = module.get_output(0).asnumpy()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     logger.info("building the model")
     build(build_dir)
     logger.info("build was successful")
index f781aa0..b9672fb 100755 (executable)
@@ -25,25 +25,24 @@ from tvm.contrib import cc
 
 
 def main(target, out_dir):
-    n = te.var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda i: A[i] + B[i], name='C')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
     s = te.create_schedule(C.op)
 
-    if target == 'cuda':
+    if target == "cuda":
         bx, tx = s[C].split(C.op.axis[0], factor=64)
-        s[C].bind(bx, te.thread_axis('blockIdx.x'))
-        s[C].bind(tx, te.thread_axis('threadIdx.x'))
+        s[C].bind(bx, te.thread_axis("blockIdx.x"))
+        s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
-    fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
+    fadd = tvm.build(s, [A, B, C], target, target_host="llvm", name="myadd")
 
-    fadd.save(osp.join(out_dir, 'test_add.o'))
-    if target == 'cuda':
-        fadd.imported_modules[0].save(osp.join(out_dir, 'test_add.ptx'))
-    cc.create_shared(
-        osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')])
+    fadd.save(osp.join(out_dir, "test_add.o"))
+    if target == "cuda":
+        fadd.imported_modules[0].save(osp.join(out_dir, "test_add.ptx"))
+    cc.create_shared(osp.join(out_dir, "test_add.so"), [osp.join(out_dir, "test_add.o")])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main(sys.argv[1], sys.argv[2])
index df21c2d..3227bc3 100644 (file)
@@ -26,46 +26,48 @@ def check_output(args, **kw):
     proc = subprocess.Popen(args, **kw, stdout=subprocess.PIPE)
     out, _ = proc.communicate()
     if proc.returncode:
-      sys.stderr.write('exited with code %d: %s\n' % (proc.returncode, ' '.join(args)))
-      sys.exit(2)
+        sys.stderr.write("exited with code %d: %s\n" % (proc.returncode, " ".join(args)))
+        sys.exit(2)
 
     if sys.version_info[0] == 2:
-      return unicode(out, 'utf-8')
+        return unicode(out, "utf-8")
     else:
-      return str(out, 'utf-8')
+        return str(out, "utf-8")
 
 
 def main():
     script_dir = os.path.dirname(__file__) or os.getcwd()
-    toplevel_dir = check_output(['git', 'rev-parse', '--show-toplevel'], cwd=script_dir).strip('\n')
+    toplevel_dir = check_output(["git", "rev-parse", "--show-toplevel"], cwd=script_dir).strip("\n")
     # NOTE: --ignore-submodules because this can drag in some problems related to mounting a git
     # worktree in the docker VM in a different location than it exists on the host. The problem
     # isn't quite clear, but anyhow it shouldn't be necessary to filter untracked files in
     # submodules here.
-    git_status_output = check_output(['git', 'status', '-s', '--ignored'],
-                                     cwd=toplevel_dir)
-    untracked = [line[3:]
-                 for line in git_status_output.split('\n')
-                 if line.startswith('?? ') or line.startswith('!! ')]
+    git_status_output = check_output(["git", "status", "-s", "--ignored"], cwd=toplevel_dir)
+    untracked = [
+        line[3:]
+        for line in git_status_output.split("\n")
+        if line.startswith("?? ") or line.startswith("!! ")
+    ]
 
     # also add .git in case rat picks up files in .git or the .git file (if a worktree).
-    toplevel_git_dentry = os.path.join(toplevel_dir, '.git')
+    toplevel_git_dentry = os.path.join(toplevel_dir, ".git")
     if os.path.isfile(toplevel_git_dentry):
-        untracked.append('.git')
+        untracked.append(".git")
     else:
-        untracked.append('.git/')
+        untracked.append(".git/")
 
     for line in sys.stdin:
         cleaned_line = line
-        if line[:2] == './':
+        if line[:2] == "./":
             cleaned_line = line[2:]
-        cleaned_line = cleaned_line.strip('\n')
-        if any((cleaned_line.startswith(u) if u[-1] == '/' else cleaned_line == u)
-               for u in untracked):
+        cleaned_line = cleaned_line.strip("\n")
+        if any(
+            (cleaned_line.startswith(u) if u[-1] == "/" else cleaned_line == u) for u in untracked
+        ):
             continue
 
         sys.stdout.write(line)
 
 
-if __name__ == '__main__':
-  main()
+if __name__ == "__main__":
+    main()
index 2a3f88c..bf9539c 100644 (file)
@@ -82,7 +82,7 @@ enable=indexing-exception,old-raise-syntax
 # --enable=similarities". If you want to run only the classes checker, but have
 # no Warning level messages displayed, use"--disable=all --enable=classes
 # --disable=W"
-disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,useless-object-inheritance,consider-using-get
+disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,useless-object-inheritance,consider-using-get,bad-continuation,too-many-lines
 
 [REPORTS]
 
index 4b76efe..cc4066d 100644 (file)
@@ -32,7 +32,8 @@ from tvm.micro import create_micro_mod
 # Ex : export CMSIS_ST_PATH="/home/yourid/st/STM32Cube_FW_F7_V1.16.0/Drivers/CMSIS"
 DEV_CONFIG_A = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
 DEV_CONFIG_B = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
-TARGET = 'micro_dev'
+TARGET = "micro_dev"
+
 
 def relay_micro_build(func, dev_config, params=None):
     """Create a graph runtime module with a micro device context from a Relay function.
@@ -53,9 +54,9 @@ def relay_micro_build(func, dev_config, params=None):
     mod : tvm.runtime.Module
         graph runtime module for the target device
     """
-    with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={
-        "tir.disable_vectorize": True
-    }):
+    with tvm.transform.PassContext(
+        disabled_pass={"FuseOps"}, config={"tir.disable_vectorize": True}
+    ):
         graph, c_mod, params = relay.build(func, target=TARGET, params=params)
     micro_mod = micro.create_micro_mod(c_mod, dev_config)
     ctx = tvm.micro_dev(0)
@@ -73,14 +74,14 @@ break UTVMDone
 
 
 def reset_gdbinit():
-    if 'server_port' not in DEV_CONFIG_A:
+    if "server_port" not in DEV_CONFIG_A:
         return
     try:
-        gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR']
+        gdb_init_dir = os.environ["MICRO_GDB_INIT_DIR"]
     except KeyError:
         return
-    with open(f'{gdb_init_dir}/.gdbinit', 'w') as f:
-        gdb_port = DEV_CONFIG_A['server_port'] - 3333
+    with open(f"{gdb_init_dir}/.gdbinit", "w") as f:
+        gdb_port = DEV_CONFIG_A["server_port"] - 3333
         f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port))
 
 
@@ -129,13 +130,10 @@ def test_add():
         micro_func(a, b, c)
 
         # ensure inputs weren't corrupted
-        tvm.testing.assert_allclose(
-                a.asnumpy(), a_np)
-        tvm.testing.assert_allclose(
-                b.asnumpy(), b_np)
+        tvm.testing.assert_allclose(a.asnumpy(), a_np)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np)
         # ensure output is correct
-        tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
 
 def test_workspace_add():
@@ -168,11 +166,9 @@ def test_workspace_add():
         micro_func(a, c)
 
         # ensure input wasn't corrupted
-        tvm.testing.assert_allclose(
-                a.asnumpy(), a_np)
+        tvm.testing.assert_allclose(a.asnumpy(), a_np)
         # ensure output is correct
-        tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + 2.0)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 2.0)
 
 
 def test_graph_runtime():
@@ -195,10 +191,8 @@ def test_graph_runtime():
         mod.run(x=x_in)
         result = mod.get_output(0).asnumpy()
 
-        tvm.testing.assert_allclose(
-                mod.get_input(0).asnumpy(), x_in)
-        tvm.testing.assert_allclose(
-                result, x_in * x_in + 1.0)
+        tvm.testing.assert_allclose(mod.get_input(0).asnumpy(), x_in)
+        tvm.testing.assert_allclose(result, x_in * x_in + 1.0)
 
 
 def test_conv2d():
@@ -209,29 +203,23 @@ def test_conv2d():
     from tvm.relay import transform
 
     dshape = (1, 4, 16, 16)
-    dtype = 'int8'
-    func_name = 'fused_nn_conv2d'
+    dtype = "int8"
+    func_name = "fused_nn_conv2d"
 
     reset_gdbinit()
 
     # Construct Relay program.
     x = relay.var("x", shape=dshape, dtype=dtype)
-    conv_expr = relay.nn.conv2d(
-            x, relay.var("w"),
-            kernel_size=(3, 3),
-            padding=(1, 1),
-            channels=4)
+    conv_expr = relay.nn.conv2d(x, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=4)
     func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr)
     mod = tvm.IRModule.from_expr(func)
     mod = transform.InferType()(mod)
 
-    x_shape = list(map(lambda x: x.value, mod['main'].params[0].checked_type.shape))
-    w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape))
-    out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape))
+    x_shape = list(map(lambda x: x.value, mod["main"].params[0].checked_type.shape))
+    w_shape = list(map(lambda x: x.value, mod["main"].params[1].checked_type.shape))
+    out_shape = list(map(lambda x: x.value, mod["main"].ret_type.shape))
 
-    with tvm.transform.PassContext(config={
-        "tir.disable_vectorize": True
-    }):
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
         graph, c_mod, params = relay.build(mod, target="c")
 
     with micro.Session(DEV_CONFIG_A):
@@ -242,7 +230,7 @@ def test_conv2d():
                 micro_func = micro_mod[candidate_func_name]
                 break
             except tvm.TVMError as e:
-                candidate_func_name = f'{func_name}_{i}'
+                candidate_func_name = f"{func_name}_{i}"
         else:
             assert False
         ctx = tvm.micro_dev(0)
@@ -253,9 +241,9 @@ def test_conv2d():
         micro_func(x_data, w_data, result)
 
         out_data = np.zeros(out_shape, dtype=dtype)
-        params = { 'x': x_data.asnumpy(), 'w': w_data.asnumpy() }
-        intrp = create_executor('debug')
-        expected_result = intrp.evaluate(mod['main'])(x_data, w_data)
+        params = {"x": x_data.asnumpy(), "w": w_data.asnumpy()}
+        intrp = create_executor("debug")
+        expected_result = intrp.evaluate(mod["main"])(x_data, w_data)
 
         tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy())
 
@@ -284,14 +272,12 @@ def test_interleave_sessions():
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
     with sess_b:
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B)
         add_const_mod.run(x=micro_tensor_b)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_b + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_b + 1.0)
 
 
 def test_nested_sessions():
@@ -317,8 +303,7 @@ def test_nested_sessions():
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
 
 
 def test_inactive_session_use():
@@ -344,8 +329,7 @@ def test_inactive_session_use():
         # These objects belong to `sess_a`.
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
 
 
 # TODO add workspace alloc/free stress test
@@ -353,34 +337,34 @@ def test_inactive_session_use():
 if __name__ == "__main__":
     test_alloc()
     print()
-    print('finished alloc test')
-    input('[press enter to continue]')
+    print("finished alloc test")
+    input("[press enter to continue]")
     test_add()
     print()
-    print('finished add test')
-    input('[press enter to continue]')
+    print("finished add test")
+    input("[press enter to continue]")
     test_workspace_add()
     print()
-    print('finished workspace add test')
-    input('[press enter to continue]')
+    print("finished workspace add test")
+    input("[press enter to continue]")
     test_graph_runtime()
     print()
-    print('finished graph runtime test')
-    input('[press enter to continue]')
+    print("finished graph runtime test")
+    input("[press enter to continue]")
     test_conv2d()
     print()
-    print('finished conv2d test')
-    input('[press enter to continue]')
+    print("finished conv2d test")
+    input("[press enter to continue]")
     # disable for now as these are currently broken
-    #test_interleave_sessions()
-    #print()
-    #print('finished interleaved sessions test')
-    #input('[press enter to continue]')
+    # test_interleave_sessions()
+    # print()
+    # print('finished interleaved sessions test')
+    # input('[press enter to continue]')
     # test_nested_sessions()
-    #print()
-    #print('finished nested sessions test')
-    #input('[press enter to continue]')
+    # print()
+    # print('finished nested sessions test')
+    # input('[press enter to continue]')
     test_inactive_session_use()
     print()
-    print('finished use inactive session test')
-    input('[press enter to continue]')
+    print("finished use inactive session test")
+    input("[press enter to continue]")
index cc4818e..42b111b 100644 (file)
@@ -67,6 +67,7 @@ class Device:
     cross_compile : str
         Specify path to cross compiler to use when connecting a remote device from a non-arm platform.
     """
+
     connection_type = "local"
     host = "localhost"
     port = 9090
@@ -82,17 +83,15 @@ class Device:
     def _get_remote(cls):
         """Get a remote (or local) device to use for testing."""
         if cls.connection_type == "tracker":
-            device = request_remote(cls.device_key,
-                                    cls.host,
-                                    cls.port,
-                                    timeout=1000)
+            device = request_remote(cls.device_key, cls.host, cls.port, timeout=1000)
         elif cls.connection_type == "remote":
             device = rpc.connect(cls.host, cls.port)
         elif cls.connection_type == "local":
             device = rpc.LocalSession()
         else:
-            raise ValueError("connection_type in test_config.json should be one of: "
-                             "local, tracker, remote.")
+            raise ValueError(
+                "connection_type in test_config.json should be one of: " "local, tracker, remote."
+            )
 
         return device
 
@@ -106,7 +105,9 @@ class Device:
         location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
         config_file = os.path.join(location, file_name)
         if not os.path.exists(config_file):
-            warnings.warn("Config file doesn't exist, resuming Arm Compute Library tests with default config.")
+            warnings.warn(
+                "Config file doesn't exist, resuming Arm Compute Library tests with default config."
+            )
             return
         with open(config_file, mode="r") as config:
             test_config = json.load(config)
@@ -121,6 +122,7 @@ class Device:
 
 def get_cpu_op_count(mod):
     """Traverse graph counting ops offloaded to TVM."""
+
     class Counter(tvm.relay.ExprVisitor):
         def __init__(self):
             super().__init__()
@@ -146,7 +148,10 @@ def skip_runtime_test():
 
     # Remote device is in use or ACL runtime not present
     # Note: Ensure that the device config has been loaded before this check
-    if not Device.connection_type != "local" and not arm_compute_lib.is_arm_compute_runtime_enabled():
+    if (
+        not Device.connection_type != "local"
+        and not arm_compute_lib.is_arm_compute_runtime_enabled()
+    ):
         print("Skip because runtime isn't present or a remote device isn't being used.")
         return True
 
@@ -166,22 +171,35 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti
         if enable_acl:
             mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params)
             tvm_op_count = get_cpu_op_count(mod)
-            assert tvm_op_count == tvm_ops, \
-                "Got {} TVM operators, expected {}".format(tvm_op_count, tvm_ops)
+            assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format(
+                tvm_op_count, tvm_ops
+            )
             partition_count = 0
             for global_var in mod.get_global_vars():
                 if "arm_compute_lib" in global_var.name_hint:
                     partition_count += 1
 
-            assert acl_partitions == partition_count, \
-                "Got {} Arm Compute Library partitions, expected {}".format(
-                    partition_count, acl_partitions)
+            assert (
+                acl_partitions == partition_count
+            ), "Got {} Arm Compute Library partitions, expected {}".format(
+                partition_count, acl_partitions
+            )
         relay.backend.compile_engine.get().clear()
         return relay.build(mod, target=target, params=params)
 
 
-def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1,
-                  tvm_ops=0, acl_partitions=1, config=None):
+def build_and_run(
+    mod,
+    inputs,
+    outputs,
+    params,
+    device,
+    enable_acl=True,
+    no_runs=1,
+    tvm_ops=0,
+    acl_partitions=1,
+    config=None,
+):
     """Build and run the relay module."""
     if config is None:
         config = {}
@@ -196,7 +214,7 @@ def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs
         raise Exception(err_msg)
 
     lib = update_lib(lib, device.device, device.cross_compile)
-    gen_module = graph_runtime.GraphModule(lib['default'](device.device.cpu(0)))
+    gen_module = graph_runtime.GraphModule(lib["default"](device.device.cpu(0)))
     gen_module.set_input(**inputs)
     out = []
     for _ in range(no_runs):
@@ -225,18 +243,20 @@ def verify(answers, atol, rtol, verify_saturation=False, config=None):
         config = {}
 
     if len(answers) < 2:
-        raise RuntimeError(
-            f"No results to compare: expected at least two, found {len(answers)}")
+        raise RuntimeError(f"No results to compare: expected at least two, found {len(answers)}")
     for answer in zip_longest(*answers):
         for outs in combinations(answer, 2):
             try:
                 if verify_saturation:
-                    assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
-                        "Output is saturated: {}".format(outs[0])
-                    assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
-                        "Output is saturated: {}".format(outs[0])
+                    assert (
+                        np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size
+                    ), "Output is saturated: {}".format(outs[0])
+                    assert (
+                        np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size
+                    ), "Output is saturated: {}".format(outs[0])
                 tvm.testing.assert_allclose(
-                   outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
+                    outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol
+                )
             except AssertionError as e:
                 err_msg = "Results not within the acceptable tolerance.\n"
                 if config:
@@ -247,19 +267,25 @@ def verify(answers, atol, rtol, verify_saturation=False, config=None):
 
 def extract_acl_modules(module):
     """Get the ACL module(s) from llvm module."""
-    return list(filter(lambda mod: mod.type_key == "arm_compute_lib",
-                       module.get_lib().imported_modules))
+    return list(
+        filter(lambda mod: mod.type_key == "arm_compute_lib", module.get_lib().imported_modules)
+    )
 
 
-def verify_codegen(module, known_good_codegen, num_acl_modules,
-                   target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon"):
+def verify_codegen(
+    module,
+    known_good_codegen,
+    num_acl_modules,
+    target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon",
+):
     """Check acl codegen against a known good output."""
     module = build_module(module, target)
     acl_modules = extract_acl_modules(module)
 
-    assert len(acl_modules) == num_acl_modules, \
-        f"The number of Arm Compute Library modules produced ({len(acl_modules)}) does not " \
+    assert len(acl_modules) == num_acl_modules, (
+        f"The number of Arm Compute Library modules produced ({len(acl_modules)}) does not "
         f"match the expected value ({num_acl_modules})."
+    )
 
     for mod in acl_modules:
         source = mod.get_source("json")
@@ -271,10 +297,11 @@ def verify_codegen(module, known_good_codegen, num_acl_modules,
         codegen_str = json.dumps(codegen, sort_keys=True, indent=2)
         known_good_codegen_str = json.dumps(known_good_codegen, sort_keys=True, indent=2)
 
-        assert codegen_str == known_good_codegen_str, \
-            f"The JSON produced by codegen does not match the expected result. \n" \
-            f"Actual={codegen_str} \n" \
+        assert codegen_str == known_good_codegen_str, (
+            f"The JSON produced by codegen does not match the expected result. \n"
+            f"Actual={codegen_str} \n"
             f"Expected={known_good_codegen_str}"
+        )
 
 
 def generate_trials(space, r_factor=3):
index 37575cc..4496a2a 100644 (file)
@@ -21,14 +21,32 @@ import numpy as np
 import tvm
 from tvm import relay
 
-from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \
-    verify, verify_codegen, generate_trials
+from .infrastructure import (
+    skip_runtime_test,
+    skip_codegen_test,
+    build_and_run,
+    verify,
+    verify_codegen,
+    generate_trials,
+)
 from .infrastructure import Device
 
 
-def _get_model(shape, kernel_h, kernel_w, padding, strides,
-               dilation, groups, dtype, channels, var_names,
-               has_bias=False, has_activation=False, has_pad=False):
+def _get_model(
+    shape,
+    kernel_h,
+    kernel_w,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    channels,
+    var_names,
+    has_bias=False,
+    has_activation=False,
+    has_pad=False,
+):
     """Return a model and any parameters it may have"""
     a = relay.var(next(var_names), shape=shape, dtype=dtype)
     if has_pad:
@@ -38,8 +56,7 @@ def _get_model(shape, kernel_h, kernel_w, padding, strides,
     else:
         if len(padding) == 2:
             padding = (padding[0], padding[1], padding[0], padding[1])
-        shape = (shape[0], shape[1] + padding[0] * 2,
-                 shape[2] + padding[1] * 2, shape[3])
+        shape = (shape[0], shape[1] + padding[0] * 2, shape[2] + padding[1] * 2, shape[3])
     weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels)
     w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
@@ -54,7 +71,7 @@ def _get_model(shape, kernel_h, kernel_w, padding, strides,
         padding=padding,
         groups=groups,
         channels=channels,
-        out_dtype=dtype
+        out_dtype=dtype,
     )
     params = {"w": w}
     if has_bias:
@@ -70,26 +87,43 @@ def _get_model(shape, kernel_h, kernel_w, padding, strides,
 def _get_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
     """Get output qnn parameters given input and kernel parameters."""
     input_max = input_sc * (255 - input_zp)
-    input_min = - input_sc * input_zp
+    input_min = -input_sc * input_zp
     kernel_max = kernel_sc * (255 - kernel_zp)
-    kernel_min = - kernel_sc * kernel_zp
-    output_limits = [kernel_max * kernel_h * kernel_w * channels * input_max,
-                     kernel_min * kernel_h * kernel_w * channels * input_max,
-                     kernel_min * kernel_h * kernel_w * channels * input_min,
-                     kernel_max * kernel_h * kernel_w * channels * input_min]
+    kernel_min = -kernel_sc * kernel_zp
+    output_limits = [
+        kernel_max * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_min,
+        kernel_max * kernel_h * kernel_w * channels * input_min,
+    ]
     output_max = max(output_limits)
     output_min = min(output_limits)
     output_sc = (output_max - output_min) / 255
-    output_zp = - int(output_min / output_sc)
+    output_zp = -int(output_min / output_sc)
     return output_zp, output_sc
 
 
-def _get_qnn_model(shape, kernel_h, kernel_w,
-                   padding, strides, dilation, groups, dtype,
-                   channels, input_zp, input_sc,
-                   kernel_zp, kernel_sc, output_zp,
-                   output_sc, var_names, has_bias=False,
-                   has_activation=False, has_pad=False):
+def _get_qnn_model(
+    shape,
+    kernel_h,
+    kernel_w,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    channels,
+    input_zp,
+    input_sc,
+    kernel_zp,
+    kernel_sc,
+    output_zp,
+    output_sc,
+    var_names,
+    has_bias=False,
+    has_activation=False,
+    has_pad=False,
+):
     """Return a model and any parameters it may have."""
     a = relay.var(next(var_names), shape=shape, dtype=dtype)
     if has_pad:
@@ -99,8 +133,7 @@ def _get_qnn_model(shape, kernel_h, kernel_w,
     else:
         if len(padding) == 2:
             padding = (padding[0], padding[1], padding[0], padding[1])
-        shape = (shape[0], shape[1] + padding[0] * 2,
-                 shape[2] + padding[1] * 2, shape[3])
+        shape = (shape[0], shape[1] + padding[0] * 2, shape[2] + padding[1] * 2, shape[3])
     weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels)
     w = tvm.nd.array(np.random.uniform(0, 255, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
@@ -119,30 +152,40 @@ def _get_qnn_model(shape, kernel_h, kernel_w,
         padding=padding,
         groups=groups,
         channels=channels,
-        out_dtype="int32"
+        out_dtype="int32",
     )
     params = {"w": w}
     if has_bias:
         b = tvm.nd.array(np.random.uniform(0, 255, weight_shape[3]).astype("int32"))
         biasc = relay.const(b, "int32")
         out = relay.nn.bias_add(out, biasc, axis=3)
-        params['b'] = b
+        params["b"] = b
     if has_activation:
         out = relay.nn.relu(out)
     req = relay.qnn.op.requantize(
         out,
-        relay.const(input_sc * kernel_sc, 'float32'),  # input scale
-        relay.const(0, 'int32'),  # input zero point
-        relay.const(output_sc, 'float32'),  # output scale
-        relay.const(output_zp, 'int32'),  # output zero point
-        out_dtype="uint8"
+        relay.const(input_sc * kernel_sc, "float32"),  # input scale
+        relay.const(0, "int32"),  # input zero point
+        relay.const(output_sc, "float32"),  # output scale
+        relay.const(output_zp, "int32"),  # output zero point
+        out_dtype="uint8",
     )
     return req, params
 
 
-def _get_expected_codegen(shape, kernel_h, kernel_w, padding, strides,
-                          dilation, groups, dtype, channels,
-                          has_bias=False, has_activation=False):
+def _get_expected_codegen(
+    shape,
+    kernel_h,
+    kernel_w,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    channels,
+    has_bias=False,
+    has_activation=False,
+):
     if len(padding) == 2:
         padding = (padding[0], padding[1], padding[0], padding[1])
     weight_shape = (channels, kernel_h, kernel_w, shape[3] // groups)
@@ -168,62 +211,51 @@ def _get_expected_codegen(shape, kernel_h, kernel_w, padding, strides,
             "shape": [[list(output_shape)]],
             "dtype": [[dtype]],
             "padding": [[str(p) for p in padding]],
-            "strides": [[str(s) for s in strides]]
+            "strides": [[str(s) for s in strides]],
         },
     }
 
     if has_activation:
         node["attrs"]["activation_type"] = [["relu"]]
 
-    inputs = [{
-        "op": "input",
-        "name": "",
-        "attrs": {
-            "shape": [[list(shape)]],
-            "dtype": [[str(dtype)]]
-        }}, {
-        "op": "const",
-        "name": "",
-        "attrs": {
-            "shape": [[list(weight_shape)]],
-            "dtype": [[str(dtype)]]
-        }}]
+    inputs = [
+        {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}},
+        {
+            "op": "const",
+            "name": "",
+            "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]},
+        },
+    ]
 
     # qnn.conv2d params, input and kernel
     if dtype == "uint8":
         node["name"] = "qnn.conv2d"
         for param_dtype in ["int32", "float32"]:
             for _ in range(2):
-                inputs.append({
-                    "op": "const",
-                    "name": "",
-                    "attrs": {
-                        "shape": [[[]]],
-                        "dtype": [[param_dtype]]
+                inputs.append(
+                    {
+                        "op": "const",
+                        "name": "",
+                        "attrs": {"shape": [[[]]], "dtype": [[param_dtype]]},
                     }
-                })
+                )
 
     if has_bias:
         bias_dtype = "int32" if dtype == "uint8" else "float32"
-        inputs.append({
-            "op": "const",
-            "name": "",
-            "attrs": {
-                "shape": [[[weight_shape[0]]]],
-                "dtype": [[bias_dtype]]}
-        })
+        inputs.append(
+            {
+                "op": "const",
+                "name": "",
+                "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]},
+            }
+        )
 
     # qnn.conv2d params, output
     if dtype == "uint8":
         for param_dtype in ["float32", "int32"]:
-            inputs.append({
-                "op": "const",
-                "name": "",
-                "attrs": {
-                    "shape": [[[]]],
-                    "dtype": [[param_dtype]]
-                }
-            })
+            inputs.append(
+                {"op": "const", "name": "", "attrs": {"shape": [[[]]], "dtype": [[param_dtype]]}}
+            )
 
     input_idx = 0
     for _ in range(len(inputs)):
@@ -251,11 +283,17 @@ def test_conv2d():
     out_channels = [4, 7, 16]
     input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
     # composite operator (pad, bias, activation)
-    composite = [(False, False, False), (False, True, False), (False, False, True),
-                 (False, True, True), (True, False, False)]
+    composite = [
+        (False, False, False),
+        (False, True, False),
+        (False, False, True),
+        (False, True, True),
+        (True, False, False),
+    ]
     dtype = "float32"
-    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
-                              input_shapes, composite], 3)
+    trials = generate_trials(
+        [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3
+    )
 
     for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
         groups = 1
@@ -265,16 +303,23 @@ def test_conv2d():
             "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)),
         }
 
-        func, params = _get_model(shape, kernel_h, kernel_w,
-                                  pad, stride, dilation, groups,
-                                  dtype, out_channels, iter(inputs),
-                                  has_pad=composite[0],
-                                  has_bias=composite[1],
-                                  has_activation=composite[2])
+        func, params = _get_model(
+            shape,
+            kernel_h,
+            kernel_w,
+            pad,
+            stride,
+            dilation,
+            groups,
+            dtype,
+            out_channels,
+            iter(inputs),
+            has_pad=composite[0],
+            has_bias=composite[1],
+            has_activation=composite[2],
+        )
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1,
-                                         params, device,
-                                         enable_acl=acl)[0])
+            outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
 
         config = {
             "shape": shape,
@@ -284,7 +329,7 @@ def test_conv2d():
             "stride": stride,
             "dilation": dilation,
             "out channels": out_channels,
-            "composite operators (pad, bias, activation)": composite
+            "composite operators (pad, bias, activation)": composite,
         }
         verify(outputs, atol=0.002, rtol=0.01, config=config)
 
@@ -303,11 +348,17 @@ def test_codegen_conv2d():
     out_channels = [4, 7, 16]
     input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
     # composite operator (pad, bias, activation)
-    composite = [(False, False, False), (False, True, False), (False, False, True),
-                 (False, True, True), (True, False, False)]
+    composite = [
+        (False, False, False),
+        (False, True, False),
+        (False, False, True),
+        (False, True, True),
+        (True, False, False),
+    ]
     dtype = "float32"
-    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
-                              input_shapes, composite], 3)
+    trials = generate_trials(
+        [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3
+    )
 
     for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
         groups = 1
@@ -316,13 +367,16 @@ def test_codegen_conv2d():
 
         args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
 
-        func, params = _get_model(*args, var_names=iter(inputs),
-                                  has_pad=composite[0],
-                                  has_bias=composite[1],
-                                  has_activation=composite[2])
-        exp_codegen = _get_expected_codegen(*args,
-                                            has_bias=composite[1],
-                                            has_activation=composite[2])
+        func, params = _get_model(
+            *args,
+            var_names=iter(inputs),
+            has_pad=composite[0],
+            has_bias=composite[1],
+            has_activation=composite[2],
+        )
+        exp_codegen = _get_expected_codegen(
+            *args, has_bias=composite[1], has_activation=composite[2]
+        )
         verify_codegen(func, exp_codegen, 1)
 
 
@@ -343,42 +397,55 @@ def test_qnn_conv2d():
     out_channels = [4, 7, 16]
     input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
     # composite operator (pad, bias, activation)
-    composite = [(False, False, False), (False, True, False), (False, False, True),
-                 (False, True, True), (True, False, False)]
+    composite = [
+        (False, False, False),
+        (False, True, False),
+        (False, False, True),
+        (False, True, True),
+        (True, False, False),
+    ]
     dtype = "uint8"
-    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
-                              input_shapes, composite], 3)
+    trials = generate_trials(
+        [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3
+    )
 
     for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
         groups = 1
         shape = (1, *input_shapes)
         outputs = []
-        inputs = {
-            "a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))
-        }
+        inputs = {"a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))}
 
         input_zp = 100
         input_sc = 0.5
         kernel_zp = 25
         kernel_sc = 0.03
-        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
-                                               kernel_zp, kernel_sc,
-                                               kernel_h, kernel_w, shape[3])
-
-        func, params = _get_qnn_model(shape, kernel_h, kernel_w,
-                                      pad, stride, dilation, groups,
-                                      dtype, out_channels,
-                                      input_zp, input_sc,
-                                      kernel_zp, kernel_sc,
-                                      output_zp, output_sc,
-                                      iter(inputs),
-                                      has_pad=composite[0],
-                                      has_bias=composite[1],
-                                      has_activation=composite[2])
+        output_zp, output_sc = _get_qnn_params(
+            input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, shape[3]
+        )
+
+        func, params = _get_qnn_model(
+            shape,
+            kernel_h,
+            kernel_w,
+            pad,
+            stride,
+            dilation,
+            groups,
+            dtype,
+            out_channels,
+            input_zp,
+            input_sc,
+            kernel_zp,
+            kernel_sc,
+            output_zp,
+            output_sc,
+            iter(inputs),
+            has_pad=composite[0],
+            has_bias=composite[1],
+            has_activation=composite[2],
+        )
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1,
-                                         params, device,
-                                         enable_acl=acl)[0])
+            outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
 
         config = {
             "shape": shape,
@@ -394,7 +461,7 @@ def test_qnn_conv2d():
             "kernel scale": kernel_sc,
             "kernel zero point": kernel_zp,
             "output scale": output_sc,
-            "output zero point": output_zp
+            "output zero point": output_zp,
         }
         verify(outputs, atol=1, rtol=0, config=config, verify_saturation=True)
 
@@ -411,11 +478,17 @@ def test_codegen_qnn_conv2d():
     out_channels = [4, 7, 16]
     input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
     # composite operator (pad, bias, activation)
-    composite = [(False, False, False), (False, True, False), (False, False, True),
-                 (False, True, True), (True, False, False)]
+    composite = [
+        (False, False, False),
+        (False, True, False),
+        (False, False, True),
+        (False, True, True),
+        (True, False, False),
+    ]
     dtype = "uint8"
-    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
-                              input_shapes, composite], 3)
+    trials = generate_trials(
+        [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3
+    )
 
     for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
         groups = 1
@@ -426,23 +499,28 @@ def test_codegen_qnn_conv2d():
         input_sc = 0.5
         kernel_zp = 25
         kernel_sc = 0.03
-        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
-                                               kernel_zp, kernel_sc,
-                                               kernel_h, kernel_w, shape[3])
+        output_zp, output_sc = _get_qnn_params(
+            input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, shape[3]
+        )
 
         args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
 
-        func, params = _get_qnn_model(*args,
-                                      input_zp=input_zp, input_sc=input_sc,
-                                      kernel_zp=kernel_zp, kernel_sc=kernel_sc,
-                                      output_zp=output_zp, output_sc=output_sc,
-                                      var_names=iter(inputs),
-                                      has_pad=composite[0],
-                                      has_bias=composite[1],
-                                      has_activation=composite[2])
-        exp_codegen = _get_expected_codegen(*args,
-                                            has_bias=composite[1],
-                                            has_activation=composite[2])
+        func, params = _get_qnn_model(
+            *args,
+            input_zp=input_zp,
+            input_sc=input_sc,
+            kernel_zp=kernel_zp,
+            kernel_sc=kernel_sc,
+            output_zp=output_zp,
+            output_sc=output_sc,
+            var_names=iter(inputs),
+            has_pad=composite[0],
+            has_bias=composite[1],
+            has_activation=composite[2],
+        )
+        exp_codegen = _get_expected_codegen(
+            *args, has_bias=composite[1], has_activation=composite[2]
+        )
         verify_codegen(func, exp_codegen, 1)
 
 
index 45e2eb7..8a3632a 100644 (file)
@@ -21,53 +21,65 @@ import numpy as np
 import tvm
 from tvm import relay
 
-from .infrastructure import Device, skip_runtime_test, skip_codegen_test, \
-    build_and_run, verify, verify_codegen, generate_trials
-
-
-def _get_model(shape, weight_shape, units, dtype, var_names,
-               has_bias=False):
+from .infrastructure import (
+    Device,
+    skip_runtime_test,
+    skip_codegen_test,
+    build_and_run,
+    verify,
+    verify_codegen,
+    generate_trials,
+)
+
+
+def _get_model(shape, weight_shape, units, dtype, var_names, has_bias=False):
     """Return a model and any parameters it may have"""
     a = relay.var(next(var_names), shape=shape, dtype=dtype)
     w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
-    out = relay.nn.dense(
-        a,
-        weights,
-        units=units,
-        out_dtype=dtype
-    )
+    out = relay.nn.dense(a, weights, units=units, out_dtype=dtype)
     params = {"w": w}
     if has_bias:
         b = tvm.nd.array(np.random.randint(-128, 127, weight_shape[0]).astype(dtype))
         biasc = relay.const(b, dtype)
         out = relay.nn.bias_add(out, biasc)
-        params['b'] = b
+        params["b"] = b
     return out, params
 
 
-def _get_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc,
-                    kernel_h, kernel_w):
+def _get_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w):
     """Get output qnn parameters given input and kernel parameters."""
     input_max = input_sc * (255 - input_zp)
-    input_min = - input_sc * input_zp
+    input_min = -input_sc * input_zp
     kernel_max = kernel_sc * (255 - kernel_zp)
-    kernel_min = - kernel_sc * kernel_zp
-    output_limits = [kernel_max * kernel_h * kernel_w * input_max,
-                     kernel_min * kernel_h * kernel_w * input_max,
-                     kernel_min * kernel_h * kernel_w * input_min,
-                     kernel_max * kernel_h * kernel_w * input_min]
+    kernel_min = -kernel_sc * kernel_zp
+    output_limits = [
+        kernel_max * kernel_h * kernel_w * input_max,
+        kernel_min * kernel_h * kernel_w * input_max,
+        kernel_min * kernel_h * kernel_w * input_min,
+        kernel_max * kernel_h * kernel_w * input_min,
+    ]
     output_max = max(output_limits)
     output_min = min(output_limits)
     output_sc = (output_max - output_min) / 255
-    output_zp = - int(output_min / output_sc)
+    output_zp = -int(output_min / output_sc)
     return output_zp, output_sc
 
 
-def _get_qnn_model(shape, weight_shape, units, dtype,
-                   input_zp, input_sc, kernel_zp,
-                   kernel_sc, output_zp, output_sc, var_names,
-                   has_bias=False):
+def _get_qnn_model(
+    shape,
+    weight_shape,
+    units,
+    dtype,
+    input_zp,
+    input_sc,
+    kernel_zp,
+    kernel_sc,
+    output_zp,
+    output_sc,
+    var_names,
+    has_bias=False,
+):
     a = relay.var(next(var_names), shape=shape, dtype=dtype)
     w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
@@ -79,27 +91,26 @@ def _get_qnn_model(shape, weight_shape, units, dtype,
         kernel_zero_point=relay.const(kernel_zp, "int32"),
         input_scale=relay.const(input_sc, "float32"),
         kernel_scale=relay.const(kernel_sc, "float32"),
-        out_dtype="int32"
+        out_dtype="int32",
     )
     params = {"w": w}
     if has_bias:
         b = tvm.nd.array(np.random.randint(0, 255, weight_shape[0]).astype("int32"))
         biasc = relay.const(b, "int32")
         out = relay.nn.bias_add(out, biasc)
-        params['b'] = b
+        params["b"] = b
     out = relay.qnn.op.requantize(
         out,
-        relay.const(input_sc * kernel_sc, 'float32'),  # input scale
-        relay.const(input_zp * kernel_zp, 'int32'),  # input zero point
-        relay.const(output_sc, 'float32'),  # output scale
-        relay.const(output_zp, 'int32'),  # output zero point
-        out_dtype="uint8"
+        relay.const(input_sc * kernel_sc, "float32"),  # input scale
+        relay.const(input_zp * kernel_zp, "int32"),  # input zero point
+        relay.const(output_sc, "float32"),  # output scale
+        relay.const(output_zp, "int32"),  # output zero point
+        out_dtype="uint8",
     )
     return out, params
 
 
-def _get_expected_codegen(shape, weight_shape, units, dtype,
-                          has_bias=False):
+def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False):
     output_shape = (shape[0], units)
     out_dtype = "int32" if dtype == "uint8" else "float32"
 
@@ -112,59 +123,48 @@ def _get_expected_codegen(shape, weight_shape, units, dtype,
             "out_dtype": [[out_dtype]],
             "shape": [[list(output_shape)]],
             "dtype": [[dtype]],
-            "units": [[str(units)]]
-        }
+            "units": [[str(units)]],
+        },
     }
 
-    inputs = [{
-        "op": "input",
-        "name": "",
-        "attrs": {
-            "shape": [[list(shape)]],
-            "dtype": [[str(dtype)]]
-        }}, {
-        "op": "const",
-        "name": "",
-        "attrs": {
-            "shape": [[list(weight_shape)]],
-            "dtype": [[str(dtype)]]
-        }}]
+    inputs = [
+        {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}},
+        {
+            "op": "const",
+            "name": "",
+            "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]},
+        },
+    ]
 
     # qnn.dense params, input and kernel
     if dtype == "uint8":
         node["name"] = "qnn.dense"
         for param_dtype in ["int32", "float32"]:
             for _ in range(2):
-                inputs.append({
-                    "op": "const",
-                    "name": "",
-                    "attrs": {
-                        "shape": [[[]]],
-                        "dtype": [[param_dtype]]
+                inputs.append(
+                    {
+                        "op": "const",
+                        "name": "",
+                        "attrs": {"shape": [[[]]], "dtype": [[param_dtype]]},
                     }
-                })
+                )
 
     if has_bias:
         bias_dtype = "int32" if dtype == "uint8" else "float32"
-        inputs.append({
-            "op": "const",
-            "name": "",
-            "attrs": {
-                "shape": [[[weight_shape[0]]]],
-                "dtype": [[bias_dtype]]}
-        })
+        inputs.append(
+            {
+                "op": "const",
+                "name": "",
+                "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]},
+            }
+        )
 
     # qnn.dense params, output
     if dtype == "uint8":
         for param_dtype in ["float32", "int32"]:
-            inputs.append({
-                "op": "const",
-                "name": "",
-                "attrs": {
-                    "shape": [[[]]],
-                    "dtype": [[param_dtype]]
-                }
-            })
+            inputs.append(
+                {"op": "const", "name": "", "attrs": {"shape": [[[]]], "dtype": [[param_dtype]]}}
+            )
 
     input_idx = 0
     for _ in range(len(inputs)):
@@ -191,21 +191,19 @@ def test_dense():
 
     for dtype, (shape, weight_shape, units), composite in trials:
         outputs = []
-        inputs = {
-            "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))
-        }
-        func, params = _get_model(shape, weight_shape, units, dtype, var_names=iter(inputs),
-                                  has_bias=composite)
+        inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))}
+        func, params = _get_model(
+            shape, weight_shape, units, dtype, var_names=iter(inputs), has_bias=composite
+        )
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1, params,
-                                         device, enable_acl=acl)[0])
+            outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
 
         config = {
             "shape": shape,
             "weight_shape": weight_shape,
             "units": units,
             "dtype": dtype,
-            "composite operators (bias)": composite
+            "composite operators (bias)": composite,
         }
         verify(outputs, atol=0.001, rtol=0.01, config=config)
 
@@ -226,8 +224,7 @@ def test_codegen_dense():
 
         args = (shape, weight_shape, units, dtype)
 
-        func, params = _get_model(*args, var_names=iter(inputs),
-                                  has_bias=composite)
+        func, params = _get_model(*args, var_names=iter(inputs), has_bias=composite)
         exp_codegen = _get_expected_codegen(*args, has_bias=composite)
         verify_codegen(func, exp_codegen, 1)
 
@@ -248,25 +245,32 @@ def test_qnn_dense():
 
     for dtype, (shape, weight_shape, units), composite in trials:
         outputs = []
-        inputs = {
-            "a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))
-        }
+        inputs = {"a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))}
         input_zp = 100
         input_sc = 0.5
         kernel_zp = 50
         kernel_sc = 0.03
-        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
-                                               kernel_zp, kernel_sc,
-                                               weight_shape[0], weight_shape[1])
-
-        func, params = _get_qnn_model(shape, weight_shape, units, dtype,
-                                      input_zp, input_sc, kernel_zp,
-                                      kernel_sc, output_zp, output_sc,
-                                      var_names=iter(inputs), has_bias=composite)
+        output_zp, output_sc = _get_qnn_params(
+            input_zp, input_sc, kernel_zp, kernel_sc, weight_shape[0], weight_shape[1]
+        )
+
+        func, params = _get_qnn_model(
+            shape,
+            weight_shape,
+            units,
+            dtype,
+            input_zp,
+            input_sc,
+            kernel_zp,
+            kernel_sc,
+            output_zp,
+            output_sc,
+            var_names=iter(inputs),
+            has_bias=composite,
+        )
 
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1, params,
-                                         device, enable_acl=acl)[0])
+            outputs.append(build_and_run(func, inputs, 1, params, device, enable_acl=acl)[0])
 
         config = {
             "shape": shape,
@@ -279,7 +283,7 @@ def test_qnn_dense():
             "kernel scale": kernel_sc,
             "kernel zero point": kernel_zp,
             "output scale": output_sc,
-            "output zero point": output_zp
+            "output zero point": output_zp,
         }
         verify(outputs, atol=1, rtol=0, config=config, verify_saturation=True)
 
@@ -303,15 +307,21 @@ def test_codegen_qnn_dense():
         input_sc = 0.5
         kernel_zp = 25
         kernel_sc = 0.03
-        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
-                                               kernel_zp, kernel_sc,
-                                               weight_shape[0], weight_shape[1])
-
-        func, params = _get_qnn_model(*args, var_names=iter(inputs),
-                                      input_zp=input_zp, input_sc=input_sc,
-                                      kernel_zp=kernel_zp, kernel_sc=kernel_sc,
-                                      output_zp=output_zp, output_sc=output_sc,
-                                      has_bias=composite)
+        output_zp, output_sc = _get_qnn_params(
+            input_zp, input_sc, kernel_zp, kernel_sc, weight_shape[0], weight_shape[1]
+        )
+
+        func, params = _get_qnn_model(
+            *args,
+            var_names=iter(inputs),
+            input_zp=input_zp,
+            input_sc=input_sc,
+            kernel_zp=kernel_zp,
+            kernel_sc=kernel_sc,
+            output_zp=output_zp,
+            output_sc=output_sc,
+            has_bias=composite,
+        )
         exp_codegen = _get_expected_codegen(*args, has_bias=composite)
         verify_codegen(func, exp_codegen, 1)
 
index e1bb83b..2526a58 100644 (file)
@@ -38,10 +38,18 @@ def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions,
 
     outputs = []
     for acl in [False, True]:
-        outputs.append(build_and_run(mod, data, 1, params,
-                                     device, enable_acl=acl,
-                                     tvm_ops=tvm_ops,
-                                     acl_partitions=acl_partitions)[0])
+        outputs.append(
+            build_and_run(
+                mod,
+                data,
+                1,
+                params,
+                device,
+                enable_acl=acl,
+                tvm_ops=tvm_ops,
+                acl_partitions=acl_partitions,
+            )[0]
+        )
     verify(outputs, atol=atol, rtol=rtol, verify_saturation=False)
 
 
@@ -49,7 +57,7 @@ def _get_tflite_model(tflite_model_path, inputs_dict):
     """Convert TFlite graph to relay."""
     import tflite.Model
 
-    with open(tflite_model_path, 'rb') as f:
+    with open(tflite_model_path, "rb") as f:
         tflite_model_buffer = f.read()
 
     try:
@@ -63,11 +71,7 @@ def _get_tflite_model(tflite_model_path, inputs_dict):
         shape_dict[input] = input_shape
         dtype_dict[input] = input_dtype
 
-    return relay.frontend.from_tflite(
-        tflite_model,
-        shape_dict=shape_dict,
-        dtype_dict=dtype_dict
-    )
+    return relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict)
 
 
 def _get_keras_model(keras_model, inputs_dict):
@@ -88,15 +92,15 @@ def test_vgg16():
 
     def get_model():
         from keras.applications import VGG16
-        vgg16 = VGG16(include_top=True, weights='imagenet',
-                      input_shape=(224, 224, 3), classes=1000)
+
+        vgg16 = VGG16(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000)
         inputs = {vgg16.input_names[0]: ((1, 224, 224, 3), "float32")}
         mod, params = _get_keras_model(vgg16, inputs)
         return mod, params, inputs
 
-    _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=4, acl_partitions=21,
-                           atol=0.002, rtol=0.01)
+    _build_and_run_network(
+        *get_model(), device=device, tvm_ops=4, acl_partitions=21, atol=0.002, rtol=0.01
+    )
 
 
 def test_mobilenet():
@@ -109,15 +113,17 @@ def test_mobilenet():
 
     def get_model():
         from keras.applications import MobileNet
-        mobilenet = MobileNet(include_top=True, weights='imagenet',
-                              input_shape=(224, 224, 3), classes=1000)
+
+        mobilenet = MobileNet(
+            include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000
+        )
         inputs = {mobilenet.input_names[0]: ((1, 224, 224, 3), "float32")}
         mod, params = _get_keras_model(mobilenet, inputs)
         return mod, params, inputs
 
-    _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=73, acl_partitions=18,
-                           atol=0.002, rtol=0.01)
+    _build_and_run_network(
+        *get_model(), device=device, tvm_ops=73, acl_partitions=18, atol=0.002, rtol=0.01
+    )
 
 
 def test_quantized_mobilenet():
@@ -132,20 +138,17 @@ def test_quantized_mobilenet():
 
     def get_model():
         model_path = tf_testing.get_workload_official(
-            "https://storage.googleapis.com/download.tensorflow.org/" \
+            "https://storage.googleapis.com/download.tensorflow.org/"
             "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
             "mobilenet_v1_1.0_224_quant.tflite",
         )
         inputs = {"input": ((1, 224, 224, 3), "uint8")}
-        mod, params = _get_tflite_model(
-            model_path,
-            inputs_dict=inputs
-        )
+        mod, params = _get_tflite_model(model_path, inputs_dict=inputs)
         return mod, params, inputs
 
-    _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=42, acl_partitions=17,
-                           atol=8, rtol=0)
+    _build_and_run_network(
+        *get_model(), device=device, tvm_ops=42, acl_partitions=17, atol=8, rtol=0
+    )
 
 
 if __name__ == "__main__":
index c104a06..3501717 100644 (file)
@@ -21,8 +21,13 @@ import numpy as np
 import tvm
 from tvm import relay
 
-from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \
-    verify, verify_codegen
+from .infrastructure import (
+    skip_runtime_test,
+    skip_codegen_test,
+    build_and_run,
+    verify,
+    verify_codegen,
+)
 from .infrastructure import Device
 
 
@@ -33,29 +38,48 @@ def _calculate_output_shape(shape, sizes, padding, strides):
     return 1, int(output_height), int(output_width), shape[3]
 
 
-def _get_pooling_model(shape, dtype, typef, sizes, strides, padding,
-                       ceil_mode, count_include_pad, var_names):
+def _get_pooling_model(
+    shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad, var_names
+):
     """Return a model and any parameters it may have."""
     if len(padding) == 2:
         padding = (padding[0], padding[1], padding[0], padding[1])
     out = relay.var(next(var_names), shape=shape, dtype=dtype)
 
     if typef == "nn.max_pool2d":
-        out = relay.nn.max_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
-                                  ceil_mode=ceil_mode, layout="NHWC")
+        out = relay.nn.max_pool2d(
+            out,
+            pool_size=sizes,
+            strides=strides,
+            padding=padding,
+            ceil_mode=ceil_mode,
+            layout="NHWC",
+        )
     elif typef == "nn.avg_pool2d":
         if dtype == "uint8":
-            out = relay.cast(out, 'int32')
-        out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
-                                  ceil_mode=ceil_mode, count_include_pad=count_include_pad,
-                                  layout="NHWC")
+            out = relay.cast(out, "int32")
+        out = relay.nn.avg_pool2d(
+            out,
+            pool_size=sizes,
+            strides=strides,
+            padding=padding,
+            ceil_mode=ceil_mode,
+            count_include_pad=count_include_pad,
+            layout="NHWC",
+        )
         if dtype == "uint8":
-            out = relay.cast(out, 'uint8')
+            out = relay.cast(out, "uint8")
     elif typef == "nn.l2_pool2d":
         out = relay.power(out, relay.const(2.0))
-        out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
-                                  ceil_mode=ceil_mode, count_include_pad=count_include_pad,
-                                  layout="NHWC")
+        out = relay.nn.avg_pool2d(
+            out,
+            pool_size=sizes,
+            strides=strides,
+            padding=padding,
+            ceil_mode=ceil_mode,
+            count_include_pad=count_include_pad,
+            layout="NHWC",
+        )
         out = relay.sqrt(out)
     else:
         raise ValueError("Function not supported")
@@ -71,18 +95,19 @@ def _get_global_pooling_model(shape, dtype, typef, var_names):
         out = relay.nn.global_max_pool2d(out, layout="NHWC")
     elif typef == "nn.global_avg_pool2d":
         if dtype == "uint8":
-            out = relay.cast(out, 'int32')
+            out = relay.cast(out, "int32")
         out = relay.nn.global_avg_pool2d(out, layout="NHWC")
         if dtype == "uint8":
-            out = relay.cast(out, 'uint8')
+            out = relay.cast(out, "uint8")
     else:
         raise ValueError("Function not supported")
 
     return out
 
 
-def _get_expected_pooling_codegen(shape, dtype, typef, sizes, strides,
-                                  padding, ceil_mode, count_include_pad):
+def _get_expected_pooling_codegen(
+    shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad
+):
     if len(padding) == 2:
         padding = (padding[0], padding[1], padding[0], padding[1])
     output_shape = _calculate_output_shape(shape, sizes, padding, strides)
@@ -100,17 +125,14 @@ def _get_expected_pooling_codegen(shape, dtype, typef, sizes, strides,
             "padding": [[str(p) for p in padding]],
             "strides": [[str(s) for s in strides]],
             "pool_size": [[str(s) for s in sizes]],
-            "ceil_mode": [[str(1 if ceil_mode else 0)]]
+            "ceil_mode": [[str(1 if ceil_mode else 0)]],
         },
     }
 
     if typef == "nn.avg_pool2d" or typef == "nn.l2_pool2d":
         node["attrs"]["count_include_pad"] = [["1" if count_include_pad else "0"]]
 
-    input = {
-        "op": "input",
-        "name": "",
-        "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
+    input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
     return [input, node]
 
 
@@ -124,14 +146,11 @@ def _get_expected_global_pooling_codegen(shape, dtype, typef):
             "num_outputs": "1",
             "layout": [["NHWC"]],
             "shape": [[[1, 1, 1, shape[3]]]],
-            "dtype": [[dtype]]
-        }
+            "dtype": [[dtype]],
+        },
     }
 
-    input = {
-        "op": "input",
-        "name": "",
-        "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
+    input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
     return [input, node]
 
 
@@ -147,30 +166,41 @@ def test_pooling():
     fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
     uint8_dtype = ("uint8", 0, 255, 1, 0)
 
-    trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
-              ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
-              ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
-              ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
-              ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]]
-
-    for typef, (dtype, low, high, atol, rtol), size, stride, pad, ceil_mode, count_include_pad, \
-            input_shape in trials:
+    trials = [
+        ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+        ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
+        ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
+        ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
+        ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)],
+    ]
+
+    for (
+        typef,
+        (dtype, low, high, atol, rtol),
+        size,
+        stride,
+        pad,
+        ceil_mode,
+        count_include_pad,
+        input_shape,
+    ) in trials:
         shape = (1, *input_shape)
         outputs = []
         inputs = {
             "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
         }
 
-        func = _get_pooling_model(shape, dtype, typef, size,
-                                  stride, pad, ceil_mode, count_include_pad, iter(inputs))
+        func = _get_pooling_model(
+            shape, dtype, typef, size, stride, pad, ceil_mode, count_include_pad, iter(inputs)
+        )
 
         config = {
             "size": size,
@@ -180,13 +210,14 @@ def test_pooling():
             "dtype": dtype,
             "padding": pad,
             "ceil_mode": ceil_mode,
-            "count_include_pad": count_include_pad
+            "count_include_pad": count_include_pad,
         }
         verify_saturation = True if dtype == "uint8" else False
 
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1, None, device,
-                                         enable_acl=acl, config=config)[0])
+            outputs.append(
+                build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0]
+            )
 
         verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)
 
@@ -203,16 +234,18 @@ def test_global_pooling():
     fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
     uint8_dtype = ("uint8", 0, 255, 1, 0)
 
-    trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
-              ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
-              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]]
+    trials = [
+        ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
+        ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
+        ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+    ]
 
     for typef, (dtype, low, high, atol, rtol), input_shape in trials:
         shape = (1, *input_shape)
@@ -231,8 +264,9 @@ def test_global_pooling():
         verify_saturation = True if dtype == "uint8" else False
 
         for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1, None, device,
-                                         enable_acl=acl, config=config)[0])
+            outputs.append(
+                build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0]
+            )
 
         verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)
 
@@ -244,26 +278,35 @@ def test_codegen_pooling():
     fp32_dtype = ("float32", -127, 128)
     uint8_dtype = ("uint8", 0, 255)
 
-    trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
-              ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
-              ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
-              ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
-              ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
-              ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (15, 15, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
-              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]]
-
-    for typef, (dtype, low, high), size, stride, pad, ceil_mode, count_include_pad, \
-            input_shape in trials:
+    trials = [
+        ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+        ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
+        ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+        ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
+        ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
+        ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (15, 15, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
+        ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)],
+    ]
+
+    for (
+        typef,
+        (dtype, low, high),
+        size,
+        stride,
+        pad,
+        ceil_mode,
+        count_include_pad,
+        input_shape,
+    ) in trials:
         shape = (1, *input_shape)
         inputs = {"a"}
-        args = (shape, dtype, typef, size,
-                stride, pad, False, False)
+        args = (shape, dtype, typef, size, stride, pad, False, False)
         func = _get_pooling_model(*args, iter(inputs))
         exp_codegen = _get_expected_pooling_codegen(*args)
         verify_codegen(func, exp_codegen, 1)
@@ -276,16 +319,18 @@ def test_codegen_global_pooling():
     fp32_dtype = ("float32", -127, 128)
     uint8_dtype = ("uint8", 0, 255)
 
-    trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
-              ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
-              ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
-              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
-              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]]
+    trials = [
+        ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
+        ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
+        ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
+        ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+        ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+    ]
 
     for typef, (dtype, low, high), input_shape in trials:
         shape = (1, *input_shape)
index b6a8754..9547aef 100644 (file)
@@ -21,8 +21,13 @@ import numpy as np
 import tvm
 from tvm import relay
 
-from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \
-    verify, verify_codegen
+from .infrastructure import (
+    skip_runtime_test,
+    skip_codegen_test,
+    build_and_run,
+    verify,
+    verify_codegen,
+)
 from .infrastructure import Device
 
 
@@ -44,14 +49,15 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
             "newshape": [[str(s) for s in output_shape]],
             "shape": [[list(output_shape)]],
             "dtype": [[dtype]],
-            "reverse": [["0"]]
+            "reverse": [["0"]],
         },
     }
 
     input = {
         "op": "input",
         "name": "",
-        "attrs": {"shape": [[list(input_shape)]], "dtype": [[dtype]]}}
+        "attrs": {"shape": [[list(input_shape)]], "dtype": [[dtype]]},
+    }
 
     return [input, node]
 
@@ -65,18 +71,17 @@ def test_reshape():
     device = Device()
     np.random.seed(0)
 
-    for dtype, low, high, atol, rtol in [("float32", -127, 128, 0.001, 0.001), ("uint8", 0, 255, 0, 0)]:
-        inputs = {
-            "a": tvm.nd.array(
-                np.random.uniform(low, high, (1, 1, 1, 1000)).astype(dtype))
-        }
+    for dtype, low, high, atol, rtol in [
+        ("float32", -127, 128, 0.001, 0.001),
+        ("uint8", 0, 255, 0, 0),
+    ]:
+        inputs = {"a": tvm.nd.array(np.random.uniform(low, high, (1, 1, 1, 1000)).astype(dtype))}
 
         for new_shape in [(1, 1000), (10, 10, 10)]:
             outputs = []
             func = _get_model(inputs["a"].shape, new_shape, dtype, iter(inputs))
             for acl in [False, True]:
-                outputs.append(build_and_run(func, inputs, 1, None, device,
-                                             enable_acl=acl)[0])
+                outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0])
 
             config = {
                 "new shape": inputs["a"].shape,
index 1ce2909..316dfad 100644 (file)
@@ -46,15 +46,14 @@ def test_multiple_ops():
         out = relay.reshape(out, (1, 1000))
         return out
 
-    inputs = {
-        "a": tvm.nd.array(np.random.uniform(0, 1, (1, 1, 1, 1000)).astype("float32"))
-    }
+    inputs = {"a": tvm.nd.array(np.random.uniform(0, 1, (1, 1, 1, 1000)).astype("float32"))}
 
     outputs = []
     for acl in [False, True]:
         func = get_model(inputs["a"].shape, iter(inputs))
-        outputs.append(build_and_run(func, inputs, 1, None, device,
-                                     enable_acl=acl, acl_partitions=2)[0])
+        outputs.append(
+            build_and_run(func, inputs, 1, None, device, enable_acl=acl, acl_partitions=2)[0]
+        )
     verify(outputs, atol=0.002, rtol=0.01)
 
 
@@ -79,16 +78,16 @@ def test_heterogeneous():
         out = relay.reshape(out, (1, 1000))
         return out
 
-    inputs = {
-        "a": tvm.nd.array(np.random.uniform(-127, 128, (1, 1, 1, 1000)).astype("float32"))
-    }
+    inputs = {"a": tvm.nd.array(np.random.uniform(-127, 128, (1, 1, 1, 1000)).astype("float32"))}
 
     outputs = []
     for acl in [False, True]:
         func = get_model(inputs["a"].shape, iter(inputs))
-        outputs.append(build_and_run(func, inputs, 1, None, device,
-                                     enable_acl=acl, tvm_ops=1,
-                                     acl_partitions=2)[0])
+        outputs.append(
+            build_and_run(
+                func, inputs, 1, None, device, enable_acl=acl, tvm_ops=1, acl_partitions=2
+            )[0]
+        )
     verify(outputs, atol=0.002, rtol=0.01)
 
 
@@ -115,7 +114,7 @@ def test_multiple_runs():
             kernel_layout="OHWI",
             strides=(1, 1),
             padding=(0, 0),
-            dilation=(1, 1)
+            dilation=(1, 1),
         )
         params = {"w": w}
         return conv, params
@@ -125,10 +124,7 @@ def test_multiple_runs():
     }
 
     func, params = get_model()
-    outputs = build_and_run(func, inputs, 1,
-                  params, device,
-                  enable_acl=True,
-                  no_runs=3)
+    outputs = build_and_run(func, inputs, 1, params, device, enable_acl=True, no_runs=3)
     verify(outputs, atol=0.002, rtol=0.01)
 
 
index 3aa0583..83b220f 100644 (file)
@@ -32,6 +32,7 @@ from tvm.contrib.binutil import *
 
 TOOLCHAIN_PREFIX = ""
 
+
 def make_binary():
     prog = "int a = 7; \
             int main() { \
@@ -43,8 +44,7 @@ def make_binary():
     tmp_obj = tmp_dir.relpath("obj.obj")
     with open(tmp_source, "w") as f:
         f.write(prog)
-    cc.create_executable(tmp_obj, tmp_source, [],
-                     cc="{}gcc".format(TOOLCHAIN_PREFIX))
+    cc.create_executable(tmp_obj, tmp_source, [], cc="{}gcc".format(TOOLCHAIN_PREFIX))
     prog_bin = bytearray(open(tmp_obj, "rb").read())
     return prog_bin
 
@@ -56,14 +56,21 @@ def test_tvm_callback_get_section_size(binary=None):
     tmp_bin = tmp_dir.relpath("obj.bin")
     with open(tmp_bin, "wb") as f:
         f.write(binary)
+
     def verify():
-        print("Text section size: %d" %
-              tvm_callback_get_section_size(tmp_bin, "text", TOOLCHAIN_PREFIX))
-        print("Data section size: %d" %
-              tvm_callback_get_section_size(tmp_bin, "data", TOOLCHAIN_PREFIX))
-        print("Bss section size: %d" %
-              tvm_callback_get_section_size(tmp_bin, "bss", TOOLCHAIN_PREFIX))
+        print(
+            "Text section size: %d"
+            % tvm_callback_get_section_size(tmp_bin, "text", TOOLCHAIN_PREFIX)
+        )
+        print(
+            "Data section size: %d"
+            % tvm_callback_get_section_size(tmp_bin, "data", TOOLCHAIN_PREFIX)
+        )
+        print(
+            "Bss section size: %d" % tvm_callback_get_section_size(tmp_bin, "bss", TOOLCHAIN_PREFIX)
+        )
         print()
+
     verify()
 
 
@@ -73,6 +80,7 @@ def test_tvm_callback_relocate_binary():
     tmp_bin = tmp_dir.relpath("obj.bin")
     with open(tmp_bin, "wb") as f:
         f.write(binary)
+
     def verify():
         word_size = 8
         text_loc = 0x0
@@ -81,40 +89,36 @@ def test_tvm_callback_relocate_binary():
         bss_loc = 0x30000
         stack_end = 0x50000
         rel_bin = tvm_callback_relocate_binary(
-            tmp_bin,
-            word_size,
-            text_loc,
-            rodata_loc,
-            data_loc,
-            bss_loc,
-            stack_end,
-            TOOLCHAIN_PREFIX)
+            tmp_bin, word_size, text_loc, rodata_loc, data_loc, bss_loc, stack_end, TOOLCHAIN_PREFIX
+        )
         print("Relocated binary section sizes")
         test_tvm_callback_get_section_size(binary=rel_bin)
         relf = tmp_dir.relpath("rel.bin")
         with open(relf, "wb") as f:
             f.write(rel_bin)
-        nm_proc = subprocess.Popen(["nm", "-C", "--defined-only", relf],
-                                   stdout=subprocess.PIPE,
-                                   stderr=subprocess.STDOUT)
+        nm_proc = subprocess.Popen(
+            ["nm", "-C", "--defined-only", relf], stdout=subprocess.PIPE, stderr=subprocess.STDOUT
+        )
         (out, _) = nm_proc.communicate()
         symbol_entries = out.decode("utf-8").split("\n")
         for entry in symbol_entries:
             if len(entry) == 0:
                 continue
-            sym_loc, section, sym_name = entry.split(' ')
+            sym_loc, section, sym_name = entry.split(" ")
             sym_loc = int(sym_loc, 16)
-            if section == 'T':  # text
+            if section == "T":  # text
                 assert sym_loc >= text_loc and sym_loc < data_loc
-            elif section == 'D':  # data
+            elif section == "D":  # data
                 assert sym_loc >= data_loc and sym_loc < bss_loc
-            elif section == 'B':  # bss
+            elif section == "B":  # bss
                 assert sym_loc >= bss_loc
+
     verify()
 
 
 def test_tvm_callback_read_binary_section():
     binary = make_binary()
+
     def verify():
         text_bin = tvm_callback_read_binary_section(binary, "text", TOOLCHAIN_PREFIX)
         data_bin = tvm_callback_read_binary_section(binary, "data", TOOLCHAIN_PREFIX)
@@ -123,6 +127,7 @@ def test_tvm_callback_read_binary_section():
         print("Read data section part of binary? %r" % (data_bin in binary))
         print("Read bss section part of binary? %r" % (bss_bin in binary))
         print()
+
     verify()
 
 
@@ -132,6 +137,7 @@ def test_tvm_callback_get_symbol_map():
     tmp_bin = tmp_dir.relpath("obj.bin")
     with open(tmp_bin, "wb") as f:
         f.write(binary)
+
     def verify():
         word_size = 8
         text_loc = 0x0
@@ -140,22 +146,17 @@ def test_tvm_callback_get_symbol_map():
         bss_loc = 0x30000
         stack_end = 0x50000
         rel_bin = tvm_callback_relocate_binary(
-            tmp_bin,
-            word_size,
-            text_loc,
-            rodata_loc,
-            data_loc,
-            bss_loc,
-            stack_end,
-            TOOLCHAIN_PREFIX)
+            tmp_bin, word_size, text_loc, rodata_loc, data_loc, bss_loc, stack_end, TOOLCHAIN_PREFIX
+        )
         symbol_map = tvm_callback_get_symbol_map(rel_bin, TOOLCHAIN_PREFIX)
         symbols = set()
-        for i, line in enumerate(symbol_map.split('\n')):
+        for i, line in enumerate(symbol_map.split("\n")):
             # Every other line is the value the symbol maps to.
             if i % 2 == 0:
                 symbols.add(line)
         assert "a" in symbols
         assert "main" in symbols
+
     verify()
 
 
index 7247ab7..dd9f777 100644 (file)
@@ -24,14 +24,15 @@ from tvm.contrib import mkl
 from tvm.contrib import mkldnn
 import tvm.testing
 
+
 def verify_matmul_add(m, l, n, lib, transa=False, transb=False, dtype="float32"):
-    bias = te.var('bias', dtype=dtype)
+    bias = te.var("bias", dtype=dtype)
     ashape = (l, n) if transa else (n, l)
     bshape = (m, l) if transb else (l, m)
-    A = te.placeholder(ashape, name='A', dtype=dtype)
-    B = te.placeholder(bshape, name='B', dtype=dtype)
+    A = te.placeholder(ashape, name="A", dtype=dtype)
+    B = te.placeholder(bshape, name="B", dtype=dtype)
     C = lib.matmul(A, B, transa, transb)
-    D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
+    D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name="D")
     s = te.create_schedule(D.op)
 
     def get_numpy(a, b, bb, transa, transb):
@@ -56,9 +57,12 @@ def verify_matmul_add(m, l, n, lib, transa=False, transb=False, dtype="float32")
         bb = 10.0
         f(a, b, d, bb)
         tvm.testing.assert_allclose(
-            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
+            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5
+        )
+
     verify()
 
+
 def test_matmul_add():
     verify_matmul_add(235, 128, 1024, cblas)
     verify_matmul_add(235, 128, 1024, cblas, True, False)
@@ -85,19 +89,20 @@ def test_matmul_add():
     verify_matmul_add(1, 16, 3, mkldnn, False, False)
     verify_matmul_add(1, 16, 3, mkldnn, True, True)
 
+
 def verify_quantized_matmul_add(m, l, n, transa=False, transb=False):
     if not tvm.get_global_func("tvm.contrib.mkl.matmul_u8s8s32", True):
         pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas")
     data_dtype = "uint8"
     kernel_dtype = "int8"
     out_dtype = "int32"
-    bias = te.var('bias', dtype=out_dtype)
+    bias = te.var("bias", dtype=out_dtype)
     ashape = (l, n) if transa else (n, l)
     bshape = (m, l) if transb else (l, m)
-    A = te.placeholder(ashape, name='A', dtype=data_dtype)
-    B = te.placeholder(bshape, name='B', dtype=kernel_dtype)
+    A = te.placeholder(ashape, name="A", dtype=data_dtype)
+    B = te.placeholder(bshape, name="B", dtype=kernel_dtype)
     C = mkl.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype)
-    D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
+    D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name="D")
     s = te.create_schedule(D.op)
 
     def get_numpy(a, b, bb, transa, transb):
@@ -122,11 +127,14 @@ def verify_quantized_matmul_add(m, l, n, transa=False, transb=False):
         bb = 10
         f(a, b, d, bb)
         tvm.testing.assert_allclose(
-                d.asnumpy(),
-                get_numpy(a.asnumpy().astype('int32'), b.asnumpy().astype('int32'), bb, transa, transb),
-                rtol=1e-5)
+            d.asnumpy(),
+            get_numpy(a.asnumpy().astype("int32"), b.asnumpy().astype("int32"), bb, transa, transb),
+            rtol=1e-5,
+        )
+
     verify()
 
+
 def test_quantized_matmul_add():
     verify_quantized_matmul_add(235, 128, 1024)
     verify_quantized_matmul_add(235, 128, 1024, True, False)
@@ -137,13 +145,16 @@ def test_quantized_matmul_add():
     verify_quantized_matmul_add(1, 16, 3, False, True)
     verify_quantized_matmul_add(1, 16, 3, True, True)
 
-def verify_batch_matmul(batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32"):
+
+def verify_batch_matmul(
+    batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32"
+):
     ashape = (batch, l, n) if transa else (batch, n, l)
     bshape = (batch, m, l) if transb else (batch, l, m)
-    A = te.placeholder(ashape, name='A', dtype=dtype)
-    B = te.placeholder(bshape, name='B', dtype=dtype)
+    A = te.placeholder(ashape, name="A", dtype=dtype)
+    B = te.placeholder(bshape, name="B", dtype=dtype)
     C = cblas.batch_matmul(A, B, transa, transb)
-    D = te.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
+    D = te.compute(C.shape, lambda k, i, j: C[k, i, j], name="D")
     s = te.create_schedule(D.op)
 
     def get_numpy(a, b, transa, transb):
@@ -167,9 +178,12 @@ def verify_batch_matmul(batch, m, l, n, lib, transa=False, transb=False, iterati
         d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
         f(a, b, d)
         tvm.testing.assert_allclose(
-            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
+            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5
+        )
+
     verify()
 
+
 def test_batch_matmul():
     verify_batch_matmul(16, 235, 128, 1024, cblas)
     verify_batch_matmul(16, 235, 128, 1024, cblas, True, False)
@@ -190,6 +204,7 @@ def test_batch_matmul():
     verify_batch_matmul(1, 1, 16, 3, mkl, True, True)
     verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True)
 
+
 if __name__ == "__main__":
     test_matmul_add()
     test_quantized_matmul_add()
index 300239c..cd51668 100644 (file)
@@ -40,8 +40,8 @@ def _create_graph():
     shape = (10, 10)
     mod = tvm.IRModule()
 
-    x = relay.var('x', shape=shape)
-    y = relay.var('y', shape=shape)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
     z = x + x
     p = y * y
     func = relay.Function([x, y], p - z)
@@ -78,8 +78,8 @@ def _create_graph_annotated():
     mod[gv2] = func2
 
     # body
-    x = relay.var('x', shape=shape)
-    y = relay.var('y', shape=shape)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
     func = relay.Function([x, y], gv0(y) - gv2(x))
     mod["main"] = func
 
@@ -97,17 +97,17 @@ def test_annotate():
 
 @pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available")
 def test_compile_and_run():
-    ctx=tvm.cpu()
-    target="llvm"
-    tol=1e-3
+    ctx = tvm.cpu()
+    target = "llvm"
+    tol = 1e-3
 
     with relay.build_config(opt_level=3):
         json, lib, params = relay.build(_create_graph_annotated(), target=target)
     m = tvm.contrib.graph_runtime.create(json, lib, ctx)
 
     shape = (10, 10)
-    x_data = np.random.rand(*shape).astype('float32')
-    y_data = np.random.rand(*shape).astype('float32')
+    x_data = np.random.rand(*shape).astype("float32")
+    y_data = np.random.rand(*shape).astype("float32")
 
     m.set_input("x", x_data)
     m.set_input("y", y_data)
@@ -120,8 +120,8 @@ def test_compile_and_run():
     tvm.testing.assert_allclose(out.asnumpy(), expected, rtol=tol, atol=tol)
 
 
-@mock.patch('tvm.contrib.coreml_runtime.create')
-@mock.patch('tvm.contrib.xcode.compile_coreml')
+@mock.patch("tvm.contrib.coreml_runtime.create")
+@mock.patch("tvm.contrib.xcode.compile_coreml")
 def _construct_model(func, m1, m2):
     mod = tvm.IRModule()
     mod["main"] = func
@@ -131,14 +131,13 @@ def _construct_model(func, m1, m2):
     fcompile = tvm._ffi.get_global_func("relay.ext.coremlcompiler")
 
     for var, func in mod.functions.items():
-        if func.attrs and 'Compiler' in func.attrs and \
-           func.attrs['Compiler'] == 'coremlcompiler':
+        if func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "coremlcompiler":
             fcompile(func)
 
 
 def test_add():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = x + x
     func = relay.Function([x], y)
     _construct_model(func)
@@ -146,7 +145,7 @@ def test_add():
 
 def test_multiply():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = x * x
     func = relay.Function([x], y)
     _construct_model(func)
@@ -154,7 +153,7 @@ def test_multiply():
 
 def test_clip():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.clip(x, a_min=0.0, a_max=1.0)
     func = relay.Function([x], y)
     _construct_model(func)
@@ -162,7 +161,7 @@ def test_clip():
 
 def test_batch_flatten():
     shape = (10, 10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.nn.batch_flatten(x)
     func = relay.Function([x], y)
     _construct_model(func)
@@ -170,7 +169,7 @@ def test_batch_flatten():
 
 def test_expand_dims():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.expand_dims(x, axis=0)
     func = relay.Function([x], y)
     _construct_model(func)
@@ -182,7 +181,7 @@ def test_expand_dims():
 
 def test_relu():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.nn.relu(x)
     func = relay.Function([x], y)
     _construct_model(func)
@@ -190,15 +189,15 @@ def test_relu():
 
 def test_softmax():
     shape = (10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.nn.softmax(x, axis=1)
     func = relay.Function([x], y)
     _construct_model(func)
 
 
 def test_conv2d():
-    x = relay.var('x', shape=(1,3,224,224))
-    w = relay.const(np.zeros((16,3,3,3), dtype='float32'))
+    x = relay.var("x", shape=(1, 3, 224, 224))
+    w = relay.const(np.zeros((16, 3, 3, 3), dtype="float32"))
     y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
     func = relay.Function([x], y)
     _construct_model(func)
@@ -206,7 +205,7 @@ def test_conv2d():
 
 def test_global_avg_pool2d():
     shape = (10, 10, 10, 10)
-    x = relay.var('x', shape=shape)
+    x = relay.var("x", shape=shape)
     y = relay.nn.global_avg_pool2d(x)
     func = relay.Function([x], y)
     _construct_model(func)
index 1d99fe9..f6b9d9e 100644 (file)
@@ -28,7 +28,8 @@ proxy_port = os.environ.get("TVM_IOS_RPC_PROXY_PORT", 9090)
 destination = os.environ.get("TVM_IOS_RPC_DESTINATION", "")
 key = "iphone"
 
-@pytest.mark.skip('skip because coremltools is not available in CI')
+
+@pytest.mark.skip("skip because coremltools is not available in CI")
 def test_coreml_runtime():
 
     import coremltools
@@ -39,23 +40,20 @@ def test_coreml_runtime():
         alpha = 2
 
         inputs = [
-            ('input0', coremltools.models.datatypes.Array(*shape)),
-            ('input1', coremltools.models.datatypes.Array(*shape))
+            ("input0", coremltools.models.datatypes.Array(*shape)),
+            ("input1", coremltools.models.datatypes.Array(*shape)),
         ]
         outputs = [
-            ('output0', coremltools.models.datatypes.Array(*shape)),
-            ('output1', coremltools.models.datatypes.Array(*shape)),
+            ("output0", coremltools.models.datatypes.Array(*shape)),
+            ("output1", coremltools.models.datatypes.Array(*shape)),
         ]
         builder = NeuralNetworkBuilder(inputs, outputs)
-        builder.add_elementwise(name='Add',
-                                input_names=['input0', 'input1'],
-                                output_name='output0',
-                                mode='ADD')
-        builder.add_elementwise(name='Mul',
-                                alpha=alpha,
-                                input_names=['input0'],
-                                output_name='output1',
-                                mode='MULTIPLY')
+        builder.add_elementwise(
+            name="Add", input_names=["input0", "input1"], output_name="output0", mode="ADD"
+        )
+        builder.add_elementwise(
+            name="Mul", alpha=alpha, input_names=["input0"], output_name="output1", mode="MULTIPLY"
+        )
         return coremltools.models.MLModel(builder.spec)
 
     def verify(coreml_model, model_path, ctx):
@@ -74,7 +72,7 @@ def test_coreml_runtime():
         coreml_outputs = [coreml_model.predict(inputs)[name] for name in out_names]
 
         # inference via tvm coreml runtime
-        runtime = coreml_runtime.create('main', model_path, ctx)
+        runtime = coreml_runtime.create("main", model_path, ctx)
         for name in inputs:
             runtime.set_input(name, tvm.nd.array(inputs[name], ctx))
         runtime.invoke()
@@ -86,8 +84,9 @@ def test_coreml_runtime():
     def check_remote(coreml_model):
         temp = util.tempdir()
         compiled_model = xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir)
-        xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination,
-                             libs=[compiled_model])
+        xcode.popen_test_rpc(
+            proxy_host, proxy_port, key, destination=destination, libs=[compiled_model]
+        )
         compiled_model = os.path.basename(compiled_model)
         remote = rpc.connect(proxy_host, proxy_port, key=key)
         ctx = remote.cpu(0)
index f387f35..175a747 100644 (file)
@@ -21,12 +21,13 @@ from tvm.contrib import cublas
 from tvm.contrib import cublaslt
 import tvm.testing
 
+
 def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
     n = 1024
     l = 128
     m = 236
-    A = te.placeholder((n, l), name='A', dtype=in_dtype)
-    B = te.placeholder((l, m), name='B', dtype=in_dtype)
+    A = te.placeholder((n, l), name="A", dtype=in_dtype)
+    B = te.placeholder((l, m), name="B", dtype=in_dtype)
     C = cublas.matmul(A, B, dtype=out_dtype)
     s = te.create_schedule(C.op)
 
@@ -41,12 +42,16 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
         c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
         f(a, b, c)
         tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
+            c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol
+        )
+
     verify()
 
+
 def roundoff(v, d):
     return int(np.floor((v + d - 1) / d) * d)
 
+
 def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
     n = 1024
     l = 1024
@@ -55,8 +60,8 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
     N = roundoff(n, 8)
     N_out = roundoff(n, 32)
 
-    A = te.placeholder((N, L), name='A', dtype=in_dtype)
-    B = te.placeholder((m, L), name='B', dtype=in_dtype)
+    A = te.placeholder((N, L), name="A", dtype=in_dtype)
+    B = te.placeholder((m, L), name="B", dtype=in_dtype)
     # C has CUBLASLT_ORDER_COL32 layout, thus a different shape
     C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype)
     s = te.create_schedule(C.op)
@@ -71,18 +76,23 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
         b_old = np.random.uniform(0, 128, size=(l, m))
 
         # Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout
-        a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l])))
-        a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L])))
+        a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L - l])))
+        a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N - n, L])))
         a_even = np.vsplit(a_new[::2], N / 8)
         a_odd = np.vsplit(a_new[1::2], N / 8)
-        a_new = [None]*(len(a_even) + len(a_odd))
+        a_new = [None] * (len(a_even) + len(a_odd))
         a_new[::2] = a_even
         a_new[1::2] = a_odd
         a_new = np.vstack(a_new)
-        a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32))
+        a_new = np.vstack(
+            np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N / 4))
+            for j in np.hsplit(a_new, L / 32)
+        )
         a_new = a_new.reshape([N, L])
         # Transform b to become CUBLASLT_ORDER_COL32 layout
-        b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32))
+        b_new = np.vstack(
+            np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32)
+        )
         b_new = b_new.reshape([m, L])
 
         a = tvm.nd.array(a_new.astype(A.dtype), ctx)
@@ -96,16 +106,19 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
         c_out = c_out[:, :n]
         c_out = c_out.T
         tvm.testing.assert_allclose(
-            c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol)
+            c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol
+        )
+
     verify()
 
+
 def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
     j = 16
     n = 1024
     l = 128
     m = 236
-    A = te.placeholder((j, n, l), name='A', dtype=in_dtype)
-    B = te.placeholder((j, l, m), name='B', dtype=in_dtype)
+    A = te.placeholder((j, n, l), name="A", dtype=in_dtype)
+    B = te.placeholder((j, l, m), name="B", dtype=in_dtype)
     C = cublas.batch_matmul(A, B, dtype=out_dtype)
     s = te.create_schedule(C.op)
 
@@ -120,29 +133,35 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
         c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
         f(a, b, c)
         tvm.testing.assert_allclose(
-            c.asnumpy(), np.matmul(a.asnumpy().astype(C.dtype),
-                                   b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol)
+            c.asnumpy(),
+            np.matmul(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)).astype(C.dtype),
+            rtol=rtol,
+        )
+
     verify()
 
+
 @tvm.testing.requires_cuda
 def test_matmul_add():
-    verify_matmul_add('float', 'float', rtol=1e-3)
-    verify_matmul_add('float16', 'float')
-    verify_matmul_add('float16', 'float16', rtol=1e-2)
-    verify_matmul_add('int8', 'int32')
+    verify_matmul_add("float", "float", rtol=1e-3)
+    verify_matmul_add("float16", "float")
+    verify_matmul_add("float16", "float16", rtol=1e-2)
+    verify_matmul_add("int8", "int32")
+
 
 @tvm.testing.requires_cuda
 def test_matmul_add_igemm():
-    verify_matmul_add_igemm('int8', 'int32')
+    verify_matmul_add_igemm("int8", "int32")
+
 
 @tvm.testing.requires_cuda
 def test_batch_matmul():
-    verify_batch_matmul('float', 'float')
-    verify_batch_matmul('float16', 'float')
-    verify_batch_matmul('float16', 'float16', rtol=1e-2)
+    verify_batch_matmul("float", "float")
+    verify_batch_matmul("float16", "float")
+    verify_batch_matmul("float16", "float16", rtol=1e-2)
+
 
 if __name__ == "__main__":
     test_matmul_add()
     test_batch_matmul()
     test_matmul_add_igemm()
-
index 5777c3b..b07f2b2 100644 (file)
@@ -22,6 +22,7 @@ import numpy as np
 import tvm.topi.testing
 import tvm.testing
 
+
 def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     in_channel = 4
     out_channel = 16
@@ -52,18 +53,20 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
         xshape = [batch, height, width, in_channel]
         wshape = [out_channel, filter_h, filter_w, in_channel // groups]
 
-    X = te.placeholder(xshape, name='X', dtype=data_dtype)
-    W = te.placeholder(wshape, name='W', dtype=data_dtype)
-    Y = cudnn.conv_forward(X,
-                           W,
-                           [pad_h, pad_w],
-                           [stride_h, stride_w],
-                           [dilation_h, dilation_w],
-                           conv_mode=1,
-                           tensor_format=tensor_format,
-                           conv_dtype=conv_dtype,
-                           algo=-1,
-                           groups=groups)
+    X = te.placeholder(xshape, name="X", dtype=data_dtype)
+    W = te.placeholder(wshape, name="W", dtype=data_dtype)
+    Y = cudnn.conv_forward(
+        X,
+        W,
+        [pad_h, pad_w],
+        [stride_h, stride_w],
+        [dilation_h, dilation_w],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        conv_dtype=conv_dtype,
+        algo=-1,
+        groups=groups,
+    )
     yshape = [x.value for x in Y.shape]
     s = te.create_schedule(Y.op)
 
@@ -79,12 +82,13 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     if tensor_format == 0:
         c_np = tvm.topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1, groups=groups)
     elif tensor_format == 1:
-        wt = w_np.transpose((1, 2, 3, 0))  #OHWI => HWIO
+        wt = w_np.transpose((1, 2, 3, 0))  # OHWI => HWIO
         c_np = tvm.topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1, groups=groups)
 
     f(x, w, y)
     tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2)
 
+
 @tvm.testing.requires_gpu
 def test_conv2d():
     verify_conv2d("float32", "float32", tensor_format=0)
@@ -97,6 +101,7 @@ def test_conv2d():
     verify_conv2d("float16", "float16", tensor_format=0, groups=2)
     verify_conv2d("int8", "int32", tensor_format=1, groups=2)
 
+
 def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     in_channel = 4
     out_channel = 16
@@ -125,18 +130,20 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     xshape = [batch, in_channel, depth, height, width]
     wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w]
 
-    X = te.placeholder(xshape, name='X', dtype=data_dtype)
-    W = te.placeholder(wshape, name='W', dtype=data_dtype)
-    Y = cudnn.conv_forward(X,
-                           W,
-                           [pad_d, pad_h, pad_w],
-                           [stride_d, stride_h, stride_w],
-                           [dilation_d, dilation_h, dilation_w],
-                           conv_mode=1,
-                           tensor_format=tensor_format,
-                           algo=-1,
-                           conv_dtype=conv_dtype,
-                           groups=groups)
+    X = te.placeholder(xshape, name="X", dtype=data_dtype)
+    W = te.placeholder(wshape, name="W", dtype=data_dtype)
+    Y = cudnn.conv_forward(
+        X,
+        W,
+        [pad_d, pad_h, pad_w],
+        [stride_d, stride_h, stride_w],
+        [dilation_d, dilation_h, dilation_w],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        algo=-1,
+        conv_dtype=conv_dtype,
+        groups=groups,
+    )
     yshape = [x.value for x in Y.shape]
     s = te.create_schedule(Y.op)
 
@@ -157,13 +164,15 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     f(x, w, y)
     tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4)
 
+
 @tvm.testing.requires_gpu
 def test_conv3d():
     verify_conv3d("float32", "float32", tensor_format=0)
     verify_conv3d("float32", "float32", tensor_format=0, groups=2)
 
+
 def verify_softmax(shape, axis, dtype="float32"):
-    A = te.placeholder(shape, dtype=dtype, name='A')
+    A = te.placeholder(shape, dtype=dtype, name="A")
     B = cudnn.softmax(A, axis)
     s = te.create_schedule([B.op])
 
@@ -176,15 +185,16 @@ def verify_softmax(shape, axis, dtype="float32"):
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
 
+
 def verify_softmax_4d(shape, dtype="float32"):
-    A = te.placeholder(shape, dtype=dtype, name='A')
+    A = te.placeholder(shape, dtype=dtype, name="A")
     B = cudnn.softmax(A, axis=1)
     s = te.create_schedule([B.op])
 
     ctx = tvm.gpu(0)
     n, c, h, w = shape
     a_np = np.random.uniform(size=shape).astype(dtype)
-    b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
+    b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
     b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
     a = tvm.nd.array(a_np, ctx)
     b = tvm.nd.array(b_np, ctx)
@@ -192,6 +202,7 @@ def verify_softmax_4d(shape, dtype="float32"):
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
 
+
 @tvm.testing.requires_gpu
 def test_softmax():
     if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
@@ -204,6 +215,7 @@ def test_softmax():
     verify_softmax_4d((1, 16, 256, 256))
     verify_softmax_4d((1, 16, 256, 256), "float64")
 
+
 if __name__ == "__main__":
     test_conv2d()
     test_conv3d()
index 453556c..661e284 100644 (file)
@@ -19,6 +19,7 @@ from tvm import te
 import numpy as np
 from tvm.contrib.dlpack import to_pytorch_func
 
+
 def test():
     a = np.random.randn(1337)
     tvm_a = tvm.nd.array(a)
@@ -33,23 +34,25 @@ def test():
         np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
         y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
         np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
-        np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy())
+        np.testing.assert_equal(
+            torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()
+        )
 
         n = tvm.runtime.convert(137)
-        xx = torch.rand(137,137)
-        yy = torch.rand(137,137)
-        zz2 = torch.empty(137,137)
+        xx = torch.rand(137, 137)
+        yy = torch.rand(137, 137)
+        zz2 = torch.empty(137, 137)
         zz = xx.mm(yy)
-        XX = te.placeholder((n,n), name='X')
-        YY = te.placeholder((n,n), name='Y')
+        XX = te.placeholder((n, n), name="X")
+        YY = te.placeholder((n, n), name="Y")
 
-        k = te.reduce_axis((0, n), name='k')
-        ZZ = te.compute((n,n), lambda i,j : te.sum(XX[i,k]*YY[k,j], axis=k))
+        k = te.reduce_axis((0, n), name="k")
+        ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k))
         s = te.create_schedule(ZZ.op)
-        f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f')
+        f = tvm.build(s, [XX, YY, ZZ], target_host="llvm", name="f")
 
         f_pytorch = to_pytorch_func(f)
-        zz2 = torch.empty(137,137)
+        zz2 = torch.empty(137, 137)
         f_pytorch(xx, yy, zz2)
         tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6)
 
@@ -57,5 +60,5 @@ def test():
         pass
 
 
-if __name__ ==  '__main__':
+if __name__ == "__main__":
     test()
index 625dc94..eef9e0b 100644 (file)
@@ -20,17 +20,19 @@ from tvm import te
 import numpy as np
 from tvm import rpc
 from tvm.contrib import util, tflite_runtime
+
 # import tflite_runtime.interpreter as tflite
 
 
 def skipped_test_tflite_runtime():
-
     def get_tflite_model_path(target_edgetpu):
         # Return a path to the model
-        edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
+        edgetpu_path = os.getenv("EDGETPU_PATH", "/home/mendel/edgetpu")
         # Obtain mobilenet model from the edgetpu repo path
         if target_edgetpu:
-            model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite")
+            model_path = os.path.join(
+                edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite"
+            )
         else:
             model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant.tflite")
         return model_path
@@ -38,11 +40,11 @@ def skipped_test_tflite_runtime():
     def init_interpreter(model_path, target_edgetpu):
         # Initialize interpreter
         if target_edgetpu:
-            edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
+            edgetpu_path = os.getenv("EDGETPU_PATH", "/home/mendel/edgetpu")
             libedgetpu = os.path.join(edgetpu_path, "libedgetpu/direct/aarch64/libedgetpu.so.1")
             interpreter = tflite.Interpreter(
-                    model_path=model_path,
-                    experimental_delegates=[tflite.load_delegate(libedgetpu)])
+                model_path=model_path, experimental_delegates=[tflite.load_delegate(libedgetpu)]
+            )
         else:
             interpreter = tflite.Interpreter(model_path=model_path)
         return interpreter
@@ -56,18 +58,18 @@ def skipped_test_tflite_runtime():
         input_details = interpreter.get_input_details()
         output_details = interpreter.get_output_details()
 
-        input_shape = input_details[0]['shape']
+        input_shape = input_details[0]["shape"]
         tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
-        interpreter.set_tensor(input_details[0]['index'], tflite_input)
+        interpreter.set_tensor(input_details[0]["index"], tflite_input)
         interpreter.invoke()
-        tflite_output = interpreter.get_tensor(output_details[0]['index'])
+        tflite_output = interpreter.get_tensor(output_details[0]["index"])
 
         # inference via remote tvm tflite runtime
         server = rpc.Server("localhost")
         remote = rpc.connect(server.host, server.port)
         ctx = remote.cpu(0)
 
-        with open(tflite_model_path, 'rb') as model_fin:
+        with open(tflite_model_path, "rb") as model_fin:
             runtime = tflite_runtime.create(model_fin.read(), ctx)
             runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
             runtime.invoke()
@@ -79,6 +81,7 @@ def skipped_test_tflite_runtime():
     # Target EdgeTPU on coral board
     check_remote(target_edgetpu=True)
 
+
 if __name__ == "__main__":
     # skipped_test_tflite_runtime()
     pass
index deba5e5..109872e 100644 (file)
@@ -15,4 +15,3 @@
 # specific language governing permissions and limitations
 # under the License.
 """Infrastructure and tests for EthosN"""
-
index 070348f..2c88d56 100644 (file)
@@ -86,9 +86,9 @@ def get_host_op_count(mod):
 
 def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1):
     relay.backend.compile_engine.get().clear()
-    with tvm.transform.PassContext(opt_level=3, config={
-            "relay.ext.ethos-n.options": {"variant": 0}
-    }):
+    with tvm.transform.PassContext(
+        opt_level=3, config={"relay.ext.ethos-n.options": {"variant": 0}}
+    ):
         with tvm.target.Target("llvm"):
             if npu:
                 f = relay.build_module.bind_params_by_name(mod["main"], params)
@@ -100,15 +100,17 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1):
                 mod = relay.transform.MergeCompilerRegions()(mod)
                 mod = relay.transform.PartitionGraph()(mod)
                 host_op_count = get_host_op_count(mod)
-                assert host_op_count == expected_host_ops, \
-                    "Got {} host operators, expected {}".format(host_op_count, expected_host_ops)
+                assert (
+                    host_op_count == expected_host_ops
+                ), "Got {} host operators, expected {}".format(host_op_count, expected_host_ops)
                 partition_count = 0
                 for global_var in mod.get_global_vars():
                     if "ethos-n" in global_var.name_hint:
                         partition_count += 1
 
-                assert npu_partitions == partition_count, \
-                    "Got {} ethos-n partitions, expected {}".format(partition_count, npu_partitions)
+                assert (
+                    npu_partitions == partition_count
+                ), "Got {} ethos-n partitions, expected {}".format(partition_count, npu_partitions)
 
             return relay.build(mod, params=params)
 
@@ -130,7 +132,9 @@ def run(graph, lib, params, inputs, outputs, npu=True):
     return out
 
 
-def build_and_run(mod, inputs, outputs, params, ctx=tvm.cpu(), npu=True, expected_host_ops=0, npu_partitions=1):
+def build_and_run(
+    mod, inputs, outputs, params, ctx=tvm.cpu(), npu=True, expected_host_ops=0, npu_partitions=1
+):
     graph, lib, params = build(mod, params, npu, expected_host_ops, npu_partitions)
     return run(graph, lib, params, inputs, outputs, npu)
 
@@ -138,26 +142,24 @@ def build_and_run(mod, inputs, outputs, params, ctx=tvm.cpu(), npu=True, expecte
 def verify(answers, atol, rtol=1e-07, verify_saturation=True):
     """Compare the array of answers. Each entry is a list of outputs"""
     if len(answers) < 2:
-        print("No results to compare: expected at least two, found ",
-              len(answers))
+        print("No results to compare: expected at least two, found ", len(answers))
     for answer in zip_longest(*answers):
         for outs in combinations(answer, 2):
             if verify_saturation:
-                assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
-                    "Output is saturated: {}".format(outs[0])
-                assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
-                    "Output is saturated: {}".format(outs[0])
-            tvm.testing.assert_allclose(
-                outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol
-            )
+                assert (
+                    np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size
+                ), "Output is saturated: {}".format(outs[0])
+                assert (
+                    np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size
+                ), "Output is saturated: {}".format(outs[0])
+            tvm.testing.assert_allclose(outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
 
 
 def inference_result(checksum, outputs):
     """Set the expected results of an Ethos inference, if the testing
     infrastructure is available. This assumes that the entire graph
     was offloaded to the neural processor."""
-    if tvm.get_global_func(
-            "relay.ethos-n.test.infra.inference_result", True):
+    if tvm.get_global_func("relay.ethos-n.test.infra.inference_result", True):
         return _infrastructure.inference_result(checksum, *outputs)
     return False
 
index cca61d1..a529e04 100644 (file)
@@ -42,12 +42,14 @@ def _get_model(shapes, dtype, axis):
 
     zeroi = relay.const(1, "int32")
     zerof = relay.const(0.5, "float32")
-    con = relay.qnn.op.concatenate(tup,
-                                   input_scales=[zerof]*len(shapes),
-                                   input_zero_points=[zeroi]*len(shapes),
-                                   output_scale=zerof,
-                                   output_zero_point=zeroi,
-                                   axis=axis)
+    con = relay.qnn.op.concatenate(
+        tup,
+        input_scales=[zerof] * len(shapes),
+        input_zero_points=[zeroi] * len(shapes),
+        output_scale=zerof,
+        output_zero_point=zeroi,
+        axis=axis,
+    )
     return con
 
 
@@ -58,7 +60,7 @@ def test_concatenate():
     trials = [
         ([(1, 4), (1, 6)], 1),
         ([(1, 16, 4), (1, 16, 4)], 1),
-        ([(1, 25, 4, 16)]*3, 3),
+        ([(1, 25, 4, 16)] * 3, 3),
         ([(1, 25, 4, 16), (1, 25, 5, 16), (1, 25, 6, 16)], 2),
     ]
 
@@ -79,10 +81,30 @@ def test_concatenate_failure():
 
     trials = [
         ([(1, 4, 4, 4, 4), (1, 4, 4, 4, 4)], "uint8", 1, "dimensions=5, dimensions must be <= 4;"),
-        ([(1, 4, 4, 4), (1, 4, 4, 4)], "uint8", 3, "Concatenation along the channels dimension (axis 3) requires input tensors with a multiple of 16 channels;"),
-        ([(1, 4, 4, 4), (1, 4, 4, 4)], "int8", 2, "dtype='int8', dtype must be either uint8 or int32; dtype='int8', dtype must be either uint8 or int32;"),
-        ([(2, 4, 4, 4), (2, 4, 4, 4)], "uint8", 2, "batch size=2, batch size must = 1; batch size=2, batch size must = 1;"),
-        ([(1, 4, 4, 4), (1, 4, 4, 4)], "uint8", 0, "Concatenation cannot be performed along batch axis (axis 0);"),
+        (
+            [(1, 4, 4, 4), (1, 4, 4, 4)],
+            "uint8",
+            3,
+            "Concatenation along the channels dimension (axis 3) requires input tensors with a multiple of 16 channels;",
+        ),
+        (
+            [(1, 4, 4, 4), (1, 4, 4, 4)],
+            "int8",
+            2,
+            "dtype='int8', dtype must be either uint8 or int32; dtype='int8', dtype must be either uint8 or int32;",
+        ),
+        (
+            [(2, 4, 4, 4), (2, 4, 4, 4)],
+            "uint8",
+            2,
+            "batch size=2, batch size must = 1; batch size=2, batch size must = 1;",
+        ),
+        (
+            [(1, 4, 4, 4), (1, 4, 4, 4)],
+            "uint8",
+            0,
+            "Concatenation cannot be performed along batch axis (axis 0);",
+        ),
     ]
 
     for shapes, dtype, axis, err_msg in trials:
index 52e3de9..64052ce 100644 (file)
@@ -40,20 +40,34 @@ def _get_same_padding(data, kernel, dilation, stride):
     return [pad_top, pad_left, pad_bottom, pad_right]
 
 
-def _get_model(shape, kernel_h, kernel_w,
-               input_zp, input_sc,
-               kernel_zp, kernel_sc,
-               output_zp, output_sc,
-               pad, strides, dilation,
-               groups, dtype,
-               out_channels, weight_format):
+def _get_model(
+    shape,
+    kernel_h,
+    kernel_w,
+    input_zp,
+    input_sc,
+    kernel_zp,
+    kernel_sc,
+    output_zp,
+    output_sc,
+    pad,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    out_channels,
+    weight_format,
+):
     """Return a model and any parameters it may have"""
     a = relay.var("a", shape=shape, dtype=dtype)
     if pad == "op" or pad == "both":
         p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides)
-        a = relay.nn.pad(a,
-                         pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
-                         pad_value=input_zp, pad_mode="constant")
+        a = relay.nn.pad(
+            a,
+            pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
+            pad_value=input_zp,
+            pad_mode="constant",
+        )
         shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3])
 
     p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides)
@@ -61,7 +75,11 @@ def _get_model(shape, kernel_h, kernel_w,
         weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels)
     else:
         weight_shape = (kernel_h, kernel_w, out_channels, 1)
-    w = tvm.nd.array(np.random.randint(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype))
+    w = tvm.nd.array(
+        np.random.randint(
+            np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype
+        )
+    )
     weights = relay.const(w, dtype)
     conv = relay.qnn.op.conv2d(
         a,
@@ -85,30 +103,31 @@ def _get_model(shape, kernel_h, kernel_w,
     bias = relay.nn.bias_add(conv, biasc, axis=3)
     req = relay.qnn.op.requantize(
         bias,
-        relay.const(input_sc * kernel_sc, 'float32'),  # input zero scale
-        relay.const(0, 'int32'),                       # input zero point
-        relay.const(output_sc, 'float32'),             # output zero scale
-        relay.const(output_zp, 'int32'),               # output zero point
-        out_dtype="uint8"
+        relay.const(input_sc * kernel_sc, "float32"),  # input zero scale
+        relay.const(0, "int32"),  # input zero point
+        relay.const(output_sc, "float32"),  # output zero scale
+        relay.const(output_zp, "int32"),  # output zero point
+        out_dtype="uint8",
     )
-    params = {"w": w,
-              "b": b}
+    params = {"w": w, "b": b}
     return req, params
 
 
 def _get_conv2d_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
     input_max = input_sc * (255 - input_zp)
-    input_min = - input_sc * input_zp
+    input_min = -input_sc * input_zp
     kernel_max = kernel_sc * (255 - kernel_zp)
-    kernel_min = - kernel_sc * kernel_zp
-    output_limits = [kernel_max * kernel_h * kernel_w * channels * input_max,
-                     kernel_min * kernel_h * kernel_w * channels * input_max,
-                     kernel_min * kernel_h * kernel_w * channels * input_min,
-                     kernel_max * kernel_h * kernel_w * channels * input_min]
+    kernel_min = -kernel_sc * kernel_zp
+    output_limits = [
+        kernel_max * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_min,
+        kernel_max * kernel_h * kernel_w * channels * input_min,
+    ]
     output_max = max(output_limits)
     output_min = min(output_limits)
     output_sc = (output_max - output_min) / 255
-    output_zp = - int(output_min / output_sc)
+    output_zp = -int(output_min / output_sc)
     return output_zp, output_sc
 
 
@@ -117,18 +136,18 @@ def test_conv2d():
         return
 
     trials = [
-        [(1, 17, 20, 26), 4, 3, 1, 'attr', (2, 2), (1, 1)],
-        [(1, 30, 27, 30), 5, 5, 3, 'none', (1, 1), (1, 1)],
-        [(1, 14, 28, 11), 6, 2, 2, 'op', (2, 2), (1, 1)],
-        [(1, 9, 20, 30), 7, 1, 5, 'none', (1, 1), (1, 1)],
-        [(1, 21, 21, 22), 8, 5, 1, 'attr', (2, 2), (1, 1)],
-        [(1, 21, 25, 29), 9, 2, 5, 'op', (1, 1), (1, 1)],
-        [(1, 31, 28, 15), 10, 1, 2, 'attr', (2, 2), (1, 1)],
-        [(1, 21, 21, 8), 11, 3, 3, 'none', (1, 1), (1, 1)],
-        [(1, 5, 11, 6), 12, 5, 2, 'op', (2, 2), (1, 1)],
-        [(1, 12, 7, 18), 13, 1, 3, 'op', (1, 1), (1, 1)],
-        [(1, 24, 6, 26), 14, 3, 5, 'none', (2, 2), (1, 1)],
-        [(1, 19, 24, 16), 15, 2, 1, 'attr', (1, 1), (1, 1)],
+        [(1, 17, 20, 26), 4, 3, 1, "attr", (2, 2), (1, 1)],
+        [(1, 30, 27, 30), 5, 5, 3, "none", (1, 1), (1, 1)],
+        [(1, 14, 28, 11), 6, 2, 2, "op", (2, 2), (1, 1)],
+        [(1, 9, 20, 30), 7, 1, 5, "none", (1, 1), (1, 1)],
+        [(1, 21, 21, 22), 8, 5, 1, "attr", (2, 2), (1, 1)],
+        [(1, 21, 25, 29), 9, 2, 5, "op", (1, 1), (1, 1)],
+        [(1, 31, 28, 15), 10, 1, 2, "attr", (2, 2), (1, 1)],
+        [(1, 21, 21, 8), 11, 3, 3, "none", (1, 1), (1, 1)],
+        [(1, 5, 11, 6), 12, 5, 2, "op", (2, 2), (1, 1)],
+        [(1, 12, 7, 18), 13, 1, 3, "op", (1, 1), (1, 1)],
+        [(1, 24, 6, 26), 14, 3, 5, "none", (2, 2), (1, 1)],
+        [(1, 19, 24, 16), 15, 2, 1, "attr", (1, 1), (1, 1)],
     ]
 
     np.random.seed(0)
@@ -152,16 +171,27 @@ def test_conv2d():
             input_sc = np.random.random() * 2
             kernel_zp = np.random.randint(0, 255)
             kernel_sc = np.random.random() * 2
-            output_zp, output_sc = _get_conv2d_qnn_params(input_zp, input_sc,
-                                                          kernel_zp, kernel_sc,
-                                                          kernel_h, kernel_w, shape[3])
-            model, params = _get_model(shape, kernel_h, kernel_w,
-                                       input_zp, input_sc,
-                                       kernel_zp, kernel_sc,
-                                       output_zp, output_sc,
-                                       pad, stride, dilation,
-                                       groups, "uint8",
-                                       out_channels, weight_format)
+            output_zp, output_sc = _get_conv2d_qnn_params(
+                input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, shape[3]
+            )
+            model, params = _get_model(
+                shape,
+                kernel_h,
+                kernel_w,
+                input_zp,
+                input_sc,
+                kernel_zp,
+                kernel_sc,
+                output_zp,
+                output_sc,
+                pad,
+                stride,
+                dilation,
+                groups,
+                "uint8",
+                out_channels,
+                weight_format,
+            )
             for npu in [False, True]:
                 mod = tei.make_module(model, params)
                 outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu))
@@ -174,31 +204,160 @@ def test_conv2d_failure():
         return
 
     trials = [
-        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
-         "Overall scale (of the input * weights / output) should be in the range [0, 1)"),
-        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "int8", 8, "HWIO",
-         "dtype='int8', dtype must be either uint8 or int32"),
-        ((1, 4, 4, 4), 2, 2, 0, 1, 0, 1, 0, 2, "both", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
-         "both op and attr padding exist, must be either op/attr only or no padding"),
-        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1, 1), (1, 1), 1, "uint8", 8, "HWIO",
-         "stride size=3, stride size must = 2"),
-        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (2, 1), 1, "uint8", 8, "HWIO",
-         "dilation=[2, 1], dilation must = [1, 1]"),
-        ((2, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
-         "batch size=2, batch size must = 1"),
+        (
+            (1, 4, 4, 4),
+            1,
+            1,
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "none",
+            (1, 1),
+            (1, 1),
+            1,
+            "uint8",
+            8,
+            "HWIO",
+            "Overall scale (of the input * weights / output) should be in the range [0, 1)",
+        ),
+        (
+            (1, 4, 4, 4),
+            1,
+            1,
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "none",
+            (1, 1),
+            (1, 1),
+            1,
+            "int8",
+            8,
+            "HWIO",
+            "dtype='int8', dtype must be either uint8 or int32",
+        ),
+        (
+            (1, 4, 4, 4),
+            2,
+            2,
+            0,
+            1,
+            0,
+            1,
+            0,
+            2,
+            "both",
+            (1, 1),
+            (1, 1),
+            1,
+            "uint8",
+            8,
+            "HWIO",
+            "both op and attr padding exist, must be either op/attr only or no padding",
+        ),
+        (
+            (1, 4, 4, 4),
+            1,
+            1,
+            0,
+            1,
+            0,
+            1,
+            0,
+            2,
+            "none",
+            (1, 1, 1),
+            (1, 1),
+            1,
+            "uint8",
+            8,
+            "HWIO",
+            "stride size=3, stride size must = 2",
+        ),
+        (
+            (1, 4, 4, 4),
+            1,
+            1,
+            0,
+            1,
+            0,
+            1,
+            0,
+            2,
+            "none",
+            (1, 1),
+            (2, 1),
+            1,
+            "uint8",
+            8,
+            "HWIO",
+            "dilation=[2, 1], dilation must = [1, 1]",
+        ),
+        (
+            (2, 4, 4, 4),
+            1,
+            1,
+            0,
+            1,
+            0,
+            1,
+            0,
+            2,
+            "none",
+            (1, 1),
+            (1, 1),
+            1,
+            "uint8",
+            8,
+            "HWIO",
+            "batch size=2, batch size must = 1",
+        ),
     ]
 
     np.random.seed(0)
-    for shape, kernel_h, kernel_w, input_zp, input_sc, kernel_zp,\
-        kernel_sc, output_zp, output_sc, pad, stride, dilation,\
-        groups, dtype, out_channels, weight_format, err_msg in trials:
-        model, params = _get_model(shape, kernel_h, kernel_w,
-                                   input_zp, input_sc,
-                                   kernel_zp, kernel_sc,
-                                   output_zp, output_sc,
-                                   pad, stride, dilation,
-                                   groups, dtype,
-                                   out_channels, weight_format)
+    for (
+        shape,
+        kernel_h,
+        kernel_w,
+        input_zp,
+        input_sc,
+        kernel_zp,
+        kernel_sc,
+        output_zp,
+        output_sc,
+        pad,
+        stride,
+        dilation,
+        groups,
+        dtype,
+        out_channels,
+        weight_format,
+        err_msg,
+    ) in trials:
+        model, params = _get_model(
+            shape,
+            kernel_h,
+            kernel_w,
+            input_zp,
+            input_sc,
+            kernel_zp,
+            kernel_sc,
+            output_zp,
+            output_sc,
+            pad,
+            stride,
+            dilation,
+            groups,
+            dtype,
+            out_channels,
+            weight_format,
+        )
         model = tei.make_ethosn_composite(model, "ethos-n.qnn_conv2d")
         mod = tei.make_ethosn_partition(model)
         tei.test_error(mod, {}, err_msg)
index d5ff9bf..3f1c577 100644 (file)
@@ -62,7 +62,13 @@ def test_split_failure():
         ((1, 4, 4, 4), "int8", 4, 2, "dtype='int8', dtype must be either uint8 or int32;"),
         ((2, 4, 4, 4), "uint8", 4, 2, "batch size=2, batch size must = 1;"),
         ((1, 4, 4, 4), "uint8", 1, 0, "Split cannot be performed along batch axis (axis 0);"),
-        ((1, 4, 4, 4), "uint8", 4, 3, "Split along the channels dimension (axis 3) requires all output sizes (specified in splitInfo.m_Sizes) to be multiples of 16;"),
+        (
+            (1, 4, 4, 4),
+            "uint8",
+            4,
+            3,
+            "Split along the channels dimension (axis 3) requires all output sizes (specified in splitInfo.m_Sizes) to be multiples of 16;",
+        ),
     ]
 
     for shape, dtype, splits, axis, err_msg in trials:
index 942186d..0cf5720 100644 (file)
@@ -32,18 +32,22 @@ def test_split_with_asym_concats():
         split = relay.op.split(a, indices_or_sections=splits, axis=axis)
         zeroi = relay.const(1, "int32")
         zerof = relay.const(0.5, "float32")
-        con1 = relay.qnn.op.concatenate([split[0], split[1]],
-                                        input_scales=[zerof]*2,
-                                        input_zero_points=[zeroi]*2,
-                                        output_scale=zerof,
-                                        output_zero_point=zeroi,
-                                        axis=axis)
-        con2 = relay.qnn.op.concatenate([split[2], split[3]],
-                                        input_scales=[zerof]*2,
-                                        input_zero_points=[zeroi]*2,
-                                        output_scale=zerof,
-                                        output_zero_point=zeroi,
-                                        axis=axis)
+        con1 = relay.qnn.op.concatenate(
+            [split[0], split[1]],
+            input_scales=[zerof] * 2,
+            input_zero_points=[zeroi] * 2,
+            output_scale=zerof,
+            output_zero_point=zeroi,
+            axis=axis,
+        )
+        con2 = relay.qnn.op.concatenate(
+            [split[2], split[3]],
+            input_scales=[zerof] * 2,
+            input_zero_points=[zeroi] * 2,
+            output_scale=zerof,
+            output_zero_point=zeroi,
+            axis=axis,
+        )
         return relay.Tuple((con2, con1))
 
     trials = [
@@ -96,12 +100,14 @@ def test_input_tuples():
 
         zeroi = relay.const(1, "int32")
         zerof = relay.const(0.5, "float32")
-        con = relay.qnn.op.concatenate(tup,
-                                       input_scales=[zerof]*len(shapes),
-                                       input_zero_points=[zeroi]*len(shapes),
-                                       output_scale=zerof,
-                                       output_zero_point=zeroi,
-                                       axis=axis)
+        con = relay.qnn.op.concatenate(
+            tup,
+            input_scales=[zerof] * len(shapes),
+            input_zero_points=[zeroi] * len(shapes),
+            output_scale=zerof,
+            output_zero_point=zeroi,
+            axis=axis,
+        )
 
         return con
 
index 9ae2c9f..5475978 100644 (file)
@@ -26,11 +26,11 @@ def benchmark_fc_int8_acc16():
     n = 128
     k = 128
 
-    X = te.placeholder((m, k), name='X', dtype="uint8")
-    W = te.placeholder((n, k), name='W', dtype="int8")
+    X = te.placeholder((m, k), name="X", dtype="uint8")
+    W = te.placeholder((n, k), name="W", dtype="int8")
 
-    peak = 512/16*2*2*2
-    gops_per_mm = 2*n*m*k
+    peak = 512 / 16 * 2 * 2 * 2
+    gops_per_mm = 2 * n * m * k
     print("Peak {} Gops/s \n".format(peak))
 
     def verify(target="llvm -mcpu=skylake-avx512"):
@@ -39,17 +39,25 @@ def benchmark_fc_int8_acc16():
             return
 
         ctx = tvm.context(target, 0)
-        X = te.placeholder((m, k), name='X', dtype="uint8")
-        W = te.placeholder((n, k), name='W', dtype="int8")
+        X = te.placeholder((m, k), name="X", dtype="uint8")
+        W = te.placeholder((n, k), name="W", dtype="int8")
         pc = dot_16x1x16_uint8_int8_int16()
-        ak = te.reduce_axis((0, k), name='k')
-
-        packedW = te.placeholder((n//128, 128*(k//2), 2), name='packedW', dtype="int8")
-        t_fc = te.compute((m, n), lambda i, j: te.sum(X[i, ak].astype("int16") * packedW[j//128, (ak//2)*128+j%128, ak%2].astype("int16"), axis=ak), name="F")
+        ak = te.reduce_axis((0, k), name="k")
+
+        packedW = te.placeholder((n // 128, 128 * (k // 2), 2), name="packedW", dtype="int8")
+        t_fc = te.compute(
+            (m, n),
+            lambda i, j: te.sum(
+                X[i, ak].astype("int16")
+                * packedW[j // 128, (ak // 2) * 128 + j % 128, ak % 2].astype("int16"),
+                axis=ak,
+            ),
+            name="F",
+        )
 
         t_sch = te.create_schedule(t_fc.op)
         a_x, a_y = t_fc.op.axis
-        a_k, = t_fc.op.reduce_axis
+        (a_k,) = t_fc.op.reduce_axis
 
         a_yo, a_yi = t_sch[t_fc].split(a_y, factor=128)
         a_ko, a_ki = t_sch[t_fc].split(a_k, factor=2)
@@ -58,34 +66,40 @@ def benchmark_fc_int8_acc16():
         a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=32)
         t_sch[t_fc].reorder(a_yo, a_xo, a_koo, a_xi, a_koi, a_yi, a_ki)
 
-               t_sch[t_fc].tensorize(a_yi, pc)
+        t_sch[t_fc].tensorize(a_yi, pc)
         # print(tvm.lower(t_sch, [X, packedW, t_fc], simple_mode=True))
         t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic")
         t_evaluator = t_func.time_evaluator(t_func.entry_name, ctx, number=10)
 
-           # generate the plain data
+        # generate the plain data
         a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
-        b_ = np.random.uniform(1, 10,  size=(n, k)).astype("int8")
+        b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
 
-        packW = np.random.uniform(1, 10,  size=(n//128, 128*(k//2), 2)).astype("int8")
+        packW = np.random.uniform(1, 10, size=(n // 128, 128 * (k // 2), 2)).astype("int8")
         # This occurs in pre_compute stage
-        for r_idx in range(n//128):
-            for s_idx in range(128*(k//2)):
+        for r_idx in range(n // 128):
+            for s_idx in range(128 * (k // 2)):
                 for t_idx in range(2):
-                    packW[r_idx][s_idx][t_idx] = b_[r_idx*128+s_idx%128][s_idx//128*2+t_idx]
+                    packW[r_idx][s_idx][t_idx] = b_[r_idx * 128 + s_idx % 128][
+                        s_idx // 128 * 2 + t_idx
+                    ]
 
         x = tvm.nd.array(a_, ctx)
         w = tvm.nd.array(packW, ctx)
         y = tvm.nd.array(np.zeros((m, n), dtype="int16"), ctx)
 
         result = t_evaluator(x, w, y)
-        gops_per_sec = gops_per_mm/result.mean/1e9
-        tvm.testing.assert_allclose(
-           y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5)
-        print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.'.format(result.mean*1000, gops_per_sec, gops_per_sec/peak))
-        #t_func.export_library("gemm_tensorize.o")
+        gops_per_sec = gops_per_mm / result.mean / 1e9
+        tvm.testing.assert_allclose(y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5)
+        print(
+            "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.".format(
+                result.mean * 1000, gops_per_sec, gops_per_sec / peak
+            )
+        )
+        # t_func.export_library("gemm_tensorize.o")
 
     verify()
 
+
 if __name__ == "__main__":
     benchmark_fc_int8_acc16()
index 5380040..3e0d5db 100644 (file)
@@ -32,8 +32,8 @@ def test_fc_int8_acc32():
     n = 1024
     k = 1024
 
-    X = te.placeholder((m, k), name='X', dtype="uint8")
-    W = te.placeholder((n, k), name='W', dtype="int8")
+    X = te.placeholder((m, k), name="X", dtype="uint8")
+    W = te.placeholder((n, k), name="W", dtype="int8")
 
     peak = 280
     print("Peak {} Gops/s".format(peak))
@@ -50,15 +50,21 @@ def test_fc_int8_acc32():
 
         ctx = tvm.context(target, 0)
         pc = dot_16x1x16_uint8_int8_int32_cascadelake()
-        ak = te.reduce_axis((0, k), name='k')
-        packedW = te.placeholder(
-            (n // 16, 16 * (k // 4), 4), name='packedW', dtype="int8")
-
-        t_fc = te.compute((m, n), lambda i, j: te.sum(X[i, ak].astype(
-            "int32") * packedW[j / 16, (ak / 4) * 16 + j % 16, ak % 4].astype("int32"), axis=ak), name="F")
+        ak = te.reduce_axis((0, k), name="k")
+        packedW = te.placeholder((n // 16, 16 * (k // 4), 4), name="packedW", dtype="int8")
+
+        t_fc = te.compute(
+            (m, n),
+            lambda i, j: te.sum(
+                X[i, ak].astype("int32")
+                * packedW[j / 16, (ak / 4) * 16 + j % 16, ak % 4].astype("int32"),
+                axis=ak,
+            ),
+            name="F",
+        )
         t_sch = te.create_schedule(t_fc.op)
         a_x, a_y = t_fc.op.axis
-        a_k, = t_fc.op.reduce_axis
+        (a_k,) = t_fc.op.reduce_axis
 
         a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16)
         a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32)
@@ -76,14 +82,14 @@ def test_fc_int8_acc32():
         a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
         b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
 
-        packW = np.random.uniform(1, 10, size=(
-            n // 16, 16 * (k // 4), 4)).astype("int8")
+        packW = np.random.uniform(1, 10, size=(n // 16, 16 * (k // 4), 4)).astype("int8")
         # This occurs in pre_compute stage
         for r_idx in range(n // 16):
             for s_idx in range(16 * (k // 4)):
                 for t_idx in range(4):
-                    packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx %
-                                                    16][(s_idx // 16) * 4 + t_idx]
+                    packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx % 16][
+                        (s_idx // 16) * 4 + t_idx
+                    ]
 
         x = tvm.nd.array(a_, ctx)
         w = tvm.nd.array(packW, ctx)
@@ -93,8 +99,11 @@ def test_fc_int8_acc32():
         gops_per_sec = gops_per_mm / result.mean / 1e9
         # verify the correctness
         tvm.testing.assert_allclose(y.asnumpy(), np.dot(a_, b_.T), rtol=0)
-        print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}'.format(
-            result.mean * 1000, gops_per_sec, gops_per_sec / peak))
+        print(
+            "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}".format(
+                result.mean * 1000, gops_per_sec, gops_per_sec / peak
+            )
+        )
         t_func.export_library("tensorize_acc32.o")
 
     verify()
index e8d348e..c80d465 100644 (file)
@@ -39,21 +39,15 @@ def test_conv2d():
         return
     wshape = (out_channel, in_channel, filter_h, filter_w)
 
-    X = te.placeholder(xshape, name='X')
-    W = te.placeholder(wshape, name='W')
-    Y = miopen.conv2d_forward(X,
-                              W,
-                              stride_h,
-                              stride_w,
-                              pad_h,
-                              pad_w,
-                              dilation_h,
-                              dilation_w,
-                              conv_mode=0,
-                              data_type=1)
+    X = te.placeholder(xshape, name="X")
+    W = te.placeholder(wshape, name="W")
+    Y = miopen.conv2d_forward(
+        X, W, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, conv_mode=0, data_type=1
+    )
 
     yshape = [x.value for x in Y.shape]
     from tvm import topi
+
     s = te.create_schedule(Y.op)
 
     def verify():
@@ -64,8 +58,9 @@ def test_conv2d():
         y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
         f(x, w, y)
 
-        Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w),
-                                    (dilation_h, dilation_w))
+        Y_ref = topi.nn.conv2d_nchw(
+            X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w)
+        )
         s_ref = te.create_schedule(Y_ref.op)
         f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm")
         y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
index 1f0906e..2e2c143 100644 (file)
@@ -19,18 +19,16 @@ from tvm import te
 import numpy as np
 from tvm.contrib import mps
 
+
 @tvm.testing.requires_metal
 def test_matmul():
     n = 1024
     l = 128
     m = 256
-    A = te.placeholder((n, l), name='A')
-    B = te.placeholder((l, m), name='B')
+    A = te.placeholder((n, l), name="A")
+    B = te.placeholder((l, m), name="B")
     C = mps.matmul(A, B)
-    D = te.compute(
-        C.shape,
-        lambda *i: C(*i) + 1.
-    )
+    D = te.compute(C.shape, lambda *i: C(*i) + 1.0)
     s = te.create_schedule(D.op)
     yo, xo = D.op.axis
     block_y = te.thread_axis("blockIdx.y")
@@ -44,8 +42,6 @@ def test_matmul():
     s[D].bind(ty, thread_y)
     s[D].bind(tx, thread_x)
 
-
-
     def verify(A, B, D, s, target="metal"):
         if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
             print("skip because extern function is not available")
@@ -56,10 +52,11 @@ def test_matmul():
         b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
         f(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
+        tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
+
     verify(A, B, D, s)
 
+
 @tvm.testing.requires_metal
 def test_conv2d():
     n = 1
@@ -72,7 +69,7 @@ def test_conv2d():
     stride = 2
     A = te.placeholder((n, h, w, ci), name="x")
     B = te.placeholder((co, kh, kw, ci), name="w")
-    C = mps.conv2d(A, B, 'SAME', 2)
+    C = mps.conv2d(A, B, "SAME", 2)
     s1 = te.create_schedule(C.op)
 
     def verify(A, B, C, target="llvm"):
@@ -92,6 +89,5 @@ def test_conv2d():
 
 
 if __name__ == "__main__":
-    #test_matmul()
+    # test_matmul()
     test_conv2d()
-
index 230e8db..afe739c 100644 (file)
@@ -56,9 +56,7 @@ def mxnet_check():
     mxf(xx, yy, zz, 10.0)
     mxf(xx, yy, zz, 10.0)
 
-
-    tvm.testing.assert_allclose(
-        zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
+    tvm.testing.assert_allclose(zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
 
 
 if __name__ == "__main__":
index bbee2b6..5c4d87b 100644 (file)
@@ -28,9 +28,9 @@ def test_fully_connected_inference():
     n = 1024
     l = 128
     m = 235
-    bias = te.var('bias', dtype="float32")
-    A = te.placeholder((l, ), name='A')
-    B = te.placeholder((m, l), name='B')
+    bias = te.var("bias", dtype="float32")
+    A = te.placeholder((l,), name="A")
+    B = te.placeholder((m, l), name="B")
     C = nnpack.fully_connected_inference(A, B)
     D = te.compute(C.shape, lambda i: C[i] + bias, name="D")
     s = te.create_schedule(D.op)
@@ -45,13 +45,14 @@ def test_fully_connected_inference():
         f = tvm.build(s, [A, B, D, bias], target)
         a = tvm.nd.array(np.random.uniform(size=(l)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(m, l)).astype(B.dtype), ctx)
-        d = tvm.nd.array(np.zeros((m, ), dtype=D.dtype), ctx)
+        d = tvm.nd.array(np.zeros((m,), dtype=D.dtype), ctx)
         bb = 10.0
         f(a, b, d, bb)
-        tvm.testing.assert_allclose(
-            d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
+        tvm.testing.assert_allclose(d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy().T) + bb, rtol=1e-5)
+
     verify()
 
+
 def np_conv(na, nw, padding, stride=1):
     batch, in_channel, in_height, in_width = na.shape
     _, num_filter, kernel_h, kernel_w = nw.shape
@@ -73,14 +74,14 @@ def np_conv(na, nw, padding, stride=1):
             for c in range(in_channel):
                 if pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_height + pad_h, in_width + pad_w))
-                    apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = na[n, c]
+                    apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = na[n, c]
                 else:
                     apad = na[n, c]
-                out = scipy.signal.convolve2d(
-                    apad, np.rot90(np.rot90(nw[f, c])), mode='valid')
+                out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(nw[f, c])), mode="valid")
                 nb[n, f] += out[::stride, ::stride]
     return nb
 
+
 @tvm.testing.requires_llvm
 def test_convolution_inference():
     BATCH = 8
@@ -92,19 +93,18 @@ def test_convolution_inference():
     PAD = 1
     STRIDE = 1
 
-    OH = (IH + 2*PAD - K) + 1
-    OW = (IW + 2*PAD - K) + 1
+    OH = (IH + 2 * PAD - K) + 1
+    OW = (IW + 2 * PAD - K) + 1
     dshape = (BATCH, IC, IH, IW)
     kshape = (OC, IC, K, K)
-    bshape = (OC, )
+    bshape = (OC,)
     oshape = (BATCH, OC, OH, OW)
 
-    data = te.placeholder(dshape, name='data')
-    kernel = te.placeholder(kshape, name='kernel')
-    bias = te.placeholder(bshape, name='bias')
-    def verify(target="llvm",
-               algorithm=nnpack.ConvolutionAlgorithm.AUTO,
-               with_bias=True):
+    data = te.placeholder(dshape, name="data")
+    kernel = te.placeholder(kshape, name="kernel")
+    bias = te.placeholder(bshape, name="bias")
+
+    def verify(target="llvm", algorithm=nnpack.ConvolutionAlgorithm.AUTO, with_bias=True):
         if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
             pytest.skip("extern function is not available")
         if not nnpack.is_available():
@@ -112,9 +112,13 @@ def test_convolution_inference():
 
         ctx = tvm.cpu(0)
         output = nnpack.convolution_inference(
-            data, kernel, bias if with_bias else None,
-            [PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
-            algorithm=algorithm)
+            data,
+            kernel,
+            bias if with_bias else None,
+            [PAD, PAD, PAD, PAD],
+            [STRIDE, STRIDE],
+            algorithm=algorithm,
+        )
         s = te.create_schedule(output.op)
 
         f = tvm.build(s, [data, kernel, bias, output], target)
@@ -127,16 +131,18 @@ def test_convolution_inference():
         tc = tvm.nd.array(nc, ctx)
         td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
         f(ta, tb, tc, td)
-        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
-        tvm.testing.assert_allclose(
-            td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(
+            1, bshape[0], 1, 1
+        )
+        tvm.testing.assert_allclose(td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+
     for algorithm in [
-            nnpack.ConvolutionAlgorithm.AUTO,
-            nnpack.ConvolutionAlgorithm.FFT_8x8,
-            nnpack.ConvolutionAlgorithm.FFT_16x16,
-            nnpack.ConvolutionAlgorithm.WT_8x8,
-            nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM,
-            nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
+        nnpack.ConvolutionAlgorithm.AUTO,
+        nnpack.ConvolutionAlgorithm.FFT_8x8,
+        nnpack.ConvolutionAlgorithm.FFT_16x16,
+        nnpack.ConvolutionAlgorithm.WT_8x8,
+        nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM,
+        nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
     ]:
         for with_bias in [True, False]:
             verify(algorithm=algorithm, with_bias=with_bias)
@@ -153,19 +159,18 @@ def test_convolution_inference_without_weight_transform():
     PAD = 1
     STRIDE = 1
 
-    OH = (IH + 2*PAD - K) + 1
-    OW = (IW + 2*PAD - K) + 1
+    OH = (IH + 2 * PAD - K) + 1
+    OW = (IW + 2 * PAD - K) + 1
     dshape = (BATCH, IC, IH, IW)
     kshape = (OC, IC, K, K)
-    bshape = (OC, )
+    bshape = (OC,)
     oshape = (BATCH, OC, OH, OW)
 
-    data = te.placeholder(dshape, name='data')
-    kernel = te.placeholder(kshape, name='kernel')
-    bias = te.placeholder(bshape, name='bias')
-    def verify(target="llvm",
-               algorithm=nnpack.ConvolutionAlgorithm.AUTO,
-               with_bias=True):
+    data = te.placeholder(dshape, name="data")
+    kernel = te.placeholder(kshape, name="kernel")
+    bias = te.placeholder(bshape, name="bias")
+
+    def verify(target="llvm", algorithm=nnpack.ConvolutionAlgorithm.AUTO, with_bias=True):
         if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
             pytest.skip("extern function is not available")
         if not nnpack.is_available():
@@ -173,11 +178,16 @@ def test_convolution_inference_without_weight_transform():
 
         ctx = tvm.cpu(0)
         transformed_kernel = nnpack.convolution_inference_weight_transform(
-            kernel, algorithm=algorithm)
+            kernel, algorithm=algorithm
+        )
         output = nnpack.convolution_inference_without_weight_transform(
-            data, transformed_kernel, bias if with_bias else None,
-            [PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
-            algorithm=algorithm)
+            data,
+            transformed_kernel,
+            bias if with_bias else None,
+            [PAD, PAD, PAD, PAD],
+            [STRIDE, STRIDE],
+            algorithm=algorithm,
+        )
 
         s = te.create_schedule(output.op)
 
@@ -185,15 +195,21 @@ def test_convolution_inference_without_weight_transform():
 
         na = np.random.uniform(size=dshape).astype(data.dtype)
         nb = np.random.uniform(size=kshape).astype(kernel.dtype)
-        nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
+        nc = (
+            np.random.uniform(size=bshape).astype(bias.dtype)
+            if with_bias
+            else np.zeros(bshape, dtype=bias.dtype)
+        )
         ta = tvm.nd.array(na, ctx)
         tb = tvm.nd.array(nb, ctx)
         tc = tvm.nd.array(nc, ctx)
         td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
         f(ta, tb, tc, td)
-        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
-        tvm.testing.assert_allclose(
-            td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(
+            1, bshape[0], 1, 1
+        )
+        tvm.testing.assert_allclose(td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+
     for algorithm in [nnpack.ConvolutionAlgorithm.WT_8x8]:
         for with_bias in [True, False]:
             verify(algorithm=algorithm, with_bias=with_bias)
index ccc122f..6e9cf3a 100644 (file)
@@ -17,8 +17,9 @@
 
 """Relay to ONNX serialization test cases"""
 import pytest
-pytest.importorskip('onnx')
-pytest.importorskip('onnxruntime')
+
+pytest.importorskip("onnx")
+pytest.importorskip("onnxruntime")
 
 import numpy as np
 import onnxruntime as rt
@@ -30,7 +31,7 @@ from tvm.contrib.target.onnx import to_onnx
 
 def func_to_onnx(func, name):
     mod = tvm.IRModule()
-    mod['main'] = func
+    mod["main"] = func
     onnx_model = to_onnx(mod, {}, name, path=None)
     return onnx_model.SerializeToString()
 
@@ -46,8 +47,8 @@ def run_onnx(onnx_model, input_data):
 
 
 def run_relay(func, data_tuple):
-    target = 'llvm'
-    ctx = tvm.context('llvm', 0)
+    target = "llvm"
+    ctx = tvm.context("llvm", 0)
     intrp = relay.create_executor("graph", ctx=ctx, target=target)
     relay_res = intrp.evaluate(func)(*data_tuple)
 
@@ -68,7 +69,7 @@ def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0):
 
 
 def test_add():
-    dtype = 'float32'
+    dtype = "float32"
     t1 = relay.TensorType((5, 10, 5))
     t2 = relay.TensorType((5, 10, 5))
     x = relay.var("x", t1, dtype=dtype)
@@ -79,14 +80,14 @@ def test_add():
     x_data = np.random.rand(5, 10, 5).astype(dtype)
     y_data = np.random.rand(5, 10, 5).astype(dtype)
 
-    verify_results(func, [x_data, y_data], 'test_add')
+    verify_results(func, [x_data, y_data], "test_add")
 
 
 def test_bias_add():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         xshape = (10, 2, 3, 4)
         bshape = (2,)
-        rtol = 1e-2 if dtype == 'float16' else 1e-5
+        rtol = 1e-2 if dtype == "float16" else 1e-5
         x = relay.var("x", shape=xshape, dtype=dtype)
         bias = relay.var("bias", dtype=dtype)
         z = relay.nn.bias_add(x, bias)
@@ -95,73 +96,82 @@ def test_bias_add():
         x_data = np.random.uniform(size=xshape).astype(dtype)
         y_data = np.random.uniform(size=bshape).astype(dtype)
 
-        verify_results(func, [x_data, y_data], 'test_bias_add', rtol=rtol)
+        verify_results(func, [x_data, y_data], "test_bias_add", rtol=rtol)
 
 
 def test_conv2d():
-    def verify_conv2d(dtype, scale, dshape, kshape,
-                      padding=(1, 1),
-                      groups=1,
-                      dilation=(1, 1),
-                      **attrs):
+    def verify_conv2d(
+        dtype, scale, dshape, kshape, padding=(1, 1), groups=1, dilation=(1, 1), **attrs
+    ):
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
-        y = relay.nn.conv2d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            **attrs)
+        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
-        verify_results(func, [data, kernel], 'test_conv2d', rtol=1e-5, atol=1e-5)
+        verify_results(func, [data, kernel], "test_conv2d", rtol=1e-5, atol=1e-5)
 
     dshape = (1, 32, 18, 18)
     kshape = (32, 1, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=32, groups=32, kernel_size=(3, 3))
+    verify_conv2d(
+        "float32", 1, dshape, kshape, padding=(1, 1), channels=32, groups=32, kernel_size=(3, 3)
+    )
 
     dshape = (1, 32, 18, 18)
     kshape = (32, 4, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3))
+    verify_conv2d(
+        "float32", 1, dshape, kshape, padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3)
+    )
 
     # also group conv2d
     dshape = (1, 32, 18, 18)
     kshape = (64, 1, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=64, groups=32, kernel_size=(3, 3))
+    verify_conv2d(
+        "float32", 1, dshape, kshape, padding=(1, 1), channels=64, groups=32, kernel_size=(3, 3)
+    )
 
     # normal conv2d
     dshape = (1, 3, 224, 224)
     kshape = (10, 3, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=10, kernel_size=(3, 3))
+    verify_conv2d("float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3))
 
     dshape = (1, 3, 224, 224)
     kshape = (10, 3, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(2, 2), channels=10, kernel_size=(3, 3))
+    verify_conv2d("float32", 1, dshape, kshape, padding=(2, 2), channels=10, kernel_size=(3, 3))
 
     dshape = (1, 3, 18, 18)
     kshape = (10, 3, 3, 3)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=10, kernel_size=(3, 3), dilation=(3, 3))
+    verify_conv2d(
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=10,
+        kernel_size=(3, 3),
+        dilation=(3, 3),
+    )
 
     dshape = (1, 3, 18, 18)
     kshape = (10, 3, 2, 2)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(2, 2), channels=10, kernel_size=(2, 2), dilation=(1, 1))
+    verify_conv2d(
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(2, 2),
+        channels=10,
+        kernel_size=(2, 2),
+        dilation=(1, 1),
+    )
 
     dshape = (1, 3, 18, 18)
     kshape = (10, 3, 4, 4)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=10, kernel_size=(4, 4))
+    verify_conv2d("float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(4, 4))
 
     dshape = (1, 3, 18, 18)
     kshape = (10, 3, 4, 4)
-    verify_conv2d("float32", 1, dshape, kshape,
-                  padding=(1, 1), channels=10, kernel_size=(4, 4))
+    verify_conv2d("float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(4, 4))
 
 
 def test_reshape():
@@ -171,7 +181,7 @@ def test_reshape():
 
         func = relay.Function([x], z)
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
-        verify_results(func, [x_data], 'test_reshape', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_reshape", rtol=1e-5, atol=1e-5)
 
     verify_reshape((2, 3, 4), tuple(np.array([4, 2, 3], dtype=np.int64)))
     verify_reshape((2, 3, 4), tuple(np.array([2, 0, 0], dtype=np.int64)))
@@ -185,7 +195,7 @@ def test_transpose():
         z = relay.transpose(x, newshape)
         func = relay.Function([x], z)
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
-        verify_results(func, [x_data], 'test_transpose', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_transpose", rtol=1e-5, atol=1e-5)
 
     verify_reshape((1, 2, 3, 4), (0, 2, 3, 1))
     verify_reshape((1, 2, 3, 4), (0, 3, 2, 1))
@@ -198,7 +208,7 @@ def test_dense():
         func = relay.Function([data, weight], relay.nn.dense(data, weight))
         x_data = np.random.uniform(size=d_shape).astype("float32")
         w_data = np.random.uniform(size=w_shape).astype("float32")
-        verify_results(func, [x_data, w_data], 'test_dense', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data, w_data], "test_dense", rtol=1e-5, atol=1e-5)
 
     verify_dense((1, 8), (16, 8))
     verify_dense((1, 4), (3, 4))
@@ -207,13 +217,16 @@ def test_dense():
 def test_max_pool():
     def verify_max_pool(x_shape, pool_size, strides, padding, ceil_mode):
         x = relay.var("x", relay.TensorType(x_shape, "float32"))
-        y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
-                                    ceil_mode=ceil_mode)
+        y = tvm.relay.nn.max_pool2d(
+            x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode
+        )
         func = relay.Function([x], y)
         x_data = np.random.uniform(size=x_shape).astype("float32")
-        verify_results(func, [x_data], 'test_max_pool', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_max_pool", rtol=1e-5, atol=1e-5)
 
-    verify_max_pool((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
+    verify_max_pool(
+        (1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False
+    )
 
 
 def test_batch_flatten():
@@ -221,7 +234,7 @@ def test_batch_flatten():
         data = relay.var("data", relay.TensorType(d_shape, "float32"))
         func = relay.Function([data], relay.nn.batch_flatten(data))
         x_data = np.random.uniform(size=d_shape).astype("float32")
-        verify_results(func, [x_data], 'test_batch_flatten', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_batch_flatten", rtol=1e-5, atol=1e-5)
 
     verify_test_batch_flatten((1, 2, 3, 4))
     verify_test_batch_flatten((1, 8))
@@ -229,7 +242,7 @@ def test_batch_flatten():
 
 def test_batch_norm():
     def verify_batch_norm(axis=1):
-        for dtype in ['float16', 'float32']:
+        for dtype in ["float16", "float32"]:
             data = relay.var("data", relay.TensorType((2, 4, 4, 1), dtype))
             gamma_shape = (data.type_annotation.shape[axis].value,)
             beta = relay.var("beta", relay.TensorType(gamma_shape, dtype))
@@ -244,8 +257,13 @@ def test_batch_norm():
             gamma = np.random.uniform(size=gamma_shape).astype(dtype)
             moving_mean = np.random.uniform(size=gamma_shape).astype(dtype)
             moving_var = np.random.uniform(size=gamma_shape).astype(dtype)
-            verify_results(func, [x_data, gamma, beta, moving_mean, moving_var], 'test_batch_norm', rtol=1e-1,
-                           atol=1e-1)
+            verify_results(
+                func,
+                [x_data, gamma, beta, moving_mean, moving_var],
+                "test_batch_norm",
+                rtol=1e-1,
+                atol=1e-1,
+            )
 
     verify_batch_norm(axis=1)
     verify_batch_norm(axis=3)
@@ -253,26 +271,26 @@ def test_batch_norm():
 
 def test_pad():
     def verify_pad():
-        for dtype in ['float16', 'float32']:
+        for dtype in ["float16", "float32"]:
             dshape = (4, 10, 7, 7)
             x = relay.var("x", shape=dshape, dtype=dtype)
             y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4)))
             func = relay.Function([x], y)
             x_data = np.random.uniform(size=dshape).astype(dtype)
-            verify_results(func, [x_data], 'test_pad', rtol=1e-5, atol=1e-5)
+            verify_results(func, [x_data], "test_pad", rtol=1e-5, atol=1e-5)
 
     verify_pad()
 
 
 def test_sofmax():
     def verify_sofmax():
-        for dtype in ['float32']:
+        for dtype in ["float32"]:
             shape = (10, 4)
             x = relay.var("x", shape=shape, dtype=dtype)
             y = relay.nn.softmax(x, axis=1)
             func = relay.Function([x], y)
             x_data = np.random.uniform(size=shape).astype(dtype)
-            verify_results(func, [x_data], 'test_softmax', rtol=1e-5, atol=1e-5)
+            verify_results(func, [x_data], "test_softmax", rtol=1e-5, atol=1e-5)
 
     verify_sofmax()
 
@@ -283,21 +301,27 @@ def test_squeeze():
         z = relay.squeeze(x, axis=axis)
         func = relay.Function([x], z)
         x_data = np.random.random_sample(shape).astype(dtype)
-        verify_results(func, [x_data], 'test_squeeze', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_squeeze", rtol=1e-5, atol=1e-5)
 
     verify_squeeze((1, 3, 2, 5), "float32", None)
-    verify_squeeze((1, 3, 1), "float32", [2, ])
+    verify_squeeze(
+        (1, 3, 1),
+        "float32",
+        [
+            2,
+        ],
+    )
     verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2])
 
 
 def test_mean():
     def verify_mean(data_shape, axis, exclude, keepdims):
         dtype = "float32"
-        x = relay.var('x', shape=data_shape, dtype=dtype)
+        x = relay.var("x", shape=data_shape, dtype=dtype)
         y = relay.mean(x, axis, keepdims, exclude)
         func = relay.Function([x], y)
         x_data = np.random.uniform(size=data_shape).astype(dtype)
-        verify_results(func, [x_data], 'test_mean', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_mean", rtol=1e-5, atol=1e-5)
 
     verify_mean((1, 2), 0, False, False)
     verify_mean((1, 2), 0, True, False)
@@ -314,7 +338,7 @@ def test_split():
         func = relay.Function([x], y.astuple())
         x_data = np.random.uniform(size=dshape).astype(dtype)
 
-        verify_results(func, [x_data], 'test_split', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_split", rtol=1e-5, atol=1e-5)
 
     verify_split((5, 5, 2, 2), 5, axis=1)
     verify_split((5, 5, 2, 2), 5, axis=0)
@@ -332,33 +356,37 @@ def test_concatenate():
 
         out_tensor = relay.concatenate(in_vars, axis)
         func = relay.Function(in_vars, out_tensor)
-        verify_results(func, in_data, 'test_concatenate', rtol=1e-5, atol=1e-5)
+        verify_results(func, in_data, "test_concatenate", rtol=1e-5, atol=1e-5)
 
     verify_concatenate([(2,), (2,), (2,)], -1)
     verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
     verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
-    verify_concatenate([(5, 6, 7, 3),
-                        (16, 6, 7, 3),
-                        (12, 6, 7, 3),
-                        (8, 6, 7, 3),
-                        (2, 6, 7, 3)], 0)
+    verify_concatenate([(5, 6, 7, 3), (16, 6, 7, 3), (12, 6, 7, 3), (8, 6, 7, 3), (2, 6, 7, 3)], 0)
     verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
 
 
 def test_strided_slice():
     def verify_strided_slice(dshape, begin, end, strides, mode):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
-        if mode == 'size':
+        if mode == "size":
             strides = None
         z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=mode)
         func = relay.Function([x], z)
         x_data = np.random.uniform(size=dshape).astype("float32")
-        verify_results(func, [x_data], 'test_strided_slice', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_strided_slice", rtol=1e-5, atol=1e-5)
 
-    for mode in ['end', 'size']:
+    for mode in ["end", "size"]:
         verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 2, 3], None, mode)
         verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -1, 3], [1, 2], mode)
-        verify_strided_slice((3, 4, 3), [1, ], [4, -3], None, mode)
+        verify_strided_slice(
+            (3, 4, 3),
+            [
+                1,
+            ],
+            [4, -3],
+            None,
+            mode,
+        )
         verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], mode)
         verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, -3], [2, 1, 1], mode)
         verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], mode)
@@ -372,10 +400,7 @@ def test_strided_slice():
 
 
 def test_cmp_type():
-    for op, ref in ((relay.greater, np.greater),
-                    (relay.less, np.less),
-                    (relay.equal, np.equal)
-                    ):
+    for op, ref in ((relay.greater, np.greater), (relay.less, np.less), (relay.equal, np.equal)):
         x_shape = (10, 4)
         y_shape = (5, 10, 1)
         t1 = relay.TensorType(x_shape)
@@ -386,19 +411,23 @@ def test_cmp_type():
         x_data = np.random.rand(*x_shape).astype(t1.dtype)
         y_data = np.random.rand(*y_shape).astype(t2.dtype)
         func = relay.Function([x, y], z)
-        verify_results(func, [x_data, y_data], 'test_cmp_type', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data, y_data], "test_cmp_type", rtol=1e-5, atol=1e-5)
 
 
 def test_unary_identity():
     for dtype in ["int16", "float32", "float64"]:
-        for op, ref in [(relay.zeros_like, np.zeros_like),
-                        (relay.ones_like, np.ones_like)]:
+        for op, ref in [(relay.zeros_like, np.zeros_like), (relay.ones_like, np.ones_like)]:
             shape = (8, 9, 4)
             x = relay.var("x", relay.TensorType(shape, dtype))
             y = op(x)
-            func = relay.Function([x, ], y)
+            func = relay.Function(
+                [
+                    x,
+                ],
+                y,
+            )
             x_data = np.random.rand(*shape).astype(dtype)
-            verify_results(func, [x_data], 'test_cmp_type', rtol=1e-5, atol=1e-5)
+            verify_results(func, [x_data], "test_cmp_type", rtol=1e-5, atol=1e-5)
 
 
 def test_binary_op():
@@ -411,53 +440,56 @@ def test_binary_op():
         x_data = np.random.rand(5, 10, 5).astype(dtype)
         y_data = np.random.rand(5, 10, 5).astype(dtype)
         func = relay.Function([x, y], z)
-        verify_results(func, [x_data, y_data], 'test_binary_op', rtol=1e-5, atol=1e-5)
-
-    for opfunc, ref in [(relay.add, np.add),
-                        (relay.subtract, np.subtract),
-                        (relay.multiply, np.multiply),
-                        (relay.divide, np.divide),
-                        ]:
-        for dtype in ['float32']:
+        verify_results(func, [x_data, y_data], "test_binary_op", rtol=1e-5, atol=1e-5)
+
+    for opfunc, ref in [
+        (relay.add, np.add),
+        (relay.subtract, np.subtract),
+        (relay.multiply, np.multiply),
+        (relay.divide, np.divide),
+    ]:
+        for dtype in ["float32"]:
             check_binary_op(opfunc, dtype)
 
 
 def test_tuple_types():
-    def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype = "float32"):
+    def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype="float32"):
         x = relay.var("x", relay.ty.TensorType(dshape, dtype))
         y = relay.split(x, indices_or_sections, axis=axis)
         z = relay.concatenate(y, axis=axis)
         func = relay.Function([x], z)
         x_data = np.random.uniform(size=dshape).astype(dtype)
-        verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_tuple_types", rtol=1e-5, atol=1e-5)
 
         split_z = relay.split(z, indices_or_sections, axis=axis)
         func = relay.Function([x], split_z.astuple())
-        verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_tuple_types", rtol=1e-5, atol=1e-5)
 
         out = relay.Tuple([y[0] + y[1], y[0] - y[1]])
         func = relay.Function([x], out)
-        verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_tuple_types", rtol=1e-5, atol=1e-5)
 
         z = relay.concatenate(out, axis=axis)
         func = relay.Function([x], z)
-        verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_tuple_types", rtol=1e-5, atol=1e-5)
 
     verify_tuple_types((5, 5, 2, 2), 5, axis=1)
     verify_tuple_types((5, 5, 2, 2), 5, axis=0)
     verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=0)
     verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=1)
 
+
 def test_layout_transform():
     def verify_layout_transform(dshape, src_layout, dst_layout, dtype="float32"):
         x = relay.var("x", relay.ty.TensorType(dshape, dtype))
         y = relay.layout_transform(x, src_layout, dst_layout)
         func = relay.Function([x], y)
         x_data = np.random.uniform(size=dshape).astype(dtype)
-        verify_results(func, [x_data], 'test_layout_transform', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_layout_transform", rtol=1e-5, atol=1e-5)
+
+    verify_layout_transform((1, 3, 8, 8), "NCHW", "NHWC")
+    verify_layout_transform((1, 8, 8, 3), "NHWC", "NCHW")
 
-    verify_layout_transform((1, 3, 8, 8), 'NCHW', 'NHWC')
-    verify_layout_transform((1, 8, 8, 3), 'NHWC', 'NCHW')
 
 def test_clip():
     def verify_clip(dshape, a_min, a_max, dtype="float32"):
@@ -465,23 +497,25 @@ def test_clip():
         y = relay.clip(x, a_min, a_max)
         func = relay.Function([x], y)
         x_data = np.random.uniform(size=dshape).astype(dtype)
-        verify_results(func, [x_data], 'test_clip', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_clip", rtol=1e-5, atol=1e-5)
 
     verify_clip((5, 5, 2, 5), 0, 0.2)
     verify_clip((5, 5, 2, 5), 0.2, 0.5)
 
+
 def test_expand_dims():
     def verify_expand_dims(dshape, axis, num_newaxis, dtype="float32"):
         x = relay.var("x", relay.ty.TensorType(dshape, dtype))
         y = relay.expand_dims(x, axis, num_newaxis)
         func = relay.Function([x], y)
         x_data = np.random.uniform(size=dshape).astype(dtype)
-        verify_results(func, [x_data], 'test_expand_dims', rtol=1e-5, atol=1e-5)
+        verify_results(func, [x_data], "test_expand_dims", rtol=1e-5, atol=1e-5)
 
     verify_expand_dims((1, 1001), 0, 2)
     verify_expand_dims((1, 1, 1001), 2, 2)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_add()
     test_bias_add()
     test_conv2d()
index 8766c0d..a3f3717 100644 (file)
@@ -17,8 +17,9 @@
 
 """Relay to ONNX target test cases"""
 import pytest
-pytest.importorskip('onnx')
-pytest.importorskip('onnxruntime')
+
+pytest.importorskip("onnx")
+pytest.importorskip("onnxruntime")
 
 from collections import OrderedDict
 import numpy as np
@@ -48,7 +49,7 @@ def run_onnx(mod, params, name, input_data):
     return res[0]
 
 
-def get_data(in_data_shapes, dtype='float32'):
+def get_data(in_data_shapes, dtype="float32"):
     in_data = OrderedDict()
     for name, shape in in_data_shapes.items():
         in_data[name] = np.random.uniform(size=shape).astype(dtype)
@@ -56,8 +57,8 @@ def get_data(in_data_shapes, dtype='float32'):
 
 
 def run_relay(mod, params, in_data):
-    target = 'llvm'
-    ctx = tvm.context('llvm', 0)
+    target = "llvm"
+    ctx = tvm.context("llvm", 0)
     intrp = relay.create_executor("graph", mod, ctx=ctx, target=target)
     in_data = [tvm.nd.array(value) for value in in_data.values()]
     return intrp.evaluate()(*in_data, **params).asnumpy()
@@ -65,7 +66,7 @@ def run_relay(mod, params, in_data):
 
 def _verify_results(mod, params, in_data):
     a = run_relay(mod, params, in_data)
-    b = run_onnx(mod, params, 'test_resent', in_data.values())
+    b = run_onnx(mod, params, "test_resent", in_data.values())
     np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7)
 
 
@@ -74,31 +75,30 @@ def test_resnet():
     in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)})
     in_data = get_data(in_data_shapes, dtype="float32")
     for n in [18, 34, 50, 101]:
-        mod, params = tvm.relay.testing.resnet.get_workload(
-            1, num_class, num_layers=n)
+        mod, params = tvm.relay.testing.resnet.get_workload(1, num_class, num_layers=n)
         _verify_results(mod, params, in_data)
 
 
 def test_squeezenet():
     in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)})
     in_data = get_data(in_data_shapes, dtype="float32")
-    for version in ['1.0', '1.1']:
+    for version in ["1.0", "1.1"]:
         mod, params = tvm.relay.testing.squeezenet.get_workload(1, version=version)
         _verify_results(mod, params, in_data)
 
 
 @pytest.mark.skip("USE_TARGET_ONNX should be ON")
 def test_partition():
-    in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
-    in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
-    in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
-    in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
-    in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
-    in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
-    in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
-    in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
-    in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
-    in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
+    in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
+    in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
+    in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
+    in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
+    in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
+    in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
+    in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
+    in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
+    in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
+    in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")
 
     begin0 = compiler_begin(in_1, "onnx")
     begin1 = compiler_begin(in_2, "onnx")
@@ -147,11 +147,11 @@ def test_partition():
 
     func = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
 
-    target = 'llvm'
+    target = "llvm"
     mod = IRModule.from_expr(func)
     mod = transform.PartitionGraph()(mod)
 
-    with tvm.transform.PassContext(opt_level=3, disabled_pass=['FuseOps']):
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["FuseOps"]):
         graph_json, mod1, params = relay.build(mod, target)
 
     assert mod1.type_key == "metadata"
@@ -161,9 +161,8 @@ def test_partition():
     assert mod1.imported_modules[1].get_source()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_resnet()
     test_squeezenet()
     # test_partition needs USE_TARGET_ONNX to be ON
     test_partition()
-
index c3601c7..fd87a06 100644 (file)
@@ -21,10 +21,11 @@ from tvm.contrib import random
 from tvm import rpc
 import tvm.testing
 
+
 def test_randint():
     m = 10240
     n = 10240
-    A = random.randint(-127, 128, size=(m, n), dtype='int32')
+    A = random.randint(-127, 128, size=(m, n), dtype="int32")
     s = te.create_schedule(A.op)
 
     def verify(target="llvm"):
@@ -42,6 +43,7 @@ def test_randint():
         assert abs(np.mean(na)) < 0.3
         assert np.min(na) == -127
         assert np.max(na) == 127
+
     verify()
 
 
@@ -66,6 +68,7 @@ def test_uniform():
         assert abs(np.mean(na) - 0.5) < 1e-1
         assert abs(np.min(na) - 0.0) < 1e-3
         assert abs(np.max(na) - 1.0) < 1e-3
+
     verify()
 
 
@@ -89,8 +92,10 @@ def test_normal():
         na = a.asnumpy()
         assert abs(np.mean(na) - 3) < 1e-1
         assert abs(np.std(na) - 4) < 1e-2
+
     verify()
 
+
 @tvm.testing.uses_gpu
 def test_random_fill():
     def test_local(ctx, dtype):
@@ -127,15 +132,27 @@ def test_random_fill():
         np_values = value.asnumpy()
         assert np.isfinite(np_values * np_values + np_values).any()
 
-    for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32",
-                  "int64", "uint64", "float16", "float32", "float64"]:
+    for dtype in [
+        "bool",
+        "int8",
+        "uint8",
+        "int16",
+        "uint16",
+        "int32",
+        "int32",
+        "int64",
+        "uint64",
+        "float16",
+        "float32",
+        "float64",
+    ]:
         for _, ctx in tvm.testing.enabled_targets():
             test_local(ctx, dtype)
         test_rpc(dtype)
 
+
 if __name__ == "__main__":
     test_randint()
     test_uniform()
     test_normal()
     test_random_fill()
-
index f5ec5be..670268f 100644 (file)
@@ -19,13 +19,14 @@ from tvm import te
 import numpy as np
 from tvm.contrib import rocblas
 
+
 @tvm.testing.requires_rocm
 def test_matmul_add():
     n = 1024
     l = 128
     m = 235
-    A = te.placeholder((n, l), name='A')
-    B = te.placeholder((l, m), name='B')
+    A = te.placeholder((n, l), name="A")
+    B = te.placeholder((l, m), name="B")
     C = rocblas.matmul(A, B)
     s = te.create_schedule(C.op)
 
@@ -39,8 +40,8 @@ def test_matmul_add():
         b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
         f(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
+        tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
+
     verify()
 
 
index 6cd865e..26d1831 100644 (file)
@@ -22,6 +22,7 @@ import time
 import multiprocessing
 from tvm import rpc
 
+
 def rpc_proxy_check():
     """This is a simple test function for RPC Proxy
 
@@ -35,20 +36,25 @@ def rpc_proxy_check():
 
     try:
         from tvm.rpc import proxy
+
         web_port = 8888
         prox = proxy.Proxy("localhost", web_port=web_port)
+
         def check():
             if not tvm.runtime.enabled("rpc"):
                 return
+
             @tvm.register_func("rpc.test2.addone")
             def addone(x):
                 return x + 1
+
             @tvm.register_func("rpc.test2.strcat")
             def addone(name, x):
                 return "%s:%d" % (name, x)
+
             server = multiprocessing.Process(
-                target=proxy.websocket_proxy_server,
-                args=("ws://localhost:%d/ws" % web_port,"x1"))
+                target=proxy.websocket_proxy_server, args=("ws://localhost:%d/ws" % web_port, "x1")
+            )
             # Need to make sure that the connection start after proxy comes up
             time.sleep(0.1)
             server.deamon = True
@@ -58,10 +64,12 @@ def rpc_proxy_check():
             assert f1(10) == 11
             f2 = client.get_function("rpc.test2.strcat")
             assert f2("abc", 11) == "abc:11"
+
         check()
     except ImportError:
         print("Skipping because tornado is not avaliable...")
 
+
 if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
     rpc_proxy_check()
index 2443c70..83cc52f 100644 (file)
@@ -22,6 +22,7 @@ import time
 import multiprocessing
 from tvm import rpc
 
+
 def check_server_drop():
     """test when server drops"""
     try:
@@ -37,24 +38,17 @@ def check_server_drop():
             base.recvjson(tclient._sock)
 
         tserver = tracker.Tracker("localhost", 8888)
-        tproxy = proxy.Proxy("localhost", 8881,
-                             tracker_addr=("localhost", tserver.port))
+        tproxy = proxy.Proxy("localhost", 8881, tracker_addr=("localhost", tserver.port))
         tclient = rpc.connect_tracker("localhost", tserver.port)
 
         server0 = rpc.Server(
-            "localhost", port=9099,
-            tracker_addr=("localhost", tserver.port),
-            key="abc")
+            "localhost", port=9099, tracker_addr=("localhost", tserver.port), key="abc"
+        )
         server1 = rpc.Server(
-            "localhost", port=9099,
-            tracker_addr=("localhost", tserver.port),
-            key="xyz")
-        server2 = rpc.Server(
-            "localhost", tproxy.port, is_proxy=True,
-            key="xyz")
-        server3 = rpc.Server(
-            "localhost", tproxy.port, is_proxy=True,
-            key="xyz1")
+            "localhost", port=9099, tracker_addr=("localhost", tserver.port), key="xyz"
+        )
+        server2 = rpc.Server("localhost", tproxy.port, is_proxy=True, key="xyz")
+        server3 = rpc.Server("localhost", tproxy.port, is_proxy=True, key="xyz1")
 
         # Fault tolerence to un-handled requested value
         _put(tclient, [TrackerCode.REQUEST, "abc", "", 1])
@@ -71,6 +65,7 @@ def check_server_drop():
                 time.sleep(sleeptime)
                 f1 = remote.get_function("rpc.test2.addone")
                 assert f1(10) == 11
+
             try:
                 tclient.request_and_run("xyz", myfunc, session_timeout=timeout)
             except RuntimeError:
index 9297a32..5660efa 100644 (file)
@@ -18,24 +18,33 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 def test_sort():
     n = 2
     l = 5
     m = 3
-    data = te.placeholder((n, l, m), name='data')
+    data = te.placeholder((n, l, m), name="data")
     sort_num = te.placeholder((n, m), name="sort_num", dtype="int32")
     axis = 1
     is_ascend = False
-    out = te.extern(data.shape, [data, sort_num],
-                     lambda ins, outs: tvm.tir.call_packed(
-                         "tvm.contrib.sort.argsort_nms", ins[0],
-                         ins[1], outs[0], axis, is_ascend),
-                     dtype='int32', name="sort_tensor")
-    input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
-             [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]]
+    out = te.extern(
+        data.shape,
+        [data, sort_num],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend
+        ),
+        dtype="int32",
+        name="sort_tensor",
+    )
+    input = [
+        [[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
+        [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
+    ]
     sort_num_input = [[1, 2, 3], [4, 5, 5]]
-    sorted_index = [[[0, 1, 1], [1, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]],
-                    [[3, 4, 4], [2, 3, 3], [1, 2, 2], [0, 1, 1], [4, 0, 0]]]
+    sorted_index = [
+        [[0, 1, 1], [1, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]],
+        [[3, 4, 4], [2, 3, 3], [1, 2, 2], [0, 1, 1], [4, 0, 0]],
+    ]
 
     ctx = tvm.cpu(0)
     target = "llvm"
@@ -47,18 +56,23 @@ def test_sort():
     f(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5)
 
+
 def test_sort_np():
     dshape = (1, 2, 3, 4, 5, 6)
     axis = 4
     reduced_shape = (1, 2, 3, 4, 6)
     is_ascend = True
-    data = te.placeholder(dshape, name='data')
+    data = te.placeholder(dshape, name="data")
     sort_num = te.placeholder(reduced_shape, name="sort_num", dtype="int32")
-    out = te.extern(data.shape, [data, sort_num],
-                     lambda ins, outs: tvm.tir.call_packed(
-                         "tvm.contrib.sort.argsort_nms", ins[0],
-                         ins[1], outs[0], axis, is_ascend),
-                     dtype='int32', name="sort_tensor")
+    out = te.extern(
+        data.shape,
+        [data, sort_num],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend
+        ),
+        dtype="int32",
+        name="sort_tensor",
+    )
 
     ctx = tvm.cpu(0)
     target = "llvm"
@@ -74,6 +88,7 @@ def test_sort_np():
     f(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5)
 
+
 if __name__ == "__main__":
     test_sort()
     test_sort_np()
index 5e0ca5c..9cea49b 100644 (file)
@@ -21,48 +21,50 @@ import tvm.runtime.ndarray as _nd
 import numpy as np
 from collections import namedtuple
 
+
 def test_static_tensor():
-    dtype = 'float32'
-    stype = 'csr'
-    target = 'llvm'
+    dtype = "float32"
+    stype = "csr"
+    target = "llvm"
     ctx = tvm.context(target, 0)
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype)
-    assert(A.stype == 'csr')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = tvmsp.placeholder(shape=(m, n), name="A", dtype=dtype)
+    assert A.stype == "csr"
     n = 3
-    a = np.maximum(np.random.uniform(size=(n,n)).astype(dtype)-.6, 0.)
+    a = np.maximum(np.random.uniform(size=(n, n)).astype(dtype) - 0.6, 0.0)
     a = tvmsp.array(a, ctx)
-    A.data = te.placeholder(a.data.shape, dtype, name='A_data')
-    Ab = tvm.tir.decl_buffer(a.data.shape, dtype, name='A_data')
+    A.data = te.placeholder(a.data.shape, dtype, name="A_data")
+    Ab = tvm.tir.decl_buffer(a.data.shape, dtype, name="A_data")
     binds = {A.data: Ab}
-    C = te.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    C = te.compute(A.data.shape, lambda i: A.data[i] * 2.0, tag="cs_scatter")
     s = te.create_schedule(C.op)
     f = tvm.build(s, [A.data, C], target, binds=binds)
-    c = tvmsp.array(np.zeros((n,n), dtype), ctx)
+    c = tvmsp.array(np.zeros((n, n), dtype), ctx)
     c.data = tvm.nd.empty(a.data.shape, dtype)
     c.indices = a.indices
     c.indptr = a.indptr
     f(a.data, c.data)
-    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2.0, rtol=1e-5)
+
 
 def test_dynamic_tensor():
-    dtype = 'float32'
-    stype = 'csr'
-    target = 'llvm'
+    dtype = "float32"
+    stype = "csr"
+    target = "llvm"
     ctx = tvm.context(target, 0)
-    nr, nc, n = te.size_var('nr'), te.size_var('nc'), te.size_var('n')
-    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
-    assert(A.stype == 'csr')
-    C = te.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    nr, nc, n = te.size_var("nr"), te.size_var("nc"), te.size_var("n")
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name="A", dtype=dtype)
+    assert A.stype == "csr"
+    C = te.compute(A.data.shape, lambda i: A.data[i] * 2.0, tag="cs_scatter")
     s = te.create_schedule(C.op)
     _nr, _nc = 3, 5
-    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
+    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype) - 0.6, 0.0)
     a = tvmsp.array(a, ctx)
     assert a.data.dtype == a.dtype
-    Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
-    Ab.data = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
-    Ab.indices = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
+    Ab = namedtuple("CSRBuffer", ["data", "indices", "indptr"])
+    Ab.data = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name="A_data")
+    Ab.indices = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name="A_indices")
     binds = {A.data: Ab.data, A.indices: Ab.indices}
     f = tvm.build(s, [nr, A.data, C], target, binds=binds)
     c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
@@ -70,20 +72,21 @@ def test_dynamic_tensor():
     c.indices = a.indices
     c.indptr = a.indptr
     f(a.data.shape[0], a.data, c.data)
-    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2.0, rtol=1e-5)
+
 
 def test_sparse_array_tuple():
-    dtype, itype = 'float32', 'int32'
-    stype = 'csr'
-    target = 'llvm'
+    dtype, itype = "float32", "int32"
+    stype = "csr"
+    target = "llvm"
     ctx = tvm.context(target, 0)
-    nr, nc, n = te.size_var('nr'), te.size_var('nc'), te.size_var('n')
-    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
-    assert(A.stype == 'csr')
-    C = te.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    nr, nc, n = te.size_var("nr"), te.size_var("nc"), te.size_var("n")
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name="A", dtype=dtype)
+    assert A.stype == "csr"
+    C = te.compute(A.data.shape, lambda i: A.data[i] * 2.0, tag="cs_scatter")
     s = te.create_schedule(C.op)
     _nr, _nc = 3, 5
-    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
+    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype) - 0.6, 0.0)
     # convert to sparse array tuple
     source_array = a
     ridx, cidx = np.nonzero(source_array)
@@ -91,16 +94,16 @@ def test_sparse_array_tuple():
     a_data = _nd.array(data, ctx)
     indices = np.nonzero(source_array)[1].astype(itype)
     a_indices = _nd.array(indices, ctx)
-    indptr = [0]+np.apply_along_axis(np.count_nonzero, axis=1, arr=source_array).tolist()
+    indptr = [0] + np.apply_along_axis(np.count_nonzero, axis=1, arr=source_array).tolist()
     indptr = np.cumsum(np.array(indptr, itype)).astype(itype)
     a_indptr = _nd.array(indptr, ctx)
     a_init = (a_data, a_indices, a_indptr)
     # construct tvm sparse array with tuple
     a = tvmsp.array(a_init, shape=source_array.shape, ctx=ctx)
     assert a.data.dtype == a.dtype
-    Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
-    Ab.data = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
-    Ab.indices = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
+    Ab = namedtuple("CSRBuffer", ["data", "indices", "indptr"])
+    Ab.data = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name="A_data")
+    Ab.indices = tvm.tir.decl_buffer(a.data.shape, a.data.dtype, name="A_indices")
     binds = {A.data: Ab.data, A.indices: Ab.indices}
     f = tvm.build(s, [nr, A.data, C], target, binds=binds)
     c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
@@ -108,10 +111,10 @@ def test_sparse_array_tuple():
     c.indices = a.indices
     c.indptr = a.indptr
     f(a.data.shape[0], a.data, c.data)
-    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2.0, rtol=1e-5)
+
 
 if __name__ == "__main__":
     test_static_tensor()
     test_dynamic_tensor()
     test_sparse_array_tuple()
-
index c6c480e..03e063e 100644 (file)
@@ -22,23 +22,25 @@ from tvm import topi
 
 def findany(pattern, str):
     matches = re.findall(pattern, str)
-    assert (len(matches) >
-            0), 'Pattern not found.\nPattern: ' + pattern + '\nString:  ' + str
+    assert len(matches) > 0, "Pattern not found.\nPattern: " + pattern + "\nString:  " + str
 
 
 def checkdepdency():
     import pkg_resources
-    return not {'graphviz', 'ipython'} - {pkg.key for pkg in pkg_resources.working_set}
+
+    return not {"graphviz", "ipython"} - {pkg.key for pkg in pkg_resources.working_set}
+
 
 def test_dfg():
-    A = te.placeholder((1024, 4096), dtype='float32', name='A')
+    A = te.placeholder((1024, 4096), dtype="float32", name="A")
     B = topi.nn.softmax(A)
     # confirm lower works
     s = te.create_schedule([B.op])
 
     def verify():
         from tvm.contrib import tedd
-        str = tedd.viz_dataflow_graph(s, False, '', True)
+
+        str = tedd.viz_dataflow_graph(s, False, "", True)
         # Check all edges are available
         findany(r"digraph \"Dataflow Graph\"", str)
         findany(r"Stage_0:O_0 -> Tensor_0_0", str)
@@ -52,6 +54,7 @@ def test_dfg():
         findany(r"Tensor_2_0 -> Stage_4:I_0", str)
         findany(r"Tensor_3_0 -> Stage_4:I_1", str)
         findany(r"Stage_4:O_0 -> Tensor_4_0", str)
+
     if checkdepdency():
         verify()
 
@@ -59,16 +62,17 @@ def test_dfg():
 def test_itervar_relationship_graph():
     n = te.var("n")
     m = te.var("m")
-    A = te.placeholder((n, m), name='A')
+    A = te.placeholder((n, m), name="A")
     k = te.reduce_axis((0, m), "k")
-    B = te.compute((n, ), lambda i: te.sum(A[i, k], axis=k), name="B")
+    B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
 
     s = te.create_schedule(B.op)
     s[B].split(B.op.reduce_axis[0], factor=16)
 
     def verify():
         from tvm.contrib import tedd
-        str = tedd.viz_itervar_relationship_graph(s, False, '', True)
+
+        str = tedd.viz_itervar_relationship_graph(s, False, "", True)
         findany(r"digraph \"IterVar Relationship Graph\"", str)
         findany(r"subgraph cluster_legend", str)
         # Check subgraphs for stages
@@ -89,19 +93,24 @@ def test_itervar_relationship_graph():
 
 
 def test_schedule_tree():
-    block_x = te.thread_axis('blockIdx.x')
-    thread_x = te.thread_axis('threadIdx.x')
+    block_x = te.thread_axis("blockIdx.x")
+    thread_x = te.thread_axis("threadIdx.x")
     n = te.var("n")
     m = te.var("m")
     l = te.var("l")
-    A = te.placeholder((n, m, l), name='A')
-    B = te.compute((n, m, l), lambda bi, bj, bk: A[bi, bj, bk] + 1, name='B')
+    A = te.placeholder((n, m, l), name="A")
+    B = te.compute((n, m, l), lambda bi, bj, bk: A[bi, bj, bk] + 1, name="B")
     r = te.reduce_axis((0, m), "r")
-    C = te.compute((n, m,),
-                   lambda ci, cj: te.sum(B[ci, cj, r], axis=r),
-                   name="C")
+    C = te.compute(
+        (
+            n,
+            m,
+        ),
+        lambda ci, cj: te.sum(B[ci, cj, r], axis=r),
+        name="C",
+    )
     s = te.create_schedule(C.op)
-    s.cache_read(A, 'shared', [B])
+    s.cache_read(A, "shared", [B])
     s[B].vectorize(B.op.axis[-1])
     s[C].reorder(C.op.reduce_axis[0], C.op.axis[0])
     _, ki = s[C].split(C.op.reduce_axis[0], factor=16)
@@ -112,14 +121,18 @@ def test_schedule_tree():
 
     def verify():
         from tvm.contrib import tedd
-        str = tedd.viz_schedule_tree(s, False, '', True)
+
+        str = tedd.viz_schedule_tree(s, False, "", True)
         findany(r"digraph \"Schedule Tree\"", str)
         findany(r"subgraph cluster_legend", str)
         # Check the A_shared stage, including memory scope, itervars,
         # and compute
-        findany(r"Stage_1.*A\.shared<br/>Scope: shared.+>0.+>" \
-            r"ax0\(kDataPar\).+>1.+ax1\(kDataPar\).+>2.+>ax2\(kDataPar\).+>" \
-            r"\[A\(ax0, ax1, ax2\)\]", str)
+        findany(
+            r"Stage_1.*A\.shared<br/>Scope: shared.+>0.+>"
+            r"ax0\(kDataPar\).+>1.+ax1\(kDataPar\).+>2.+>ax2\(kDataPar\).+>"
+            r"\[A\(ax0, ax1, ax2\)\]",
+            str,
+        )
         # Check itervars of types different from KDataPar
         findany(r"bk\(kVectorized\)", str)
         findany(r"r.outer\(kCommReduce\)", str)
index 1b911b7..c24747d 100644 (file)
@@ -34,21 +34,26 @@ def _create_tflite_model():
     try:
         import tensorflow as tf
     except ImportError:
-        print('skip because tensorflow not installed...')
+        print("skip because tensorflow not installed...")
         return
 
     root = tf.Module()
-    root.const = tf.constant([1., 2.], tf.float32)
+    root.const = tf.constant([1.0, 2.0], tf.float32)
     root.f = tf.function(lambda x: root.const * x)
 
-    input_signature = tf.TensorSpec(shape=[2,  ], dtype=tf.float32)
+    input_signature = tf.TensorSpec(
+        shape=[
+            2,
+        ],
+        dtype=tf.float32,
+    )
     concrete_func = root.f.get_concrete_function(input_signature)
     converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
     tflite_model = converter.convert()
     return tflite_model
 
 
-@pytest.mark.skip('skip because accessing output tensor is flakey')
+@pytest.mark.skip("skip because accessing output tensor is flakey")
 def test_local():
     if not tvm.runtime.enabled("tflite"):
         print("skip because tflite runtime is not enabled...")
@@ -60,14 +65,14 @@ def test_local():
     try:
         import tensorflow as tf
     except ImportError:
-        print('skip because tensorflow not installed...')
+        print("skip because tensorflow not installed...")
         return
 
     tflite_fname = "model.tflite"
     tflite_model = _create_tflite_model()
     temp = util.tempdir()
     tflite_model_path = temp.relpath(tflite_fname)
-    open(tflite_model_path, 'wb').write(tflite_model)
+    open(tflite_model_path, "wb").write(tflite_model)
 
     # inference via tflite interpreter python apis
     interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
@@ -75,14 +80,14 @@ def test_local():
     input_details = interpreter.get_input_details()
     output_details = interpreter.get_output_details()
 
-    input_shape = input_details[0]['shape']
+    input_shape = input_details[0]["shape"]
     tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
-    interpreter.set_tensor(input_details[0]['index'], tflite_input)
+    interpreter.set_tensor(input_details[0]["index"], tflite_input)
     interpreter.invoke()
-    tflite_output = interpreter.get_tensor(output_details[0]['index'])
+    tflite_output = interpreter.get_tensor(output_details[0]["index"])
 
     # inference via tvm tflite runtime
-    with open(tflite_model_path, 'rb') as model_fin:
+    with open(tflite_model_path, "rb") as model_fin:
         runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
         runtime.set_input(0, tvm.nd.array(tflite_input))
         runtime.invoke()
@@ -101,14 +106,14 @@ def test_remote():
     try:
         import tensorflow as tf
     except ImportError:
-        print('skip because tensorflow not installed...')
+        print("skip because tensorflow not installed...")
         return
 
     tflite_fname = "model.tflite"
     tflite_model = _create_tflite_model()
     temp = util.tempdir()
     tflite_model_path = temp.relpath(tflite_fname)
-    open(tflite_model_path, 'wb').write(tflite_model)
+    open(tflite_model_path, "wb").write(tflite_model)
 
     # inference via tflite interpreter python apis
     interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
@@ -116,11 +121,11 @@ def test_remote():
     input_details = interpreter.get_input_details()
     output_details = interpreter.get_output_details()
 
-    input_shape = input_details[0]['shape']
+    input_shape = input_details[0]["shape"]
     tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
-    interpreter.set_tensor(input_details[0]['index'], tflite_input)
+    interpreter.set_tensor(input_details[0]["index"], tflite_input)
     interpreter.invoke()
-    tflite_output = interpreter.get_tensor(output_details[0]['index'])
+    tflite_output = interpreter.get_tensor(output_details[0]["index"])
 
     # inference via remote tvm tflite runtime
     server = rpc.Server("localhost")
@@ -128,7 +133,7 @@ def test_remote():
     ctx = remote.cpu(0)
     a = remote.upload(tflite_model_path)
 
-    with open(tflite_model_path, 'rb') as model_fin:
+    with open(tflite_model_path, "rb") as model_fin:
         runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
         runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
         runtime.invoke()
index 816e651..29c6fbf 100644 (file)
@@ -23,66 +23,65 @@ from tvm.contrib import util
 
 
 def validate_debug_dir_path(temp_dir, expected_basename):
-  dirname, basename = os.path.split(temp_dir.temp_dir)
-  assert basename == expected_basename, 'unexpected basename: %s' % (basename,)
-
-  parent_dir = os.path.basename(dirname)
-  create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S')
-  assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60)
+    dirname, basename = os.path.split(temp_dir.temp_dir)
+    assert basename == expected_basename, "unexpected basename: %s" % (basename,)
 
+    parent_dir = os.path.basename(dirname)
+    create_time = datetime.datetime.strptime(parent_dir.split("___", 1)[0], "%Y-%m-%dT%H-%M-%S")
+    assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60)
 
 
 def test_tempdir():
-  assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"
-
-  temp_dir = util.tempdir()
-  assert os.path.exists(temp_dir.temp_dir)
-
-  old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG
-  old_tempdirs = util.TempDirectory.TEMPDIRS
-  try:
-    for temp_dir_number in range(0, 3):
-      with util.TempDirectory.set_keep_for_debug():
-        debug_temp_dir = util.tempdir()
-        try:
-          validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number))
-        finally:
-          shutil.rmtree(debug_temp_dir.temp_dir)
-
-    with util.TempDirectory.set_keep_for_debug():
-      # Create 2 temp_dir within the same session.
-      debug_temp_dir = util.tempdir()
-      try:
-        validate_debug_dir_path(debug_temp_dir, '00003')
-      finally:
-        shutil.rmtree(debug_temp_dir.temp_dir)
-
-      debug_temp_dir = util.tempdir()
-      try:
-        validate_debug_dir_path(debug_temp_dir, '00004')
-      finally:
-        shutil.rmtree(debug_temp_dir.temp_dir)
-
-      with util.TempDirectory.set_keep_for_debug(False):
-        debug_temp_dir = util.tempdir()  # This one should get deleted.
-
-        # Simulate atexit hook
-        util.TempDirectory.remove_tempdirs()
-
-        # Calling twice should be a no-op.
-        util.TempDirectory.remove_tempdirs()
-
-        # Creating a new TempDirectory should fail now
-        try:
-          util.tempdir()
-          assert False, 'creation should fail'
-        except util.DirectoryCreatedPastAtExit:
-          pass
-
-  finally:
-    util.TempDirectory.DEBUG_MODE = old_debug_mode
-    util.TempDirectory.TEMPDIRS = old_tempdirs
-
-
-if __name__ == '__main__':
-  test_tempdir()
+    assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"
+
+    temp_dir = util.tempdir()
+    assert os.path.exists(temp_dir.temp_dir)
+
+    old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG
+    old_tempdirs = util.TempDirectory.TEMPDIRS
+    try:
+        for temp_dir_number in range(0, 3):
+            with util.TempDirectory.set_keep_for_debug():
+                debug_temp_dir = util.tempdir()
+                try:
+                    validate_debug_dir_path(debug_temp_dir, "0000" + str(temp_dir_number))
+                finally:
+                    shutil.rmtree(debug_temp_dir.temp_dir)
+
+        with util.TempDirectory.set_keep_for_debug():
+            # Create 2 temp_dir within the same session.
+            debug_temp_dir = util.tempdir()
+            try:
+                validate_debug_dir_path(debug_temp_dir, "00003")
+            finally:
+                shutil.rmtree(debug_temp_dir.temp_dir)
+
+            debug_temp_dir = util.tempdir()
+            try:
+                validate_debug_dir_path(debug_temp_dir, "00004")
+            finally:
+                shutil.rmtree(debug_temp_dir.temp_dir)
+
+            with util.TempDirectory.set_keep_for_debug(False):
+                debug_temp_dir = util.tempdir()  # This one should get deleted.
+
+                # Simulate atexit hook
+                util.TempDirectory.remove_tempdirs()
+
+                # Calling twice should be a no-op.
+                util.TempDirectory.remove_tempdirs()
+
+                # Creating a new TempDirectory should fail now
+                try:
+                    util.tempdir()
+                    assert False, "creation should fail"
+                except util.DirectoryCreatedPastAtExit:
+                    pass
+
+    finally:
+        util.TempDirectory.DEBUG_MODE = old_debug_mode
+        util.TempDirectory.TEMPDIRS = old_tempdirs
+
+
+if __name__ == "__main__":
+    test_tempdir()
index 8567e4b..2005090 100644 (file)
@@ -21,9 +21,11 @@ Caffe testcases
 This article is a test script to test Caffe operator with Relay.
 """
 import os
-os.environ['GLOG_minloglevel'] = '2'
+
+os.environ["GLOG_minloglevel"] = "2"
 import sys
 import logging
+
 logging.basicConfig(level=logging.ERROR)
 
 import numpy as np
@@ -37,7 +39,7 @@ from tvm import relay
 from tvm.contrib import util, graph_runtime
 from tvm.contrib.download import download_testdata
 
-CURRENT_DIR = os.path.join(os.path.expanduser('~'), '.tvm_test_data', 'caffe_test')
+CURRENT_DIR = os.path.join(os.path.expanduser("~"), ".tvm_test_data", "caffe_test")
 
 #######################################################################
 # Generic functions for TVM & Caffe
@@ -54,7 +56,7 @@ def _list_to_str(ll):
     """ Convert list or tuple to str, separated by underline. """
     if isinstance(ll, (tuple, list)):
         tmp = [str(i) for i in ll]
-        return '_'.join(tmp)
+        return "_".join(tmp)
 
 
 def _gen_filename_str(op_name, data_shape, *args, **kwargs):
@@ -66,14 +68,14 @@ def _gen_filename_str(op_name, data_shape, *args, **kwargs):
     res += shape_str
     for arg in args:
         if isinstance(arg, (tuple, list)):
-            res += ("_" + _list_to_str(arg))
+            res += "_" + _list_to_str(arg)
         elif isinstance(arg, (int, float, str)):
-            res += ("_" + str(arg))
+            res += "_" + str(arg)
     for _, v in kwargs.items():
         if isinstance(v, (tuple, list)):
-            res += ("_" + _list_to_str(v))
+            res += "_" + _list_to_str(v)
         elif isinstance(v, (int, float, str)):
-            res += ("_" + str(v))
+            res += "_" + str(v)
     res = res.replace(".", "_")
     res = res.replace("-", "_")
     proto_file = os.path.join(file_dir, res + ".prototxt")
@@ -86,7 +88,7 @@ def _gen_filename_str(op_name, data_shape, *args, **kwargs):
 def _save_prototxt(n_netspec, f_path):
     """ Generate .prototxt file according to caffe.NetSpec"""
     s = n_netspec.to_proto()
-    with open(f_path, 'w') as f:
+    with open(f_path, "w") as f:
         f.write(str(s))
 
 
@@ -106,7 +108,7 @@ def _save_solver(solver_file, proto_file, blob_file):
     s.snapshot = 100000
     s.snapshot_prefix = blob_file_prefix
 
-    with open(solver_file, 'w') as f:
+    with open(solver_file, "w") as f:
         f.write(str(s))
 
 
@@ -125,7 +127,7 @@ def _gen_model_files(n_netspec, proto_file, blob_file, solver_file):
 def _siso_op(data, func, *args, **kwargs):
     """ Create single input and single output Caffe op """
     n = caffe.NetSpec()
-    n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}})
+    n.data = L.Input(input_param={"shape": {"dim": list(data.shape)}})
     n.output = func(n.data, *args, **kwargs)
     return n
 
@@ -134,15 +136,11 @@ def _miso_op(data_list, func, *args, **kwargs):
     """ Create multi input and single output Caffe op """
     n = caffe.NetSpec()
     if not isinstance(data_list, (tuple, list)):
-        raise TypeError("Need tuple or list but get {}".format(
-            type(data_list)))
+        raise TypeError("Need tuple or list but get {}".format(type(data_list)))
     input_list = list()
     for idx, data in enumerate(data_list):
-        n['data' +
-          str(idx)] = L.Input(input_param={'shape': {
-              'dim': list(data.shape)
-          }})
-        input_list.append(n['data' + str(idx)])
+        n["data" + str(idx)] = L.Input(input_param={"shape": {"dim": list(data.shape)}})
+        input_list.append(n["data" + str(idx)])
     n.output = func(*input_list, *args, **kwargs)
     return n
 
@@ -150,10 +148,10 @@ def _miso_op(data_list, func, *args, **kwargs):
 def _simo_op(data, func, *args, **kwargs):
     """ Create single input and multi output Caffe op """
     n = caffe.NetSpec()
-    n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}})
+    n.data = L.Input(input_param={"shape": {"dim": list(data.shape)}})
     output_list = func(n.data, *args, **kwargs)
     for idx, out in enumerate(output_list):
-        n['output' + str(idx)] = out
+        n["output" + str(idx)] = out
     return n
 
 
@@ -162,17 +160,17 @@ def _run_caffe(data, proto_file, blob_file):
     net = caffe.Net(proto_file, blob_file, caffe.TEST)
     if isinstance(data, (list, tuple)):
         for idx, d in enumerate(data):
-            net.blobs['data' + str(idx)].data[...] = d
+            net.blobs["data" + str(idx)].data[...] = d
     else:
-        net.blobs['data'].data[...] = data
+        net.blobs["data"].data[...] = data
     out = net.forward()
 
     caffe_output = list()
     for i in range(len(out.keys())):
-        if 'output'+str(i) not in out.keys():
+        if "output" + str(i) not in out.keys():
             caffe_output.clear()
             return list(out.values())
-        caffe_output.append(out['output'+str(i)])
+        caffe_output.append(out["output" + str(i)])
     return caffe_output
 
 
@@ -182,41 +180,37 @@ def _run_tvm(data, proto_file, blob_file):
     predict_net = pb.NetParameter()
 
     # load model
-    with open(proto_file, 'r') as f:
+    with open(proto_file, "r") as f:
         text_format.Merge(f.read(), predict_net)
     # load blob
-    with open(blob_file, 'rb') as f:
+    with open(blob_file, "rb") as f:
         init_net.ParseFromString(f.read())
 
     shape_dict = dict()
     dtype_dict = dict()
     if isinstance(data, (tuple, list)):
         for idx, d in enumerate(data):
-            shape_dict['data' + str(idx)] = d.shape
-            dtype_dict['data' + str(idx)] = 'float32'
+            shape_dict["data" + str(idx)] = d.shape
+            dtype_dict["data" + str(idx)] = "float32"
     else:
-        shape_dict = {'data': data.shape}
-        dtype_dict = {'data': 'float32'}
+        shape_dict = {"data": data.shape}
+        dtype_dict = {"data": "float32"}
 
-    mod, params = relay.frontend.from_caffe(
-        init_net, predict_net, shape_dict, dtype_dict)
+    mod, params = relay.frontend.from_caffe(init_net, predict_net, shape_dict, dtype_dict)
 
-    target = 'llvm'
-    target_host = 'llvm'
+    target = "llvm"
+    target_host = "llvm"
 
     ctx = tvm.cpu(0)
     with tvm.transform.PassContext(opt_level=3):
-        lib = relay.build(mod,
-                          target=target,
-                          target_host=target_host,
-                          params=params)
-    dtype = 'float32'
-    m = graph_runtime.GraphModule(lib['default'](ctx))
+        lib = relay.build(mod, target=target, target_host=target_host, params=params)
+    dtype = "float32"
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     if isinstance(data, (tuple, list)):
         for idx, d in enumerate(data):
-            m.set_input('data' + str(idx), tvm.nd.array(d.astype(dtype)))
+            m.set_input("data" + str(idx), tvm.nd.array(d.astype(dtype)))
     else:
-        m.set_input('data', tvm.nd.array(data.astype(dtype)))
+        m.set_input("data", tvm.nd.array(data.astype(dtype)))
     # execute
     m.run()
     tvm_output = list()
@@ -230,10 +224,7 @@ def _compare_caffe_tvm(caffe_out, tvm_out, is_network=False):
     for i in range(len(caffe_out)):
         if is_network:
             caffe_out[i] = caffe_out[i][:1]
-        tvm.testing.assert_allclose(caffe_out[i],
-                                    tvm_out[i],
-                                    rtol=1e-5,
-                                    atol=1e-5)
+        tvm.testing.assert_allclose(caffe_out[i], tvm_out[i], rtol=1e-5, atol=1e-5)
 
 
 def _test_op(data, func_op, op_name, **kwargs):
@@ -245,8 +236,8 @@ def _test_op(data, func_op, op_name, **kwargs):
             shape_list.extend(list(d.shape))
     else:
         output_num = 1
-        if 'ntop' in kwargs.keys():
-            output_num = kwargs['ntop']
+        if "ntop" in kwargs.keys():
+            output_num = kwargs["ntop"]
         if output_num == 1:
             n = _siso_op(data, func_op, **kwargs)
         else:
@@ -254,8 +245,7 @@ def _test_op(data, func_op, op_name, **kwargs):
         shape_list = list(data.shape)
 
     # obtain the .caffemodel file and .prototxt file
-    (proto_file, blob_file,
-     solver_file) = _gen_filename_str(op_name, shape_list, **kwargs)
+    (proto_file, blob_file, solver_file) = _gen_filename_str(op_name, shape_list, **kwargs)
     _gen_model_files(n, proto_file, blob_file, solver_file)
     # run model in Caffe
     caffe_out = _run_caffe(data, proto_file, blob_file)
@@ -279,11 +269,9 @@ def _test_network(data, proto_file, blob_file):
 
 def _test_batchnorm(data, moving_average_fraction=0.999, eps=1e-5):
     """ One iteration of BatchNorm """
-    _test_op(data,
-             L.BatchNorm,
-             "BatchNorm",
-             moving_average_fraction=moving_average_fraction,
-             eps=eps)
+    _test_op(
+        data, L.BatchNorm, "BatchNorm", moving_average_fraction=moving_average_fraction, eps=eps
+    )
 
 
 def test_forward_BatchNorm():
@@ -305,12 +293,8 @@ def _test_concat(data_list, axis=1):
 
 def test_forward_Concat():
     """ Concat """
-    _test_concat([np.random.rand(1, 3, 10, 10),
-                  np.random.rand(1, 2, 10, 10)],
-                 axis=1)
-    _test_concat([np.random.rand(3, 10, 10),
-                  np.random.rand(2, 10, 10)],
-                 axis=0)
+    _test_concat([np.random.rand(1, 3, 10, 10), np.random.rand(1, 2, 10, 10)], axis=1)
+    _test_concat([np.random.rand(3, 10, 10), np.random.rand(2, 10, 10)], axis=0)
     _test_concat([np.random.rand(3, 10), np.random.rand(2, 10)], axis=0)
 
 
@@ -327,55 +311,65 @@ def _test_convolution(data, **kwargs):
 def test_forward_Convolution():
     """ Convolution """
     data = np.random.rand(1, 3, 10, 10).astype(np.float32)
-    _test_convolution(data,
-                      num_output=20,
-                      bias_term=True,
-                      pad=0,
-                      kernel_size=3,
-                      stride=2,
-                      dilation=1,
-                      weight_filler=dict(type="xavier"),
-                      bias_filler=dict(type="xavier"))
-    _test_convolution(data,
-                      num_output=20,
-                      bias_term=False,
-                      pad=[1, 2],
-                      kernel_size=3,
-                      stride=2,
-                      dilation=1,
-                      weight_filler=dict(type="xavier"),
-                      bias_filler=dict(type="xavier"))
-    _test_convolution(data,
-                      num_output=20,
-                      bias_term=True,
-                      pad=[1, 2],
-                      kernel_size=[3, 5],
-                      stride=[2, 1],
-                      dilation=[1, 2],
-                      weight_filler=dict(type="xavier"),
-                      bias_filler=dict(type="xavier"))
-    _test_convolution(np.random.rand(1, 2, 10, 10).astype(np.float32),
-                      num_output=20,
-                      bias_term=True,
-                      pad=[1, 2],
-                      kernel_size=[3, 5],
-                      stride=[2, 1],
-                      dilation=[1, 2],
-                      weight_filler=dict(type="xavier"),
-                      bias_filler=dict(type="xavier"),
-                      group=2)
-    _test_convolution(data,
-                      num_output=20,
-                      bias_term=True,
-                      pad_h=1,
-                      pad_w=2,
-                      kernel_h=3,
-                      kernel_w=5,
-                      stride_h=2,
-                      stride_w=1,
-                      dilation=[1, 2],
-                      weight_filler=dict(type="xavier"),
-                      bias_filler=dict(type="xavier"))
+    _test_convolution(
+        data,
+        num_output=20,
+        bias_term=True,
+        pad=0,
+        kernel_size=3,
+        stride=2,
+        dilation=1,
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
+    _test_convolution(
+        data,
+        num_output=20,
+        bias_term=False,
+        pad=[1, 2],
+        kernel_size=3,
+        stride=2,
+        dilation=1,
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
+    _test_convolution(
+        data,
+        num_output=20,
+        bias_term=True,
+        pad=[1, 2],
+        kernel_size=[3, 5],
+        stride=[2, 1],
+        dilation=[1, 2],
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
+    _test_convolution(
+        np.random.rand(1, 2, 10, 10).astype(np.float32),
+        num_output=20,
+        bias_term=True,
+        pad=[1, 2],
+        kernel_size=[3, 5],
+        stride=[2, 1],
+        dilation=[1, 2],
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+        group=2,
+    )
+    _test_convolution(
+        data,
+        num_output=20,
+        bias_term=True,
+        pad_h=1,
+        pad_w=2,
+        kernel_h=3,
+        kernel_w=5,
+        stride_h=2,
+        stride_w=1,
+        dilation=[1, 2],
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
 
 
 #######################################################################
@@ -390,36 +384,17 @@ def _test_crop(data, **kwargs):
 
 def test_forward_Crop():
     """ Crop """
+    _test_crop([np.random.rand(10, 10, 120, 120), np.random.rand(10, 5, 50, 60)])
+    _test_crop([np.random.rand(10, 10, 120, 120), np.random.rand(10, 5, 50, 60)], axis=1)
+    _test_crop([np.random.rand(10, 10, 120, 120), np.random.rand(10, 5, 50, 60)], axis=1, offset=2)
     _test_crop(
-        [np.random.rand(10, 10, 120, 120),
-         np.random.rand(10, 5, 50, 60)])
-    _test_crop(
-        [np.random.rand(10, 10, 120, 120),
-         np.random.rand(10, 5, 50, 60)],
-        axis=1)
-    _test_crop(
-        [np.random.rand(10, 10, 120, 120),
-         np.random.rand(10, 5, 50, 60)],
-        axis=1,
-        offset=2)
-    _test_crop(
-        [np.random.rand(10, 10, 120, 120),
-         np.random.rand(10, 5, 50, 60)],
-        axis=1,
-        offset=[1, 2, 4])
+        [np.random.rand(10, 10, 120, 120), np.random.rand(10, 5, 50, 60)], axis=1, offset=[1, 2, 4]
+    )
     _test_crop(
-        [np.random.rand(10, 10, 120, 120),
-         np.random.rand(10, 5, 50, 60)],
-        axis=2,
-        offset=[2, 4])
-    _test_crop([np.random.rand(10, 120, 120),
-                np.random.rand(5, 50, 60)],
-               axis=1,
-               offset=[2, 4])
-    _test_crop([np.random.rand(120, 120),
-                np.random.rand(50, 60)],
-               axis=0,
-               offset=[2, 4])
+        [np.random.rand(10, 10, 120, 120), np.random.rand(10, 5, 50, 60)], axis=2, offset=[2, 4]
+    )
+    _test_crop([np.random.rand(10, 120, 120), np.random.rand(5, 50, 60)], axis=1, offset=[2, 4])
+    _test_crop([np.random.rand(120, 120), np.random.rand(50, 60)], axis=0, offset=[2, 4])
 
 
 #######################################################################
@@ -435,39 +410,48 @@ def _test_deconvolution(data, **kwargs):
 def test_forward_Deconvolution():
     """ Deconvolution """
     data = np.random.rand(1, 16, 32, 32).astype(np.float32)
-    _test_deconvolution(data,
-                        convolution_param=dict(
-                            num_output=20,
-                            bias_term=True,
-                            pad=0,
-                            kernel_size=3,
-                            stride=2,
-                            dilation=1,
-                            weight_filler=dict(type="xavier"),
-                            bias_filler=dict(type="xavier")))
-    _test_deconvolution(data,
-                        convolution_param=dict(
-                            num_output=20,
-                            bias_term=False,
-                            pad=[1, 2],
-                            kernel_size=3,
-                            stride=2,
-                            dilation=1,
-                            weight_filler=dict(type="xavier"),
-                            bias_filler=dict(type="xavier")))
-    _test_deconvolution(data,
-                        convolution_param=dict(
-                            num_output=20,
-                            bias_term=True,
-                            pad_h=1,
-                            pad_w=2,
-                            kernel_h=3,
-                            kernel_w=5,
-                            stride_h=2,
-                            stride_w=1,
-                            dilation=1,
-                            weight_filler=dict(type="xavier"),
-                            bias_filler=dict(type="xavier")))
+    _test_deconvolution(
+        data,
+        convolution_param=dict(
+            num_output=20,
+            bias_term=True,
+            pad=0,
+            kernel_size=3,
+            stride=2,
+            dilation=1,
+            weight_filler=dict(type="xavier"),
+            bias_filler=dict(type="xavier"),
+        ),
+    )
+    _test_deconvolution(
+        data,
+        convolution_param=dict(
+            num_output=20,
+            bias_term=False,
+            pad=[1, 2],
+            kernel_size=3,
+            stride=2,
+            dilation=1,
+            weight_filler=dict(type="xavier"),
+            bias_filler=dict(type="xavier"),
+        ),
+    )
+    _test_deconvolution(
+        data,
+        convolution_param=dict(
+            num_output=20,
+            bias_term=True,
+            pad_h=1,
+            pad_w=2,
+            kernel_h=3,
+            kernel_w=5,
+            stride_h=2,
+            stride_w=1,
+            dilation=1,
+            weight_filler=dict(type="xavier"),
+            bias_filler=dict(type="xavier"),
+        ),
+    )
 
 
 #######################################################################
@@ -499,27 +483,35 @@ def _test_eltwise(data_list, **kwargs):
 
 def test_forward_Eltwise():
     """ Eltwise """
-    _test_eltwise([
-        np.random.rand(1, 3, 10, 11).astype(np.float32),
-        np.random.rand(1, 3, 10, 11).astype(np.float32)
-    ],
-                  operation=0)
-    _test_eltwise([
-        np.random.rand(1, 3, 10, 11).astype(np.float32),
-        np.random.rand(1, 3, 10, 11).astype(np.float32)
-    ],
-                  operation=1)
-    _test_eltwise([
-        np.random.rand(1, 3, 10, 11).astype(np.float32),
-        np.random.rand(1, 3, 10, 11).astype(np.float32)
-    ],
-                  operation=2)
-    _test_eltwise([
-        np.random.rand(1, 3, 10, 11).astype(np.float32),
-        np.random.rand(1, 3, 10, 11).astype(np.float32)
-    ],
-                  operation=1,
-                  coeff=[0.5, 1])
+    _test_eltwise(
+        [
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+        ],
+        operation=0,
+    )
+    _test_eltwise(
+        [
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+        ],
+        operation=1,
+    )
+    _test_eltwise(
+        [
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+        ],
+        operation=2,
+    )
+    _test_eltwise(
+        [
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+            np.random.rand(1, 3, 10, 11).astype(np.float32),
+        ],
+        operation=1,
+        coeff=[0.5, 1],
+    )
 
 
 #######################################################################
@@ -529,7 +521,7 @@ def test_forward_Eltwise():
 
 def _test_flatten(data, axis=1):
     """ One iteration of Flatten """
-    _test_op(data, L.Flatten, 'Flatten', axis=axis)
+    _test_op(data, L.Flatten, "Flatten", axis=axis)
 
 
 def test_forward_Flatten():
@@ -552,20 +544,21 @@ def _test_inner_product(data, **kwargs):
 def test_forward_InnerProduct():
     """ InnerProduct """
     data = np.random.rand(1, 3, 10, 10)
-    _test_inner_product(data,
-                        num_output=20,
-                        bias_term=False,
-                        weight_filler=dict(type='xavier'))
-    _test_inner_product(data,
-                        num_output=20,
-                        bias_term=True,
-                        weight_filler=dict(type='xavier'),
-                        bias_filler=dict(type='xavier'))
-    _test_inner_product(np.random.rand(20, 10).astype(np.float32),
-                        num_output=30,
-                        bias_term=True,
-                        weight_filler=dict(type='xavier'),
-                        bias_filler=dict(type='xavier'))
+    _test_inner_product(data, num_output=20, bias_term=False, weight_filler=dict(type="xavier"))
+    _test_inner_product(
+        data,
+        num_output=20,
+        bias_term=True,
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
+    _test_inner_product(
+        np.random.rand(20, 10).astype(np.float32),
+        num_output=30,
+        bias_term=True,
+        weight_filler=dict(type="xavier"),
+        bias_filler=dict(type="xavier"),
+    )
 
 
 #######################################################################
@@ -573,15 +566,9 @@ def test_forward_InnerProduct():
 # -----------
 
 
-def _test_lrn(data, local_size=5, alpha=1., beta=0.75, k=1.):
+def _test_lrn(data, local_size=5, alpha=1.0, beta=0.75, k=1.0):
     """ One iteration of LRN """
-    _test_op(data,
-             L.LRN,
-             'LRN',
-             local_size=local_size,
-             alpha=alpha,
-             beta=beta,
-             k=k)
+    _test_op(data, L.LRN, "LRN", local_size=local_size, alpha=alpha, beta=beta, k=k)
 
 
 def test_forward_LRN():
@@ -589,14 +576,14 @@ def test_forward_LRN():
     data = np.random.rand(1, 3, 10, 10).astype(np.float32)
     _test_lrn(data)
     _test_lrn(data, local_size=3)
-    _test_lrn(data, local_size=3, alpha=2.)
+    _test_lrn(data, local_size=3, alpha=2.0)
     _test_lrn(
         data,
         local_size=3,
-        alpha=2.,
+        alpha=2.0,
         beta=0.5,
     )
-    _test_lrn(data, local_size=3, alpha=2., beta=0.5, k=2.)
+    _test_lrn(data, local_size=3, alpha=2.0, beta=0.5, k=2.0)
 
 
 #######################################################################
@@ -614,26 +601,16 @@ def test_forward_Pooling():
     data = np.random.rand(1, 3, 10, 10).astype(np.float32)
     # MAX Pooling
     _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.MAX)
-    _test_pooling(data,
-                  kernel_h=2,
-                  kernel_w=3,
-                  stride_h=2,
-                  stride_w=1,
-                  pad_h=1,
-                  pad_w=2,
-                  pool=P.Pooling.MAX)
+    _test_pooling(
+        data, kernel_h=2, kernel_w=3, stride_h=2, stride_w=1, pad_h=1, pad_w=2, pool=P.Pooling.MAX
+    )
     _test_pooling(data, pool=P.Pooling.MAX, global_pooling=True)
 
     # AVE Pooing
     _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.AVE)
-    _test_pooling(data,
-                  kernel_h=2,
-                  kernel_w=3,
-                  stride_h=2,
-                  stride_w=1,
-                  pad_h=1,
-                  pad_w=2,
-                  pool=P.Pooling.AVE)
+    _test_pooling(
+        data, kernel_h=2, kernel_w=3, stride_h=2, stride_w=1, pad_h=1, pad_w=2, pool=P.Pooling.AVE
+    )
     _test_pooling(data, pool=P.Pooling.AVE, global_pooling=True)
 
 
@@ -650,7 +627,7 @@ def _test_prelu(data, **kwargs):
 def test_forward_PReLU():
     """ PReLU """
     data = np.random.rand(1, 3, 10, 10).astype(np.float32)
-    _test_prelu(data, filler=dict(type='constant', value=0.5))
+    _test_prelu(data, filler=dict(type="constant", value=0.5))
     _test_prelu(data)
     _test_prelu(np.random.rand(10, 20).astype(np.float32))
 
@@ -685,37 +662,17 @@ def _test_reshape(data, **kwargs):
 def test_forward_Reshape():
     """ Reshape """
     data = np.random.rand(1, 8, 6).astype(np.float32)
-    _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}})
-    _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, 3]}})
-    _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, -1]}})
-    _test_reshape(data, reshape_param={'shape': {'dim': [0, -1]}})
-
-    _test_reshape(data, reshape_param={'shape': {'dim': [2, 3]}, 'axis': 2})
-    _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}, 'axis': 1})
-    _test_reshape(data,
-                  reshape_param={
-                      'shape': {
-                          'dim': [4, 3, 4]
-                      },
-                      'axis': -3
-                  })
-
-    _test_reshape(data,
-                  reshape_param={
-                      'shape': {
-                          'dim': [2, 4]
-                      },
-                      'axis': 1,
-                      'num_axes': 1
-                  })
-    _test_reshape(data,
-                  reshape_param={
-                      'shape': {
-                          'dim': [3, 16]
-                      },
-                      'axis': 1,
-                      'num_axes': 2
-                  })
+    _test_reshape(data, reshape_param={"shape": {"dim": [4, 3, 4]}})
+    _test_reshape(data, reshape_param={"shape": {"dim": [2, 0, 3]}})
+    _test_reshape(data, reshape_param={"shape": {"dim": [2, 0, -1]}})
+    _test_reshape(data, reshape_param={"shape": {"dim": [0, -1]}})
+
+    _test_reshape(data, reshape_param={"shape": {"dim": [2, 3]}, "axis": 2})
+    _test_reshape(data, reshape_param={"shape": {"dim": [4, 3, 4]}, "axis": 1})
+    _test_reshape(data, reshape_param={"shape": {"dim": [4, 3, 4]}, "axis": -3})
+
+    _test_reshape(data, reshape_param={"shape": {"dim": [2, 4]}, "axis": 1, "num_axes": 1})
+    _test_reshape(data, reshape_param={"shape": {"dim": [3, 16]}, "axis": 1, "num_axes": 2})
 
 
 #######################################################################
@@ -732,10 +689,7 @@ def test_forward_Scale():
     """ Scale """
     data = np.random.rand(1, 3, 10, 10).astype(np.float32)
     _test_scale(data, filler=dict(type="xavier"))
-    _test_scale(data,
-                filler=dict(type="xavier"),
-                bias_term=True,
-                bias_filler=dict(type="xavier"))
+    _test_scale(data, filler=dict(type="xavier"), bias_term=True, bias_filler=dict(type="xavier"))
 
 
 #######################################################################
@@ -823,14 +777,14 @@ def _test_mobilenetv2(data):
     data_process = data_process / 58.8
     data_process = data_process.astype(np.float32)
 
-    proto_file_url = ("https://github.com/shicai/MobileNet-Caffe/raw/"
-                        "master/mobilenet_v2_deploy.prototxt")
-    blob_file_url = ("https://github.com/shicai/MobileNet-Caffe/blob/"
-                        "master/mobilenet_v2.caffemodel?raw=true")
-    proto_file = download_testdata(proto_file_url, 'mobilenetv2.prototxt',
-                                     module='model')
-    blob_file = download_testdata(blob_file_url, 'mobilenetv2.caffemodel',
-                                     module='model')
+    proto_file_url = (
+        "https://github.com/shicai/MobileNet-Caffe/raw/" "master/mobilenet_v2_deploy.prototxt"
+    )
+    blob_file_url = (
+        "https://github.com/shicai/MobileNet-Caffe/blob/" "master/mobilenet_v2.caffemodel?raw=true"
+    )
+    proto_file = download_testdata(proto_file_url, "mobilenetv2.prototxt", module="model")
+    blob_file = download_testdata(blob_file_url, "mobilenetv2.caffemodel", module="model")
     _test_network(data_process, proto_file, blob_file)
 
 
@@ -853,13 +807,12 @@ def _test_alexnet(data):
     data_process = data - mean_val
     data_process = data_process.astype(np.float32)
 
-    proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models/"
-                        "bvlc_alexnet/deploy.prototxt")
-    blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel'
-    proto_file = download_testdata(proto_file_url, 'alexnet.prototxt',
-                                    module="model")
-    blob_file = download_testdata(blob_file_url, 'alexnet.caffemodel',
-                                    module='model')
+    proto_file_url = (
+        "https://github.com/BVLC/caffe/raw/master/models/" "bvlc_alexnet/deploy.prototxt"
+    )
+    blob_file_url = "http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel"
+    proto_file = download_testdata(proto_file_url, "alexnet.prototxt", module="model")
+    blob_file = download_testdata(blob_file_url, "alexnet.caffemodel", module="model")
     _test_network(data_process, proto_file, blob_file)
 
 
@@ -881,16 +834,16 @@ def _test_resnet50(data):
     mean_val = np.tile(mean_val, (1, 1, 224, 224))
     data_process = data - mean_val
     data_process = data_process.astype(np.float32)
-    
-    proto_file_url = ("https://github.com/fernchen/CaffeModels/raw/"
-                        "master/resnet/ResNet-50-deploy.prototxt")
-    blob_file_url = ("https://github.com/fernchen/CaffeModels/raw/"
-                       "master/resnet/ResNet-50-model.caffemodel")
 
-    proto_file = download_testdata(proto_file_url, 'resnet50.prototxt',
-                                    module="model")
-    blob_file = download_testdata(blob_file_url, 'resnet50.caffemodel',
-                                    module='model')
+    proto_file_url = (
+        "https://github.com/fernchen/CaffeModels/raw/" "master/resnet/ResNet-50-deploy.prototxt"
+    )
+    blob_file_url = (
+        "https://github.com/fernchen/CaffeModels/raw/" "master/resnet/ResNet-50-model.caffemodel"
+    )
+
+    proto_file = download_testdata(proto_file_url, "resnet50.prototxt", module="model")
+    blob_file = download_testdata(blob_file_url, "resnet50.caffemodel", module="model")
 
     _test_network(data_process, proto_file, blob_file)
 
@@ -915,13 +868,12 @@ def _test_inceptionv1(data):
     data_process = data_process / 58.8
     data_process = data_process.astype(np.float32)
 
-    proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models"
-                        "/bvlc_googlenet/deploy.prototxt")
-    blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel'
-    proto_file = download_testdata(proto_file_url, 'inceptionv1.prototxt',
-                                    module="model")
-    blob_file = download_testdata(blob_file_url, 'inceptionv1.caffemodel',
-                                    module='model')
+    proto_file_url = (
+        "https://github.com/BVLC/caffe/raw/master/models" "/bvlc_googlenet/deploy.prototxt"
+    )
+    blob_file_url = "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel"
+    proto_file = download_testdata(proto_file_url, "inceptionv1.prototxt", module="model")
+    blob_file = download_testdata(blob_file_url, "inceptionv1.caffemodel", module="model")
     _test_network(data_process, proto_file, blob_file)
 
 
index ed34819..0d29358 100644 (file)
@@ -24,22 +24,24 @@ from . import squeezenet
 from caffe2.python.models.download import ModelDownloader
 
 models = [
-    'squeezenet',
-    'resnet50',
-    'vgg19',
+    "squeezenet",
+    "resnet50",
+    "vgg19",
 ]
 
 mf = ModelDownloader()
 
+
 class Model:
     def __init__(self, model_name):
         self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
 
+
 for model in models:
     try:
-        locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
+        locals()["c2_" + model] = importlib.import_module("caffe2.python.models." + model)
     except ImportError:
-        locals()['c2_' + model] = Model(model)
+        locals()["c2_" + model] = Model(model)
 
 # squeezenet
 def relay_squeezenet():
index 3c21138..5777656 100644 (file)
@@ -41,10 +41,13 @@ def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, pr
 
 
 def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""):
-    net = relay.nn.conv2d(net, relay.var("%s_weight" % prefix),
-                        channels=channels,
-                        kernel_size=(kernel_size, kernel_size),
-                        padding=(padding, padding))
+    net = relay.nn.conv2d(
+        net,
+        relay.var("%s_weight" % prefix),
+        channels=channels,
+        kernel_size=(kernel_size, kernel_size),
+        padding=(padding, padding),
+    )
     net = relay.nn.bias_add(net, relay.var("%s_bias" % prefix))
     net = relay.nn.relu(net)
     return net
@@ -71,15 +74,18 @@ def get_net(batch_size, image_shape, num_classes, dtype):
     """
     data_shape = (batch_size,) + image_shape
     net = relay.var("data", shape=data_shape, dtype=dtype)
-    net = relay.nn.conv2d(net, relay.var("conv1_weight"),
-                        channels=64,
-                        kernel_size=(3, 3),
-                        strides=(2, 2),
-                        padding=(0, 0))
+    net = relay.nn.conv2d(
+        net,
+        relay.var("conv1_weight"),
+        channels=64,
+        kernel_size=(3, 3),
+        strides=(2, 2),
+        padding=(0, 0),
+    )
     net = relay.nn.bias_add(net, relay.var("conv1_bias"))
     net = relay.nn.relu(net)
     net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
-    net = _make_fire(net, 16, 64, 64, 'fire2')
+    net = _make_fire(net, 16, 64, 64, "fire2")
     net = _make_fire(net, 16, 64, 64, "fire3")
     net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
     net = _make_fire(net, 32, 128, 128, "fire4")
@@ -90,7 +96,7 @@ def get_net(batch_size, image_shape, num_classes, dtype):
     net = _make_fire(net, 64, 256, 256, "fire8")
     net = _make_fire(net, 64, 256, 256, "fire9")
     net = relay.nn.dropout(net, rate=0.5)
-    net = relay.nn.conv2d(net, relay.var('conv10_weight'), channels=num_classes, kernel_size=(1, 1))
+    net = relay.nn.conv2d(net, relay.var("conv10_weight"), channels=num_classes, kernel_size=(1, 1))
     net = relay.nn.bias_add(net, relay.var("conv10_bias"))
     net = relay.nn.relu(net)
     net = relay.nn.global_avg_pool2d(net)
@@ -99,10 +105,7 @@ def get_net(batch_size, image_shape, num_classes, dtype):
     return relay.Function(args, net)
 
 
-def get_workload(batch_size=1,
-                 image_shape=(3, 224, 224),
-                 num_classes=1000,
-                 dtype="float32"):
+def get_workload(batch_size=1, image_shape=(3, 224, 224), num_classes=1000, dtype="float32"):
     """Get benchmark workload for SqueezeNet
 
     Parameters
index 84d03d9..75f9371 100644 (file)
@@ -26,12 +26,7 @@ from collections import namedtuple
 import tvm.testing
 
 
-def get_tvm_output(model,
-                   input_data,
-                   target,
-                   ctx,
-                   output_shape,
-                   output_dtype='float32'):
+def get_tvm_output(model, input_data, target, ctx, output_shape, output_dtype="float32"):
     """ Generic function to execute and get tvm output"""
     # supporting multiple inputs in caffe2 in a bit tricky,
     # because the input names can appear at the beginning or end of model.predict_net.external_input
@@ -42,7 +37,8 @@ def get_tvm_output(model,
     shape_dict = {input_names: input_data.shape}
     dtype_dict = {input_names: input_data.dtype}
     mod, params = relay.frontend.from_caffe2(
-        model.init_net, model.predict_net, shape_dict, dtype_dict)
+        model.init_net, model.predict_net, shape_dict, dtype_dict
+    )
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, params = relay.build(mod, target, params=params)
 
@@ -63,12 +59,11 @@ def get_tvm_output(model,
             tvm_output_list.append(tvm_output.asnumpy())
         return tvm_output_list
     else:
-        tvm_output = m.get_output(0, tvm.nd.empty((output_shape),
-                                                  output_dtype))
+        tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
         return tvm_output.asnumpy()
 
 
-def get_caffe2_output(model, x, dtype='float32'):
+def get_caffe2_output(model, x, dtype="float32"):
     workspace.RunNetOnce(model.init_net)
 
     input_blob = model.predict_net.op[0].input[0]
@@ -81,7 +76,7 @@ def get_caffe2_output(model, x, dtype='float32'):
 
 
 def verify_caffe2_forward_impl(model, data_shape, out_shape):
-    dtype = 'float32'
+    dtype = "float32"
     data = np.random.uniform(size=data_shape).astype(dtype)
     c2_out = get_caffe2_output(model, data, dtype)
     for target, ctx in tvm.testing.enabled_targets():
@@ -104,43 +99,47 @@ def test_forward_vgg19():
     verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000))
 
 
-Model = namedtuple('Model', ['init_net', 'predict_net'])
+Model = namedtuple("Model", ["init_net", "predict_net"])
 
 
 @tvm.testing.uses_gpu
 def test_elementwise_add():
     data_shape = (1, 16, 9, 9)
     init_net = caffe2_pb2.NetDef()
-    init_net.name = 'test_init_net'
-    init_net.external_output[:] = ['A', 'B']
-    init_net.op.extend([
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['A'],
-            shape=data_shape,
-            values=np.random.uniform(size=data_shape).flatten().tolist(),
-        ),
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['B'],
-            shape=data_shape,
-            values=np.random.uniform(size=data_shape).flatten().tolist(),
-        ),
-    ])
+    init_net.name = "test_init_net"
+    init_net.external_output[:] = ["A", "B"]
+    init_net.op.extend(
+        [
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["A"],
+                shape=data_shape,
+                values=np.random.uniform(size=data_shape).flatten().tolist(),
+            ),
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["B"],
+                shape=data_shape,
+                values=np.random.uniform(size=data_shape).flatten().tolist(),
+            ),
+        ]
+    )
 
     predict_net = caffe2_pb2.NetDef()
-    predict_net.name = 'test_predict_net'
-    predict_net.external_input[:] = ['A', 'B']
-    predict_net.external_output[:] = ['C']
-    predict_net.op.extend([
-        core.CreateOperator(
-            'Add',
-            ['A', 'B'],
-            ['C'],
-        )
-    ])
+    predict_net.name = "test_predict_net"
+    predict_net.external_input[:] = ["A", "B"]
+    predict_net.external_output[:] = ["C"]
+    predict_net.op.extend(
+        [
+            core.CreateOperator(
+                "Add",
+                ["A", "B"],
+                ["C"],
+            )
+        ]
+    )
 
     model = Model(init_net, predict_net)
     verify_caffe2_forward_impl(model, data_shape, data_shape)
@@ -150,37 +149,41 @@ def test_elementwise_add():
 def test_elementwise_add_with_broadcast():
     data_shape = (1, 16, 9, 9)
     init_net = caffe2_pb2.NetDef()
-    init_net.name = 'test_init_net'
-    init_net.external_output[:] = ['A', 'B']
-    init_net.op.extend([
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['A'],
-            shape=data_shape,
-            values=np.random.uniform(size=data_shape).flatten().tolist(),
-        ),
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['B'],
-            shape=(1,),
-            values=np.random.uniform(size=1).flatten().tolist(),
-        ),
-    ])
+    init_net.name = "test_init_net"
+    init_net.external_output[:] = ["A", "B"]
+    init_net.op.extend(
+        [
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["A"],
+                shape=data_shape,
+                values=np.random.uniform(size=data_shape).flatten().tolist(),
+            ),
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["B"],
+                shape=(1,),
+                values=np.random.uniform(size=1).flatten().tolist(),
+            ),
+        ]
+    )
 
     predict_net = caffe2_pb2.NetDef()
-    predict_net.name = 'test_predict_net'
-    predict_net.external_input[:] = ['A', 'B']
-    predict_net.external_output[:] = ['C']
-    predict_net.op.extend([
-        core.CreateOperator(
-            'Add',
-            ['A', 'B'],
-            ['C'],
-            broadcast=1,
-        )
-    ])
+    predict_net.name = "test_predict_net"
+    predict_net.external_input[:] = ["A", "B"]
+    predict_net.external_output[:] = ["C"]
+    predict_net.op.extend(
+        [
+            core.CreateOperator(
+                "Add",
+                ["A", "B"],
+                ["C"],
+                broadcast=1,
+            )
+        ]
+    )
 
     model = Model(init_net, predict_net)
     verify_caffe2_forward_impl(model, data_shape, data_shape)
@@ -190,49 +193,59 @@ def test_elementwise_add_with_broadcast():
 def test_normalize_yuv():
     data_shape = (1, 3, 96, 96)
     init_net = caffe2_pb2.NetDef()
-    init_net.name = 'test_init_net'
-    init_net.external_output[:] = ['A', 'mean', 'std']
-    init_net.op.extend([
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['A'],
-            shape=data_shape,
-            values=np.random.uniform(size=data_shape).flatten().tolist(),
-        ),
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['mean'],
-            shape=(1, 3,),
-            values=np.random.uniform(size=3).flatten().tolist(),
-        ),
-        core.CreateOperator(
-            'GivenTensorFill',
-            [],
-            ['std'],
-            shape=(1, 3,),
-            values=np.random.uniform(size=3).flatten().tolist(),
-        ),
-    ])
+    init_net.name = "test_init_net"
+    init_net.external_output[:] = ["A", "mean", "std"]
+    init_net.op.extend(
+        [
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["A"],
+                shape=data_shape,
+                values=np.random.uniform(size=data_shape).flatten().tolist(),
+            ),
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["mean"],
+                shape=(
+                    1,
+                    3,
+                ),
+                values=np.random.uniform(size=3).flatten().tolist(),
+            ),
+            core.CreateOperator(
+                "GivenTensorFill",
+                [],
+                ["std"],
+                shape=(
+                    1,
+                    3,
+                ),
+                values=np.random.uniform(size=3).flatten().tolist(),
+            ),
+        ]
+    )
 
     predict_net = caffe2_pb2.NetDef()
-    predict_net.name = 'test_predict_net'
-    predict_net.external_input[:] = ['A', 'mean', 'std']
-    predict_net.external_output[:] = ['C']
-    predict_net.op.extend([
-        core.CreateOperator(
-            'NormalizePlanarYUV',
-            ['A', 'mean', 'std'],
-            ['C'],
-        )
-    ])
+    predict_net.name = "test_predict_net"
+    predict_net.external_input[:] = ["A", "mean", "std"]
+    predict_net.external_output[:] = ["C"]
+    predict_net.op.extend(
+        [
+            core.CreateOperator(
+                "NormalizePlanarYUV",
+                ["A", "mean", "std"],
+                ["C"],
+            )
+        ]
+    )
 
     model = Model(init_net, predict_net)
     verify_caffe2_forward_impl(model, data_shape, data_shape)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_forward_squeezenet1_1()
     test_forward_resnet50()
     test_forward_vgg19()
index d64b133..232029c 100644 (file)
@@ -28,13 +28,14 @@ def compare_graph(lhs_mod, rhs_mod):
 
 
 def test_squeeze_net():
-    shape_dict = {'data': (1, 3, 224, 224)}
-    dtype_dict = {'data': 'float32'}
+    shape_dict = {"data": (1, 3, 224, 224)}
+    dtype_dict = {"data": "float32"}
     mod, _, = relay.frontend.from_caffe2(
-        c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
+        c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict
+    )
     relay_mod, _ = relay_squeezenet()
     compare_graph(mod, relay_mod)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_squeeze_net()
index b8cee30..bf270b8 100644 (file)
@@ -20,24 +20,27 @@ from PIL import Image
 import numpy as np
 from tvm.contrib.download import download_testdata
 
+
 def get_mobilenet():
-    url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
-    dst = 'mobilenet.mlmodel'
-    real_dst = download_testdata(url, dst, module='coreml')
+    url = "https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel"
+    dst = "mobilenet.mlmodel"
+    real_dst = download_testdata(url, dst, module="coreml")
     return os.path.abspath(real_dst)
 
+
 def get_resnet50():
-    url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
-    dst = 'resnet50.mlmodel'
-    real_dst = download_testdata(url, dst, module='coreml')
+    url = "https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel"
+    dst = "resnet50.mlmodel"
+    real_dst = download_testdata(url, dst, module="coreml")
     return os.path.abspath(real_dst)
 
+
 def get_cat_image():
-    url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
-    dst = 'cat.png'
-    real_dst = download_testdata(url, dst, module='data')
+    url = "https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png"
+    dst = "cat.png"
+    real_dst = download_testdata(url, dst, module="data")
     img = Image.open(real_dst).resize((224, 224))
     # CoreML's standard model image format is BGR
     img_bgr = np.array(img)[:, :, ::-1]
     img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
-    return np.asarray(img)
\ No newline at end of file
+    return np.asarray(img)
index d3a31fe..d808469 100644 (file)
@@ -31,8 +31,10 @@ import coremltools as cm
 import model_zoo
 import tvm.testing
 
-def get_tvm_output(func, x, params, target, ctx,
-                   out_shape=(1, 1000), input_name='image', dtype='float32'):
+
+def get_tvm_output(
+    func, x, params, target, ctx, out_shape=(1, 1000), input_name="image", dtype="float32"
+):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, params = relay.build(func, target, params=params)
     m = graph_runtime.create(graph, lib, ctx)
@@ -44,28 +46,34 @@ def get_tvm_output(func, x, params, target, ctx,
     out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
     return out.asnumpy()
 
-def run_model_checkonly(model_file, model_name='', input_name='image'):
+
+def run_model_checkonly(model_file, model_name="", input_name="image"):
     model = cm.models.MLModel(model_file)
     x = model_zoo.get_cat_image()
-    shape_dict = {input_name : x.shape}
+    shape_dict = {input_name: x.shape}
     # Some Relay passes change operators on the fly. Ensuring that we generate
     # new graph for each target.
     for target, ctx in tvm.testing.enabled_targets():
         mod, params = relay.frontend.from_coreml(model, shape_dict)
         tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
-        print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))
+        print(target, ctx, model_name, "prediction id: ", np.argmax(tvm_output.flat))
+
 
 @tvm.testing.uses_gpu
 def test_mobilenet_checkonly():
     model_file = model_zoo.get_mobilenet()
-    run_model_checkonly(model_file, 'mobilenet')
+    run_model_checkonly(model_file, "mobilenet")
+
 
 @tvm.testing.uses_gpu
 def test_resnet50_checkonly():
     model_file = model_zoo.get_resnet50()
-    run_model_checkonly(model_file, 'resnet50')
+    run_model_checkonly(model_file, "resnet50")
+
 
-def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shape, output_dtype='float32'):
+def run_tvm_graph(
+    coreml_model, target, ctx, input_data, input_name, output_shape, output_dtype="float32"
+):
     """ Generic function to compile on relay and execute on tvm """
     if isinstance(input_data, list):
         shape_dict = {}
@@ -82,6 +90,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
         graph, lib, params = relay.build(mod, target, params=params)
 
     from tvm.contrib import graph_runtime
+
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
     if isinstance(input_data, list):
@@ -107,115 +116,128 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
             tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
         return tvm_output.asnumpy()
 
+
 def verify_AddLayerParams(input_dim, alpha=2):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
 
     b_np = np.add(a_np1, a_np2) + alpha
-    inputs = [('input1', datatypes.Array(*input_dim)),
-              ('input2', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='Add',
-                            alpha=alpha,
-                            input_names=['input1', 'input2'],
-                            output_name='output',
-                            mode='ADD')
+    builder.add_elementwise(
+        name="Add", alpha=alpha, input_names=["input1", "input2"], output_name="output", mode="ADD"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_AddLayerParams():
     verify_AddLayerParams((1, 2, 2), 0)
     verify_AddLayerParams((1, 2, 2), 1)
     verify_AddLayerParams((1, 3, 3), 2)
 
+
 def verify_MultiplyLayerParams(input_dim, alpha):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
 
     b_np = np.multiply(a_np1, a_np2) * alpha
-    inputs = [('input1', datatypes.Array(*input_dim)),
-              ('input2', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='Mul',
-                            alpha=alpha,
-                            input_names=['input1', 'input2'],
-                            output_name='output',
-                            mode='MULTIPLY')
+    builder.add_elementwise(
+        name="Mul",
+        alpha=alpha,
+        input_names=["input1", "input2"],
+        output_name="output",
+        mode="MULTIPLY",
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_MultiplyLayerParams():
     verify_MultiplyLayerParams((1, 2, 2), 0)
     verify_MultiplyLayerParams((1, 2, 2), 1)
     verify_MultiplyLayerParams((1, 3, 3), 2)
 
+
 def verify_ConcatLayerParams(input1_dim, input2_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input1_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input2_dim).astype(dtype)
 
     b_np = np.concatenate((a_np1, a_np2), axis=1)
-    inputs = [('input1', datatypes.Array(*input1_dim)),
-              ('input2', datatypes.Array(*input2_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(*input1_dim)), ("input2", datatypes.Array(*input2_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='Concate',
-                            input_names=['input1', 'input2'],
-                            output_name='output',
-                            mode='CONCAT')
+    builder.add_elementwise(
+        name="Concate", input_names=["input1", "input2"], output_name="output", mode="CONCAT"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_ConcatLayerParams():
     verify_ConcatLayerParams((1, 1, 2, 2), (1, 2, 2, 2))
     verify_ConcatLayerParams((1, 2, 4, 4), (1, 3, 4, 4))
 
+
 def verify_UpsampleLayerParams(input_dim, scale, mode):
     dtype = "float32"
 
     a_np = np.full(input_dim, 1, dtype=dtype)
-    if mode == 'NN':
+    if mode == "NN":
         b_np = tvm.topi.testing.upsampling_python(a_np, (scale, scale))
     else:
         new_h = input_dim[2] * scale
         new_w = input_dim[3] * scale
-        b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), 'NCHW')
+        b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), "NCHW")
 
-    input = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    input = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(input, output)
-    builder.add_upsample(name='Upsample',
-                         scaling_factor_h=scale,
-                         scaling_factor_w=scale,
-                         mode=mode,
-                         input_name='input',
-                         output_name='output')
+    builder.add_upsample(
+        name="Upsample",
+        scaling_factor_h=scale,
+        scaling_factor_w=scale,
+        mode=mode,
+        input_name="input",
+        output_name="output",
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_UpsampleLayerParams():
-    verify_UpsampleLayerParams((1, 16, 32, 32), 2, 'NN')
-    verify_UpsampleLayerParams((1, 4, 6, 6), 3, 'BILINEAR')
+    verify_UpsampleLayerParams((1, 16, 32, 32), 2, "NN")
+    verify_UpsampleLayerParams((1, 4, 6, 6), 3, "BILINEAR")
+
 
 def verify_l2_normalize(input_dim, eps):
     dtype = "float32"
@@ -223,75 +245,83 @@ def verify_l2_normalize(input_dim, eps):
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     b_np = tvm.topi.testing.l2_normalize_python(a_np, eps, 1)
 
-    input = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    input = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(input, output)
-    builder.add_l2_normalize(name='L2', epsilon=eps, input_name='input', output_name='output')
+    builder.add_l2_normalize(name="L2", epsilon=eps, input_name="input", output_name="output")
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_l2_normalize():
     verify_l2_normalize((1, 3, 20, 20), 0.001)
 
+
 def verify_lrn(input_dim, size, bias, alpha, beta):
     dtype = "float32"
-    axis=1
+    axis = 1
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     b_np = tvm.topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
 
-    input = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    input = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(input, output)
-    builder.add_lrn(name='LRN',
-                    input_name='input',
-                    output_name='output',
-                    alpha=alpha,
-                    beta=beta,
-                    k=bias,
-                    local_size=size)
+    builder.add_lrn(
+        name="LRN",
+        input_name="input",
+        output_name="output",
+        alpha=alpha,
+        beta=beta,
+        k=bias,
+        local_size=size,
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_lrn():
     verify_lrn((1, 3, 10, 20), 3, 1.0, 1.0, 0.5)
 
+
 def verify_average(input_dim1, input_dim2, axis=0):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim1).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim2).astype(dtype)
 
     b_np = np.mean((a_np1, a_np2), axis=axis)
 
-    inputs = [('input1', datatypes.Array(*input_dim1)),
-              ('input2', datatypes.Array(*input_dim2))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(*input_dim1)), ("input2", datatypes.Array(*input_dim2))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='MEAN',
-                            input_names=['input1', 'input2'],
-                            output_name='output',
-                            mode='AVE')
+    builder.add_elementwise(
+        name="MEAN", input_names=["input1", "input2"], output_name="output", mode="AVE"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_average():
     verify_average((1, 3, 20, 20), (1, 3, 20, 20))
     verify_average((3, 20, 20), (1, 3, 20, 20))
     verify_average((20, 20), (1, 3, 20, 20))
 
+
 def verify_max(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
@@ -299,28 +329,38 @@ def verify_max(input_dim):
 
     b_np = np.max((a_np1, a_np2, a_np3), axis=0)
 
-    inputs = [('input1', datatypes.Array(*input_dim)),
-              ('input2', datatypes.Array(*input_dim)),
-              ('input3', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [
+        ("input1", datatypes.Array(*input_dim)),
+        ("input2", datatypes.Array(*input_dim)),
+        ("input3", datatypes.Array(*input_dim)),
+    ]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='Max',
-                            input_names=['input1', 'input2', 'input3'],
-                            output_name='output',
-                            mode='MAX')
+    builder.add_elementwise(
+        name="Max", input_names=["input1", "input2", "input3"], output_name="output", mode="MAX"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2, a_np3],
-                            ['input1', 'input2', 'input3'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model,
+            target,
+            ctx,
+            [a_np1, a_np2, a_np3],
+            ["input1", "input2", "input3"],
+            b_np.shape,
+            dtype,
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_max():
     verify_max((1, 3, 20, 20))
     verify_max((20, 20))
 
+
 def verify_min(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
@@ -328,21 +368,30 @@ def verify_min(input_dim):
 
     b_np = np.min((a_np1, a_np2, a_np3), axis=0)
 
-    inputs = [('input1', datatypes.Array(*input_dim)),
-              ('input2', datatypes.Array(*input_dim)),
-              ('input3', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [
+        ("input1", datatypes.Array(*input_dim)),
+        ("input2", datatypes.Array(*input_dim)),
+        ("input3", datatypes.Array(*input_dim)),
+    ]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_elementwise(name='Min',
-                            input_names=['input1', 'input2', 'input3'],
-                            output_name='output',
-                            mode='MIN')
+    builder.add_elementwise(
+        name="Min", input_names=["input1", "input2", "input3"], output_name="output", mode="MIN"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np1, a_np2, a_np3],
-                            ['input1', 'input2', 'input3'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model,
+            target,
+            ctx,
+            [a_np1, a_np2, a_np3],
+            ["input1", "input2", "input3"],
+            b_np.shape,
+            dtype,
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_min():
     verify_min((1, 3, 20, 20))
@@ -350,174 +399,146 @@ def test_forward_min():
 
 
 def verify_unary_sqrt(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = np.sqrt(a_np)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="sqrt",
-                      input_name='input',
-                      output_name='output',
-                      mode='sqrt')
+    builder.add_unary(name="sqrt", input_name="input", output_name="output", mode="sqrt")
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_rsqrt(input_dim, epsilon=0):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = 1 / np.sqrt(a_np + epsilon)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="rsqrt",
-                      input_name='input',
-                      output_name='output',
-                      mode='rsqrt',
-                      epsilon=epsilon)
+    builder.add_unary(
+        name="rsqrt", input_name="input", output_name="output", mode="rsqrt", epsilon=epsilon
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_inverse(input_dim, epsilon=0):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = 1 / (a_np + epsilon)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="inverse",
-                      input_name='input',
-                      output_name='output',
-                      mode='inverse',
-                      epsilon=epsilon)
+    builder.add_unary(
+        name="inverse", input_name="input", output_name="output", mode="inverse", epsilon=epsilon
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_power(input_dim, alpha):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = np.power(a_np, alpha)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="power",
-                      input_name='input',
-                      output_name='output',
-                      mode='power',
-                      alpha=alpha)
+    builder.add_unary(
+        name="power", input_name="input", output_name="output", mode="power", alpha=alpha
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_exp(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = np.exp(a_np)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="exp",
-                      input_name='input',
-                      output_name='output',
-                      mode='exp')
+    builder.add_unary(name="exp", input_name="input", output_name="output", mode="exp")
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_log(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     ref_val = np.log(a_np)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="log",
-                      input_name='input',
-                      output_name='output',
-                      mode='log')
+    builder.add_unary(name="log", input_name="input", output_name="output", mode="log")
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_abs(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
     ref_val = np.abs(a_np)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="abs",
-                      input_name='input',
-                      output_name='output',
-                      mode='abs')
+    builder.add_unary(name="abs", input_name="input", output_name="output", mode="abs")
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def verify_unary_threshold(input_dim, alpha):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
     ref_val = np.maximum(a_np, alpha)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_unary(name="threshold",
-                      input_name='input',
-                      output_name='output',
-                      mode='threshold',
-                      alpha=alpha)
+    builder.add_unary(
+        name="threshold", input_name="input", output_name="output", mode="threshold", alpha=alpha
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
@@ -540,6 +561,7 @@ def test_forward_unary():
 @tvm.testing.uses_gpu
 def test_forward_reduce():
     from enum import Enum
+
     class ReduceAxis(Enum):
         CHW = 0
         HW = 1
@@ -547,7 +569,7 @@ def test_forward_reduce():
         H = 3
         W = 4
 
-    def _verify_reduce(input_dim, mode, axis, ref_func, dtype='float32'):
+    def _verify_reduce(input_dim, mode, axis, ref_func, dtype="float32"):
         print(input_dim, mode, axis)
         a_np = np.random.uniform(size=input_dim).astype(dtype)
 
@@ -568,19 +590,16 @@ def test_forward_reduce():
         else:
             ref_val = ref_func(a_np, np_axis, keepdims=True)
 
-        inputs = [('input', datatypes.Array(*input_dim))]
-        output = [('output', datatypes.Array(*ref_val.shape))]
+        inputs = [("input", datatypes.Array(*input_dim))]
+        output = [("output", datatypes.Array(*ref_val.shape))]
         builder = NeuralNetworkBuilder(inputs, output)
-        builder.add_reduce(name=mode,
-                          input_name='input',
-                          output_name='output',
-                          axis=axis.name,
-                          mode=mode)
+        builder.add_reduce(
+            name=mode, input_name="input", output_name="output", axis=axis.name, mode=mode
+        )
 
         model = cm.models.MLModel(builder.spec)
         for target, ctx in tvm.testing.enabled_targets():
-            out = run_tvm_graph(model, target, ctx, [a_np],
-                                ['input'], ref_val.shape, dtype)
+            out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
             tvm.testing.assert_allclose(out, ref_val, rtol=1e-5, atol=1e-5)
 
     dshapes = [[10, 10], [1, 10, 10], [1, 3, 10, 10]]
@@ -596,28 +615,29 @@ def test_forward_reduce():
             _verify_reduce(dshape, "max", axis, np.max)
             if axis in [ReduceAxis.C, ReduceAxis.H, ReduceAxis.W]:
                 # For mode ArgMax, axis must be [-1] or [-2] or [-3]
-                _verify_reduce(dshape, "argmax", axis, np.argmax, dtype='int32')
+                _verify_reduce(dshape, "argmax", axis, np.argmax, dtype="int32")
 
 
 def verify_reshape(input_dim, target_shape, mode):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
     ref_val = np.reshape(a_np, target_shape)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*ref_val.shape))]
+    inputs = [("input", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*ref_val.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_reshape(name="reshape",
-                       input_name='input',
-                       output_name='output',
-                       target_shape=target_shape,
-                       mode=mode)
+    builder.add_reshape(
+        name="reshape",
+        input_name="input",
+        output_name="output",
+        target_shape=target_shape,
+        mode=mode,
+    )
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], ref_val.shape, dtype)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
@@ -628,12 +648,12 @@ def test_forward_reshape():
 
 
 def verify_split(input_dim, nOutputs):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
     ref_val = np.split(a_np, nOutputs, axis=-3)
 
-    inputs = [('input', datatypes.Array(*input_dim))]
+    inputs = [("input", datatypes.Array(*input_dim))]
 
     output_names = []
     outputs = []
@@ -645,24 +665,39 @@ def verify_split(input_dim, nOutputs):
         output_shapes = output_shapes + [out.shape]
 
     builder = NeuralNetworkBuilder(inputs, outputs)
-    builder.add_split(name="split",
-                      input_name='input',
-                      output_names=output_names)
+    builder.add_split(name="split", input_name="input", output_names=output_names)
 
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input'], output_shapes, [dtype] * len(output_shapes))
+        out = run_tvm_graph(
+            model, target, ctx, [a_np], ["input"], output_shapes, [dtype] * len(output_shapes)
+        )
         tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
 
 
 def test_forward_split():
-    verify_split((1, 4, 4, 4,), 2)
-    verify_split((1, 3, 30, 20,), 3)
+    verify_split(
+        (
+            1,
+            4,
+            4,
+            4,
+        ),
+        2,
+    )
+    verify_split(
+        (
+            1,
+            3,
+            30,
+            20,
+        ),
+        3,
+    )
 
 
 def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0, image_scale=1.0):
-    dtype = 'float32'
+    dtype = "float32"
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     # make sure it is valid image format CHW.
     assert len(a_np.shape) == 3 and a_np.shape[0] == 3
@@ -671,65 +706,83 @@ def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0,
     b_np[1, :, :] = image_scale * a_np[1, :, :] + green_bias
     b_np[2, :, :] = image_scale * a_np[2, :, :] + red_bias
     b_np = np.add(a_np, b_np)
-    inputs = [('input1', datatypes.Array(*input_dim)),
-              ('input2', datatypes.Array(*input_dim))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.set_pre_processing_parameters(image_input_names=['input1'],
-                                          is_bgr=True,
-                                          blue_bias=blue_bias,
-                                          green_bias=green_bias,
-                                          red_bias=red_bias,
-                                          image_scale=image_scale)
+    builder.set_pre_processing_parameters(
+        image_input_names=["input1"],
+        is_bgr=True,
+        blue_bias=blue_bias,
+        green_bias=green_bias,
+        red_bias=red_bias,
+        image_scale=image_scale,
+    )
     # add one add layer to make CoreML model format valid
     # add layer has been tested before.
-    builder.add_elementwise(name='add', input_names=['input1', 'input2'],
-                            output_name='output', alpha=0, mode='ADD')
+    builder.add_elementwise(
+        name="add", input_names=["input1", "input2"], output_name="output", alpha=0, mode="ADD"
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np, a_np],
-                            ['input1', 'input2'], b_np.shape, dtype)
+        out = run_tvm_graph(
+            model, target, ctx, [a_np, a_np], ["input1", "input2"], b_np.shape, dtype
+        )
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_image_scaler():
     verify_image_scaler((3, 224, 224), image_scale=0.17)
-    verify_image_scaler((3, 224, 224),
-                        blue_bias=-1.7669800519943237,
-                        green_bias=-1.985260009765625,
-                        red_bias=-2.102560043334961,
-                        image_scale=0.379)
+    verify_image_scaler(
+        (3, 224, 224),
+        blue_bias=-1.7669800519943237,
+        green_bias=-1.985260009765625,
+        red_bias=-2.102560043334961,
+        image_scale=0.379,
+    )
+
 
 def verify_convolution(input_dim, filter, padding):
-    dtype = 'float32'
+    dtype = "float32"
     N, C, H, W = input_dim
     OC, _, KH, KW = filter
     a_np = np.random.uniform(size=input_dim).astype(dtype)
     w_np = np.random.uniform(size=(OC, C, KH, KW)).astype(dtype)
     w_np_cm = np.transpose(w_np, axes=(2, 3, 1, 0))
     b_np = conv2d_nchw_python(a_np, w_np, [1, 1], padding)
-    inputs = [('input1', datatypes.Array(C, H, W))]
-    output = [('output', datatypes.Array(*b_np.shape))]
+    inputs = [("input1", datatypes.Array(C, H, W))]
+    output = [("output", datatypes.Array(*b_np.shape))]
     builder = NeuralNetworkBuilder(inputs, output)
-    builder.add_convolution(name='conv', kernel_channels=3, output_channels=OC,
-                            height=KH, width=KW, stride_height=1, stride_width=1,
-                            border_mode=padding.lower(), groups=1,
-                            W=w_np_cm, b=None, has_bias=False,
-                            is_deconv=False,
-                            input_name='input1',
-                            output_name='output')
+    builder.add_convolution(
+        name="conv",
+        kernel_channels=3,
+        output_channels=OC,
+        height=KH,
+        width=KW,
+        stride_height=1,
+        stride_width=1,
+        border_mode=padding.lower(),
+        groups=1,
+        W=w_np_cm,
+        b=None,
+        has_bias=False,
+        is_deconv=False,
+        input_name="input1",
+        output_name="output",
+    )
     model = cm.models.MLModel(builder.spec)
     for target, ctx in tvm.testing.enabled_targets():
-        out = run_tvm_graph(model, target, ctx, [a_np],
-                            ['input1'], output_shape=None)
+        out = run_tvm_graph(model, target, ctx, [a_np], ["input1"], output_shape=None)
         tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_convolution():
-    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding='VALID')
-    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding='SAME')
+    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="VALID")
+    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="SAME")
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_forward_AddLayerParams()
     test_forward_ConcatLayerParams()
     test_forward_MultiplyLayerParams()
index fcaeaec..c7bc775 100644 (file)
@@ -26,22 +26,26 @@ import tvm
 from tvm import te
 from tvm.contrib import graph_runtime
 from tvm.contrib.download import download_testdata
+
 download_testdata.__test__ = False
 from tvm.relay.testing.darknet import LAYERTYPE
 from tvm.relay.testing.darknet import __darknetffi__
 from tvm.relay.frontend.darknet import ACTIVATION
 from tvm import relay
 
-REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
-DARKNET_LIB = 'libdarknet2.0.so'
-DARKNETLIB_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
-LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet'))
+REPO_URL = "https://github.com/dmlc/web-data/blob/master/darknet/"
+DARKNET_LIB = "libdarknet2.0.so"
+DARKNETLIB_URL = REPO_URL + "lib/" + DARKNET_LIB + "?raw=true"
+LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module="darknet"))
+
+DARKNET_TEST_IMAGE_NAME = "dog.jpg"
+DARKNET_TEST_IMAGE_URL = REPO_URL + "data/" + DARKNET_TEST_IMAGE_NAME + "?raw=true"
+DARKNET_TEST_IMAGE_PATH = download_testdata(
+    DARKNET_TEST_IMAGE_URL, DARKNET_TEST_IMAGE_NAME, module="data"
+)
 
-DARKNET_TEST_IMAGE_NAME = 'dog.jpg'
-DARKNET_TEST_IMAGE_URL = REPO_URL + 'data/' + DARKNET_TEST_IMAGE_NAME +'?raw=true'
-DARKNET_TEST_IMAGE_PATH = download_testdata(DARKNET_TEST_IMAGE_URL, DARKNET_TEST_IMAGE_NAME, module='data')
 
-def _read_memory_buffer(shape, data, dtype='float32'):
+def _read_memory_buffer(shape, data, dtype="float32"):
     length = 1
     for x in shape:
         length *= x
@@ -50,21 +54,20 @@ def _read_memory_buffer(shape, data, dtype='float32'):
         data_np[i] = data[i]
     return data_np.reshape(shape)
 
-def _get_tvm_output(net, data, build_dtype='float32', states=None):
-    '''Compute TVM output'''
-    dtype = 'float32'
+
+def _get_tvm_output(net, data, build_dtype="float32", states=None):
+    """Compute TVM output"""
+    dtype = "float32"
     mod, params = relay.frontend.from_darknet(net, data.shape, dtype)
-    target = 'llvm'
-    shape_dict = {'data': data.shape}
-    graph, library, params = relay.build(mod,
-                                         target,
-                                         params=params)
+    target = "llvm"
+    shape_dict = {"data": data.shape}
+    graph, library, params = relay.build(mod, target, params=params)
 
     # Execute on TVM
     ctx = tvm.cpu(0)
     m = graph_runtime.create(graph, library, ctx)
     # set inputs
-    m.set_input('data', tvm.nd.array(data.astype(dtype)))
+    m.set_input("data", tvm.nd.array(data.astype(dtype)))
     if states:
         for name in states.keys():
             m.set_input(name, tvm.nd.array(states[name].astype(dtype)))
@@ -76,54 +79,64 @@ def _get_tvm_output(net, data, build_dtype='float32', states=None):
         tvm_out.append(m.get_output(i).asnumpy())
     return tvm_out
 
+
 def _load_net(cfg_url, cfg_name, weights_url, weights_name):
-    cfg_path = download_testdata(cfg_url, cfg_name, module='darknet')
-    weights_path = download_testdata(weights_url, weights_name, module='darknet')
-    net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
+    cfg_path = download_testdata(cfg_url, cfg_name, module="darknet")
+    weights_path = download_testdata(weights_url, weights_name, module="darknet")
+    net = LIB.load_network(cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0)
     return net
 
-def verify_darknet_frontend(net, build_dtype='float32'):
-    '''Test network with given input image on both darknet and tvm'''
+
+def verify_darknet_frontend(net, build_dtype="float32"):
+    """Test network with given input image on both darknet and tvm"""
+
     def get_darknet_output(net, img):
         LIB.network_predict_image(net, img)
         out = []
         for i in range(net.n):
             layer = net.layers[i]
             if layer.type == LAYERTYPE.REGION:
-                attributes = np.array([layer.n, layer.out_c, layer.out_h,
-                                       layer.out_w, layer.classes,
-                                       layer.coords, layer.background],
-                                      dtype=np.int32)
+                attributes = np.array(
+                    [
+                        layer.n,
+                        layer.out_c,
+                        layer.out_h,
+                        layer.out_w,
+                        layer.classes,
+                        layer.coords,
+                        layer.background,
+                    ],
+                    dtype=np.int32,
+                )
                 out.insert(0, attributes)
-                out.insert(0, _read_memory_buffer((layer.n*2, ), layer.biases))
-                layer_outshape = (layer.batch, layer.out_c,
-                                  layer.out_h, layer.out_w)
+                out.insert(0, _read_memory_buffer((layer.n * 2,), layer.biases))
+                layer_outshape = (layer.batch, layer.out_c, layer.out_h, layer.out_w)
                 out.insert(0, _read_memory_buffer(layer_outshape, layer.output))
             elif layer.type == LAYERTYPE.YOLO:
-                attributes = np.array([layer.n, layer.out_c, layer.out_h,
-                                       layer.out_w, layer.classes,
-                                       layer.total],
-                                      dtype=np.int32)
+                attributes = np.array(
+                    [layer.n, layer.out_c, layer.out_h, layer.out_w, layer.classes, layer.total],
+                    dtype=np.int32,
+                )
                 out.insert(0, attributes)
-                out.insert(0, _read_memory_buffer((layer.total*2, ), layer.biases))
-                out.insert(0, _read_memory_buffer((layer.n, ), layer.mask, dtype='int32'))
-                layer_outshape = (layer.batch, layer.out_c,
-                                  layer.out_h, layer.out_w)
+                out.insert(0, _read_memory_buffer((layer.total * 2,), layer.biases))
+                out.insert(0, _read_memory_buffer((layer.n,), layer.mask, dtype="int32"))
+                layer_outshape = (layer.batch, layer.out_c, layer.out_h, layer.out_w)
                 out.insert(0, _read_memory_buffer(layer_outshape, layer.output))
-            elif i == net.n-1:
+            elif i == net.n - 1:
                 if layer.type == LAYERTYPE.CONNECTED:
                     darknet_outshape = (layer.batch, layer.out_c)
                 elif layer.type in [LAYERTYPE.SOFTMAX]:
                     darknet_outshape = (layer.batch, layer.outputs)
                 else:
-                    darknet_outshape = (layer.batch, layer.out_c,
-                                        layer.out_h, layer.out_w)
+                    darknet_outshape = (layer.batch, layer.out_c, layer.out_h, layer.out_w)
                 out.insert(0, _read_memory_buffer(darknet_outshape, layer.output))
         return out
 
-    dtype = 'float32'
+    dtype = "float32"
 
-    img = LIB.letterbox_image(LIB.load_image_color(DARKNET_TEST_IMAGE_PATH.encode('utf-8'), 0, 0), net.w, net.h)
+    img = LIB.letterbox_image(
+        LIB.load_image_color(DARKNET_TEST_IMAGE_PATH.encode("utf-8"), 0, 0), net.w, net.h
+    )
     darknet_output = get_darknet_output(net, img)
     batch_size = 1
     data = np.empty([batch_size, img.c, img.h, img.w], dtype)
@@ -138,96 +151,106 @@ def verify_darknet_frontend(net, build_dtype='float32'):
     for tvm_outs, darknet_out in zip(tvm_out, darknet_output):
         tvm.testing.assert_allclose(darknet_out, tvm_outs, rtol=1e-3, atol=1e-3)
 
+
 def _test_rnn_network(net, states):
-    '''Test network with given input data on both darknet and tvm'''
+    """Test network with given input data on both darknet and tvm"""
+
     def get_darknet_network_predict(net, data):
         return LIB.network_predict(net, data)
+
     from cffi import FFI
+
     ffi = FFI()
-    np_arr = np.zeros([1, net.inputs], dtype='float32')
+    np_arr = np.zeros([1, net.inputs], dtype="float32")
     np_arr[0, 2] = 1
-    cffi_arr = ffi.cast('float*', np_arr.ctypes.data)
+    cffi_arr = ffi.cast("float*", np_arr.ctypes.data)
     tvm_out = _get_tvm_output(net, np_arr, states=states)[0]
     darknet_output = get_darknet_network_predict(net, cffi_arr)
-    darknet_out = np.zeros(net.outputs, dtype='float32')
+    darknet_out = np.zeros(net.outputs, dtype="float32")
     for i in range(net.outputs):
         darknet_out[i] = darknet_output[i]
-    last_layer = net.layers[net.n-1]
+    last_layer = net.layers[net.n - 1]
     darknet_outshape = (last_layer.batch, last_layer.outputs)
     darknet_out = darknet_out.reshape(darknet_outshape)
     tvm.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-4, atol=1e-4)
 
+
 def test_forward_extraction():
-    '''test extraction model'''
-    model_name = 'extraction'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test extraction model"""
+    model_name = "extraction"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_alexnet():
-    '''test alexnet model'''
-    model_name = 'alexnet'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test alexnet model"""
+    model_name = "alexnet"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_resnet50():
-    '''test resnet50 model'''
-    model_name = 'resnet50'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test resnet50 model"""
+    model_name = "resnet50"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_resnext50():
-    '''test resnet50 model'''
-    model_name = 'resnext50'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test resnet50 model"""
+    model_name = "resnext50"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
 
 def test_forward_yolov2():
-    '''test yolov2 model'''
-    model_name = 'yolov2'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test yolov2 model"""
+    model_name = "yolov2"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     build_dtype = {}
     verify_darknet_frontend(net, build_dtype)
     LIB.free_network(net)
 
+
 def test_forward_yolov3():
-    '''test yolov3 model'''
-    model_name = 'yolov3'
-    cfg_name = model_name + '.cfg'
-    weights_name = model_name + '.weights'
-    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
-    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    """test yolov3 model"""
+    model_name = "yolov3"
+    cfg_name = model_name + ".cfg"
+    weights_name = model_name + ".weights"
+    cfg_url = "https://github.com/pjreddie/darknet/blob/master/cfg/" + cfg_name + "?raw=true"
+    weights_url = "http://pjreddie.com/media/files/" + weights_name + "?raw=true"
     net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
     build_dtype = {}
     verify_darknet_frontend(net, build_dtype)
     LIB.free_network(net)
 
+
 def test_forward_convolutional():
-    '''test convolutional layer'''
+    """test convolutional layer"""
     net = LIB.make_network(1)
     layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
     net.layers[0] = layer
@@ -236,8 +259,9 @@ def test_forward_convolutional():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_dense():
-    '''test fully connected layer'''
+    """test fully connected layer"""
     net = LIB.make_network(1)
     layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0)
     net.layers[0] = layer
@@ -246,8 +270,9 @@ def test_forward_dense():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_dense_batchnorm():
-    '''test fully connected layer with batchnorm'''
+    """test fully connected layer with batchnorm"""
     net = LIB.make_network(1)
     layer = LIB.make_connected_layer(1, 12, 2, 1, 1, 0)
     for i in range(5):
@@ -260,8 +285,9 @@ def test_forward_dense_batchnorm():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_maxpooling():
-    '''test maxpooling layer'''
+    """test maxpooling layer"""
     net = LIB.make_network(1)
     layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0)
     net.layers[0] = layer
@@ -270,8 +296,9 @@ def test_forward_maxpooling():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_avgpooling():
-    '''test avgerage pooling layer'''
+    """test avgerage pooling layer"""
     net = LIB.make_network(1)
     layer = LIB.make_avgpool_layer(1, 224, 224, 3)
     net.layers[0] = layer
@@ -280,8 +307,9 @@ def test_forward_avgpooling():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_conv_batch_norm():
-    '''test batch normalization layer'''
+    """test batch normalization layer"""
     net = LIB.make_network(1)
     layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0)
     for i in range(32):
@@ -293,8 +321,9 @@ def test_forward_conv_batch_norm():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_shortcut():
-    '''test shortcut layer'''
+    """test shortcut layer"""
     net = LIB.make_network(3)
     layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
     layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0)
@@ -310,8 +339,9 @@ def test_forward_shortcut():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_reorg():
-    '''test reorg layer'''
+    """test reorg layer"""
     net = LIB.make_network(2)
     layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
     layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0)
@@ -322,8 +352,9 @@ def test_forward_reorg():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_region():
-    '''test region layer'''
+    """test region layer"""
     net = LIB.make_network(2)
     layer_1 = LIB.make_convolutional_layer(1, 19, 19, 3, 425, 1, 1, 1, 0, 1, 0, 0, 0, 0)
     layer_2 = LIB.make_region_layer(1, 19, 19, 5, 80, 4)
@@ -336,8 +367,9 @@ def test_forward_region():
     verify_darknet_frontend(net, build_dtype)
     LIB.free_network(net)
 
+
 def test_forward_yolo_op():
-    '''test yolo layer'''
+    """test yolo layer"""
     net = LIB.make_network(2)
     layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 14, 1, 3, 2, 0, 1, 0, 0, 0, 0)
     layer_2 = LIB.make_yolo_layer(1, 111, 111, 2, 9, __darknetffi__.NULL, 2)
@@ -349,8 +381,9 @@ def test_forward_yolo_op():
     verify_darknet_frontend(net, build_dtype)
     LIB.free_network(net)
 
+
 def test_forward_upsample():
-    '''test upsample layer'''
+    """test upsample layer"""
     net = LIB.make_network(1)
     layer = LIB.make_upsample_layer(1, 19, 19, 3, 3)
     layer.scale = 1
@@ -360,10 +393,11 @@ def test_forward_upsample():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_l2normalize():
-    '''test l2 normalization layer'''
+    """test l2 normalization layer"""
     net = LIB.make_network(1)
-    layer = LIB.make_l2norm_layer(1, 224*224*3)
+    layer = LIB.make_l2norm_layer(1, 224 * 224 * 3)
     layer.c = layer.out_c = 3
     layer.h = layer.out_h = 224
     layer.w = layer.out_w = 224
@@ -373,8 +407,9 @@ def test_forward_l2normalize():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_elu():
-    '''test elu activation layer'''
+    """test elu activation layer"""
     net = LIB.make_network(1)
     layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
     layer_1.activation = ACTIVATION.ELU
@@ -384,8 +419,9 @@ def test_forward_elu():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_softmax():
-    '''test softmax layer'''
+    """test softmax layer"""
     net = LIB.make_network(1)
     layer_1 = LIB.make_softmax_layer(1, 75, 1)
     layer_1.temperature = 1
@@ -395,8 +431,9 @@ def test_forward_softmax():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_softmax_temperature():
-    '''test softmax layer'''
+    """test softmax layer"""
     net = LIB.make_network(1)
     layer_1 = LIB.make_softmax_layer(1, 75, 1)
     layer_1.temperature = 0.8
@@ -406,8 +443,9 @@ def test_forward_softmax_temperature():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_activation_logistic():
-    '''test logistic activation layer'''
+    """test logistic activation layer"""
     net = LIB.make_network(1)
     batch = 1
     h = 224
@@ -423,8 +461,22 @@ def test_forward_activation_logistic():
     binary = 0
     xnor = 0
     adam = 0
-    layer_1 = LIB.make_convolutional_layer(batch, h, w, c, n, groups, size, stride, padding,
-                                           activation, batch_normalize, binary, xnor, adam)
+    layer_1 = LIB.make_convolutional_layer(
+        batch,
+        h,
+        w,
+        c,
+        n,
+        groups,
+        size,
+        stride,
+        padding,
+        activation,
+        batch_normalize,
+        binary,
+        xnor,
+        adam,
+    )
     net.layers[0] = layer_1
     net.w = w
     net.h = h
@@ -432,8 +484,9 @@ def test_forward_activation_logistic():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+
 def test_forward_rnn():
-    '''test RNN layer'''
+    """test RNN layer"""
     net = LIB.make_network(1)
     batch = 1
     inputs = 4
@@ -452,7 +505,8 @@ def test_forward_rnn():
     _test_rnn_network(net, states)
     LIB.free_network(net)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_forward_resnet50()
     test_forward_resnext50()
     test_forward_alexnet()
index 9482230..3ba7a03 100644 (file)
@@ -29,13 +29,15 @@ except ImportError:
 
 from tensorflow import keras as tf_keras
 from packaging import version as package_version
+
 # prevent Keras from using up all gpu memory
 if tf.executing_eagerly():
-    gpus = tf.config.experimental.list_physical_devices('GPU')
+    gpus = tf.config.experimental.list_physical_devices("GPU")
     for gpu in gpus:
         tf.config.experimental.set_memory_growth(gpu, True)
 else:
     from keras.backend.tensorflow_backend import set_session
+
     config = tf.ConfigProto()
     config.gpu_options.per_process_gpu_memory_fraction = 0.5
     set_session(tf.Session(config=config))
@@ -63,11 +65,11 @@ using_classic_keras = ("keras", {"keras": keras})
 using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})
 
 
-def verify_keras_frontend(keras_model, need_transpose=True, layout='NCHW'):
+def verify_keras_frontend(keras_model, need_transpose=True, layout="NCHW"):
     # Keras frontend currently supports tensorflow backend only.
-    assert(keras.backend.backend() == 'tensorflow')
+    assert keras.backend.backend() == "tensorflow"
 
-    if layout != 'NCHW':
+    if layout != "NCHW":
         need_transpose = False
 
     in_shapes = []
@@ -75,19 +77,18 @@ def verify_keras_frontend(keras_model, need_transpose=True, layout='NCHW'):
         if tf.executing_eagerly():
             in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape))
         else:
-            in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
-
+            in_shapes.append(
+                tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)
+            )
 
-    def get_keras_output(xs, dtype='float32'):
+    def get_keras_output(xs, dtype="float32"):
         return keras_model.predict(xs)
 
-    def get_tvm_output(xs, target, ctx, dtype='float32'):
+    def get_tvm_output(xs, target, ctx, dtype="float32"):
         shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
         mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
         with tvm.transform.PassContext(opt_level=2):
-            graph, lib, params = relay.build(mod,
-                                             target,
-                                             params=params)
+            graph, lib, params = relay.build(mod, target, params=params)
         m = graph_runtime.create(graph, lib, ctx)
         for name, x in zip(keras_model.input_names, xs):
             m.set_input(name, tvm.nd.array(x.astype(dtype)))
@@ -122,16 +123,18 @@ class TestKeras:
         x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
         y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
         z = keras.layers.Conv2D(8, (3, 3), padding="same")(y)
-        merge_funcs = [keras.layers.Add(),
-                    keras.layers.Subtract(),
-                    keras.layers.Multiply(),
-                    keras.layers.Maximum(),
-                    keras.layers.Minimum(),
-                    keras.layers.Average(),
-                    keras.layers.Concatenate()]
+        merge_funcs = [
+            keras.layers.Add(),
+            keras.layers.Subtract(),
+            keras.layers.Multiply(),
+            keras.layers.Maximum(),
+            keras.layers.Minimum(),
+            keras.layers.Average(),
+            keras.layers.Concatenate(),
+        ]
         for merge_func in merge_funcs:
             class_name = type(merge_func).__name__
-            if class_name in ('Subtract', 'Dot'):
+            if class_name in ("Subtract", "Dot"):
                 out = merge_func([x, y])
             else:
                 out = merge_func([x, y, z])
@@ -141,12 +144,14 @@ class TestKeras:
     def test_forward_merge_dot(self, keras):
         data1 = keras.layers.Input(shape=(2, 2))
         data2 = keras.layers.Input(shape=(2, 2))
-        merge_funcs = [keras.layers.Dot(axes=[1, 2]),
-                    keras.layers.Dot(axes=[2, 1]),
-                    keras.layers.Dot(axes=[1, 1]),
-                    keras.layers.Dot(axes=[2, 2]),
-                    keras.layers.Dot(axes=1),
-                    keras.layers.Dot(axes=2)]
+        merge_funcs = [
+            keras.layers.Dot(axes=[1, 2]),
+            keras.layers.Dot(axes=[2, 1]),
+            keras.layers.Dot(axes=[1, 1]),
+            keras.layers.Dot(axes=[2, 2]),
+            keras.layers.Dot(axes=1),
+            keras.layers.Dot(axes=2),
+        ]
         for merge_func in merge_funcs:
             out = merge_func([data1, data2])
             keras_model = keras.models.Model([data1, data2], out)
@@ -154,43 +159,44 @@ class TestKeras:
 
     def test_forward_activations(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))
-        act_funcs = [keras.layers.Activation('softmax'),
-                    keras.layers.Softmax(),
-                    keras.layers.Softmax(axis=-1),
-                    keras.layers.Softmax(axis=1),
-                    keras.layers.Softmax(axis=2),
-                    keras.layers.Softmax(axis=3),
-                    keras.layers.Activation('softplus'),
-                    keras.layers.Activation('relu'),
-                    keras.layers.Activation('softsign'),
-                    keras.layers.Activation('hard_sigmoid'),
-                    keras.layers.Activation('sigmoid'),
-                    keras.layers.Activation('tanh'),
-                    keras.layers.Activation('linear'),
-                    keras.layers.Activation('selu'),
-                    keras.layers.ReLU(),
-                    keras.layers.ReLU(max_value=6.),
-                    keras.layers.ReLU(max_value=6., threshold=0.),
-                    keras.layers.ReLU(max_value=6., threshold=1.),
-                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.),
-                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.5),
-                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=1.),
-                    keras.layers.LeakyReLU(alpha=0.3),
-                    keras.layers.PReLU(weights=np.random.rand(1, 32, 32, 3)),
-                    keras.layers.ELU(alpha=0.5),
-                    keras.layers.ThresholdedReLU(theta=0.5)]
+        act_funcs = [
+            keras.layers.Activation("softmax"),
+            keras.layers.Softmax(),
+            keras.layers.Softmax(axis=-1),
+            keras.layers.Softmax(axis=1),
+            keras.layers.Softmax(axis=2),
+            keras.layers.Softmax(axis=3),
+            keras.layers.Activation("softplus"),
+            keras.layers.Activation("relu"),
+            keras.layers.Activation("softsign"),
+            keras.layers.Activation("hard_sigmoid"),
+            keras.layers.Activation("sigmoid"),
+            keras.layers.Activation("tanh"),
+            keras.layers.Activation("linear"),
+            keras.layers.Activation("selu"),
+            keras.layers.ReLU(),
+            keras.layers.ReLU(max_value=6.0),
+            keras.layers.ReLU(max_value=6.0, threshold=0.0),
+            keras.layers.ReLU(max_value=6.0, threshold=1.0),
+            keras.layers.ReLU(max_value=6.0, threshold=1.0, negative_slope=0.0),
+            keras.layers.ReLU(max_value=6.0, threshold=1.0, negative_slope=0.5),
+            keras.layers.ReLU(max_value=6.0, threshold=1.0, negative_slope=1.0),
+            keras.layers.LeakyReLU(alpha=0.3),
+            keras.layers.PReLU(weights=np.random.rand(1, 32, 32, 3)),
+            keras.layers.ELU(alpha=0.5),
+            keras.layers.ThresholdedReLU(theta=0.5),
+        ]
         for act_func in act_funcs:
             x = act_func(data)
             keras_model = keras.models.Model(data, x)
             verify_keras_frontend(keras_model)
-            verify_keras_frontend(keras_model, need_transpose=False, layout='NHWC')
-
+            verify_keras_frontend(keras_model, need_transpose=False, layout="NHWC")
 
     def test_forward_dense(self, keras):
         data = keras.layers.Input(shape=(32, 32, 1))
         x = keras.layers.Flatten()(data)
         x = keras.layers.Dropout(0.5)(x)
-        x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
+        x = keras.layers.Dense(10, activation="relu", kernel_initializer="uniform")(x)
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
 
@@ -201,38 +207,40 @@ class TestKeras:
         verify_keras_frontend(keras_model, need_transpose=False)
 
     def test_forward_sequential(self, keras):
-        keras_model = keras.models.Sequential([
-            keras.layers.Dense(16, input_dim=32, activation='relu'),
-            keras.layers.Dropout(0.5),
-            keras.layers.Dense(8, activation='relu'),
-            keras.layers.Dropout(0.5),
-            keras.layers.Dense(1, activation='sigmoid')
-        ])
+        keras_model = keras.models.Sequential(
+            [
+                keras.layers.Dense(16, input_dim=32, activation="relu"),
+                keras.layers.Dropout(0.5),
+                keras.layers.Dense(8, activation="relu"),
+                keras.layers.Dropout(0.5),
+                keras.layers.Dense(1, activation="sigmoid"),
+            ]
+        )
         verify_keras_frontend(keras_model)
 
-
     def test_forward_pool(self, keras):
         data = keras.layers.Input(shape=(32, 32, 1))
         # maxpool
-        x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
+        x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding="same")(data)
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
         # avgpool
-        y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
+        y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding="same")(data)
         keras_model = keras.models.Model(data, y)
         verify_keras_frontend(keras_model)
 
-
     def test_forward_conv(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))
-        conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
-                                        strides=(2, 2), padding='same'),
-                    keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
-                                        dilation_rate=(2, 2), padding='same'),
-                    keras.layers.Conv2D(filters=1, kernel_size=(3, 3), padding='same'),
-                    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
-                    keras.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding='valid'),
-                    keras.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding='same')]
+        conv_funcs = [
+            keras.layers.Conv2D(filters=10, kernel_size=(3, 3), strides=(2, 2), padding="same"),
+            keras.layers.Conv2D(
+                filters=10, kernel_size=(3, 3), dilation_rate=(2, 2), padding="same"
+            ),
+            keras.layers.Conv2D(filters=1, kernel_size=(3, 3), padding="same"),
+            keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding="same"),
+            keras.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding="valid"),
+            keras.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding="same"),
+        ]
         for conv_func in conv_funcs:
             x = conv_func(data)
             keras_model = keras.models.Model(data, x)
@@ -240,42 +248,63 @@ class TestKeras:
 
     def test_forward_batch_norm(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))
-        batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
-                                                            center=True, scale=False,
-                                                            beta_initializer='zeros',
-                                                            gamma_initializer='ones',
-                                                            moving_mean_initializer='zeros',
-                                                            moving_variance_initializer='ones'),
-                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
-                                                            center=True, scale=True,
-                                                            beta_initializer='zeros',
-                                                            gamma_initializer='ones',
-                                                            moving_mean_initializer='zeros',
-                                                            moving_variance_initializer='ones'),
-                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
-                                                            center=False, scale=True,
-                                                            beta_initializer='zeros',
-                                                            gamma_initializer='ones',
-                                                            moving_mean_initializer='zeros',
-                                                            moving_variance_initializer='ones'),
-                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
-                                                            center=False, scale=False,
-                                                            beta_initializer='zeros',
-                                                            gamma_initializer='ones',
-                                                            moving_mean_initializer='zeros',
-                                                            moving_variance_initializer='ones')]
+        batch_norm_funcs = [
+            keras.layers.BatchNormalization(
+                axis=-1,
+                momentum=0.99,
+                epsilon=0.001,
+                center=True,
+                scale=False,
+                beta_initializer="zeros",
+                gamma_initializer="ones",
+                moving_mean_initializer="zeros",
+                moving_variance_initializer="ones",
+            ),
+            keras.layers.BatchNormalization(
+                axis=-1,
+                momentum=0.99,
+                epsilon=0.001,
+                center=True,
+                scale=True,
+                beta_initializer="zeros",
+                gamma_initializer="ones",
+                moving_mean_initializer="zeros",
+                moving_variance_initializer="ones",
+            ),
+            keras.layers.BatchNormalization(
+                axis=-1,
+                momentum=0.99,
+                epsilon=0.001,
+                center=False,
+                scale=True,
+                beta_initializer="zeros",
+                gamma_initializer="ones",
+                moving_mean_initializer="zeros",
+                moving_variance_initializer="ones",
+            ),
+            keras.layers.BatchNormalization(
+                axis=-1,
+                momentum=0.99,
+                epsilon=0.001,
+                center=False,
+                scale=False,
+                beta_initializer="zeros",
+                gamma_initializer="ones",
+                moving_mean_initializer="zeros",
+                moving_variance_initializer="ones",
+            ),
+        ]
         for batch_norm_func in batch_norm_funcs:
             x = batch_norm_func(data)
             keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
 
-    def test_forward_upsample(self, keras, interpolation='nearest'):
+    def test_forward_upsample(self, keras, interpolation="nearest"):
         data = keras.layers.Input(shape=(32, 32, 3))
         x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
 
-
     def test_forward_reshape(self, keras):
         # input_shape len is 3, target_shape len is 3
         data = keras.layers.Input(shape=(32, 32, 3))
@@ -308,7 +337,6 @@ class TestKeras:
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
 
-
     def test_forward_crop(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))
         x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
@@ -321,7 +349,6 @@ class TestKeras:
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
 
-
     def test_forward_multi_inputs(self, keras):
         data1 = keras.layers.Input(shape=(32, 32, 3))
         data2 = keras.layers.Input(shape=(32, 32, 3))
@@ -332,7 +359,6 @@ class TestKeras:
         keras_model = keras.models.Model([data1, data2], z)
         verify_keras_frontend(keras_model)
 
-
     def test_forward_multi_outputs(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))
         x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
@@ -342,7 +368,6 @@ class TestKeras:
         keras_model = keras.models.Model(data, [x, y])
         verify_keras_frontend(keras_model)
 
-
     def test_forward_reuse_layers(self, keras):
         # reuse conv2d
         data = keras.layers.Input(shape=(32, 32, 3))
@@ -363,136 +388,121 @@ class TestKeras:
         keras_model = keras.models.Model(data, z)
         verify_keras_frontend(keras_model)
 
-
-    def test_forward_rnn(self,keras):
+    def test_forward_rnn(self, keras):
         data = keras.layers.Input(shape=(1, 32))
-        rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
-                        recurrent_activation='sigmoid', activation='tanh'),
-                    keras.layers.SimpleRNN(units=16, return_state=False,
-                        activation='tanh'),
-                    keras.layers.GRU(units=16, return_state=False,
-                        recurrent_activation='sigmoid', activation='tanh', reset_after=False)]
+        rnn_funcs = [
+            keras.layers.LSTM(
+                units=16, return_state=False, recurrent_activation="sigmoid", activation="tanh"
+            ),
+            keras.layers.SimpleRNN(units=16, return_state=False, activation="tanh"),
+            keras.layers.GRU(
+                units=16,
+                return_state=False,
+                recurrent_activation="sigmoid",
+                activation="tanh",
+                reset_after=False,
+            ),
+        ]
         for rnn_func in rnn_funcs:
             x = rnn_func(data)
             keras_model = keras.models.Model(data, x)
             verify_keras_frontend(keras_model, need_transpose=False)
 
-
-    def test_forward_vgg16(self, keras, layout='NCHW'):
-        keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
-            input_shape=(224, 224, 3), classes=1000)
+    def test_forward_vgg16(self, keras, layout="NCHW"):
+        keras_model = keras.applications.VGG16(
+            include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000
+        )
         verify_keras_frontend(keras_model, layout=layout)
 
-
-    def test_forward_xception(self, keras, layout='NCHW'):
-        keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
-            input_shape=(299, 299, 3), classes=1000)
+    def test_forward_xception(self, keras, layout="NCHW"):
+        keras_model = keras.applications.Xception(
+            include_top=True, weights="imagenet", input_shape=(299, 299, 3), classes=1000
+        )
         verify_keras_frontend(keras_model, layout=layout)
 
-
-    def test_forward_resnet50(self, keras, layout='NCHW'):
-        keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
-            input_shape=(224, 224, 3), classes=1000)
+    def test_forward_resnet50(self, keras, layout="NCHW"):
+        keras_model = keras.applications.ResNet50(
+            include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000
+        )
         verify_keras_frontend(keras_model, layout=layout)
 
-
-    def test_forward_mobilenet(self, keras, layout='NCHW'):
-        keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
-            input_shape=(224, 224, 3), classes=1000)
+    def test_forward_mobilenet(self, keras, layout="NCHW"):
+        keras_model = keras.applications.MobileNet(
+            include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000
+        )
         verify_keras_frontend(keras_model, layout=layout)
 
     def test_forward_conv3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 3))
-        conv_funcs = [keras.layers.Conv3D(filters=10,
-                                          kernel_size=(3, 3, 3),
-                                          strides=(2, 2, 2),
-                                          padding='same'),
-                      keras.layers.Conv3D(filters=10,
-                                          kernel_size=(3, 3, 3),
-                                          dilation_rate=(2, 2, 2),
-                                          padding='same'),
-                      keras.layers.Conv3D(filters=1,
-                                          kernel_size=(3, 3, 3),
-                                          padding='valid',
-                                          use_bias=False),
-                      keras.layers.Conv3D(filters=10,
-                                          kernel_size=(2, 2, 2),
-                                          padding='valid'),
-                    ]
+        conv_funcs = [
+            keras.layers.Conv3D(
+                filters=10, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same"
+            ),
+            keras.layers.Conv3D(
+                filters=10, kernel_size=(3, 3, 3), dilation_rate=(2, 2, 2), padding="same"
+            ),
+            keras.layers.Conv3D(filters=1, kernel_size=(3, 3, 3), padding="valid", use_bias=False),
+            keras.layers.Conv3D(filters=10, kernel_size=(2, 2, 2), padding="valid"),
+        ]
         for conv_func in conv_funcs:
             x = conv_func(data)
             keras_model = keras.models.Model(data, x)
-            verify_keras_frontend(keras_model, layout='NDHWC')
-
+            verify_keras_frontend(keras_model, layout="NDHWC")
 
     def test_forward_conv3d_transpose(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 3))
-        conv_funcs = [keras.layers.Conv3DTranspose(filters=10,
-                                          kernel_size=(3, 3, 3),
-                                          strides=(2, 2, 2),
-                                          padding='same'),
-                      keras.layers.Conv3DTranspose(filters=10,
-                                          kernel_size=(1, 1, 1),
-                                          dilation_rate=(1, 1, 1),
-                                          padding='same'),
-                      keras.layers.Conv3DTranspose(filters=1,
-                                          kernel_size=(3, 3, 3),
-                                          padding='valid',
-                                          use_bias=False),
-                      keras.layers.Conv3DTranspose(filters=10,
-                                          kernel_size=(2, 2, 2),
-                                          padding='valid'),
-                    ]
+        conv_funcs = [
+            keras.layers.Conv3DTranspose(
+                filters=10, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same"
+            ),
+            keras.layers.Conv3DTranspose(
+                filters=10, kernel_size=(1, 1, 1), dilation_rate=(1, 1, 1), padding="same"
+            ),
+            keras.layers.Conv3DTranspose(
+                filters=1, kernel_size=(3, 3, 3), padding="valid", use_bias=False
+            ),
+            keras.layers.Conv3DTranspose(filters=10, kernel_size=(2, 2, 2), padding="valid"),
+        ]
         for conv_func in conv_funcs:
             x = conv_func(data)
             keras_model = keras.models.Model(data, x)
-            verify_keras_frontend(keras_model, layout='NDHWC')
-
+            verify_keras_frontend(keras_model, layout="NDHWC")
 
     def test_forward_pool3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 1))
-        pool_funcs = [# maxpool
-                      keras.layers.MaxPooling3D(pool_size=(2, 2, 2),
-                                                strides=(1, 1, 1),
-                                                padding='same'),
-                      keras.layers.MaxPooling3D(pool_size=(3, 3, 3),
-                                                strides=(2, 2, 2),
-                                                padding='valid'),
-                      # avgpool
-                      keras.layers.AveragePooling3D(pool_size=(3, 3, 3),
-                                                    strides=(2, 2, 2),
-                                                    padding='same'),
-                      keras.layers.AveragePooling3D(pool_size=(2, 2, 2),
-                                                    strides=(1, 1, 1),
-                                                    padding='valid'),
-                     ]
+        pool_funcs = [  # maxpool
+            keras.layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding="same"),
+            keras.layers.MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="valid"),
+            # avgpool
+            keras.layers.AveragePooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="same"),
+            keras.layers.AveragePooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding="valid"),
+        ]
         for pool_func in pool_funcs:
             x = pool_func(data)
             keras_model = keras.models.Model(data, x)
-            verify_keras_frontend(keras_model, layout='NDHWC')
+            verify_keras_frontend(keras_model, layout="NDHWC")
 
     def test_forward_upsample3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 3))
         x = keras.layers.UpSampling3D(size=(2, 3, 4))(data)
         keras_model = keras.models.Model(data, x)
-        verify_keras_frontend(keras_model, layout='NDHWC')
+        verify_keras_frontend(keras_model, layout="NDHWC")
 
     def test_forward_zero_padding3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 3))
-        pad_funcs = [# Integer
-                     keras.layers.ZeroPadding3D(padding=2),
-                     # tuple of 3 ints
-                     keras.layers.ZeroPadding3D(padding=(1, 2, 3)),
-                     # tuple of 3 tuples of 2 ints
-                     keras.layers.ZeroPadding3D(padding=((1,1), (2,2), (2,2))),
-                     # tuple of 3 tuples of 2 ints different values
-                     keras.layers.ZeroPadding3D(padding=((1,2), (2,3), (3,2))),
-                    ]
+        pad_funcs = [  # Integer
+            keras.layers.ZeroPadding3D(padding=2),
+            # tuple of 3 ints
+            keras.layers.ZeroPadding3D(padding=(1, 2, 3)),
+            # tuple of 3 tuples of 2 ints
+            keras.layers.ZeroPadding3D(padding=((1, 1), (2, 2), (2, 2))),
+            # tuple of 3 tuples of 2 ints different values
+            keras.layers.ZeroPadding3D(padding=((1, 2), (2, 3), (3, 2))),
+        ]
         for pad_func in pad_funcs:
             x = pad_func(data)
             keras_model = keras.models.Model(data, x)
-            verify_keras_frontend(keras_model, layout='NDHWC')
-
+            verify_keras_frontend(keras_model, layout="NDHWC")
 
     def test_forward_embedding(self, keras):
         data = keras.layers.Input(shape=(2, 4), dtype="int32")
@@ -510,7 +520,6 @@ class TestKeras:
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
 
-
     def test_forward_repeat_vector(self, keras):
         data = keras.layers.Input(shape=(5,), dtype="float32")
         x = keras.layers.Dense(6)(data)
@@ -529,20 +538,20 @@ class TestKeras:
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
 
-
     def test_forward_global_pool3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 1))
-        pool_funcs = [# global maxpool
-                      keras.layers.GlobalMaxPooling3D(),
-                      # global avgpool
-                      keras.layers.GlobalAveragePooling3D()
-                     ]
+        pool_funcs = [  # global maxpool
+            keras.layers.GlobalMaxPooling3D(),
+            # global avgpool
+            keras.layers.GlobalAveragePooling3D(),
+        ]
         for pool_func in pool_funcs:
             x = pool_func(data)
             keras_model = keras.models.Model(data, x)
-            verify_keras_frontend(keras_model, layout='NDHWC')
+            verify_keras_frontend(keras_model, layout="NDHWC")
+
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     for k in [keras, tf_keras]:
         sut = TestKeras()
         sut.test_forward_merge_dot(keras=k)
@@ -554,8 +563,8 @@ if __name__ == '__main__':
         sut.test_forward_pool(keras=k)
         sut.test_forward_conv(keras=k)
         sut.test_forward_batch_norm(keras=k)
-        sut.test_forward_upsample(keras=k, interpolation='nearest')
-        sut.test_forward_upsample(keras=k, interpolation='bilinear')
+        sut.test_forward_upsample(keras=k, interpolation="nearest")
+        sut.test_forward_upsample(keras=k, interpolation="bilinear")
         sut.test_forward_reshape(keras=k)
         sut.test_forward_crop(keras=k)
         sut.test_forward_multi_inputs(keras=k)
@@ -563,12 +572,12 @@ if __name__ == '__main__':
         sut.test_forward_reuse_layers(keras=k)
         sut.test_forward_rnn(keras=k)
         sut.test_forward_vgg16(keras=k)
-        sut.test_forward_vgg16(keras=k, layout='NHWC')
+        sut.test_forward_vgg16(keras=k, layout="NHWC")
         sut.test_forward_xception(keras=k)
         sut.test_forward_resnet50(keras=k)
-        sut.test_forward_resnet50(keras=k, layout='NHWC')
+        sut.test_forward_resnet50(keras=k, layout="NHWC")
         sut.test_forward_mobilenet(keras=k)
-        sut.test_forward_mobilenet(keras=k, layout='NHWC')
+        sut.test_forward_mobilenet(keras=k, layout="NHWC")
         sut.test_forward_conv3d(keras=k)
         sut.test_forward_conv3d_transpose(keras=k)
         sut.test_forward_pool3d(keras=k)
index d042728..2c324a0 100644 (file)
@@ -25,52 +25,62 @@ def mx_mlp():
     num_class = 10
     return mlp.get_symbol(num_class)
 
+
 def relay_mlp():
     num_class = 10
     return tvm.relay.testing.mlp.get_workload(1, num_class)[0]
 
+
 # vgg
 def mx_vgg(num_layers):
     num_class = 1000
     return vgg.get_symbol(num_class, num_layers)
 
+
 def relay_vgg(num_layers):
     num_class = 1000
-    return tvm.relay.testing.vgg.get_workload(
-        1, num_class, num_layers=num_layers)[0]
+    return tvm.relay.testing.vgg.get_workload(1, num_class, num_layers=num_layers)[0]
+
 
 # resnet
 def mx_resnet(num_layers):
     num_class = 1000
-    return resnet.get_symbol(num_class, num_layers, '3,224,224')
+    return resnet.get_symbol(num_class, num_layers, "3,224,224")
+
 
 def relay_resnet(num_layers):
     num_class = 1000
-    return tvm.relay.testing.resnet.get_workload(
-        1, num_class, num_layers=num_layers)[0]
+    return tvm.relay.testing.resnet.get_workload(1, num_class, num_layers=num_layers)[0]
 
 
 # dqn
 mx_dqn = dqn.get_symbol
 
+
 def relay_dqn():
     return tvm.relay.testing.dqn.get_workload(1)[0]
 
+
 # squeezenet
 def mx_squeezenet(version):
     return squeezenet.get_symbol(version=version)
 
+
 def relay_squeezenet(version):
     return tvm.relay.testing.squeezenet.get_workload(1, version=version)[0]
 
+
 # inception
 mx_inception_v3 = inception_v3.get_symbol
 
+
 def relay_inception_v3():
     return tvm.relay.testing.inception_v3.get_workload(1)[0]
 
+
 # dcgan generator
 mx_dcgan = dcgan.get_symbol
 
+
 def relay_dcgan(batch_size):
     return tvm.relay.testing.dcgan.get_workload(batch_size=batch_size)[0]
index e606b78..cf086bc 100644 (file)
@@ -29,6 +29,7 @@ arXiv preprint arXiv:1511.06434 (2015).
 
 import mxnet as mx
 
+
 def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
     """a deconv layer that enlarges the feature map"""
     target_shape = (oshape[-2], oshape[-1])
@@ -37,46 +38,54 @@ def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
     adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
     adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
 
-    net = mx.sym.Deconvolution(data,
-                               kernel=kshape,
-                               stride=stride,
-                               pad=(pad_y, pad_x),
-                               adj=(adj_y, adj_x),
-                               num_filter=oshape[0],
-                               no_bias=True,
-                               name=name)
+    net = mx.sym.Deconvolution(
+        data,
+        kernel=kshape,
+        stride=stride,
+        pad=(pad_y, pad_x),
+        adj=(adj_y, adj_x),
+        num_filter=oshape[0],
+        no_bias=True,
+        name=name,
+    )
     return net
 
+
 def deconv2d_bn_relu(data, prefix, **kwargs):
     """a block of deconv + batch norm + relu"""
     eps = 1e-5 + 1e-12
 
     net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
     net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
-    net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
+    net = mx.sym.Activation(net, name="%s_act" % prefix, act_type="relu")
     return net
 
+
 def get_symbol(oshape=(3, 64, 64), ngf=128, code=None):
     """get symbol of dcgan generator"""
     assert oshape[-1] == 64, "Only support 64x64 image"
     assert oshape[-2] == 64, "Only support 64x64 image"
 
     code = mx.sym.Variable("data") if code is None else code
-    net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False)
-    net = mx.sym.Activation(net, act_type='relu')
+    net = mx.sym.FullyConnected(
+        code, name="g1", num_hidden=ngf * 8 * 4 * 4, no_bias=True, flatten=False
+    )
+    net = mx.sym.Activation(net, act_type="relu")
     # 4 x 4
     net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
     # 8 x 8
     net = deconv2d_bn_relu(
-        net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
+        net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2"
+    )
     # 16x16
     net = deconv2d_bn_relu(
-        net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
+        net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3"
+    )
     # 32x32
     net = deconv2d_bn_relu(
-        net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
+        net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4"
+    )
     # 64x64
-    net = deconv2d(
-        net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
-    net = mx.sym.Activation(net, act_type='tanh')
+    net = deconv2d(net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
+    net = mx.sym.Activation(net, act_type="tanh")
     return net
index e661e18..df611c7 100644 (file)
@@ -25,19 +25,17 @@ Nature 518.7540 (2015): 529.
 
 import mxnet as mx
 
+
 def get_symbol(num_action=18):
-    data = mx.sym.Variable(name='data')
-    net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
-                             num_filter=32, name='conv1')
-    net = mx.sym.Activation(net, act_type='relu', name='relu1')
-    net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
-                             num_filter=64, name='conv2')
-    net = mx.sym.Activation(net, act_type='relu', name='relu2')
-    net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
-                             num_filter=64, name='conv3')
-    net = mx.sym.Activation(net, act_type='relu', name='relu3')
-    net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
-    net = mx.sym.Activation(net, act_type='relu', name='relu4')
-    net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
+    data = mx.sym.Variable(name="data")
+    net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4), num_filter=32, name="conv1")
+    net = mx.sym.Activation(net, act_type="relu", name="relu1")
+    net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2), num_filter=64, name="conv2")
+    net = mx.sym.Activation(net, act_type="relu", name="relu2")
+    net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1), num_filter=64, name="conv3")
+    net = mx.sym.Activation(net, act_type="relu", name="relu3")
+    net = mx.sym.FullyConnected(net, num_hidden=512, name="fc4")
+    net = mx.sym.Activation(net, act_type="relu", name="relu4")
+    net = mx.sym.FullyConnected(net, num_hidden=num_action, name="fc5", flatten=False)
 
     return net
index 8e8f36a..bbed9b5 100644 (file)
@@ -26,101 +26,312 @@ Adopted from https://github.com/apache/incubator-mxnet/blob/
 import mxnet as mx
 import numpy as np
 
-def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
-    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
-    bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix))
-    act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
+
+def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=""):
+    conv = mx.sym.Convolution(
+        data=data,
+        num_filter=num_filter,
+        kernel=kernel,
+        stride=stride,
+        pad=pad,
+        no_bias=True,
+        name="%s%s_conv2d" % (name, suffix),
+    )
+    bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name="%s%s_batchnorm" % (name, suffix))
+    act = mx.sym.Activation(data=bn, act_type="relu", name="%s%s_relu" % (name, suffix))
     return act
 
 
-def Inception7A(data,
-                num_1x1,
-                num_3x3_red, num_3x3_1, num_3x3_2,
-                num_5x5_red, num_5x5,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
-    tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
-    tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
-    tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
-    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
-    cproj = Conv(pooling, proj, name=('%s_tower_2' %  name), suffix='_conv')
-    concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
+def Inception7A(
+    data, num_1x1, num_3x3_red, num_3x3_1, num_3x3_2, num_5x5_red, num_5x5, pool, proj, name
+):
+    tower_1x1 = Conv(data, num_1x1, name=("%s_conv" % name))
+    tower_5x5 = Conv(data, num_5x5_red, name=("%s_tower" % name), suffix="_conv")
+    tower_5x5 = Conv(
+        tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=("%s_tower" % name), suffix="_conv_1"
+    )
+    tower_3x3 = Conv(data, num_3x3_red, name=("%s_tower_1" % name), suffix="_conv")
+    tower_3x3 = Conv(
+        tower_3x3,
+        num_3x3_1,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_3x3 = Conv(
+        tower_3x3,
+        num_3x3_2,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    pooling = mx.sym.Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+    cproj = Conv(pooling, proj, name=("%s_tower_2" % name), suffix="_conv")
+    concat = mx.sym.Concat(
+        *[tower_1x1, tower_5x5, tower_3x3, cproj], name="ch_concat_%s_chconcat" % name
+    )
     return concat
 
+
 # First Downsample
-def Inception7B(data,
-                num_3x3,
-                num_d3x3_red, num_d3x3_1, num_d3x3_2,
-                pool,
-                name):
-    tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
-    tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
-    tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
-    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
-    concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
+def Inception7B(data, num_3x3, num_d3x3_red, num_d3x3_1, num_d3x3_2, pool, name):
+    tower_3x3 = Conv(
+        data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=("%s_conv" % name)
+    )
+    tower_d3x3 = Conv(data, num_d3x3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d3x3 = Conv(
+        tower_d3x3,
+        num_d3x3_1,
+        kernel=(3, 3),
+        pad=(1, 1),
+        stride=(1, 1),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d3x3 = Conv(
+        tower_d3x3,
+        num_d3x3_2,
+        kernel=(3, 3),
+        pad=(0, 0),
+        stride=(2, 2),
+        name=("%s_tower" % name),
+        suffix="_conv_2",
+    )
+    pooling = mx.sym.Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(2, 2),
+        pad=(0, 0),
+        pool_type="max",
+        name=("max_pool_%s_pool" % name),
+    )
+    concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name="ch_concat_%s_chconcat" % name)
     return concat
 
-def Inception7C(data,
-                num_1x1,
-                num_d7_red, num_d7_1, num_d7_2,
-                num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
-    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
-    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
-    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
-    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
-    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
-    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
+
+def Inception7C(
+    data,
+    num_1x1,
+    num_d7_red,
+    num_d7_1,
+    num_d7_2,
+    num_q7_red,
+    num_q7_1,
+    num_q7_2,
+    num_q7_3,
+    num_q7_4,
+    pool,
+    proj,
+    name,
+):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=("%s_conv" % name))
+    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d7 = Conv(
+        data=tower_d7,
+        num_filter=num_d7_1,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d7 = Conv(
+        data=tower_d7,
+        num_filter=num_d7_2,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower" % name),
+        suffix="_conv_2",
+    )
+    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=("%s_tower_1" % name), suffix="_conv")
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_1,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_2,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_3,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_3",
+    )
+    tower_q7 = Conv(
+        data=tower_q7,
+        num_filter=num_q7_4,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_4",
+    )
+    pooling = mx.sym.Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+    cproj = Conv(
+        data=pooling, num_filter=proj, kernel=(1, 1), name=("%s_tower_2" % name), suffix="_conv"
+    )
     # concat
-    concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
+    concat = mx.sym.Concat(
+        *[tower_1x1, tower_d7, tower_q7, cproj], name="ch_concat_%s_chconcat" % name
+    )
     return concat
 
-def Inception7D(data,
-                num_3x3_red, num_3x3,
-                num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
-                pool,
-                name):
-    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
-    tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
-    tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
-    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
-    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
+
+def Inception7D(
+    data, num_3x3_red, num_3x3, num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, pool, name
+):
+    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_3x3 = Conv(
+        data=tower_3x3,
+        num_filter=num_3x3,
+        kernel=(3, 3),
+        pad=(0, 0),
+        stride=(2, 2),
+        name=("%s_tower" % name),
+        suffix="_conv_1",
+    )
+    tower_d7_3x3 = Conv(
+        data=data, num_filter=num_d7_3x3_red, name=("%s_tower_1" % name), suffix="_conv"
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_1,
+        kernel=(1, 7),
+        pad=(0, 3),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_2,
+        kernel=(7, 1),
+        pad=(3, 0),
+        name=("%s_tower_1" % name),
+        suffix="_conv_2",
+    )
+    tower_d7_3x3 = Conv(
+        data=tower_d7_3x3,
+        num_filter=num_d7_3x3,
+        kernel=(3, 3),
+        stride=(2, 2),
+        name=("%s_tower_1" % name),
+        suffix="_conv_3",
+    )
+    pooling = mx.sym.Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(2, 2),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
     # concat
-    concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
+    concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name="ch_concat_%s_chconcat" % name)
     return concat
 
-def Inception7E(data,
-                num_1x1,
-                num_d3_red, num_d3_1, num_d3_2,
-                num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
-                pool, proj,
-                name):
-    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
-    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
-    tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
-    tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
-    tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
-    tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
-    tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
-    tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
-    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
-    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
+
+def Inception7E(
+    data,
+    num_1x1,
+    num_d3_red,
+    num_d3_1,
+    num_d3_2,
+    num_3x3_d3_red,
+    num_3x3,
+    num_3x3_d3_1,
+    num_3x3_d3_2,
+    pool,
+    proj,
+    name,
+):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=("%s_conv" % name))
+    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=("%s_tower" % name), suffix="_conv")
+    tower_d3_a = Conv(
+        data=tower_d3,
+        num_filter=num_d3_1,
+        kernel=(1, 3),
+        pad=(0, 1),
+        name=("%s_tower" % name),
+        suffix="_mixed_conv",
+    )
+    tower_d3_b = Conv(
+        data=tower_d3,
+        num_filter=num_d3_2,
+        kernel=(3, 1),
+        pad=(1, 0),
+        name=("%s_tower" % name),
+        suffix="_mixed_conv_1",
+    )
+    tower_3x3_d3 = Conv(
+        data=data, num_filter=num_3x3_d3_red, name=("%s_tower_1" % name), suffix="_conv"
+    )
+    tower_3x3_d3 = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3,
+        kernel=(3, 3),
+        pad=(1, 1),
+        name=("%s_tower_1" % name),
+        suffix="_conv_1",
+    )
+    tower_3x3_d3_a = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3_d3_1,
+        kernel=(1, 3),
+        pad=(0, 1),
+        name=("%s_tower_1" % name),
+        suffix="_mixed_conv",
+    )
+    tower_3x3_d3_b = Conv(
+        data=tower_3x3_d3,
+        num_filter=num_3x3_d3_2,
+        kernel=(3, 1),
+        pad=(1, 0),
+        name=("%s_tower_1" % name),
+        suffix="_mixed_conv_1",
+    )
+    pooling = mx.sym.Pooling(
+        data=data,
+        kernel=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        pool_type=pool,
+        name=("%s_pool_%s_pool" % (pool, name)),
+    )
+    cproj = Conv(
+        data=pooling, num_filter=proj, kernel=(1, 1), name=("%s_tower_2" % name), suffix="_conv"
+    )
     # concat
-    concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
+    concat = mx.sym.Concat(
+        *[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj],
+        name="ch_concat_%s_chconcat" % name,
+    )
     return concat
 
+
 def get_symbol(num_classes=1000, **kwargs):
     data = mx.sym.Variable(name="data")
     # stage 1
@@ -134,53 +345,24 @@ def get_symbol(num_classes=1000, **kwargs):
     pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
 
     # # stage 3
-    in3a = Inception7A(pool1, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 32, "mixed")
-    in3b = Inception7A(in3a, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 64, "mixed_1")
-    in3c = Inception7A(in3b, 64,
-                       64, 96, 96,
-                       48, 64,
-                       "avg", 64, "mixed_2")
-    in3d = Inception7B(in3c, 384,
-                       64, 96, 96,
-                       "max", "mixed_3")
+    in3a = Inception7A(pool1, 64, 64, 96, 96, 48, 64, "avg", 32, "mixed")
+    in3b = Inception7A(in3a, 64, 64, 96, 96, 48, 64, "avg", 64, "mixed_1")
+    in3c = Inception7A(in3b, 64, 64, 96, 96, 48, 64, "avg", 64, "mixed_2")
+    in3d = Inception7B(in3c, 384, 64, 96, 96, "max", "mixed_3")
     # stage 4
-    in4a = Inception7C(in3d, 192,
-                       128, 128, 192,
-                       128, 128, 128, 128, 192,
-                       "avg", 192, "mixed_4")
-    in4b = Inception7C(in4a, 192,
-                       160, 160, 192,
-                       160, 160, 160, 160, 192,
-                       "avg", 192, "mixed_5")
-    in4c = Inception7C(in4b, 192,
-                       160, 160, 192,
-                       160, 160, 160, 160, 192,
-                       "avg", 192, "mixed_6")
-    in4d = Inception7C(in4c, 192,
-                       192, 192, 192,
-                       192, 192, 192, 192, 192,
-                       "avg", 192, "mixed_7")
-    in4e = Inception7D(in4d, 192, 320,
-                       192, 192, 192, 192,
-                       "max", "mixed_8")
+    in4a = Inception7C(in3d, 192, 128, 128, 192, 128, 128, 128, 128, 192, "avg", 192, "mixed_4")
+    in4b = Inception7C(in4a, 192, 160, 160, 192, 160, 160, 160, 160, 192, "avg", 192, "mixed_5")
+    in4c = Inception7C(in4b, 192, 160, 160, 192, 160, 160, 160, 160, 192, "avg", 192, "mixed_6")
+    in4d = Inception7C(in4c, 192, 192, 192, 192, 192, 192, 192, 192, 192, "avg", 192, "mixed_7")
+    in4e = Inception7D(in4d, 192, 320, 192, 192, 192, 192, "max", "mixed_8")
     # stage 5
-    in5a = Inception7E(in4e, 320,
-                       384, 384, 384,
-                       448, 384, 384, 384,
-                       "avg", 192, "mixed_9")
-    in5b = Inception7E(in5a, 320,
-                       384, 384, 384,
-                       448, 384, 384, 384,
-                       "max", 192, "mixed_10")
+    in5a = Inception7E(in4e, 320, 384, 384, 384, 448, 384, 384, 384, "avg", 192, "mixed_9")
+    in5b = Inception7E(in5a, 320, 384, 384, 384, 448, 384, 384, 384, "max", 192, "mixed_10")
     # pool
-    pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
+    pool = mx.sym.Pooling(
+        data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool"
+    )
     flatten = mx.sym.Flatten(data=pool, name="flatten")
-    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False)
-    softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax')
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1", flatten=False)
+    softmax = mx.sym.SoftmaxOutput(data=fc1, name="softmax")
     return softmax
index 922b208..45f33f9 100644 (file)
@@ -20,21 +20,22 @@ a simple multilayer perceptron
 """
 import mxnet as mx
 
+
 def get_symbol(num_classes=10, **kwargs):
-    data = mx.symbol.Variable('data')
+    data = mx.symbol.Variable("data")
     data = mx.sym.Flatten(data=data)
     try:
-        fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
-        act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-        fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
-        act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-        fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
-        mlp  = mx.symbol.softmax(data = fc3, name = 'softmax')
+        fc1 = mx.symbol.FullyConnected(data=data, name="fc1", num_hidden=128, flatten=False)
+        act1 = mx.symbol.Activation(data=fc1, name="relu1", act_type="relu")
+        fc2 = mx.symbol.FullyConnected(data=act1, name="fc2", num_hidden=64, flatten=False)
+        act2 = mx.symbol.Activation(data=fc2, name="relu2", act_type="relu")
+        fc3 = mx.symbol.FullyConnected(data=act2, name="fc3", num_hidden=num_classes, flatten=False)
+        mlp = mx.symbol.softmax(data=fc3, name="softmax")
     except:
-        fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
-        act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-        fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
-        act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-        fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
-        mlp  = mx.symbol.softmax(data = fc3, name = 'softmax')
+        fc1 = mx.symbol.FullyConnected(data=data, name="fc1", num_hidden=128)
+        act1 = mx.symbol.Activation(data=fc1, name="relu1", act_type="relu")
+        fc2 = mx.symbol.FullyConnected(data=act1, name="fc2", num_hidden=64)
+        act2 = mx.symbol.Activation(data=fc2, name="relu2", act_type="relu")
+        fc3 = mx.symbol.FullyConnected(data=act2, name="fc3", num_hidden=num_classes)
+        mlp = mx.symbol.softmax(data=fc3, name="softmax")
     return mlp
index 3f9a870..98cdce6 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 
-'''
+"""
 Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
 Original author Wei Wu
 
 Implemented the following paper:
 
 Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
-'''
+"""
 import mxnet as mx
 import numpy as np
 
-def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
+
+def residual_unit(
+    data,
+    num_filter,
+    stride,
+    dim_match,
+    name,
+    bottle_neck=True,
+    bn_mom=0.9,
+    workspace=256,
+    memonger=False,
+):
     """Return ResNet Unit symbol for building ResNet
     Parameters
     ----------
@@ -46,45 +57,121 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b
         Workspace used in convolution operator
     """
     if bottle_neck:
-        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
-        act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
-        conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0),
-                                   no_bias=True, workspace=workspace, name=name + '_conv1')
-        bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
-        act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
-        conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
-                                   no_bias=True, workspace=workspace, name=name + '_conv2')
-        bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
-        act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
-        conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
-                                   workspace=workspace, name=name + '_conv3')
+        bn1 = mx.sym.BatchNorm(
+            data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + "_bn1"
+        )
+        act1 = mx.sym.Activation(data=bn1, act_type="relu", name=name + "_relu1")
+        conv1 = mx.sym.Convolution(
+            data=act1,
+            num_filter=int(num_filter * 0.25),
+            kernel=(1, 1),
+            stride=stride,
+            pad=(0, 0),
+            no_bias=True,
+            workspace=workspace,
+            name=name + "_conv1",
+        )
+        bn2 = mx.sym.BatchNorm(
+            data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + "_bn2"
+        )
+        act2 = mx.sym.Activation(data=bn2, act_type="relu", name=name + "_relu2")
+        conv2 = mx.sym.Convolution(
+            data=act2,
+            num_filter=int(num_filter * 0.25),
+            kernel=(3, 3),
+            stride=(1, 1),
+            pad=(1, 1),
+            no_bias=True,
+            workspace=workspace,
+            name=name + "_conv2",
+        )
+        bn3 = mx.sym.BatchNorm(
+            data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + "_bn3"
+        )
+        act3 = mx.sym.Activation(data=bn3, act_type="relu", name=name + "_relu3")
+        conv3 = mx.sym.Convolution(
+            data=act3,
+            num_filter=num_filter,
+            kernel=(1, 1),
+            stride=(1, 1),
+            pad=(0, 0),
+            no_bias=True,
+            workspace=workspace,
+            name=name + "_conv3",
+        )
         if dim_match:
             shortcut = data
         else:
-            shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
-                                            workspace=workspace, name=name+'_sc')
+            shortcut = mx.sym.Convolution(
+                data=act1,
+                num_filter=num_filter,
+                kernel=(1, 1),
+                stride=stride,
+                no_bias=True,
+                workspace=workspace,
+                name=name + "_sc",
+            )
         if memonger:
-            shortcut._set_attr(mirror_stage='True')
+            shortcut._set_attr(mirror_stage="True")
         return conv3 + shortcut
     else:
-        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
-        act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
-        conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
-                                      no_bias=True, workspace=workspace, name=name + '_conv1')
-        bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
-        act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
-        conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
-                                      no_bias=True, workspace=workspace, name=name + '_conv2')
+        bn1 = mx.sym.BatchNorm(
+            data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + "_bn1"
+        )
+        act1 = mx.sym.Activation(data=bn1, act_type="relu", name=name + "_relu1")
+        conv1 = mx.sym.Convolution(
+            data=act1,
+            num_filter=num_filter,
+            kernel=(3, 3),
+            stride=stride,
+            pad=(1, 1),
+            no_bias=True,
+            workspace=workspace,
+            name=name + "_conv1",
+        )
+        bn2 = mx.sym.BatchNorm(
+            data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + "_bn2"
+        )
+        act2 = mx.sym.Activation(data=bn2, act_type="relu", name=name + "_relu2")
+        conv2 = mx.sym.Convolution(
+            data=act2,
+            num_filter=num_filter,
+            kernel=(3, 3),
+            stride=(1, 1),
+            pad=(1, 1),
+            no_bias=True,
+            workspace=workspace,
+            name=name + "_conv2",
+        )
         if dim_match:
             shortcut = data
         else:
-            shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
-                                            workspace=workspace, name=name+'_sc')
+            shortcut = mx.sym.Convolution(
+                data=act1,
+                num_filter=num_filter,
+                kernel=(1, 1),
+                stride=stride,
+                no_bias=True,
+                workspace=workspace,
+                name=name + "_sc",
+            )
         if memonger:
-            shortcut._set_attr(mirror_stage='True')
+            shortcut._set_attr(mirror_stage="True")
         return conv2 + shortcut
 
-def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
+
+def resnet(
+    units,
+    num_stages,
+    filter_list,
+    num_classes,
+    image_shape,
+    bottle_neck=True,
+    bn_mom=0.9,
+    workspace=256,
+    dtype="float32",
+    memonger=False,
+):
     """Return ResNet symbol of
     Parameters
     ----------
@@ -104,65 +191,101 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck
         Precision (float32 or float16)
     """
     num_unit = len(units)
-    assert(num_unit == num_stages)
-    data = mx.sym.Variable(name='data')
-    if dtype == 'float32':
+    assert num_unit == num_stages
+    data = mx.sym.Variable(name="data")
+    if dtype == "float32":
         # data = mx.sym.identity(data=data, name='id')
         data = data
     else:
-        if dtype == 'float16':
+        if dtype == "float16":
             data = mx.sym.Cast(data=data, dtype=np.float16)
-    data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
+    data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name="bn_data")
     (nchannel, height, width) = image_shape
-    if height <= 32:            # such as cifar10
-        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
-                                  no_bias=True, name="conv0", workspace=workspace)
-    else:                       # often expected to be 224 such as imagenet
-        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
-                                  no_bias=True, name="conv0", workspace=workspace)
-        body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
-        body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
-        body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
+    if height <= 32:  # such as cifar10
+        body = mx.sym.Convolution(
+            data=data,
+            num_filter=filter_list[0],
+            kernel=(3, 3),
+            stride=(1, 1),
+            pad=(1, 1),
+            no_bias=True,
+            name="conv0",
+            workspace=workspace,
+        )
+    else:  # often expected to be 224 such as imagenet
+        body = mx.sym.Convolution(
+            data=data,
+            num_filter=filter_list[0],
+            kernel=(7, 7),
+            stride=(2, 2),
+            pad=(3, 3),
+            no_bias=True,
+            name="conv0",
+            workspace=workspace,
+        )
+        body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name="bn0")
+        body = mx.sym.Activation(data=body, act_type="relu", name="relu0")
+        body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type="max")
 
     for i in range(num_stages):
-        body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
-                             name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
-                             memonger=memonger)
-        for j in range(units[i]-1):
-            body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
-                                 bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
-    bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
-    relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
+        body = residual_unit(
+            body,
+            filter_list[i + 1],
+            (1 if i == 0 else 2, 1 if i == 0 else 2),
+            False,
+            name="stage%d_unit%d" % (i + 1, 1),
+            bottle_neck=bottle_neck,
+            workspace=workspace,
+            memonger=memonger,
+        )
+        for j in range(units[i] - 1):
+            body = residual_unit(
+                body,
+                filter_list[i + 1],
+                (1, 1),
+                True,
+                name="stage%d_unit%d" % (i + 1, j + 2),
+                bottle_neck=bottle_neck,
+                workspace=workspace,
+                memonger=memonger,
+            )
+    bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name="bn1")
+    relu1 = mx.sym.Activation(data=bn1, act_type="relu", name="relu1")
     # Although kernel is not used here when global_pool=True, we should put one
-    pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
+    pool1 = mx.sym.Pooling(
+        data=relu1, global_pool=True, kernel=(7, 7), pool_type="avg", name="pool1"
+    )
     flat = mx.sym.Flatten(data=pool1)
     try:
-        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
+        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name="fc1", flatten=False)
     except:
-        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
-    if dtype == 'float16':
+        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name="fc1")
+    if dtype == "float16":
         fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
-    return mx.sym.softmax(data=fc1, name='softmax')
+    return mx.sym.softmax(data=fc1, name="softmax")
+
 
-def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
+def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype="float32", **kwargs):
     """
     Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
     Original author Wei Wu
     """
-    image_shape = [int(l) for l in image_shape.split(',')]
+    image_shape = [int(l) for l in image_shape.split(",")]
     (nchannel, height, width) = image_shape
     if height <= 28:
         num_stages = 3
-        if (num_layers-2) % 9 == 0 and num_layers >= 164:
-            per_unit = [(num_layers-2)//9]
+        if (num_layers - 2) % 9 == 0 and num_layers >= 164:
+            per_unit = [(num_layers - 2) // 9]
             filter_list = [16, 64, 128, 256]
             bottle_neck = True
-        elif (num_layers-2) % 6 == 0 and num_layers < 164:
-            per_unit = [(num_layers-2)//6]
+        elif (num_layers - 2) % 6 == 0 and num_layers < 164:
+            per_unit = [(num_layers - 2) // 6]
             filter_list = [16, 16, 32, 64]
             bottle_neck = False
         else:
-            raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
+            raise ValueError(
+                "no experiments done on num_layers {}, you can do it yourself".format(num_layers)
+            )
         units = per_unit * num_stages
     else:
         if num_layers >= 50:
@@ -187,13 +310,17 @@ def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='
         elif num_layers == 269:
             units = [3, 30, 48, 8]
         else:
-            raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
+            raise ValueError(
+                "no experiments done on num_layers {}, you can do it yourself".format(num_layers)
+            )
 
-    return resnet(units       = units,
-                  num_stages  = num_stages,
-                  filter_list = filter_list,
-                  num_classes = num_classes,
-                  image_shape = image_shape,
-                  bottle_neck = bottle_neck,
-                  workspace   = conv_workspace,
-                  dtype       = dtype)
+    return resnet(
+        units=units,
+        num_stages=num_stages,
+        filter_list=filter_list,
+        num_classes=num_classes,
+        image_shape=image_shape,
+        bottle_neck=bottle_neck,
+        workspace=conv_workspace,
+        dtype=dtype,
+    )
index 093da51..146f7fa 100644 (file)
@@ -35,14 +35,17 @@ def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
 
     return net
 
+
 def _make_fire_conv(net, channels, kernel_size, padding=0):
-    net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
-                             pad=(padding, padding))
-    net = mx.sym.Activation(net, act_type='relu')
+    net = mx.sym.Convolution(
+        net, num_filter=channels, kernel=(kernel_size, kernel_size), pad=(padding, padding)
+    )
+    net = mx.sym.Activation(net, act_type="relu")
     return net
 
+
 # Net
-def get_symbol(num_classes=1000, version='1.0', **kwargs):
+def get_symbol(num_classes=1000, version="1.0", **kwargs):
     """Get symbol of SqueezeNet
 
     Parameters
@@ -53,40 +56,42 @@ def get_symbol(num_classes=1000, version='1.0', **kwargs):
     version : str, optional
         "1.0" or "1.1" of SqueezeNet
     """
-    assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
-                                       "1.0 or 1.1 expected".format(version=version))
+    assert version in [
+        "1.0",
+        "1.1",
+    ], "Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)
     net = mx.sym.Variable("data")
-    if version == '1.0':
+    if version == "1.0":
         net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
-        net = mx.sym.Activation(net, act_type='relu')
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = mx.sym.Activation(net, act_type="relu")
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 16, 64, 64)
         net = _make_fire(net, 16, 64, 64)
         net = _make_fire(net, 32, 128, 128)
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 32, 128, 128)
         net = _make_fire(net, 48, 192, 192)
         net = _make_fire(net, 48, 192, 192)
         net = _make_fire(net, 64, 256, 256)
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 64, 256, 256)
     else:
         net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
-        net = mx.sym.Activation(net, act_type='relu')
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = mx.sym.Activation(net, act_type="relu")
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 16, 64, 64)
         net = _make_fire(net, 16, 64, 64)
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max',  stride=(2, 2))
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 32, 128, 128)
         net = _make_fire(net, 32, 128, 128)
-        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max',  stride=(2, 2))
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type="max", stride=(2, 2))
         net = _make_fire(net, 48, 192, 192)
         net = _make_fire(net, 48, 192, 192)
         net = _make_fire(net, 64, 256, 256)
         net = _make_fire(net, 64, 256, 256)
     net = mx.sym.Dropout(net, p=0.5)
     net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
-    net = mx.sym.Activation(net, act_type='relu')
-    net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
+    net = mx.sym.Activation(net, act_type="relu")
+    net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type="avg")
     net = mx.sym.flatten(net)
     return mx.sym.softmax(net)
index 68215bb..1578034 100644 (file)
@@ -24,16 +24,34 @@ large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
 import mxnet as mx
 import numpy as np
 
-def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
+
+def get_feature(internel_layer, layers, filters, batch_norm=False, **kwargs):
     for i, num in enumerate(layers):
         for j in range(num):
-            internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
+            internel_layer = mx.sym.Convolution(
+                data=internel_layer,
+                kernel=(3, 3),
+                pad=(1, 1),
+                num_filter=filters[i],
+                name="conv%s_%s" % (i + 1, j + 1),
+            )
             if batch_norm:
-                internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
-            internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
-        internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
+                internel_layer = mx.symbol.BatchNorm(
+                    data=internel_layer, name="bn%s_%s" % (i + 1, j + 1)
+                )
+            internel_layer = mx.sym.Activation(
+                data=internel_layer, act_type="relu", name="relu%s_%s" % (i + 1, j + 1)
+            )
+        internel_layer = mx.sym.Pooling(
+            data=internel_layer,
+            pool_type="max",
+            kernel=(2, 2),
+            stride=(2, 2),
+            name="pool%s" % (i + 1),
+        )
     return internel_layer
 
+
 def get_classifier(input_data, num_classes, **kwargs):
     flatten = mx.sym.Flatten(data=input_data, name="flatten")
     try:
@@ -54,7 +72,8 @@ def get_classifier(input_data, num_classes, **kwargs):
         fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
     return fc8
 
-def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
+
+def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype="float32", **kwargs):
     """
     Parameters
     ----------
@@ -67,19 +86,23 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **
     dtype: str, float32 or float16
         Data precision.
     """
-    vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
-                13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
-                16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
-                19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
+    vgg_spec = {
+        11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
+        13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
+        16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
+        19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512]),
+    }
     if num_layers not in vgg_spec:
-        raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
+        raise ValueError(
+            "Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)
+        )
     layers, filters = vgg_spec[num_layers]
     data = mx.sym.Variable(name="data")
-    if dtype == 'float16':
+    if dtype == "float16":
         data = mx.sym.Cast(data=data, dtype=np.float16)
     feature = get_feature(data, layers, filters, batch_norm)
     classifier = get_classifier(feature, num_classes)
-    if dtype == 'float16':
+    if dtype == "float16":
         classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
-    symbol = mx.sym.softmax(data=classifier, name='softmax')
+    symbol = mx.sym.softmax(data=classifier, name="softmax")
     return symbol
index bc5cbeb..639f8e2 100644 (file)
@@ -31,43 +31,51 @@ import model_zoo
 
 import tvm.testing
 
-def verify_mxnet_frontend_impl(mx_symbol,
-                               data_shape=(1, 3, 224, 224),
-                               out_shape=(1, 1000),
-                               gluon_impl=False,
-                               name=None,
-                               dtype='float32'):
+
+def verify_mxnet_frontend_impl(
+    mx_symbol,
+    data_shape=(1, 3, 224, 224),
+    out_shape=(1, 1000),
+    gluon_impl=False,
+    name=None,
+    dtype="float32",
+):
     """Use name different from test to avoid pytest picking it up"""
     if gluon_impl:
+
         def get_gluon_output(name, x):
             net = vision.get_model(name)
             net.collect_params().initialize(mx.init.Xavier())
-            net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
-                                           inputs=mx.sym.var('data'),
-                                           params=net.collect_params())
+            net_sym = gluon.nn.SymbolBlock(
+                outputs=net(mx.sym.var("data")),
+                inputs=mx.sym.var("data"),
+                params=net.collect_params(),
+            )
             out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
             return out, net_sym
+
     else:
-        def get_mxnet_output(symbol, x, dtype='float32'):
+
+        def get_mxnet_output(symbol, x, dtype="float32"):
             from collections import namedtuple
-            Batch = namedtuple('Batch', ['data'])
+
+            Batch = namedtuple("Batch", ["data"])
             mod = mx.mod.Module(symbol, label_names=None)
-            mod.bind(data_shapes=[('data', x.shape)], for_training=False)
+            mod.bind(data_shapes=[("data", x.shape)], for_training=False)
             mod.init_params()
             mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
             out = mod.get_outputs()[0].asnumpy()
             args, auxs = mod.get_params()
             return out, args, auxs
 
-    def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
+    def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype="float32"):
         shape_dict = {"data": x.shape}
         if gluon_impl:
             mod, params = relay.frontend.from_mxnet(symbol, shape_dict)
         else:
-            mod, params = relay.frontend.from_mxnet(symbol,
-                                                    shape_dict,
-                                                    arg_params=args,
-                                                    aux_params=auxs)
+            mod, params = relay.frontend.from_mxnet(
+                symbol, shape_dict, arg_params=args, aux_params=auxs
+            )
         with tvm.transform.PassContext(opt_level=3):
             graph, lib, params = relay.build(mod, target, params=params)
         m = graph_runtime.create(graph, lib, ctx)
@@ -93,12 +101,12 @@ def verify_mxnet_frontend_impl(mx_symbol,
             tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
             tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_forward_mlp():
     mlp = model_zoo.mx_mlp()
-    verify_mxnet_frontend_impl(mlp,
-                               data_shape=(1, 1, 28, 28),
-                               out_shape=(1, 10))
+    verify_mxnet_frontend_impl(mlp, data_shape=(1, 1, 28, 28), out_shape=(1, 10))
+
 
 @tvm.testing.uses_gpu
 def test_forward_vgg():
@@ -106,60 +114,68 @@ def test_forward_vgg():
         mx_sym = model_zoo.mx_vgg(n)
         verify_mxnet_frontend_impl(mx_sym)
 
+
 @tvm.testing.uses_gpu
 def test_forward_resnet():
     for n in [18]:
         mx_sym = model_zoo.mx_resnet(18)
         verify_mxnet_frontend_impl(mx_sym)
 
+
 @tvm.testing.uses_gpu
 def test_forward_leaky_relu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
     mx_sym = mx.sym.LeakyReLU(data)
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
-    mx_sym = mx.sym.LeakyReLU(data, act_type='leaky')
+    mx_sym = mx.sym.LeakyReLU(data, act_type="leaky")
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_elu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
-    mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
+    mx_sym = mx.sym.LeakyReLU(data, act_type="elu")
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_rrelu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
-    mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
+    mx_sym = mx.sym.LeakyReLU(data, act_type="rrelu", lower_bound=0.3, upper_bound=0.7)
     verify_mxnet_frontend_impl(mx_sym[0], (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_prelu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
-    mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
+    mx_sym = mx.sym.LeakyReLU(data, act_type="prelu")
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_gelu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
-    mx_sym = mx.sym.LeakyReLU(data, act_type='gelu')
+    mx_sym = mx.sym.LeakyReLU(data, act_type="gelu")
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_softrelu():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
-    mx_sym = mx.sym.Activation(data, act_type='softrelu')
+    mx_sym = mx.sym.Activation(data, act_type="softrelu")
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_fc_flatten():
     # test flatten=True option in mxnet 0.11.1
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     try:
         mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
         verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
@@ -168,124 +184,141 @@ def test_forward_fc_flatten():
     except:
         pass
 
+
 @tvm.testing.uses_gpu
 def test_forward_clip():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
     mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+
 @tvm.testing.uses_gpu
 def test_forward_split():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
     verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
 
+
 @tvm.testing.uses_gpu
 def test_forward_split_squeeze():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
     verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
 
+
 @tvm.testing.uses_gpu
 def test_forward_expand_dims():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.expand_dims(data, axis=1)
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_pooling():
-    data = mx.sym.var('data')
-    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
+    data = mx.sym.var("data")
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type="avg")
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
 
-    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type="max")
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
 
+
 @tvm.testing.uses_gpu
 def test_forward_pooling3d():
-    data = mx.sym.var('data')
-    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg')
+    data = mx.sym.var("data")
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type="avg")
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))
 
-    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max')
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type="max")
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))
 
+
 @tvm.testing.uses_gpu
 def test_forward_adaptive_pooling():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,))
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 1, 1))
 
     mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(3, 3))
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 3, 3))
 
+
 @tvm.testing.uses_gpu
 def test_forward_lrn():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
     verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
 
+
 @tvm.testing.uses_gpu
 def test_forward_ones():
-    data = mx.sym.var('data')
-    ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
+    data = mx.sym.var("data")
+    ones = mx.sym.ones(shape=(2, 3, 4), dtype="float32")
     mx_sym = mx.sym.elemwise_add(data, ones)
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_zeros():
-    data = mx.sym.var('data')
-    zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
+    data = mx.sym.var("data")
+    zeros = mx.sym.zeros(shape=(2, 3, 4), dtype="float32")
     mx_sym = mx.sym.elemwise_add(data, zeros)
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_ones_like():
-    data = mx.sym.var('data')
-    mx_sym = mx.sym.ones_like(data, dtype='float32')
+    data = mx.sym.var("data")
+    mx_sym = mx.sym.ones_like(data, dtype="float32")
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_make_loss():
-    data = mx.sym.var('data')
-    ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
-    mx_sym = mx.sym.make_loss((data-ones)**2/2, dtype='float32')
+    data = mx.sym.var("data")
+    ones = mx.sym.ones(shape=(2, 3, 4), dtype="float32")
+    mx_sym = mx.sym.make_loss((data - ones) ** 2 / 2, dtype="float32")
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_zeros_like():
-    data = mx.sym.var('data')
-    mx_sym = mx.sym.zeros_like(data, dtype='float32')
+    data = mx.sym.var("data")
+    mx_sym = mx.sym.zeros_like(data, dtype="float32")
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_argmax():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.argmax(data, axis=1)
     verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
 
+
 @tvm.testing.uses_gpu
 def test_forward_argmin():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.argmin(data, axis=0)
     verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
 
+
 @tvm.testing.uses_gpu
 def test_forward_slice():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
     verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
     mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
     verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
 
+
 @tvm.testing.uses_gpu
 def test_forward_where():
-    cond = mx.sym.var('cond')
-    x = mx.sym.var('x')
-    y = mx.sym.var('y')
+    cond = mx.sym.var("cond")
+    x = mx.sym.var("x")
+    y = mx.sym.var("y")
     dshape = (2, 2)
-    dtype = 'float32'
+    dtype = "float32"
     mx_sym = mx.sym.where(cond, x, y)
     np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
     np_x = np.random.uniform(size=dshape).astype(dtype)
@@ -293,8 +326,8 @@ def test_forward_where():
     mx_cond = mx.nd.array(np_cond)
     mx_x = mx.nd.array(np_x)
     mx_y = mx.nd.array(np_y)
-    shapes = {'cond': dshape, 'x': dshape, 'y': dshape}
-    mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
+    shapes = {"cond": dshape, "x": dshape, "y": dshape}
+    mod = mx.mod.Module(mx_sym, label_names=None, data_names=["cond", "x", "y"])
     mod.bind(data_shapes=shapes.items(), for_training=False)
     mod.init_params()
     args, auxs = mod.get_params()
@@ -330,6 +363,7 @@ def test_forward_arange():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()()
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
+
     verify(0, 20, None)
     verify(0, 20, 2)
     verify(1, 20, None)
@@ -340,44 +374,48 @@ def test_forward_arange():
     verify(20, 1, -1)
     verify(20, 1, -1.5)
 
+
 def _mx_symbol(F, op_name, inputs):
     op = getattr(F, op_name)
     return op(*inputs)
 
+
 @tvm.testing.uses_gpu
 def test_forward_broadcast_ops():
-    for op in ["broadcast_add",
-               "broadcast_plus",
-               "broadcast_sub",
-               "broadcast_minus",
-               "broadcast_mul",
-               "broadcast_div",
-               "broadcast_mod",
-               "broadcast_maximum",
-               "broadcast_minimum",
-               "broadcast_equal",
-               "broadcast_not_equal",
-               "broadcast_greater",
-               "broadcast_greater_equal",
-               "broadcast_lesser",
-               "broadcast_lesser_equal",
-               "broadcast_power",
-               "broadcast_logical_or",
-               "broadcast_logical_and",
-               "broadcast_logical_xor"]:
+    for op in [
+        "broadcast_add",
+        "broadcast_plus",
+        "broadcast_sub",
+        "broadcast_minus",
+        "broadcast_mul",
+        "broadcast_div",
+        "broadcast_mod",
+        "broadcast_maximum",
+        "broadcast_minimum",
+        "broadcast_equal",
+        "broadcast_not_equal",
+        "broadcast_greater",
+        "broadcast_greater_equal",
+        "broadcast_lesser",
+        "broadcast_lesser_equal",
+        "broadcast_power",
+        "broadcast_logical_or",
+        "broadcast_logical_and",
+        "broadcast_logical_xor",
+    ]:
         a_shape = (3, 4, 5)
         b_shape = (4, 5)
         if op == "broadcast_mod":
-            dtype = 'int32'
+            dtype = "int32"
             a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
             b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
         else:
-            dtype = 'float32'
+            dtype = "float32"
             a_np = np.random.uniform(size=a_shape).astype(dtype)
             b_np = np.random.uniform(size=b_shape).astype(dtype)
-        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var("a"), mx.sym.var("b")])
         ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
-        shapes = {'a': a_shape, 'b': b_shape}
+        shapes = {"a": a_shape, "b": b_shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
@@ -385,23 +423,34 @@ def test_forward_broadcast_ops():
                 op_res = intrp.evaluate()(a_np, b_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
+
 @tvm.testing.uses_gpu
 def test_forward_elemwise_ops():
-    for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
-               "elemwise_div", "maximum", "minimum",
-               operator.lt, operator.le, operator.eq,
-               operator.ne, operator.gt, operator.ge]:
+    for op in [
+        "elemwise_add",
+        "elemwise_sub",
+        "elemwise_mul",
+        "elemwise_div",
+        "maximum",
+        "minimum",
+        operator.lt,
+        operator.le,
+        operator.eq,
+        operator.ne,
+        operator.gt,
+        operator.ge,
+    ]:
         shape = (3, 4, 5)
-        dtype = 'float32'
+        dtype = "float32"
         a_np = np.random.uniform(size=shape).astype(dtype)
         b_np = np.random.uniform(size=shape).astype(dtype)
         if type(op) == str:
-            mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+            mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var("a"), mx.sym.var("b")])
             ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
         else:
-            mx_sym = op(mx.sym.var('a'), mx.sym.var('b'))
+            mx_sym = op(mx.sym.var("a"), mx.sym.var("b"))
             ref_res = op(mx.nd.array(a_np), mx.nd.array(b_np))
-        shapes = {'a': shape, 'b': shape}
+        shapes = {"a": shape, "b": shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
@@ -412,7 +461,7 @@ def test_forward_elemwise_ops():
 
 @tvm.testing.uses_gpu
 def test_forward_softmin():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.softmin(data)
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))
 
@@ -422,38 +471,67 @@ def test_forward_softmin():
 
 @tvm.testing.uses_gpu
 def test_forward_unary_ops():
-    for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc",
-               "softsign", "hard_sigmoid",
-               "cos", "sin", "tan",
-               "cosh", "sinh", "tanh",
-               "arccos", "arcsin", "arctan",
-               "arccosh", "arcsinh", "arctanh"]:
+    for op in [
+        "abs",
+        "sqrt",
+        "ceil",
+        "floor",
+        "round",
+        "reciprocal",
+        "trunc",
+        "softsign",
+        "hard_sigmoid",
+        "cos",
+        "sin",
+        "tan",
+        "cosh",
+        "sinh",
+        "tanh",
+        "arccos",
+        "arcsin",
+        "arctan",
+        "arccosh",
+        "arcsinh",
+        "arctanh",
+    ]:
         shape = (1, 3, 4, 5)
-        dtype = 'float32'
+        dtype = "float32"
         a_np = np.random.uniform(size=shape).astype(dtype)
-        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a')])
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var("a")])
         ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)])
-        shapes = {'a': shape}
+        shapes = {"a": shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(a_np)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5
+                )
 
 
 @tvm.testing.uses_gpu
 def test_forward_scalar_ops():
-    for op in [operator.add, operator.sub, operator.mul, operator.truediv,
-               operator.pow, operator.lt, operator.le, operator.eq,
-               operator.ne, operator.gt, operator.ge]:
-        dtype='float32'
+    for op in [
+        operator.add,
+        operator.sub,
+        operator.mul,
+        operator.truediv,
+        operator.pow,
+        operator.lt,
+        operator.le,
+        operator.eq,
+        operator.ne,
+        operator.gt,
+        operator.ge,
+    ]:
+        dtype = "float32"
         a_shape = (3, 4, 5)
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         b_scalar = 2.3
-        mx_sym = op(mx.sym.var('a'), b_scalar)
+        mx_sym = op(mx.sym.var("a"), b_scalar)
         ref_res = op(mx.nd.array(a_np), b_scalar)
-        shapes = {'a': a_shape}
+        shapes = {"a": a_shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
@@ -461,13 +539,13 @@ def test_forward_scalar_ops():
                 op_res = intrp.evaluate()(a_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
     for op in ["maximum", "minimum"]:
-        dtype='float32'
+        dtype = "float32"
         a_shape = (3, 4, 5)
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         b_scalar = 2.3
-        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var("a"), b_scalar])
         ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
-        shapes = {'a': a_shape}
+        shapes = {"a": a_shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
@@ -475,6 +553,7 @@ def test_forward_scalar_ops():
                 op_res = intrp.evaluate()(a_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
+
 @tvm.testing.uses_gpu
 def test_forward_slice_axis():
     def verify(shape, axis, begin, end):
@@ -487,12 +566,14 @@ def test_forward_slice_axis():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(data_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((3, 4), 0, 1, 2)
     verify((3, 4), 0, 1, None)
     verify((3, 4), 1, 0, 2)
     verify((3, 4), 1, -3, -1)
     verify((3, 4), -1, -3, -1)
 
+
 @tvm.testing.uses_gpu
 def test_forward_slice_like():
     def verify(x_shape, y_shape, axes):
@@ -510,11 +591,13 @@ def test_forward_slice_like():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np, y_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((3, 4), (2, 3), None)
     verify((3, 4), (2, 3), (0, 1))
     verify((3, 4), (2, 3), (0))
     verify((3, 4), (2, 3), (-1))
 
+
 @tvm.testing.uses_gpu
 def test_forward_sequence_reverse():
     def verify(shape, seq_lengths, use_seq_lengths, seq_axis):
@@ -523,7 +606,7 @@ def test_forward_sequence_reverse():
         ref_res_args = [mx.nd.array(data_np), None, use_seq_lengths, seq_axis]
         mx_sym_args = [mx.sym.var("data"), None, use_seq_lengths, seq_axis]
         from_mxnet_args = [{"data": shape}, {"data": "float32"}]
-        in_data= [data_np]
+        in_data = [data_np]
 
         if use_seq_lengths and seq_lengths:
             seq_lengths_np = np.array(seq_lengths).astype("int32")
@@ -549,12 +632,14 @@ def test_forward_sequence_reverse():
     # MXNet accepts axis value as 0 only
     # verify((3, 4, 5, 6), None, False, 2)
 
+
 @tvm.testing.uses_gpu
 def test_forward_l2_normalize():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.L2Normalization(data, mode="channel")
     verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
 
+
 @tvm.testing.uses_gpu
 def test_forward_shape_array():
     def verify(shape):
@@ -567,10 +652,12 @@ def test_forward_shape_array():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((1,))
     verify((3, 4, 5))
     verify((3, 4, 5, 6))
 
+
 @tvm.testing.uses_gpu
 def test_forward_squeeze():
     def verify(shape, axis):
@@ -587,19 +674,20 @@ def test_forward_squeeze():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((1, 3, 1), None)
     verify((1, 3, 1), 0)
     verify((1, 3, 1), 2)
     verify((1, 3, 1), (0, 2))
 
+
 @tvm.testing.uses_gpu
 def test_forward_broadcast_axis():
     def verify(shape, axis, size):
         x_np = np.random.uniform(size=shape).astype("float32")
-        for op in ["broadcast_axis",
-                   "broadcast_axes"]:
-            mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size])
-            ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size])
+        for op in ["broadcast_axis", "broadcast_axes"]:
+            mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var("x"), axis, size])
+            ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np), axis, size])
             mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
             for target, ctx in tvm.testing.enabled_targets():
                 for kind in ["graph", "debug"]:
@@ -631,11 +719,11 @@ def test_forward_broadcast_to():
 @tvm.testing.uses_gpu
 def test_forward_logical_not():
     a_shape = (3, 4, 5)
-    dtype = 'float32'
+    dtype = "float32"
     a_np = np.random.uniform(size=a_shape).astype(dtype)
-    mx_sym = mx.sym.logical_not(mx.sym.var('a'))
+    mx_sym = mx.sym.logical_not(mx.sym.var("a"))
     ref_res = mx.nd.logical_not(mx.nd.array(a_np))
-    shapes = {'a': a_shape}
+    shapes = {"a": a_shape}
     mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
     for target, ctx in tvm.testing.enabled_targets():
         for kind in ["graph", "debug"]:
@@ -658,38 +746,44 @@ def test_forward_full():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()()
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify(2, (3, 4), "float32")
     verify(2, (3, 4), "int32")
     verify(3.5, (1, 3, 4), "float32")
 
+
 @tvm.testing.uses_gpu
 def test_forward_embedding():
     def verify(data_shape, weight_shape):
         in_dim, out_dim = weight_shape
         x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
         w_np = np.random.uniform(size=weight_shape).astype("float32")
-        ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
-                                  input_dim=in_dim, output_dim=out_dim)
-        mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
-                                  input_dim=in_dim, output_dim=out_dim)
-        mod, _ = relay.frontend.from_mxnet(
-            mx_sym, {"x": data_shape, "w": weight_shape})
+        ref_res = mx.nd.Embedding(
+            mx.nd.array(x_np), mx.nd.array(w_np), input_dim=in_dim, output_dim=out_dim
+        )
+        mx_sym = mx.sym.Embedding(
+            mx.sym.var("x"), mx.sym.var("w"), input_dim=in_dim, output_dim=out_dim
+        )
+        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": data_shape, "w": weight_shape})
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x=x_np, w=w_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((2, 2), (4, 5))
     verify((2, 3, 4), (4, 5))
 
+
 @tvm.testing.uses_gpu
 def test_forward_smooth_l1():
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.smooth_l1(data)
     verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
     mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
     verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
 
+
 @tvm.testing.uses_gpu
 def test_forward_take():
     def verify(shape, indices_src, axis, mode="clip"):
@@ -703,13 +797,15 @@ def test_forward_take():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np, indices_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
-    verify((2,2), [[[1,0],[0,1]]], 0)
-    verify((2,2), [[[1,0],[0,1]]], 1)
-    verify((4,3,5,6), [[2,1,0,0]], -2)
-    verify((3,4), [-1, 5], 0)
-    verify((3,4), [-1, 5], 0, mode="wrap")
-    verify((3,4), [-1, 5], 1)
-    verify((3,4), [-1, 5], 1, mode="wrap")
+
+    verify((2, 2), [[[1, 0], [0, 1]]], 0)
+    verify((2, 2), [[[1, 0], [0, 1]]], 1)
+    verify((4, 3, 5, 6), [[2, 1, 0, 0]], -2)
+    verify((3, 4), [-1, 5], 0)
+    verify((3, 4), [-1, 5], 0, mode="wrap")
+    verify((3, 4), [-1, 5], 1)
+    verify((3, 4), [-1, 5], 1, mode="wrap")
+
 
 @tvm.testing.uses_gpu
 def test_forward_gather_nd():
@@ -717,7 +813,9 @@ def test_forward_gather_nd():
         x_data = np.random.uniform(size=xshape).astype("float32")
         ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
         mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
-        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
+        mod, _ = relay.frontend.from_mxnet(
+            mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"}
+        )
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
@@ -730,13 +828,15 @@ def test_forward_gather_nd():
     verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
     verify((1, 4), (1, 1), [[0]])
 
+
 @tvm.testing.uses_gpu
 def test_forward_bilinear_resize():
     # add tests including scale_height and scale_width when mxnet is updated to version 1.5
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
     verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
 
+
 @tvm.testing.uses_gpu
 def test_forward_grid_generator():
     def verify(shape, transform_type, target_shape):
@@ -747,14 +847,16 @@ def test_forward_grid_generator():
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
-                intrp = relay.create_executor(
-                    kind, mod=mod, ctx=ctx, target=target)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x)
                 tvm.testing.assert_allclose(
-                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
-    verify((4, 6), 'affine', (16, 32))
-    verify((4, 2, 16, 16), 'warp', None)
-    verify((1, 2, 16, 16), 'warp', None)
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5
+                )
+
+    verify((4, 6), "affine", (16, 32))
+    verify((4, 2, 16, 16), "warp", None)
+    verify((1, 2, 16, 16), "warp", None)
+
 
 @tvm.testing.uses_gpu
 def test_forward_bilinear_sampler():
@@ -767,23 +869,33 @@ def test_forward_bilinear_sampler():
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
-                intrp = relay.create_executor(
-                    kind, mod=mod, ctx=ctx, target=target)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(data, grid)
                 tvm.testing.assert_allclose(
-                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5
+                )
+
     verify((4, 4, 16, 32), (4, 2, 8, 8))
     verify((4, 4, 16, 32), (4, 2, 32, 32))
 
+
 @tvm.testing.uses_gpu
 def test_forward_rnn_layer():
-    def verify(mode, seq_len, input_size, hidden_size, num_layers,
-               batch=1, init_states=True, bidirectional=False):
+    def verify(
+        mode,
+        seq_len,
+        input_size,
+        hidden_size,
+        num_layers,
+        batch=1,
+        init_states=True,
+        bidirectional=False,
+    ):
         if mode == "rnn":
             layer = gluon.rnn.RNN(hidden_size, num_layers, bidirectional=bidirectional)
         elif mode == "gru":
             layer = gluon.rnn.GRU(hidden_size, num_layers, bidirectional=bidirectional)
-        else: # mode == "lstm"
+        else:  # mode == "lstm"
             layer = gluon.rnn.LSTM(hidden_size, num_layers, bidirectional=bidirectional)
         num_states = 2 if mode == "lstm" else 1
         layer.initialize()
@@ -795,22 +907,22 @@ def test_forward_rnn_layer():
         data_mx = mx.nd.array(data_np)
 
         if init_states:
-            shape_dict = {'data0': data_np.shape}
-            inputs = {'data0': data_np}
-            state_shape = (num_layers*directions, batch, hidden_size)
+            shape_dict = {"data0": data_np.shape}
+            inputs = {"data0": data_np}
+            state_shape = (num_layers * directions, batch, hidden_size)
             states_np = []
             states_mx = []
             for i in range(num_states):
                 s = np.random.uniform(size=state_shape).astype(dtype)
                 states_np.append(s)
                 states_mx.append(mx.nd.array(s))
-                shape_dict['data%s' % (i+1)] = s.shape
-                inputs['data%s' % (i+1)] = s
+                shape_dict["data%s" % (i + 1)] = s.shape
+                inputs["data%s" % (i + 1)] = s
             mx_out, mx_states = layer(data_mx, states_mx)
             mx_res = [mx_out] + mx_states
         else:
-            shape_dict = {'data': data_np.shape}
-            inputs = {'data': data_np}
+            shape_dict = {"data": data_np.shape}
+            inputs = {"data": data_np}
             mx_res = layer(data_mx)
 
         mx_sym = layer._cached_graph[1]
@@ -818,8 +930,7 @@ def test_forward_rnn_layer():
         for name, param in layer.collect_params().items():
             mx_params[name] = param._reduce()
 
-        mod, params = relay.frontend.from_mxnet(
-            mx_sym, shape=shape_dict, arg_params=mx_params)
+        mod, params = relay.frontend.from_mxnet(mx_sym, shape=shape_dict, arg_params=mx_params)
         for target, ctx in tvm.testing.enabled_targets():
             # only test graph runtime because debug runtime is too slow
             for kind in ["graph"]:
@@ -828,11 +939,9 @@ def test_forward_rnn_layer():
                 if init_states:
                     assert len(op_res) == len(mx_res)
                     for i, val in enumerate(op_res):
-                        tvm.testing.assert_allclose(
-                            val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
+                        tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
                 else:
-                    tvm.testing.assert_allclose(
-                        op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
+                    tvm.testing.assert_allclose(op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
 
     for mode in ["rnn", "gru", "lstm"]:
         verify(mode, 1, 64, 64, 1)
@@ -844,6 +953,7 @@ def test_forward_rnn_layer():
         # verify(mode, 10, 64, 64, 3, init_states=False)
         # verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False)
 
+
 @tvm.testing.uses_gpu
 def test_forward_Crop():
     def verify(xshape, yshape, offset=None):
@@ -864,12 +974,14 @@ def test_forward_Crop():
                 else:
                     op_res = intrp.evaluate()(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((1, 3, 40, 40), (1, 3, 20, 20))
     verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0))
     verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10))
     verify((5, 32, 40, 40), (5, 32, 25, 25))
     verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
 
+
 @tvm.testing.uses_gpu
 def test_forward_argsort():
     def verify(shape, axis, is_ascend, dtype="float32"):
@@ -882,18 +994,22 @@ def test_forward_argsort():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((2, 3, 4), axis=0, is_ascend=False)
     verify((1, 4, 6), axis=1, is_ascend=True)
     verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
 
+
 @tvm.testing.uses_gpu
 def test_forward_topk():
     def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
         x_np = np.random.uniform(size=shape).astype("float32")
-        ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
-                             is_ascend=is_ascend, dtype=dtype)
-        mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
-                             is_ascend=is_ascend, dtype=dtype)
+        ref_res = mx.nd.topk(
+            mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype
+        )
+        mx_sym = mx.sym.topk(
+            mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype
+        )
         mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
@@ -905,45 +1021,53 @@ def test_forward_topk():
                         tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
                 else:
                     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((3, 4), k=1, axis=0, ret_type="both")
     verify((3, 4), k=1, axis=-1, ret_type="indices")
     verify((3, 5, 6), k=2, axis=2, ret_type="value")
     verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
     verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
 
+
 @tvm.testing.uses_gpu
 def test_forward_sequence_mask():
     def verify(shape, use_sequence_length, value, axis, dtype, itype):
         data_np = np.random.uniform(size=shape).astype(dtype)
-        valid_length_np = np.random.randint(0, shape[axis], size=shape[1-axis]).astype(itype)
+        valid_length_np = np.random.randint(0, shape[axis], size=shape[1 - axis]).astype(itype)
         if use_sequence_length:
-            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
-                                         sequence_length=mx.nd.array(valid_length_np, dtype=itype),
-                                         use_sequence_length=use_sequence_length,
-                                         value=value,
-                                         axis=axis)
-            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
-                                         sequence_length=mx.sym.var('valid_length'),
-                                         use_sequence_length=use_sequence_length,
-                                         value=value,
-                                         axis=axis)
-            mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape,
-                                                        'valid_length': valid_length_np.shape},
-                                               dtype={"data": dtype,
-                                                      "valid_length": itype})
+            ref_res = mx.nd.SequenceMask(
+                mx.nd.array(data_np, dtype=dtype),
+                sequence_length=mx.nd.array(valid_length_np, dtype=itype),
+                use_sequence_length=use_sequence_length,
+                value=value,
+                axis=axis,
+            )
+            mx_sym = mx.sym.SequenceMask(
+                mx.sym.var("data"),
+                sequence_length=mx.sym.var("valid_length"),
+                use_sequence_length=use_sequence_length,
+                value=value,
+                axis=axis,
+            )
+            mod, _ = relay.frontend.from_mxnet(
+                mx_sym,
+                {"data": shape, "valid_length": valid_length_np.shape},
+                dtype={"data": dtype, "valid_length": itype},
+            )
         else:
-            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
-                                         use_sequence_length=use_sequence_length,
-                                         value=value,
-                                         axis=axis)
-            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
-                                         use_sequence_length=use_sequence_length,
-                                         value=value,
-                                         axis=axis)
+            ref_res = mx.nd.SequenceMask(
+                mx.nd.array(data_np, dtype=dtype),
+                use_sequence_length=use_sequence_length,
+                value=value,
+                axis=axis,
+            )
+            mx_sym = mx.sym.SequenceMask(
+                mx.sym.var("data"), use_sequence_length=use_sequence_length, value=value, axis=axis
+            )
             mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}, dtype={"data": dtype})
         for target, ctx in tvm.testing.enabled_targets():
-            for kind in ['graph', 'debug']:
-                if use_sequence_length is False and kind == 'graph':
+            for kind in ["graph", "debug"]:
+                if use_sequence_length is False and kind == "graph":
                     # Disable the test for 'graph' when it's identity.
                     continue
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
@@ -952,10 +1076,12 @@ def test_forward_sequence_mask():
                 else:
                     op_res = intrp.evaluate()(data_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
-    verify((5, 10), True, 0.0, 0, 'float32', 'float32')
-    verify((5, 4, 3), True, 1.0, 1, 'float32', 'float32')
-    verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
-    verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
+
+    verify((5, 10), True, 0.0, 0, "float32", "float32")
+    verify((5, 4, 3), True, 1.0, 1, "float32", "float32")
+    verify((5, 4, 3), False, 1.0, 1, "float64", "float64")
+    verify((5, 4, 3, 2), True, 1.0, 0, "float32", "float32")
+
 
 @tvm.testing.uses_gpu
 def test_forward_contrib_div_sqrt_dim():
@@ -969,9 +1095,11 @@ def test_forward_contrib_div_sqrt_dim():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((3, 4))
     verify((3, 4, 5))
 
+
 @tvm.testing.uses_gpu
 def test_forward_batch_norm():
     def verify(shape, axis=1, fix_gamma=False):
@@ -980,23 +1108,42 @@ def test_forward_batch_norm():
         beta = np.random.uniform(size=(shape[axis])).astype("float32")
         moving_mean = np.random.uniform(size=(shape[axis])).astype("float32")
         moving_var = np.abs(np.random.uniform(size=(shape[axis])).astype("float32")) + 0.5
-        ref_res = mx.nd.BatchNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta),
-                                  mx.nd.array(moving_mean), mx.nd.array(moving_var),
-                                  axis=axis, use_global_stats=True, fix_gamma=fix_gamma)
-        mx_sym = mx.sym.BatchNorm(mx.sym.var("x"), mx.sym.var("gamma"),
-                                  mx.sym.var("beta"), mx.sym.var("mean"),
-                                  mx.sym.var("var"), axis=axis, use_global_stats=True,
-                                  fix_gamma=fix_gamma)
-
-        shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape,
-                      "mean": moving_mean.shape, "var": moving_var.shape}
+        ref_res = mx.nd.BatchNorm(
+            mx.nd.array(x),
+            mx.nd.array(gamma),
+            mx.nd.array(beta),
+            mx.nd.array(moving_mean),
+            mx.nd.array(moving_var),
+            axis=axis,
+            use_global_stats=True,
+            fix_gamma=fix_gamma,
+        )
+        mx_sym = mx.sym.BatchNorm(
+            mx.sym.var("x"),
+            mx.sym.var("gamma"),
+            mx.sym.var("beta"),
+            mx.sym.var("mean"),
+            mx.sym.var("var"),
+            axis=axis,
+            use_global_stats=True,
+            fix_gamma=fix_gamma,
+        )
+
+        shape_dict = {
+            "x": x.shape,
+            "gamma": gamma.shape,
+            "beta": beta.shape,
+            "mean": moving_mean.shape,
+            "var": moving_var.shape,
+        }
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
-        #print(mod)
+        # print(mod)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+
     verify((2, 3, 4, 5))
     verify((2, 3, 4, 5), axis=0)
     verify((2, 3, 4, 5), axis=-1)
@@ -1010,14 +1157,19 @@ def test_forward_instance_norm():
         gamma = np.random.uniform(size=(shape[axis])).astype("float32")
         beta = np.random.uniform(size=(shape[axis])).astype("float32")
         ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
-        mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
+        mx_sym = mx.sym.InstanceNorm(
+            mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon
+        )
         shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, gamma, beta)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5
+                )
+
     verify((2, 3, 4, 5))
     verify((32, 64, 80, 64))
     verify((8, 6, 5))
@@ -1030,21 +1182,25 @@ def test_forward_layer_norm():
         x = np.random.uniform(size=shape).astype("float32")
         gamma = np.random.uniform(size=(shape[axis])).astype("float32")
         beta = np.random.uniform(size=(shape[axis])).astype("float32")
-        ref_res = mx.nd.LayerNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta),
-                                  axis=axis)
-        mx_sym = mx.sym.LayerNorm(mx.sym.var("x"), mx.sym.var("gamma"),
-                                  mx.sym.var("beta"), axis=axis)
+        ref_res = mx.nd.LayerNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), axis=axis)
+        mx_sym = mx.sym.LayerNorm(
+            mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), axis=axis
+        )
         shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, gamma, beta)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
+
     verify((2, 5))
     verify((2, 5), axis=0)
     verify((2, 5, 6))
 
+
 @tvm.testing.uses_gpu
 def test_forward_one_hot():
     def verify(indices_shape, depth, on_value, off_value, dtype):
@@ -1057,7 +1213,10 @@ def test_forward_one_hot():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x.astype("float32"))
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
+
     verify((3,), 3, 1, 0, "int32")
     verify((3,), 3, 1.0, 0.0, "float32")
     verify((2, 2), 5, 2, -2, "int32")
@@ -1065,40 +1224,77 @@ def test_forward_one_hot():
     verify((3, 2, 4, 5), 6, 1, 0, "int32")
     verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")
 
+
 @tvm.testing.uses_gpu
 def test_forward_pad():
     def verify(data_shape, out_shape, mode, pad_width, constant_value=0.0):
-        data = mx.sym.var('data')
+        data = mx.sym.var("data")
         mx_sym = mx.sym.pad(data, mode=mode, pad_width=pad_width, constant_value=constant_value)
         verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=out_shape)
 
-    verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant",
-           pad_width=(0,0,0,0,1,2,3,4))
-    verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant",
-           pad_width=(0,0,0,0,1,2,3,4), constant_value=3.0)
-    verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="edge",
-           pad_width=(0,0,0,0,1,2,3,4))
-    verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="reflect",
-           pad_width=(0,0,0,0,1,2,3,4))
-    verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant",
-           pad_width=(0,0,0,0,1,2,3,4,5,6))
-    verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant",
-           pad_width=(0,0,0,0,1,2,3,4,5,6), constant_value=3.0)
-    verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="edge",
-           pad_width=(0,0,0,0,1,2,3,4,5,6))
-    verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="reflect",
-           pad_width=(0,0,0,0,1,2,3,4,5,6))
+    verify(
+        data_shape=(1, 1, 3, 5),
+        out_shape=(1, 1, 6, 12),
+        mode="constant",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4),
+    )
+    verify(
+        data_shape=(1, 1, 3, 5),
+        out_shape=(1, 1, 6, 12),
+        mode="constant",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4),
+        constant_value=3.0,
+    )
+    verify(
+        data_shape=(1, 1, 3, 5),
+        out_shape=(1, 1, 6, 12),
+        mode="edge",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4),
+    )
+    verify(
+        data_shape=(1, 1, 3, 5),
+        out_shape=(1, 1, 6, 12),
+        mode="reflect",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4),
+    )
+    verify(
+        data_shape=(1, 1, 3, 5, 7),
+        out_shape=(1, 1, 6, 12, 18),
+        mode="constant",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4, 5, 6),
+    )
+    verify(
+        data_shape=(1, 1, 3, 5, 7),
+        out_shape=(1, 1, 6, 12, 18),
+        mode="constant",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4, 5, 6),
+        constant_value=3.0,
+    )
+    verify(
+        data_shape=(1, 1, 3, 5, 7),
+        out_shape=(1, 1, 6, 12, 18),
+        mode="edge",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4, 5, 6),
+    )
+    verify(
+        data_shape=(1, 1, 3, 5, 7),
+        out_shape=(1, 1, 6, 12, 18),
+        mode="reflect",
+        pad_width=(0, 0, 0, 0, 1, 2, 3, 4, 5, 6),
+    )
 
 
 @tvm.testing.uses_gpu
 def test_forward_slice():
     def verify(data_shape, out_shape, begin, end):
-        data = mx.sym.var('data')
+        data = mx.sym.var("data")
         mx_sym = mx.sym.slice(data, begin=begin, end=end)
         verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=out_shape)
 
-    verify(data_shape=(1,1,10), out_shape=(1,1,8), begin=(0, 0, 2), end=(1, 1, 10))
-    verify(data_shape=(1,1,10), out_shape=(1,1,8), begin=(None, None, 2), end=(None, None, None))
+    verify(data_shape=(1, 1, 10), out_shape=(1, 1, 8), begin=(0, 0, 2), end=(1, 1, 10))
+    verify(
+        data_shape=(1, 1, 10), out_shape=(1, 1, 8), begin=(None, None, 2), end=(None, None, None)
+    )
 
 
 @tvm.testing.uses_gpu
@@ -1106,19 +1302,39 @@ def test_forward_convolution():
     def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False):
         if is_depthwise:
             groups = data_shape[1]
-            weight_shape=(data_shape[1], num_filter // groups,) + kernel_size
+            weight_shape = (
+                data_shape[1],
+                num_filter // groups,
+            ) + kernel_size
         else:
             groups = 1
-            weight_shape=(num_filter, data_shape[1],) + kernel_size
+            weight_shape = (
+                num_filter,
+                data_shape[1],
+            ) + kernel_size
         x = np.random.uniform(size=data_shape).astype("float32")
         weight = np.random.uniform(size=weight_shape).astype("float32")
         bias = np.random.uniform(size=num_filter).astype("float32")
-        ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight),
-                                    bias=mx.nd.array(bias), kernel=kernel_size, stride=stride,
-                                    pad=pad, num_filter=num_filter, num_group=groups)
-        mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
-                                    kernel=kernel_size, stride=stride,
-                                    pad=pad, num_filter=num_filter, num_group=groups)
+        ref_res = mx.nd.Convolution(
+            data=mx.nd.array(x),
+            weight=mx.nd.array(weight),
+            bias=mx.nd.array(bias),
+            kernel=kernel_size,
+            stride=stride,
+            pad=pad,
+            num_filter=num_filter,
+            num_group=groups,
+        )
+        mx_sym = mx.sym.Convolution(
+            mx.sym.var("x"),
+            mx.sym.var("weight"),
+            mx.sym.var("bias"),
+            kernel=kernel_size,
+            stride=stride,
+            pad=pad,
+            num_filter=num_filter,
+            num_group=groups,
+        )
         shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
@@ -1127,51 +1343,99 @@ def test_forward_convolution():
                 op_res = intrp.evaluate()(x, weight, bias)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
 
-    verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(20, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(1, 8, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(20, 8, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
     verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
-    verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
-           is_depthwise=True)
-    verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
-    verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
-    verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2)
-    verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
+    verify(
+        data_shape=(1, 8, 32, 32),
+        kernel_size=(3, 3),
+        stride=(1, 1),
+        pad=(1, 1),
+        num_filter=8,
+        is_depthwise=True,
+    )
+    verify(
+        data_shape=(1, 1, 16, 16, 16),
+        kernel_size=(3, 3, 3),
+        stride=(1, 1, 1),
+        pad=(1, 1, 1),
+        num_filter=2,
+    )
+    verify(
+        data_shape=(20, 1, 16, 16, 16),
+        kernel_size=(3, 3, 3),
+        stride=(1, 1, 1),
+        pad=(1, 1, 1),
+        num_filter=2,
+    )
+    verify(
+        data_shape=(1, 8, 16, 16, 16),
+        kernel_size=(3, 3, 3),
+        stride=(2, 2, 2),
+        pad=(1, 1, 1),
+        num_filter=2,
+    )
+    verify(
+        data_shape=(20, 8, 16, 16, 16),
+        kernel_size=(3, 3, 3),
+        stride=(1, 1, 1),
+        pad=(1, 1, 1),
+        num_filter=2,
+    )
+
 
 @tvm.testing.uses_gpu
 def test_forward_deconvolution():
     def verify(data_shape, kernel_size, stride, pad, num_filter):
-        weight_shape=(data_shape[1], num_filter) + kernel_size
+        weight_shape = (data_shape[1], num_filter) + kernel_size
         x = np.random.uniform(size=data_shape).astype("float32")
         weight = np.random.uniform(size=weight_shape).astype("float32")
         bias = np.random.uniform(size=num_filter).astype("float32")
-        ref_res = mx.nd.Deconvolution(data=mx.nd.array(x), weight=mx.nd.array(weight), bias=mx.nd.array(bias),
-                                      kernel=kernel_size, stride=stride,
-                                      pad=pad, num_filter=num_filter, no_bias=False)
-        mx_sym = mx.sym.Deconvolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
-                                      kernel=kernel_size, stride=stride,
-                                      pad=pad, num_filter=num_filter, no_bias=False)
+        ref_res = mx.nd.Deconvolution(
+            data=mx.nd.array(x),
+            weight=mx.nd.array(weight),
+            bias=mx.nd.array(bias),
+            kernel=kernel_size,
+            stride=stride,
+            pad=pad,
+            num_filter=num_filter,
+            no_bias=False,
+        )
+        mx_sym = mx.sym.Deconvolution(
+            mx.sym.var("x"),
+            mx.sym.var("weight"),
+            mx.sym.var("bias"),
+            kernel=kernel_size,
+            stride=stride,
+            pad=pad,
+            num_filter=num_filter,
+            no_bias=False,
+        )
         shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, weight, bias)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
 
-    verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
-    verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(20, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(1, 8, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
+    verify(data_shape=(20, 8, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
     verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
 
+
 @tvm.testing.uses_gpu
 def test_forward_cond():
     def verify(a_np, b_np):
@@ -1195,17 +1459,18 @@ def test_forward_cond():
                 op_res = intrp.evaluate()(a_np, b_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
 
-    verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
-    verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
+    verify(np.asarray([1.0], "float32"), np.asarray([2.0], "float32"))
+    verify(np.asarray([4.0], "float32"), np.asarray([3.0], "float32"))
+
 
 @tvm.testing.uses_gpu
 def test_forward_amp_cast():
     def verify(from_dtype, to_dtype):
-        from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype)
-        x_var = mx.sym.var('x', dtype=from_dtype)
+        from_np = np.random.uniform(size=(1, 3, 18)).astype(from_dtype)
+        x_var = mx.sym.var("x", dtype=from_dtype)
         mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
-        shape_dict = {'x': (1,3,18)}
-        dtype_dict = {'x': from_dtype}
+        shape_dict = {"x": (1, 3, 18)}
+        dtype_dict = {"x": from_dtype}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "vm", "debug"]:
@@ -1214,20 +1479,20 @@ def test_forward_amp_cast():
                 assert op_res.dtype == to_dtype, op_res.dtype
                 tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype))
 
-    verify('float32', 'float16')
-    verify('float16', 'float32')
+    verify("float32", "float16")
+    verify("float16", "float32")
+
 
 @tvm.testing.uses_gpu
 def test_forward_amp_multicast():
     def verify(dtypes, cast_narrow, expected_dtype):
-        x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes]
+        x_nps = [np.random.uniform(size=(1, 3, 18)).astype(dtype) for dtype in dtypes]
         x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)]
-        mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow,
-                                      num_outputs=len(dtypes))
+        mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow, num_outputs=len(dtypes))
         shape_dict = {}
         dtype_dict = {}
         for i, dtype in enumerate(dtypes):
-            shape_dict[str(i)] = (1,3,18)
+            shape_dict[str(i)] = (1, 3, 18)
             dtype_dict[str(i)] = dtype
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
         for target, ctx in tvm.testing.enabled_targets():
@@ -1238,21 +1503,21 @@ def test_forward_amp_multicast():
                     assert res.dtype == expected_dtype, res.dtype
                     tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype))
 
-    verify(['float32', 'float16'], False, 'float32')
-    verify(['float32', 'float16'], True,  'float16')
-    verify(['float32', 'float32'], False, 'float32')
-    verify(['float32', 'float32'], True,  'float32')
-    verify(['float16', 'float16'], False, 'float16')
-    verify(['float16', 'float16'], True, 'float16')
+    verify(["float32", "float16"], False, "float32")
+    verify(["float32", "float16"], True, "float16")
+    verify(["float32", "float32"], False, "float32")
+    verify(["float32", "float32"], True, "float32")
+    verify(["float16", "float16"], False, "float16")
+    verify(["float16", "float16"], True, "float16")
 
 
 @tvm.testing.uses_gpu
 def test_forward_unravel_index():
     def verify(x, shape, dtype):
         a_np = np.array(x).astype(dtype)
-        mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
-        ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
-        shapes = {'a': a_np.shape}
+        mx_sym = _mx_symbol(mx.sym, "unravel_index", [mx.sym.var("a"), shape])
+        ref_res = _mx_symbol(mx.nd, "unravel_index", [mx.nd.array(a_np), shape])
+        shapes = {"a": a_np.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
 
         for target, ctx in tvm.testing.enabled_targets():
@@ -1276,7 +1541,7 @@ def test_forward_unravel_index():
 @tvm.testing.uses_gpu
 def test_forward_swap_axis():
     def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
-        data = mx.sym.var('data')
+        data = mx.sym.var("data")
         mx_sym = mx.sym.swapaxes(data, dim1, dim2)
         verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape)
 
@@ -1292,13 +1557,17 @@ def test_forward_depth_to_space():
         x = np.random.uniform(size=shape).astype("float32")
         ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize)
         mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize)
-        shape_dict = {"x": x.shape, }
+        shape_dict = {
+            "x": x.shape,
+        }
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
 
     verify((1, 18, 3, 3), 3)
 
@@ -1309,48 +1578,137 @@ def test_forward_space_to_depth():
         x = np.random.uniform(size=shape).astype("float32")
         ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize)
         mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize)
-        shape_dict = {"x": x.shape, }
+        shape_dict = {
+            "x": x.shape,
+        }
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
 
     verify((1, 1, 9, 9), 3)
 
 
 @tvm.testing.uses_gpu
 def test_forward_correlation():
-    def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size,
-               is_multiply):
+    def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply):
         data1 = np.random.uniform(size=data_shape).astype("float32")
         data2 = np.random.uniform(size=data_shape).astype("float32")
-        ref_res = mx.nd.Correlation(data1=mx.nd.array(data1), data2=mx.nd.array(data2),
-                                    kernel_size=kernel_size, max_displacement=max_displacement,
-                                    stride1=stride1, stride2=stride2, pad_size=pad_size,
-                                    is_multiply=is_multiply)
-        mx_sym = mx.sym.Correlation(data1=mx.sym.var('data1'), data2=mx.sym.var('data2'),
-                                    kernel_size=kernel_size, max_displacement=max_displacement,
-                                    stride1=stride1, stride2=stride2, pad_size=pad_size,
-                                    is_multiply=is_multiply)
+        ref_res = mx.nd.Correlation(
+            data1=mx.nd.array(data1),
+            data2=mx.nd.array(data2),
+            kernel_size=kernel_size,
+            max_displacement=max_displacement,
+            stride1=stride1,
+            stride2=stride2,
+            pad_size=pad_size,
+            is_multiply=is_multiply,
+        )
+        mx_sym = mx.sym.Correlation(
+            data1=mx.sym.var("data1"),
+            data2=mx.sym.var("data2"),
+            kernel_size=kernel_size,
+            max_displacement=max_displacement,
+            stride1=stride1,
+            stride2=stride2,
+            pad_size=pad_size,
+            is_multiply=is_multiply,
+        )
         shape_dict = {"data1": data1.shape, "data2": data2.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(data1, data2)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
-
-    verify((1, 3, 10, 10), kernel_size = 1, max_displacement = 4, stride1 = 1, stride2 = 1, pad_size = 4, is_multiply = False)
-    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = False)
-    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = True)
-    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 10, stride1 = 1, stride2 = 2, pad_size = 10, is_multiply = True)
-    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = True)
-    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = True)
-    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False)
-    verify((5, 1, 6, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False)
-    verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
+
+    verify(
+        (1, 3, 10, 10),
+        kernel_size=1,
+        max_displacement=4,
+        stride1=1,
+        stride2=1,
+        pad_size=4,
+        is_multiply=False,
+    )
+    verify(
+        (5, 1, 15, 15),
+        kernel_size=1,
+        max_displacement=5,
+        stride1=1,
+        stride2=1,
+        pad_size=5,
+        is_multiply=False,
+    )
+    verify(
+        (5, 1, 15, 15),
+        kernel_size=1,
+        max_displacement=5,
+        stride1=1,
+        stride2=1,
+        pad_size=5,
+        is_multiply=True,
+    )
+    verify(
+        (5, 1, 15, 15),
+        kernel_size=1,
+        max_displacement=10,
+        stride1=1,
+        stride2=2,
+        pad_size=10,
+        is_multiply=True,
+    )
+    verify(
+        (5, 1, 4, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=1,
+        stride2=1,
+        pad_size=2,
+        is_multiply=True,
+    )
+    verify(
+        (5, 1, 4, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=1,
+        pad_size=2,
+        is_multiply=True,
+    )
+    verify(
+        (5, 1, 4, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=1,
+        pad_size=2,
+        is_multiply=False,
+    )
+    verify(
+        (5, 1, 6, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=1,
+        pad_size=2,
+        is_multiply=False,
+    )
+    verify(
+        (5, 1, 11, 11),
+        kernel_size=5,
+        max_displacement=1,
+        stride1=1,
+        stride2=1,
+        pad_size=2,
+        is_multiply=False,
+    )
 
 
 @tvm.testing.uses_gpu
@@ -1358,15 +1716,15 @@ def test_forward_arange_like():
     def verify(data_shape, start=None, step=None, axis=None):
         attrs = {}
         if start is not None:
-            attrs['start'] = start
+            attrs["start"] = start
         if step is not None:
-            attrs['step'] = step
+            attrs["step"] = step
         if axis is not None:
-            attrs['axis'] = axis
-        data = mx.sym.var('data')
+            attrs["axis"] = axis
+        data = mx.sym.var("data")
         data_np = np.random.uniform(size=data_shape).astype("float32")
         ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)
-        
+
         mx_sym = mx.sym.contrib.arange_like(data, **attrs)
         mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
         for target, ctx in tvm.testing.enabled_targets():
@@ -1375,20 +1733,19 @@ def test_forward_arange_like():
                 op_res = intrp.evaluate()()
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
-    verify(data_shape=(3,), start=0., step=1.)
-    verify(data_shape=(3, 4, 5), start=0., step=1.)
-    verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
-    verify(data_shape=(3, 4, 5), start=2., step=3., axis=1)
+    verify(data_shape=(3,), start=0.0, step=1.0)
+    verify(data_shape=(3, 4, 5), start=0.0, step=1.0)
+    verify(data_shape=(3, 4, 5), start=0.0, step=1.0, axis=-1)
+    verify(data_shape=(3, 4, 5), start=2.0, step=3.0, axis=1)
 
 
 @tvm.testing.uses_gpu
 def test_forward_interleaved_matmul_selfatt_qk():
     def verify(batch, seq_length, num_heads, head_dim):
         data_shape = (seq_length, batch, num_heads * head_dim * 3)
-        data = mx.sym.var('data')
-        data_np = np.random.uniform(size=data_shape).astype('float32')
-        ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(
-            mx.nd.array(data_np), heads=num_heads)
+        data = mx.sym.var("data")
+        data_np = np.random.uniform(size=data_shape).astype("float32")
+        ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(mx.nd.array(data_np), heads=num_heads)
 
         mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads)
         mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
@@ -1407,17 +1764,16 @@ def test_forward_interleaved_matmul_selfatt_valatt():
     def verify(batch, seq_length, num_heads, head_dim):
         data_shape = (seq_length, batch, num_heads * head_dim * 3)
         weight_shape = (batch * num_heads, seq_length, seq_length)
-        data = mx.sym.var('data')
-        weight = mx.sym.var('weight')
-        data_np = np.random.uniform(size=data_shape).astype('float32')
-        weight_np = np.random.uniform(size=weight_shape).astype('float32')
+        data = mx.sym.var("data")
+        weight = mx.sym.var("weight")
+        data_np = np.random.uniform(size=data_shape).astype("float32")
+        weight_np = np.random.uniform(size=weight_shape).astype("float32")
         ref_res = mx.nd.contrib.interleaved_matmul_selfatt_valatt(
-            mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads)
+            mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads
+        )
 
-        mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
-            data, weight, heads=num_heads)
-        mod, _ = relay.frontend.from_mxnet(
-            mx_sym, {"data": data_shape, "weight": weight_shape})
+        mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(data, weight, heads=num_heads)
+        mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape, "weight": weight_shape})
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
@@ -1434,15 +1790,35 @@ def test_forward_box_decode():
         dtype = "float32"
         data = np.random.uniform(low=-2, high=2, size=data_shape).astype(dtype)
         anchors = np.random.uniform(low=-2, high=2, size=anchor_shape).astype(dtype)
-        ref_res = mx.nd.contrib.box_decode(mx.nd.array(data), mx.nd.array(anchors), stds[0], stds[1], stds[2], stds[3], clip, in_format)
-        mx_sym = mx.sym.contrib.box_decode(mx.sym.var("data"), mx.sym.var("anchors"), stds[0], stds[1], stds[2], stds[3], clip, in_format)
+        ref_res = mx.nd.contrib.box_decode(
+            mx.nd.array(data),
+            mx.nd.array(anchors),
+            stds[0],
+            stds[1],
+            stds[2],
+            stds[3],
+            clip,
+            in_format,
+        )
+        mx_sym = mx.sym.contrib.box_decode(
+            mx.sym.var("data"),
+            mx.sym.var("anchors"),
+            stds[0],
+            stds[1],
+            stds[2],
+            stds[3],
+            clip,
+            in_format,
+        )
         shape_dict = {"data": data_shape, "anchors": anchor_shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(data, anchors)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
 
     verify((1, 10, 4), (1, 10, 4))
     verify((4, 10, 4), (1, 10, 4))
@@ -1457,12 +1833,18 @@ def test_forward_softmax():
         dtype = "float32"
         x = np.random.uniform(low=-100, high=100, size=data_shape).astype(dtype)
         if use_length:
-            ref_res = mx.nd.softmax(data=mx.nd.array(x),
-                                    length=mx.nd.array(length, dtype="int32"),
-                                    axis=axis, use_length=use_length)
-            mx_sym = mx.symbol.softmax(data=mx.sym.var("data"),
-                                       length=mx.sym.var("length"),
-                                       axis=axis, use_length=use_length)
+            ref_res = mx.nd.softmax(
+                data=mx.nd.array(x),
+                length=mx.nd.array(length, dtype="int32"),
+                axis=axis,
+                use_length=use_length,
+            )
+            mx_sym = mx.symbol.softmax(
+                data=mx.sym.var("data"),
+                length=mx.sym.var("length"),
+                axis=axis,
+                use_length=use_length,
+            )
             shape_dict = {"data": data_shape, "length": (length.shape)}
             dtype_dict = {"data": dtype, "length": "int32"}
             mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
@@ -1480,33 +1862,39 @@ def test_forward_softmax():
                 else:
                     op_res = intrp.evaluate()(x)
 
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
+                )
 
     verify((2, 3, 5), -1, False, None)
     verify((2, 3, 5), 2, False, None)
-    verify((2, 3), -1, True, np.array([2, 1]).astype('int32'))
-    verify((2, 3, 4), -1, True, np.array([[3, 4, 2], [2, 1, 1]]).astype('int32'))
-    verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype('int32'))
+    verify((2, 3), -1, True, np.array([2, 1]).astype("int32"))
+    verify((2, 3, 4), -1, True, np.array([[3, 4, 2], [2, 1, 1]]).astype("int32"))
+    verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype("int32"))
 
 
-@pytest.mark.skipif(not hasattr(mx.sym.np, 'pad'), reason="mx.sym.np.pad hasn't been publish yet")
+@pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet")
 @pytest.mark.parametrize(
     "data_shape, pad_width",
-    [((1,1,3,5),(0,0,0,0,1,2,3,4)), ((1,1,3,5,7),(0,0,0,0,1,2,3,4,5,6))]
+    [((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))],
 )
 @pytest.mark.parametrize("mode", ["constant", "edge", "reflect"])
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32'])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
 @pytest.mark.parametrize("constant_value", [0.0, 3.0])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value,target, ctx, kind):
+def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, target, ctx, kind):
     data_np = np.random.uniform(size=data_shape).astype(dtype)
-    data = mx.sym.var('data')
-    if mode == 'constant':
-        ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode,pad_width=pad_width, constant_value=constant_value)
-        mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value)
+    data = mx.sym.var("data")
+    if mode == "constant":
+        ref_res = mx.ndarray.pad(
+            mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value
+        )
+        mx_sym = mx.sym.np.pad(
+            data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value
+        )
     else:
-        ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode,pad_width=pad_width)
+        ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width)
         mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width)
     mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
     intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
@@ -1514,15 +1902,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value,targ
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.skipif(not hasattr(mx.sym.np, 'pad'), reason="test'll abort with Mxnet 1.x, skip for now")
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2)])
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
-@pytest.mark.parametrize("axes", [(1,0,2),None])
+@pytest.mark.skipif(
+    not hasattr(mx.sym.np, "pad"), reason="test'll abort with Mxnet 1.x, skip for now"
+)
+@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2)])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"])
+@pytest.mark.parametrize("axes", [(1, 0, 2), None])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_transpose(data_shape, axes, dtype,target, ctx, kind):
+def test_forward_npi_transpose(data_shape, axes, dtype, target, ctx, kind):
     data_np = np.random.uniform(size=data_shape).astype(dtype)
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     ref_res = mx.np.transpose(mx.np.array(data_np), axes=axes)
     mx_sym = mx.sym.np.transpose(data.as_np_ndarray(), axes=axes)
     mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
@@ -1533,31 +1923,39 @@ def test_forward_npi_transpose(data_shape, axes, dtype,target, ctx, kind):
 
 @pytest.mark.parametrize(
     "data_shape1, data_shape2, axis",
-    [((2,2),(2,2),1),((2,4),(2,3),1),((1,3,2),(1,3,5),2),((1,3,3),(1,3,3),1),((1,3),(1,3),0)]
+    [
+        ((2, 2), (2, 2), 1),
+        ((2, 4), (2, 3), 1),
+        ((1, 3, 2), (1, 3, 5), 2),
+        ((1, 3, 3), (1, 3, 3), 1),
+        ((1, 3), (1, 3), 0),
+    ],
 )
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32'])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype,target, ctx, kind):
+def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype, target, ctx, kind):
     data_np1 = np.random.uniform(size=data_shape1).astype(dtype)
     data_np2 = np.random.uniform(size=data_shape2).astype(dtype)
-    data1 = mx.sym.var('data1')
-    data2 = mx.sym.var('data2')
+    data1 = mx.sym.var("data1")
+    data2 = mx.sym.var("data2")
     ref_res = mx.np.concatenate([mx.np.array(data_np1), mx.np.array(data_np2)], axis=axis)
     mx_sym = mx.sym.np.concatenate([data1.as_np_ndarray(), data2.as_np_ndarray()], axis=axis)
-    mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype)
+    mod, _ = relay.frontend.from_mxnet(
+        mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype
+    )
     intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
     op_res = intrp.evaluate()(data_np1, data_np2)
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8)])
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
+@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8)])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_np_copy(data_shape,dtype,target, ctx, kind):
+def test_forward_np_copy(data_shape, dtype, target, ctx, kind):
     data_np = np.random.uniform(size=data_shape).astype(dtype)
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     ref_res = mx.np.copy(mx.np.array(data_np))
     mx_sym = mx.sym.np.copy(data.as_np_ndarray())
     mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
@@ -1566,18 +1964,22 @@ def test_forward_np_copy(data_shape,dtype,target, ctx, kind):
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-@pytest.mark.parametrize("data_shape,out_shape,reverse",
-                         [((2, 3, 8),(-2, -2, 2, -1),False),
-                          ((8, 3, 3, 3, 4, 4),(-6, 2, -1, -4),False),
-                          ((8, 3, 3, 3, 4, 4),(-5, -4),False),
-                          ((8, 3, 3, 3, 3, 8),(-4, -5),True),
-                          ((8, 3, 2, 4, 8),(-4, -1, 2, -6),True)])
-def test_forward_npx_reshape(data_shape,out_shape,dtype,target,reverse, ctx, kind):
+@pytest.mark.parametrize(
+    "data_shape,out_shape,reverse",
+    [
+        ((2, 3, 8), (-2, -2, 2, -1), False),
+        ((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False),
+        ((8, 3, 3, 3, 4, 4), (-5, -4), False),
+        ((8, 3, 3, 3, 3, 8), (-4, -5), True),
+        ((8, 3, 2, 4, 8), (-4, -1, 2, -6), True),
+    ],
+)
+def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind):
     data_np = np.random.uniform(size=data_shape).astype(dtype)
-    data = mx.sym.var('data')
+    data = mx.sym.var("data")
     ref_res = mx.npx.reshape(mx.np.array(data_np), newshape=out_shape, reverse=reverse)
     mx_sym = mx.sym.npx.reshape(data.as_np_ndarray(), newshape=out_shape, reverse=reverse)
     mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
@@ -1586,47 +1988,53 @@ def test_forward_npx_reshape(data_shape,out_shape,dtype,target,reverse, ctx, kin
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32'])
+@pytest.mark.parametrize(
+    "data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8), (2, 2), (1, 3)]
+)
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_binary(data_shape,dtype,target, ctx, kind):
+def test_forward_npi_binary(data_shape, dtype, target, ctx, kind):
     ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.less]
     mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, mx.sym.np.less]
     for i in range(len(ref_ops)):
         ref_op = ref_ops[i]
         mx_op = mx_ops[i]
         # mx.np.power only support float type
-        if ref_op == mx.np.power and dtype not in ['float64', 'float32']:
+        if ref_op == mx.np.power and dtype not in ["float64", "float32"]:
             continue
         data_np1 = np.random.uniform(size=data_shape).astype(dtype)
         data_np2 = np.random.uniform(size=data_shape).astype(dtype)
-        data1 = mx.sym.var('lhs')
-        data2 = mx.sym.var('rhs')
+        data1 = mx.sym.var("lhs")
+        data2 = mx.sym.var("rhs")
         ref_res = ref_op(mx.np.array(data_np1), mx.np.array(data_np2))
         mx_sym = mx_op(data1.as_np_ndarray(), data2.as_np_ndarray())
-        mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape, "rhs": data_shape}, dtype=dtype)
+        mod, _ = relay.frontend.from_mxnet(
+            mx_sym, shape={"lhs": data_shape, "rhs": data_shape}, dtype=dtype
+        )
         intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
         op_res = intrp.evaluate()(data_np1, data_np2)
         tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32'])
+@pytest.mark.parametrize(
+    "data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8), (2, 2), (1, 3)]
+)
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
 @tvm.testing.parametrize_targets
-@pytest.mark.parametrize("scalar", [1.0,2.0,3.0,4.0])
+@pytest.mark.parametrize("scalar", [1.0, 2.0, 3.0, 4.0])
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_binary_scalar(data_shape,dtype,scalar,target, ctx, kind):
+def test_forward_npi_binary_scalar(data_shape, dtype, scalar, target, ctx, kind):
     ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.true_divide]
     mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, mx.sym.np.true_divide]
     for i in range(len(ref_ops)):
         ref_op = ref_ops[i]
         mx_op = mx_ops[i]
         # mx.np.power only support float type
-        if ref_op == mx.np.power and dtype not in ['float64', 'float32']:
+        if ref_op == mx.np.power and dtype not in ["float64", "float32"]:
             continue
         data_np1 = np.random.uniform(size=data_shape).astype(dtype)
-        data1 = mx.sym.var('lhs')
+        data1 = mx.sym.var("lhs")
         ref_res = ref_op(mx.np.array(data_np1), scalar)
         mx_sym = mx_op(data1.as_np_ndarray(), scalar)
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape}, dtype=dtype)
@@ -1635,13 +2043,15 @@ def test_forward_npi_binary_scalar(data_shape,dtype,scalar,target, ctx, kind):
         tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)])
-@pytest.mark.parametrize("dtype", ['float64', 'float32'])
+@pytest.mark.parametrize(
+    "data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8), (2, 2), (1, 3)]
+)
+@pytest.mark.parametrize("dtype", ["float64", "float32"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_tanh(data_shape,dtype,target, ctx, kind):
+def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind):
     data_np1 = np.random.uniform(size=data_shape).astype(dtype)
-    data1 = mx.sym.var('data')
+    data1 = mx.sym.var("data")
     ref_res = mx.np.tanh(mx.np.array(data_np1))
     mx_sym = mx.sym.np.tanh(data1.as_np_ndarray())
     mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype)
@@ -1650,43 +2060,56 @@ def test_forward_npi_tanh(data_shape,dtype,target, ctx, kind):
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.skipif(not hasattr(mx.np, 'where'), reason="mx.np.where hasn't been publish yet")
-@pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(1,8),(2,2),(1,3)])
-@pytest.mark.parametrize("data_dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
-@pytest.mark.parametrize("cond_dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
-@pytest.mark.parametrize("scalar", [1.0,2.0])
+@pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet")
+@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)])
+@pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"])
+@pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"])
+@pytest.mark.parametrize("scalar", [1.0, 2.0])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-def test_forward_npi_where_rscalar(data_shape,cond_dtype,data_dtype,scalar,target, ctx, kind):
-    if data_dtype == 'bool':
+def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind):
+    if data_dtype == "bool":
         scalar = scalar == 0.0
     cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
     data_np = np.random.uniform(size=data_shape).astype(data_dtype)
-    cond = mx.sym.var('condition')
-    data = mx.sym.var('x')
+    cond = mx.sym.var("condition")
+    data = mx.sym.var("x")
     ref_res = mx.np.where(mx.np.array(cond_np), mx.np.array(data_np), scalar)
     mx_sym = mx.sym.np.where(cond.as_np_ndarray(), data.as_np_ndarray(), scalar)
     dtypeDic = {}
     dtypeDic["condition"] = cond_dtype
     dtypeDic["x"] = data_dtype
     mod, _ = relay.frontend.from_mxnet(
-        mx_sym, shape={"condition": data_shape, "x": data_shape}, 
-        dtype=dtypeDic)
+        mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic
+    )
     intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
     op_res = intrp.evaluate()(cond_np, data_np)
     tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
 
 
-@pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool'])
+@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"])
 @tvm.testing.parametrize_targets
 @pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
-@pytest.mark.parametrize("data_shape, axis, indices_or_sections, squeeze_axis", 
-                         [((3,2,1),1,2,False),((3,2,1),0,3,False),((3,2,1),0,3,True),((3,2,1),0,(1,2),False)])
-def test_forward_split_v2(data_shape, axis, dtype, indices_or_sections, squeeze_axis, target, ctx, kind):
+@pytest.mark.parametrize(
+    "data_shape, axis, indices_or_sections, squeeze_axis",
+    [
+        ((3, 2, 1), 1, 2, False),
+        ((3, 2, 1), 0, 3, False),
+        ((3, 2, 1), 0, 3, True),
+        ((3, 2, 1), 0, (1, 2), False),
+    ],
+)
+def test_forward_split_v2(
+    data_shape, axis, dtype, indices_or_sections, squeeze_axis, target, ctx, kind
+):
     data_np = np.random.uniform(size=data_shape).astype(dtype)
-    data = mx.sym.var('data')
-    ref_res = mx.ndarray.split_v2(mx.nd.array(data_np), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis)
-    mx_sym = mx.sym.split_v2(data.as_nd_ndarray(), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis)
+    data = mx.sym.var("data")
+    ref_res = mx.ndarray.split_v2(
+        mx.nd.array(data_np), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis
+    )
+    mx_sym = mx.sym.split_v2(
+        data.as_nd_ndarray(), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis
+    )
     mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
     intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
     op_res = intrp.evaluate()(data_np)
@@ -1694,10 +2117,10 @@ def test_forward_split_v2(data_shape, axis, dtype, indices_or_sections, squeeze_
     for arr in op_res:
         op_res_.append(arr.asnumpy().tolist())
     ref_res_ = []
-    for arr in  ref_res:
+    for arr in ref_res:
         ref_res_.append(arr.asnumpy().tolist())
     tvm.testing.assert_allclose(op_res_, ref_res_, rtol=1e-5)
 
 
-if __name__ == '__main__':
-    pytest.main(['test_forward.py'])
+if __name__ == "__main__":
+    pytest.main(["test_forward.py"])
index b7c01a5..5c009fe 100644 (file)
@@ -22,11 +22,13 @@ from tvm import relay
 from tvm.relay import transform
 import model_zoo
 
+
 def compare_graph(lhs_mod, rhs_mod):
     lhs_mod = transform.InferType()(lhs_mod)
     rhs_mod = transform.InferType()(rhs_mod)
     assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
 
+
 def test_mlp():
     shape = {"data": (1, 1, 28, 28)}
     mx_fun = model_zoo.mx_mlp()
@@ -55,7 +57,7 @@ def test_resnet():
 
 def test_squeezenet():
     shape = {"data": (1, 3, 224, 224)}
-    for version in ['1.0', '1.1']:
+    for version in ["1.0", "1.1"]:
         mx_sym = model_zoo.mx_squeezenet(version)
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
         relay_mod = model_zoo.relay_squeezenet(version)
@@ -105,8 +107,7 @@ def test_multi_outputs():
         return tvm.IRModule.from_expr(func)
 
     mx_sym = mx_compose(mx, num_outputs=3, axis=1)
-    mod, _ = relay.frontend.from_mxnet(
-        mx_sym, shape={"x":xshape, "y":yshape})
+    mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"x": xshape, "y": yshape})
     relay_mod = relay_compose(relay, indices_or_sections=3, axis=1)
     compare_graph(mod, relay_mod)
 
index 541162d..c2e2425 100644 (file)
@@ -19,24 +19,24 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.contrib import graph_runtime
-from tvm.relay.frontend.mxnet_qnn_op_utils import dequantize_mxnet_min_max, \
-                                                  quantize_mxnet_min_max,   \
-                                                  get_mkldnn_int8_scale,    \
-                                                  get_mkldnn_uint8_scale,   \
-                                                  quantize_conv_bias_mkldnn_from_var
+from tvm.relay.frontend.mxnet_qnn_op_utils import (
+    dequantize_mxnet_min_max,
+    quantize_mxnet_min_max,
+    get_mkldnn_int8_scale,
+    get_mkldnn_uint8_scale,
+    quantize_conv_bias_mkldnn_from_var,
+)
 
 
 def test_mkldnn_dequantize():
-
     def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
         shape = in_data.shape
         input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
-        min_range = quant_args['min_range']
-        max_range = quant_args['max_range']
-        dequantized_output = dequantize_mxnet_min_max(input_data,
-                                                      min_range=min_range,
-                                                      max_range=max_range,
-                                                      in_dtype=in_dtype)
+        min_range = quant_args["min_range"]
+        max_range = quant_args["max_range"]
+        dequantized_output = dequantize_mxnet_min_max(
+            input_data, min_range=min_range, max_range=max_range, in_dtype=in_dtype
+        )
         mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output)
         mod = tvm.IRModule.from_expr(mod)
         with tvm.transform.PassContext(opt_level=3):
@@ -50,32 +50,58 @@ def test_mkldnn_dequantize():
             assert res.dtype == np.float32
 
     def test_uint8_to_float32():
-        data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-            .astype('uint8') \
-            .reshape((2, 5))
-        output = np.array([0., 0.25048923, 0.50097847, 0.7514677, 1.0019569, 62.8728, 63.123287,
-                           63.373775, 63.624268, 63.874756]) \
-            .astype('float32') \
+        data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8").reshape((2, 5))
+        output = (
+            np.array(
+                [
+                    0.0,
+                    0.25048923,
+                    0.50097847,
+                    0.7514677,
+                    1.0019569,
+                    62.8728,
+                    63.123287,
+                    63.373775,
+                    63.624268,
+                    63.874756,
+                ]
+            )
+            .astype("float32")
             .reshape((2, 5))
+        )
         quant_args = {"min_range": -63.5, "max_range": 64}
-        dequantize_test_driver(in_dtype='uint8',
-                               quant_args=quant_args,
-                               in_data=data,
-                               verify_output_data=output)
+        dequantize_test_driver(
+            in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output
+        )
 
     def test_int8_to_float32():
-        data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \
-            .astype('int8') \
+        data = (
+            np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127])
+            .astype("int8")
             .reshape((2, 5))
-        output = np.array([-63.247063, -62.745102, -62.24314, -61.74118, -61.23922,
-                           61.74118, 62.24314, 62.745102, 63.247063, 63.749023]) \
-            .astype('float32') \
+        )
+        output = (
+            np.array(
+                [
+                    -63.247063,
+                    -62.745102,
+                    -62.24314,
+                    -61.74118,
+                    -61.23922,
+                    61.74118,
+                    62.24314,
+                    62.745102,
+                    63.247063,
+                    63.749023,
+                ]
+            )
+            .astype("float32")
             .reshape((2, 5))
+        )
         dequantize_args = {"min_range": -63.5, "max_range": 64}
-        dequantize_test_driver(in_dtype='int8',
-                               quant_args=dequantize_args,
-                               in_data=data,
-                               verify_output_data=output)
+        dequantize_test_driver(
+            in_dtype="int8", quant_args=dequantize_args, in_data=data, verify_output_data=output
+        )
 
     test_uint8_to_float32()
     test_int8_to_float32()
@@ -84,13 +110,12 @@ def test_mkldnn_dequantize():
 def test_mkldnn_quantize():
     def quantize_test_driver(out_dtype, quant_args, in_data, verify_output_data):
         shape = in_data.shape
-        input_data = relay.var("input_data", shape=shape, dtype='float32')
-        min_range = quant_args['min_range']
-        max_range = quant_args['max_range']
-        quantized_output, _, _ = quantize_mxnet_min_max(input_data,
-                                                        min_range=min_range,
-                                                        max_range=max_range,
-                                                        out_dtype=out_dtype)
+        input_data = relay.var("input_data", shape=shape, dtype="float32")
+        min_range = quant_args["min_range"]
+        max_range = quant_args["max_range"]
+        quantized_output, _, _ = quantize_mxnet_min_max(
+            input_data, min_range=min_range, max_range=max_range, out_dtype=out_dtype
+        )
         mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
         mod = tvm.IRModule.from_expr(mod)
         with tvm.transform.PassContext(opt_level=3):
@@ -104,34 +129,60 @@ def test_mkldnn_quantize():
             assert res.dtype == verify_output_data.dtype
 
     def test_float32_to_uint8():
-        data = np.array([0., 0.25048923, 0.50097847, 0.7514677, 1.0019569, 62.8728, 63.123287,
-                         63.373775, 63.624268, 63.874756]) \
-            .astype('float32') \
-            .reshape((2, 5))
-        output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-            .astype('uint8') \
+        data = (
+            np.array(
+                [
+                    0.0,
+                    0.25048923,
+                    0.50097847,
+                    0.7514677,
+                    1.0019569,
+                    62.8728,
+                    63.123287,
+                    63.373775,
+                    63.624268,
+                    63.874756,
+                ]
+            )
+            .astype("float32")
             .reshape((2, 5))
+        )
+        output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8").reshape((2, 5))
 
         quant_args = {"min_range": -63.5, "max_range": 64}
-        quantize_test_driver(out_dtype='uint8',
-                             quant_args=quant_args,
-                             in_data=data,
-                             verify_output_data=output)
+        quantize_test_driver(
+            out_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output
+        )
 
     def test_float32_to_int8():
-        data = np.array([-63.247063, -62.745102, -62.24314, -61.74118, -61.23922,
-                         61.74118, 62.24314, 62.745102, 63.247063, 63.749023]) \
-            .astype('float32') \
+        data = (
+            np.array(
+                [
+                    -63.247063,
+                    -62.745102,
+                    -62.24314,
+                    -61.74118,
+                    -61.23922,
+                    61.74118,
+                    62.24314,
+                    62.745102,
+                    63.247063,
+                    63.749023,
+                ]
+            )
+            .astype("float32")
             .reshape((2, 5))
-        output = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \
-            .astype('int8') \
+        )
+        output = (
+            np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127])
+            .astype("int8")
             .reshape((2, 5))
+        )
 
         quant_args = {"min_range": -63.5, "max_range": 64}
-        quantize_test_driver(out_dtype='int8',
-                             quant_args=quant_args,
-                             in_data=data,
-                             verify_output_data=output)
+        quantize_test_driver(
+            out_dtype="int8", quant_args=quant_args, in_data=data, verify_output_data=output
+        )
 
     test_float32_to_uint8()
     test_float32_to_int8()
@@ -141,8 +192,7 @@ def test_get_mkldnn_int8_scale():
     range_min = -3.904039
     range_max = 3.904039
     expected = 0.03061991354976495
-    output = get_mkldnn_int8_scale(range_max=range_max,
-                                   range_min=range_min)
+    output = get_mkldnn_int8_scale(range_max=range_max, range_min=range_min)
     assert np.allclose(output, expected)
 
 
@@ -150,20 +200,19 @@ def test_get_mkldnn_uint8_scale():
     range_min = 0.0
     range_max = 55.77269
     expected = 0.21828841189047482
-    output = get_mkldnn_uint8_scale(range_max=range_max,
-                                    range_min=range_min)
+    output = get_mkldnn_uint8_scale(range_max=range_max, range_min=range_min)
     assert np.allclose(output, expected)
 
 
 def test_quantize_conv_bias_mkldnn_from_var():
-    bias_var = relay.var('bias', shape=(3,), dtype='float32')
+    bias_var = relay.var("bias", shape=(3,), dtype="float32")
     bias_scale = tvm.nd.array(np.array([0.5, 0.6, 0.7]))
     output = quantize_conv_bias_mkldnn_from_var(bias_var, bias_scale)
     assert isinstance(output, tvm.relay.expr.Call)
     attrs = output.attrs
     assert attrs.axis == 0
-    assert attrs.out_dtype == 'int32'
-    assert output.op.name == 'qnn.quantize'
+    assert attrs.out_dtype == "int32"
+    assert output.op.name == "qnn.quantize"
     assert output.args[1].data == bias_scale
 
 
index 394c745..6f63cbf 100644 (file)
@@ -51,24 +51,24 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
 
     mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
 
-    ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)
+    ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
     indata = tvm.nd.array(input_data)
     result = ex.evaluate()(indata)
     return result.asnumpy()
 
 
-def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
+def get_tvm_output(
+    graph_def, input_data, target, ctx, output_shape=None, output_dtype="float32", opset=None
+):
     """ Generic function to execute and get tvm output"""
-    target = 'llvm'
+    target = "llvm"
 
     input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)
 
     mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
 
     with tvm.transform.PassContext(opt_level=1):
-        graph, lib, params = relay.build(mod,
-                                         target,
-                                         params=params)
+        graph, lib, params = relay.build(mod, target, params=params)
 
     ctx = tvm.cpu(0)
     m = graph_runtime.create(graph, lib, ctx)
@@ -78,13 +78,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
             # Its possible for some onnx inputs to not be needed in the tvm
             # module, confirm its present before setting.
             try:
-                m.set_input(input_names[i], tvm.nd.array(
-                    input_data[i].astype(input_data[i].dtype)))
+                m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
             except:
                 continue
     else:
-        m.set_input(input_names, tvm.nd.array(
-            input_data.astype(input_data.dtype)))
+        m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
 
     m.set_input(**params)
     # execute
@@ -101,9 +99,10 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
         return tvm_output.asnumpy()
 
 
-def get_onnxruntime_output(model, inputs, dtype='float32'):
+def get_onnxruntime_output(model, inputs, dtype="float32"):
     import onnxruntime.backend
-    rep = onnxruntime.backend.prepare(model, 'CPU')
+
+    rep = onnxruntime.backend.prepare(model, "CPU")
     if isinstance(inputs, list) and len(inputs) > 1:
         ort_out = rep.run(inputs)
     else:
@@ -113,7 +112,7 @@ def get_onnxruntime_output(model, inputs, dtype='float32'):
 
 
 def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
-    dtype = 'float32'
+    dtype = "float32"
     x = np.random.uniform(size=data_shape)
     model = onnx.load_model(graph_file)
     c2_out = get_onnxruntime_output(model, x, dtype)
@@ -121,14 +120,15 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
         tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
         tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def make_constant_node(name, data_type, dims, vals):
-    return helper.make_node('Constant',
-                            inputs=[],
-                            outputs=[name],
-                            value=helper.make_tensor(name=name,
-                                                     data_type=data_type,
-                                                     dims=dims,
-                                                     vals=vals))
+    return helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=[name],
+        value=helper.make_tensor(name=name, data_type=data_type, dims=dims, vals=vals),
+    )
+
 
 @tvm.testing.uses_gpu
 def test_reshape():
@@ -136,56 +136,63 @@ def test_reshape():
     ref_shape = (6, 2, 4, 3)
 
     ref_array = np.array(ref_shape)
-    ref_node = onnx.helper.make_node('Constant',
-                                     inputs=[],
-                                     outputs=['ref_in'],
-                                     value=onnx.helper.make_tensor(name='const_tensor',
-                                                                   data_type=onnx.TensorProto.INT32,
-                                                                   dims=ref_array.shape,
-                                                                   vals=ref_array.flatten().astype(int)))
+    ref_node = onnx.helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["ref_in"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=onnx.TensorProto.INT32,
+            dims=ref_array.shape,
+            vals=ref_array.flatten().astype(int),
+        ),
+    )
     reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
 
-    graph = helper.make_graph([ref_node, reshape_node],
-                              "reshape_test",
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(ref_shape))])
+    graph = helper.make_graph(
+        [ref_node, reshape_node],
+        "reshape_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='reshape_test')
+    model = helper.make_model(graph, producer_name="reshape_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype('int32')
-        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+        x = np.random.uniform(size=in_shape).astype("int32")
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
 
     tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
 @tvm.testing.uses_gpu
 def test_expand():
-
     def _test_expand(name, data, shape, ref_data):
         shape_array = np.array(shape)
-        shape_node = onnx.helper.make_node('Constant',
-                                    inputs=[],
-                                    outputs=['shape'],
-                                    value=onnx.helper.make_tensor(name = 'const_tensor',
-                                                                  data_type = onnx.TensorProto.INT32,
-                                                                  dims = shape_array.shape,
-                                                                  vals = shape_array.flatten().astype('int32')))
+        shape_node = onnx.helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=["shape"],
+            value=onnx.helper.make_tensor(
+                name="const_tensor",
+                data_type=onnx.TensorProto.INT32,
+                dims=shape_array.shape,
+                vals=shape_array.flatten().astype("int32"),
+            ),
+        )
         expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
 
-        graph = helper.make_graph([shape_node, expand_node],
-                                "expand_test",
-                                inputs = [helper.make_tensor_value_info("in",
-                                                TensorProto.FLOAT, list(data.shape))],
-                                outputs = [helper.make_tensor_value_info("out",
-                                                TensorProto.FLOAT, list(ref_data.shape))])
+        graph = helper.make_graph(
+            [shape_node, expand_node],
+            "expand_test",
+            inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_data.shape))],
+        )
 
         model = helper.make_model(graph, producer_name=name)
 
         for target, ctx in tvm.testing.enabled_targets():
-            tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32')
+            tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, "float32")
 
         tvm.testing.assert_allclose(ref_data, tvm_out)
 
@@ -193,32 +200,31 @@ def test_expand():
     shape = (3, 4)
     data = np.random.uniform(size=in_shape).astype(np.float32)
     ref_data = np.tile(data, 4)
-    _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data)
+    _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
 
     in_shape = (3, 1)
     shape = (2, 1, 6)
     data = np.random.uniform(size=in_shape).astype(np.float32)
     ref_data = data * np.ones(shape, dtype=np.float32)
-    _test_expand('expand_with_dim_changed_test', data, shape, ref_data)
+    _test_expand("expand_with_dim_changed_test", data, shape, ref_data)
 
 
 def verify_depth_to_space(inshape, outshape, mode, blockSize):
-    node = onnx.helper.make_node('DepthToSpace',
-                                 inputs=['x'],
-                                 outputs=['y'],
-                                 blocksize=blockSize)
+    node = onnx.helper.make_node("DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blockSize)
 
-    graph = helper.make_graph([node],
-                              "depth_to_space_test",
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
+    graph = helper.make_graph(
+        [node],
+        "depth_to_space_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))],
+    )
 
-    model = helper.make_model(graph, producer_name='depth_to_space_test')
+    model = helper.make_model(graph, producer_name="depth_to_space_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=inshape).astype('float32')
-        tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
-        onnx_out = get_onnxruntime_output(model, x, 'float32')
+        x = np.random.uniform(size=inshape).astype("float32")
+        tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32")
+        onnx_out = get_onnxruntime_output(model, x, "float32")
         tvm.testing.assert_allclose(onnx_out, tvm_out)
 
 
@@ -231,22 +237,21 @@ def test_depth_to_space():
 
 
 def verify_space_to_depth(inshape, outshape, blockSize):
-    node = onnx.helper.make_node('SpaceToDepth',
-                                 inputs=['x'],
-                                 outputs=['y'],
-                                 blocksize=blockSize)
+    node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blockSize)
 
-    graph = helper.make_graph([node],
-                              "space_to_depth_test",
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
+    graph = helper.make_graph(
+        [node],
+        "space_to_depth_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))],
+    )
 
-    model = helper.make_model(graph, producer_name='space_to_depth_test')
+    model = helper.make_model(graph, producer_name="space_to_depth_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=inshape).astype('float32')
-        tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
-        onnx_out = get_onnxruntime_output(model, x, 'float32')
+        x = np.random.uniform(size=inshape).astype("float32")
+        tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32")
+        onnx_out = get_onnxruntime_output(model, x, "float32")
         tvm.testing.assert_allclose(onnx_out, tvm_out)
 
 
@@ -261,29 +266,33 @@ def test_shape():
     ref_shape = (6, 2, 4, 3)
 
     ref_array = np.array(ref_shape)
-    ref_node = onnx.helper.make_node('Constant',
-                                     inputs=[],
-                                     outputs=['ref_in'],
-                                     value=onnx.helper.make_tensor(name='const_tensor',
-                                                                   data_type=onnx.TensorProto.INT32,
-                                                                   dims=ref_array.shape,
-                                                                   vals=ref_array.flatten().astype(int)))
+    ref_node = onnx.helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["ref_in"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=onnx.TensorProto.INT32,
+            dims=ref_array.shape,
+            vals=ref_array.flatten().astype(int),
+        ),
+    )
     reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
 
-    shape_node = helper.make_node("Shape", ['out'], ['final_out'])
+    shape_node = helper.make_node("Shape", ["out"], ["final_out"])
 
-    graph = helper.make_graph([ref_node, reshape_node, shape_node],
-                              "shape_test",
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("final_out",
-                                                                     TensorProto.FLOAT, list(ref_shape))])
+    graph = helper.make_graph(
+        [ref_node, reshape_node, shape_node],
+        "shape_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("final_out", TensorProto.FLOAT, list(ref_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='shape_test')
+    model = helper.make_model(graph, producer_name="shape_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype('int32')
-        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32')
+        x = np.random.uniform(size=in_shape).astype("int32")
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "int32")
 
     tvm.testing.assert_allclose(ref_shape, tvm_out)
 
@@ -297,18 +306,19 @@ def _test_power_iteration(x_shape, y_shape):
 
     np_res = np.power(x, y).astype(np.float32)
 
-    res = helper.make_node("Pow", ['x', 'y'], ['out'])
+    res = helper.make_node("Pow", ["x", "y"], ["out"])
 
-    graph = helper.make_graph([res],
-                              'power_test',
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    TensorProto.FLOAT, list(x_shape)),
-                                      helper.make_tensor_value_info("y",
-                                                                    TensorProto.FLOAT, list(y_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(np_res.shape))])
+    graph = helper.make_graph(
+        [res],
+        "power_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
+            helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='power_test')
+    model = helper.make_model(graph, producer_name="power_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
@@ -326,20 +336,20 @@ def test_power():
 def test_squeeze():
     in_shape = (1, 3, 1, 3, 1, 1)
     out_shape = (3, 3)
-    y = helper.make_node("Squeeze", ['in'], ['out'], axes=[0, 2, 4, 5])
+    y = helper.make_node("Squeeze", ["in"], ["out"], axes=[0, 2, 4, 5])
 
-    graph = helper.make_graph([y],
-                              'squeeze_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [y],
+        "squeeze_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='squeeze_test')
+    model = helper.make_model(graph, producer_name="squeeze_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype('float32')
-        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
+        x = np.random.uniform(size=in_shape).astype("float32")
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
 
     tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
@@ -353,18 +363,18 @@ def test_flatten():
 
     flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis)
 
-    graph = helper.make_graph([flatten_node],
-                              "flatten_test",
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(ref_shape))])
+    graph = helper.make_graph(
+        [flatten_node],
+        "flatten_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='flatten_test')
+    model = helper.make_model(graph, producer_name="flatten_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype('int32')
-        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+        x = np.random.uniform(size=in_shape).astype("int32")
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
 
     tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
@@ -374,20 +384,20 @@ def test_unsqueeze():
     in_shape = (3, 3)
     axis = (0, 3, 4)
     out_shape = (1, 3, 3, 1, 1)
-    y = helper.make_node("Unsqueeze", ['in'], ['out'], axes=list(axis))
+    y = helper.make_node("Unsqueeze", ["in"], ["out"], axes=list(axis))
 
-    graph = helper.make_graph([y],
-                              'squeeze_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [y],
+        "squeeze_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='squeeze_test')
+    model = helper.make_model(graph, producer_name="squeeze_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype('float32')
-        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
+        x = np.random.uniform(size=in_shape).astype("float32")
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
 
     tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
@@ -397,32 +407,32 @@ def verify_gather(in_shape, indices, axis, dtype):
     indices = np.array(indices, dtype="int32")
     out_np = np.take(x, indices, axis=axis)
 
-    y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis)
+    y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis)
 
-    graph = helper.make_graph([y],
-                              'gather_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape)),
-                                      helper.make_tensor_value_info("indices",
-                                                                    TensorProto.INT32, list(indices.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_np.shape))])
-    model = helper.make_model(graph, producer_name='gather_test')
+    graph = helper.make_graph(
+        [y],
+        "gather_test",
+        inputs=[
+            helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)),
+            helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
+    )
+    model = helper.make_model(graph, producer_name="gather_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x, indices], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_gather():
-    verify_gather((4,), [1], 0, 'int32')
-    verify_gather((1, 4), [0], 0, 'int32')
-    verify_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
-    verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
-    verify_gather((3, 3, 3), [[[1, 0]]], -1, 'int32')
-    verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
+    verify_gather((4,), [1], 0, "int32")
+    verify_gather((1, 4), [0], 0, "int32")
+    verify_gather((4,), [[[1, 0], [0, 1]]], 0, "float32")
+    verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, "int32")
+    verify_gather((3, 3, 3), [[[1, 0]]], -1, "int32")
+    verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32")
 
 
 def verify_scatter(in_shape, indices, axis):
@@ -430,24 +440,23 @@ def verify_scatter(in_shape, indices, axis):
     indices = np.array(indices, dtype="int32")
     updates = np.random.uniform(size=indices.shape).astype("float32")
 
-    y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis)
-
-    graph = helper.make_graph([y],
-                              'scatter_test',
-                              inputs=[helper.make_tensor_value_info("data",
-                                                                    TensorProto.FLOAT, list(in_shape)),
-                                      helper.make_tensor_value_info("indices",
-                                                                    TensorProto.INT32, list(indices.shape)),
-                                      helper.make_tensor_value_info("updates",
-                                                                    TensorProto.FLOAT, list(indices.shape))],
-                              outputs=[helper.make_tensor_value_info("output",
-                                                                     TensorProto.FLOAT, list(in_shape))])
-    model = helper.make_model(graph, producer_name='scatter_test')
+    y = helper.make_node("ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis)
+
+    graph = helper.make_graph(
+        [y],
+        "scatter_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)),
+            helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)),
+            helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))],
+    )
+    model = helper.make_model(graph, producer_name="scatter_test")
     onnx_out = get_onnxruntime_output(model, [x, indices, updates])
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x, indices, updates], target, ctx, onnx_out[0].shape)
+        tvm_out = get_tvm_output(model, [x, indices, updates], target, ctx, onnx_out[0].shape)
         tvm.testing.assert_allclose(onnx_out[0], tvm_out)
 
 
@@ -463,75 +472,82 @@ def test_scatter():
 
 def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
     if axes:
-        y = helper.make_node(
-            "Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
+        y = helper.make_node("Slice", ["in"], ["out"], axes=axes, starts=starts, ends=ends)
     else:
-        y = helper.make_node(
-            "Slice", ['in'], ['out'], starts=starts, ends=ends)
+        y = helper.make_node("Slice", ["in"], ["out"], starts=starts, ends=ends)
 
-    graph = helper.make_graph([y],
-                              'slice_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(outdata.shape))])
+    graph = helper.make_graph(
+        [y],
+        "slice_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='slice_test')
+    model = helper.make_model(graph, producer_name="slice_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, outdata.shape, 'float32', opset=1)
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=1)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
+
 def _test_slice_iteration_v10(indata, outdata, **attrs):
-    starts = attrs['starts']
-    ends = attrs['ends']
-    axes = None if 'axes' not in attrs else attrs['axes']
+    starts = attrs["starts"]
+    ends = attrs["ends"]
+    axes = None if "axes" not in attrs else attrs["axes"]
     starts = np.asarray(starts)
     ends = np.asarray(ends)
     inputs = [
-        helper.make_tensor_value_info("data", TensorProto.FLOAT,
-                                      list(indata.shape)),
-        helper.make_tensor_value_info("starts", TensorProto.INT64,
-                                      list(starts.shape)),
-        helper.make_tensor_value_info("ends", TensorProto.INT64,
-                                      list(ends.shape))
+        helper.make_tensor_value_info("data", TensorProto.FLOAT, list(indata.shape)),
+        helper.make_tensor_value_info("starts", TensorProto.INT64, list(starts.shape)),
+        helper.make_tensor_value_info("ends", TensorProto.INT64, list(ends.shape)),
     ]
     initializer = [
-        helper.make_tensor("starts", TensorProto.INT64, list(starts.shape),
-                           starts),
-        helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends)
+        helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), starts),
+        helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends),
     ]
     nodes = []
 
-    if 'add_noop_to_input_attrs' in attrs:
+    if "add_noop_to_input_attrs" in attrs:
+
         def add_noop_to_input_attr(attr_name, attr):
-            output_name = attr_name+"_output"
+            output_name = attr_name + "_output"
 
             ref_shape = list(np.array(attr).shape)
             ref_shape.insert(0, 1)
             ref_shape = tuple(ref_shape)
             ref_array = np.array(ref_shape)
-            ref_node = onnx.helper.make_node('Constant',
-                                             inputs=[],
-                                             outputs=['ref_in_'+attr_name],
-                                             value=onnx.helper.make_tensor(name='const_tensor__1_'+attr_name,
-                                                                           data_type=onnx.TensorProto.INT64,
-                                                                           dims=ref_array.shape,
-                                                                           vals=ref_array.flatten().astype(int)))
+            ref_node = onnx.helper.make_node(
+                "Constant",
+                inputs=[],
+                outputs=["ref_in_" + attr_name],
+                value=onnx.helper.make_tensor(
+                    name="const_tensor__1_" + attr_name,
+                    data_type=onnx.TensorProto.INT64,
+                    dims=ref_array.shape,
+                    vals=ref_array.flatten().astype(int),
+                ),
+            )
             in_shape = np.array(attr).shape
             in_array = np.array(in_shape)
-            ref_node2 = onnx.helper.make_node('Constant',
-                                              inputs=[],
-                                              outputs=['input_shape_'+attr_name],
-                                              value=onnx.helper.make_tensor(name='const_tensor__2_'+attr_name,
-                                                                            data_type=onnx.TensorProto.INT64,
-                                                                            dims=in_array.shape,
-                                                                            vals=in_array.flatten().astype(int)))
-
-            reshape1_node = helper.make_node("Reshape", [attr_name, "ref_in_"+attr_name], ["reshape_"+attr_name])
-            reshape2_node = helper.make_node("Reshape", ["reshape_"+attr_name, "input_shape_"+attr_name], [output_name])
+            ref_node2 = onnx.helper.make_node(
+                "Constant",
+                inputs=[],
+                outputs=["input_shape_" + attr_name],
+                value=onnx.helper.make_tensor(
+                    name="const_tensor__2_" + attr_name,
+                    data_type=onnx.TensorProto.INT64,
+                    dims=in_array.shape,
+                    vals=in_array.flatten().astype(int),
+                ),
+            )
+
+            reshape1_node = helper.make_node(
+                "Reshape", [attr_name, "ref_in_" + attr_name], ["reshape_" + attr_name]
+            )
+            reshape2_node = helper.make_node(
+                "Reshape", ["reshape_" + attr_name, "input_shape_" + attr_name], [output_name]
+            )
             return [ref_node, ref_node2, reshape1_node, reshape2_node]
 
     slice_inputs = []
@@ -546,34 +562,22 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
 
     if axes:
         axes = np.asarray(axes)
-        inputs.append(
-            helper.make_tensor_value_info("axes", TensorProto.INT32,
-                                          list(axes.shape)))
-        initializer.append(
-            helper.make_tensor("axes", TensorProto.INT32, list(axes.shape),
-                               axes))
+        inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape)))
+        initializer.append(helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes))
     y = helper.make_node("Slice", ["data", *slice_inputs], ["out"])
 
     nodes.append(y)
-    graph = helper.make_graph(nodes,
-                              'slice_test',
-                              inputs=inputs,
-                              outputs=[
-                                  helper.make_tensor_value_info(
-                                      "out", TensorProto.FLOAT,
-                                      list(outdata.shape))
-                              ],
-                              initializer=initializer)
-    model = helper.make_model(graph, producer_name='slice_test')
+    graph = helper.make_graph(
+        nodes,
+        "slice_test",
+        inputs=inputs,
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+        initializer=initializer,
+    )
+    model = helper.make_model(graph, producer_name="slice_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model,
-                                 indata,
-                                 target,
-                                 ctx,
-                                 outdata.shape,
-                                 'float32',
-                                 opset=10)
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
@@ -589,123 +593,158 @@ def test_slice():
     _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4))
     _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,))
     _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))
-    _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["starts"])
-    _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"])
-    _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"])
-    _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,), add_noop_to_input_attrs=["starts", "ends"])
-    _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["ends", "axes"])
-    _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["starts", "axes"])
-    _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["starts", "ends", "axes"])
+    _test_slice_iteration_v10(
+        x,
+        x[0:3, 0:10],
+        starts=(0, 0),
+        ends=(3, 10),
+        axes=(0, 1),
+        add_noop_to_input_attrs=["starts"],
+    )
+    _test_slice_iteration_v10(
+        x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"]
+    )
+    _test_slice_iteration_v10(
+        x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"]
+    )
+    _test_slice_iteration_v10(
+        x,
+        x[:, 0:-1],
+        starts=(0,),
+        ends=(-1,),
+        axes=(1,),
+        add_noop_to_input_attrs=["starts", "ends"],
+    )
+    _test_slice_iteration_v10(
+        x,
+        x[0:3, 0:10],
+        starts=(0, 0),
+        ends=(3, 10),
+        axes=(0, 1),
+        add_noop_to_input_attrs=["ends", "axes"],
+    )
+    _test_slice_iteration_v10(
+        x,
+        x[:, :, 3:4],
+        starts=(0, 0, 3),
+        ends=(20, 10, 4),
+        add_noop_to_input_attrs=["starts", "axes"],
+    )
+    _test_slice_iteration_v10(
+        x,
+        x[:, 1:1000],
+        starts=(1,),
+        ends=(1000,),
+        axes=(1,),
+        add_noop_to_input_attrs=["starts", "ends", "axes"],
+    )
     x = np.random.randn(1, 1, 1, 128).astype(np.float32)
-    _test_slice_iteration_v10(x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3))
+    _test_slice_iteration_v10(
+        x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3)
+    )
 
 
 def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
     indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
     outdata = outfunc(indata, **npargs)
 
-    y = helper.make_node(opname, ['in'], ['out'], **kwargs)
+    y = helper.make_node(opname, ["in"], ["out"], **kwargs)
 
-    graph = helper.make_graph([y],
-                              opname+'_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(outdata.shape))])
+    graph = helper.make_graph(
+        [y],
+        opname + "_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name=opname+'_test')
+    model = helper.make_model(graph, producer_name=opname + "_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, outdata.shape, dtype)
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_floor():
-    _test_onnx_op_elementwise((2, 4, 5, 6), np.floor,
-                              {}, 'float32', 'Floor', {})
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, "float32", "Floor", {})
 
 
 @tvm.testing.uses_gpu
 def test_ceil():
-    _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {})
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, "float32", "Ceil", {})
 
 
 @tvm.testing.uses_gpu
 def test_clip():
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              np.clip,
-                              {'a_min': -1.0, 'a_max': 1.0},
-                              'float32',
-                              'Clip',
-                              {'min': -1.0, 'max': 1.0})
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6),
+        np.clip,
+        {"a_min": -1.0, "a_max": 1.0},
+        "float32",
+        "Clip",
+        {"min": -1.0, "max": 1.0},
+    )
 
 
 @tvm.testing.uses_gpu
 def test_clip_min_max_as_inputs():
-    input_shape=(2,4,5,6)
+    input_shape = (2, 4, 5, 6)
     nodes = [
-        make_constant_node('min', onnx.TensorProto.FLOAT, (), [0.]),
-        make_constant_node('max', onnx.TensorProto.FLOAT, (), [6.]),
+        make_constant_node("min", onnx.TensorProto.FLOAT, (), [0.0]),
+        make_constant_node("max", onnx.TensorProto.FLOAT, (), [6.0]),
     ]
-    input_names = ['in', 'min', 'max']
-    nodes.append(helper.make_node(
-        'Clip',
-        inputs=input_names,
-        outputs=['out']))
-    graph = helper.make_graph(nodes,
-                              "clip_test",
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(input_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(input_shape))])
-    model = helper.make_model(graph, producer_name='clip_test')
-
-    indata = np.random.uniform(-1, 7, size=input_shape).astype('float32')
-    onnx_out = get_onnxruntime_output(model, indata, 'float32')
+    input_names = ["in", "min", "max"]
+    nodes.append(helper.make_node("Clip", inputs=input_names, outputs=["out"]))
+    graph = helper.make_graph(
+        nodes,
+        "clip_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(input_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_shape))],
+    )
+    model = helper.make_model(graph, producer_name="clip_test")
+
+    indata = np.random.uniform(-1, 7, size=input_shape).astype("float32")
+    onnx_out = get_onnxruntime_output(model, indata, "float32")
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, input_shape, 'float32')
+        tvm_out = get_tvm_output(model, indata, target, ctx, input_shape, "float32")
     tvm.testing.assert_allclose(onnx_out, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_round():
-    _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {})
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, "float32", "Round", {})
 
 
 def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs):
     indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype)
 
     outdata = outfunc(indata, **npargs)
-    y = helper.make_node(opname, ['in'], ['out'], **kwargs)
+    y = helper.make_node(opname, ["in"], ["out"], **kwargs)
 
-    graph = helper.make_graph([y],
-                              opname+'_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.BOOL, list(outdata.shape))])
+    graph = helper.make_graph(
+        [y],
+        opname + "_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name=opname+'_test')
+    model = helper.make_model(graph, producer_name=opname + "_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, outdata.shape, dtype)
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_isinf():
-    _test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {})
+    _test_finite_ops((2, 4, 5, 6), np.isinf, {}, "float32", "IsInf", {})
 
 
 @tvm.testing.uses_gpu
 def test_isnan():
-    _test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {})
+    _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {})
 
 
 def verify_gather_nd(in_shape, indices, dtype):
@@ -713,60 +752,60 @@ def verify_gather_nd(in_shape, indices, dtype):
     indices = np.array(indices, dtype="int32")
     out_np = tvm.topi.testing.gather_nd_python(x, indices)
 
-    y = helper.make_node("GatherND", ['in', 'indices'], ['out'])
+    y = helper.make_node("GatherND", ["in", "indices"], ["out"])
 
-    graph = helper.make_graph([y],
-                              'gather_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(in_shape)),
-                                      helper.make_tensor_value_info("indices",
-                                                                    TensorProto.INT32, list(indices.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_np.shape))])
-    model = helper.make_model(graph, producer_name='gather_test')
+    graph = helper.make_graph(
+        [y],
+        "gather_test",
+        inputs=[
+            helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)),
+            helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
+    )
+    model = helper.make_model(graph, producer_name="gather_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x, indices], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_gather_nd():
-    verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32')
-    verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32')
-    verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32')
+    verify_gather_nd((2, 2), [[0, 0], [1, 1]], "int32")
+    verify_gather_nd((3, 3, 3), [[0, 1], [1, 0]], "float32")
+    verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32")
 
 
 @tvm.testing.uses_gpu
 def test_onehot():
     indices_shape = [10]
-    indices_array = np.random.randint(
-        low=0, high=9, size=indices_shape, dtype='int32')
+    indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32")
     depth = 10
     values = np.asarray([0, 1])
     out_np = np.eye(depth)[indices_array.reshape(-1)]
 
-    onehot_node = helper.make_node(
-        "OneHot", ["indices", "depth", "values"], ["out"])
-
-    graph = helper.make_graph([onehot_node],
-                              "onehot_test",
-                              inputs=[helper.make_tensor_value_info("indices",
-                                                                    TensorProto.INT32, indices_shape),
-                                      helper.make_tensor_value_info("depth",
-                                                                    TensorProto.INT32, [1]),
-                                      helper.make_tensor_value_info("values",
-                                                                    TensorProto.INT32, values.shape)],
-                              initializer=[helper.make_tensor("depth", TensorProto.INT32, [1], [depth]),
-                                           helper.make_tensor("values", TensorProto.INT32, values.shape, values)],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)])
+    onehot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["out"])
+
+    graph = helper.make_graph(
+        [onehot_node],
+        "onehot_test",
+        inputs=[
+            helper.make_tensor_value_info("indices", TensorProto.INT32, indices_shape),
+            helper.make_tensor_value_info("depth", TensorProto.INT32, [1]),
+            helper.make_tensor_value_info("values", TensorProto.INT32, values.shape),
+        ],
+        initializer=[
+            helper.make_tensor("depth", TensorProto.INT32, [1], [depth]),
+            helper.make_tensor("values", TensorProto.INT32, values.shape, values),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)],
+    )
 
     model = helper.make_model(graph, producer_name="onehot_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [indices_array], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(model, [indices_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -775,57 +814,60 @@ def test_matmul():
     a_shape = (4, 3)
     b_shape = (3, 4)
 
-    a_array = np.random.uniform(size=a_shape).astype('float32')
-    b_array = np.random.uniform(size=b_shape).astype('float32')
+    a_array = np.random.uniform(size=a_shape).astype("float32")
+    b_array = np.random.uniform(size=b_shape).astype("float32")
     out_np = np.matmul(a_array, b_array)
 
     mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
 
-    graph = helper.make_graph([mul_node],
-                              "matmul_test",
-                              inputs=[helper.make_tensor_value_info("a",
-                                                                    TensorProto.FLOAT, list(a_shape)),
-                                      helper.make_tensor_value_info("b",
-                                                                    TensorProto.FLOAT, list(b_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_np.shape))])
+    graph = helper.make_graph(
+        [mul_node],
+        "matmul_test",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
+            helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='matmul_test')
+    model = helper.make_model(graph, producer_name="matmul_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_array, b_array], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def verify_batch_matmul(a_shape, b_shape):
-    a_array = np.random.uniform(size=a_shape).astype('float32')
-    b_array = np.random.uniform(size=b_shape).astype('float32')
+    a_array = np.random.uniform(size=a_shape).astype("float32")
+    b_array = np.random.uniform(size=b_shape).astype("float32")
     out_np = np.matmul(a_array, b_array)
 
     mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
 
-    graph = helper.make_graph([mul_node],
-                              "matmul_test",
-                              inputs=[helper.make_tensor_value_info("a",
-                                                                    TensorProto.FLOAT, list(a_shape)),
-                                      helper.make_tensor_value_info("b",
-                                                                    TensorProto.FLOAT, list(b_shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(out_np.shape))])
+    graph = helper.make_graph(
+        [mul_node],
+        "matmul_test",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
+            helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='matmul_test')
+    model = helper.make_model(graph, producer_name="matmul_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_array, b_array], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_batch_matmul():
     verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
     verify_batch_matmul((2, 4, 3), (3, 4))
     verify_batch_matmul((2, 3, 4, 3), (3, 4))
 
+
 def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
     in_array = np.random.uniform(size=shape).astype(dtype)
 
@@ -833,46 +875,51 @@ def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
         alpha = 0.0001
         beta = 0.75
         bias = 1.0
-        node = onnx.helper.make_node(
-            'LRN', inputs=['in'], outputs=['out'], size=nsize)
+        node = onnx.helper.make_node("LRN", inputs=["in"], outputs=["out"], size=nsize)
     else:
-        node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
-                                     beta=beta, bias=bias, size=nsize)
+        node = onnx.helper.make_node(
+            "LRN", inputs=["in"], outputs=["out"], alpha=alpha, beta=beta, bias=bias, size=nsize
+        )
 
-    graph = helper.make_graph([node],
-                              "lrn_test",
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
-    model = helper.make_model(graph, producer_name='lrn_test')
+    graph = helper.make_graph(
+        [node],
+        "lrn_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))],
+    )
+    model = helper.make_model(graph, producer_name="lrn_test")
 
     def _get_python_lrn():
         square_sum = np.zeros(shape).astype(dtype)
         for n, c, h, w in np.ndindex(in_array.shape):
-            square_sum[n, c, h, w] = sum(in_array[n,
-                                                  max(0, c - int(math.floor((nsize - 1) / 2))):
-                                                  min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
-                                                  h,
-                                                  w] ** 2)
+            square_sum[n, c, h, w] = sum(
+                in_array[
+                    n,
+                    max(0, c - int(math.floor((nsize - 1) / 2))) : min(
+                        5, c + int(math.ceil((nsize - 1) / 2)) + 1
+                    ),
+                    h,
+                    w,
+                ]
+                ** 2
+            )
         py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
         return py_out
 
     for target, ctx in tvm.testing.enabled_targets():
         input_name = model.graph.input[0].name
         py_out = _get_python_lrn()
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, py_out.shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, "float32")
         tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_lrn():
-    verify_lrn((5, 5, 5, 5), 3, 'float32')
-    verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
+    verify_lrn((5, 5, 5, 5), 3, "float32")
+    verify_lrn((5, 5, 5, 5), 3, "float32", alpha=0.0002, beta=0.5, bias=2.0)
 
 
 def verify_instance_norm(shape, axis=1):
-
     def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
         dims_x = len(x.shape)
         axis = tuple(range(2, dims_x))
@@ -890,22 +937,24 @@ def verify_instance_norm(shape, axis=1):
     y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
 
     node = onnx.helper.make_node(
-        'InstanceNormalization',
-        inputs=['x', 'gamma', 'beta'],
-        outputs=['y'],
+        "InstanceNormalization",
+        inputs=["x", "gamma", "beta"],
+        outputs=["y"],
         epsilon=epsilon,
     )
-    graph = helper.make_graph([node],
-                              "instance_norm_test",
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
-                                      helper.make_tensor_value_info(
-                                          "gamma", TensorProto.FLOAT, (shape[1],)),
-                                      helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
-    model = helper.make_model(graph, producer_name='instance_norm_test')
+    graph = helper.make_graph(
+        [node],
+        "instance_norm_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
+            helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
+            helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],)),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))],
+    )
+    model = helper.make_model(graph, producer_name="instance_norm_test")
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x, gamma, beta], target, ctx, shape, 'float32')
+        tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, "float32")
         tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -920,145 +969,152 @@ def test_instance_norm():
 def _test_upsample_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3)
-    out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in'], [
-                         'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
+    out_shape = (1, 1, 3 * scale, 3 * scale)
+    y = helper.make_node("Upsample", ["in"], ["out"], mode="nearest", scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.upsampling_python(
-        in_array, (scale, scale), "NCHW")
+    out_array = tvm.topi.testing.upsampling_python(in_array, (scale, scale), "NCHW")
 
-    graph = helper.make_graph([y],
-                              'upsample_nearest_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [y],
+        "upsample_nearest_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='upsample_nearest_test')
+    model = helper.make_model(graph, producer_name="upsample_nearest_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32")
         tvm.testing.assert_allclose(out_array, tvm_out)
 
 
 def _test_upsample3d_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3, 3)
-    out_shape = (1, 1, 3*scale, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in'], [
-                         'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0, 2.0])
+    out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale)
+    y = helper.make_node(
+        "Upsample", ["in"], ["out"], mode="nearest", scales=[1.0, 1.0, 2.0, 2.0, 2.0]
+    )
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.upsampling3d_python(
-        in_array, (scale, scale, scale), "NCDHW")
+    out_array = tvm.topi.testing.upsampling3d_python(in_array, (scale, scale, scale), "NCDHW")
 
-    graph = helper.make_graph([y],
-                              'upsample_nearest_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [y],
+        "upsample_nearest_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='upsample_nearest_test')
+    model = helper.make_model(graph, producer_name="upsample_nearest_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32")
         tvm.testing.assert_allclose(out_array, tvm_out)
 
+
 def _test_upsample_bilinear():
     scale = 2
     in_shape = (1, 1, 3, 3)
-    out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in'], [
-                         'out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
+    out_shape = (1, 1, 3 * scale, 3 * scale)
+    y = helper.make_node("Upsample", ["in"], ["out"], mode="linear", scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.bilinear_resize_python(
-        in_array, (3*scale, 3*scale), "NCHW")
+    out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW")
 
-    graph = helper.make_graph([y],
-                              'upsample_bilinear_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [y],
+        "upsample_bilinear_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='upsample_bilinear_test')
+    model = helper.make_model(graph, producer_name="upsample_bilinear_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32")
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def _test_upsample_bilinear_opset9():
     scale = 2
     in_shape = (1, 1, 3, 3)
-    out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear')
+    out_shape = (1, 1, 3 * scale, 3 * scale)
+    y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear")
     scales = [1, 1, 2, 2]
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.bilinear_resize_python(
-        in_array, (3*scale, 3*scale), "NCHW")
-
-    ref_node = helper.make_node('Constant',
-                                inputs=[],
-                                outputs=['const'],
-                                value=onnx.helper.make_tensor(name='const_tensor',
-                                                              data_type=TensorProto.FLOAT,
-                                                              dims=scales,
-                                                              vals=np.random.random(scales).flatten().astype(float)))
+    out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW")
+
+    ref_node = helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["const"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=TensorProto.FLOAT,
+            dims=scales,
+            vals=np.random.random(scales).flatten().astype(float),
+        ),
+    )
 
-    shape_node = helper.make_node("Shape", ['const'], ['scales'])
+    shape_node = helper.make_node("Shape", ["const"], ["scales"])
 
-    graph = helper.make_graph([ref_node, shape_node, y],
-                              'upsample_bilinear_opset9_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [ref_node, shape_node, y],
+        "upsample_bilinear_opset9_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(
-        graph, producer_name='upsample_bilinear_opset9_test')
+    model = helper.make_model(graph, producer_name="upsample_bilinear_opset9_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32")
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def _test_upsample3d_trilinear():
     scale = 2
     in_shape = (1, 1, 3, 3, 3)
-    out_shape = (1, 1, 3*scale, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear')
+    out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale)
+    y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear")
     scales = [1.0, 1.0, 2.0, 2.0, 2.0]
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
     out_array = tvm.topi.testing.trilinear_resize3d_python(
-        in_array, (3*scale, 3*scale, 3*scale), "NCDHW", coordinate_transformation_mode="half_pixel")
+        in_array,
+        (3 * scale, 3 * scale, 3 * scale),
+        "NCDHW",
+        coordinate_transformation_mode="half_pixel",
+    )
 
     ref_array = np.array(scales)
-    ref_node = helper.make_node('Constant',
-                                inputs=[],
-                                outputs=['scales'],
-                                value=onnx.helper.make_tensor(name='const_tensor',
-                                                              data_type=TensorProto.FLOAT,
-                                                              dims=ref_array.shape,
-                                                              vals=ref_array.flatten().astype(float)))
-
-    graph = helper.make_graph([ref_node, y],
-                              'upsample_trilinear_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(in_shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
-
-    model = helper.make_model(
-        graph, producer_name='upsample_trilinear_test')
+    ref_node = helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["scales"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=TensorProto.FLOAT,
+            dims=ref_array.shape,
+            vals=ref_array.flatten().astype(float),
+        ),
+    )
+
+    graph = helper.make_graph(
+        [ref_node, y],
+        "upsample_trilinear_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="upsample_trilinear_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32")
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_upsample():
     _test_upsample_nearest()
@@ -1067,28 +1123,28 @@ def test_upsample():
     _test_upsample3d_nearest()
     _test_upsample3d_trilinear()
 
+
 def _test_softmax(inshape, axis):
-    opname = 'Softmax'
+    opname = "Softmax"
     indata = np.random.uniform(size=inshape).astype(np.float32)
     outshape = inshape
     outdata = tvm.topi.testing.softmax_python(indata)
     if isinstance(axis, int):
-        y = helper.make_node(opname, ['in'], ['out'], axis=axis)
+        y = helper.make_node(opname, ["in"], ["out"], axis=axis)
     elif axis is None:
-        y = helper.make_node(opname, ['in'], ['out'])
+        y = helper.make_node(opname, ["in"], ["out"])
 
-    graph = helper.make_graph([y],
-                              opname+'_test',
-                              inputs=[helper.make_tensor_value_info("in",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(outdata.shape))])
+    graph = helper.make_graph(
+        [y],
+        opname + "_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name=opname+'_test')
+    model = helper.make_model(graph, producer_name=opname + "_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, outshape, 'float32')
+        tvm_out = get_tvm_output(model, indata, target, ctx, outshape, "float32")
         tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -1099,7 +1155,7 @@ def test_softmax():
 
 
 def verify_min(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
@@ -1109,22 +1165,21 @@ def verify_min(input_dim):
 
     min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"])
 
-    graph = helper.make_graph([min_node],
-                              "Min_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np2",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np3",
-                                                                    TensorProto.FLOAT, list(input_dim))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(b_np.shape))])
+    graph = helper.make_graph(
+        [min_node],
+        "Min_test",
+        inputs=[
+            helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='Min_test')
+    model = helper.make_model(graph, producer_name="Min_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -1135,7 +1190,7 @@ def test_forward_min():
 
 
 def verify_max(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
@@ -1145,22 +1200,21 @@ def verify_max(input_dim):
 
     max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"])
 
-    graph = helper.make_graph([max_node],
-                              "Max_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np2",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np3",
-                                                                    TensorProto.FLOAT, list(input_dim))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(b_np.shape))])
+    graph = helper.make_graph(
+        [max_node],
+        "Max_test",
+        inputs=[
+            helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='Max_test')
+    model = helper.make_model(graph, producer_name="Max_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -1171,7 +1225,7 @@ def test_forward_max():
 
 
 def verify_mean(input_dim):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
@@ -1181,22 +1235,21 @@ def verify_mean(input_dim):
 
     mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"])
 
-    graph = helper.make_graph([mean_node],
-                              "Mean_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np2",
-                                                                    TensorProto.FLOAT, list(input_dim)),
-                                      helper.make_tensor_value_info("a_np3",
-                                                                    TensorProto.FLOAT, list(input_dim))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(b_np.shape))])
+    graph = helper.make_graph(
+        [mean_node],
+        "Mean_test",
+        inputs=[
+            helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)),
+            helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='Mean_test')
+    model = helper.make_model(graph, producer_name="Mean_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -1207,23 +1260,22 @@ def test_forward_mean():
 
 
 def verify_hardsigmoid(input_dim, alpha, beta):
-    dtype = 'float32'
+    dtype = "float32"
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
 
     b_np = np.clip(a_np1 * alpha + beta, 0, 1)
 
-    hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], [
-                                        "out"], alpha=alpha, beta=beta)
+    hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta)
 
-    graph = helper.make_graph([hardsigmoid_node],
-                              "HardSigmoid_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.FLOAT, list(input_dim))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.FLOAT, list(b_np.shape))])
+    graph = helper.make_graph(
+        [hardsigmoid_node],
+        "HardSigmoid_test",
+        inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='HardSigmoid_test')
+    model = helper.make_model(graph, producer_name="HardSigmoid_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
@@ -1239,101 +1291,79 @@ def test_forward_hardsigmoid():
 def verify_argmin(input_dim, axis=None, keepdims=None):
     def _argmin_numpy(data, axis=0, keepdims=True):
         result = np.argmin(data, axis=axis)
-        if (keepdims == 1):
+        if keepdims == 1:
             result = np.expand_dims(result, axis)
         return result.astype(data.dtype)
 
     a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
     if keepdims is None and axis is None:
         b_np = _argmin_numpy(a_np1)
-        node = onnx.helper.make_node('ArgMin',
-                                     inputs=['a_np1'],
-                                     outputs=['out'])
+        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"])
     elif axis is None:
         b_np = _argmin_numpy(a_np1, keepdims=keepdims)
-        node = onnx.helper.make_node('ArgMin',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], keepdims=keepdims)
     elif keepdims is None:
         b_np = _argmin_numpy(a_np1, axis=axis)
-        node = onnx.helper.make_node('ArgMin',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     axis=axis)
+        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis)
     else:
         b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims)
-        node = onnx.helper.make_node('ArgMin',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     axis=axis,
-                                     keepdims=keepdims)
-    graph = helper.make_graph([node],
-                              "argmin_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.INT32, list(a_np1.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.INT32, list(b_np.shape))])
-
-    model = helper.make_model(graph, producer_name='argmin_test')
+        node = onnx.helper.make_node(
+            "ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims
+        )
+    graph = helper.make_graph(
+        [node],
+        "argmin_test",
+        inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="argmin_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def verify_argmax(input_dim, axis=None, keepdims=None):
     def _argmax_numpy(data, axis=0, keepdims=True):
         result = np.argmax(data, axis=axis)
-        if (keepdims == 1):
+        if keepdims == 1:
             result = np.expand_dims(result, axis)
         return result.astype(data.dtype)
 
     a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
     if keepdims is None and axis is None:
         b_np = _argmax_numpy(a_np1)
-        node = onnx.helper.make_node('ArgMax',
-                                     inputs=['a_np1'],
-                                     outputs=['out'])
+        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"])
     elif axis is None:
         b_np = _argmax_numpy(a_np1, keepdims=keepdims)
-        node = onnx.helper.make_node('ArgMax',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], keepdims=keepdims)
     elif keepdims is None:
         b_np = _argmax_numpy(a_np1, axis=axis)
-        node = onnx.helper.make_node('ArgMax',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     axis=axis)
+        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis)
     else:
         b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims)
-        node = onnx.helper.make_node('ArgMax',
-                                     inputs=['a_np1'],
-                                     outputs=['out'],
-                                     axis=axis,
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node(
+            "ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims
+        )
 
-    graph = helper.make_graph([node],
-                              "argmax_test",
-                              inputs=[helper.make_tensor_value_info("a_np1",
-                                                                    TensorProto.INT32, list(a_np1.shape))],
-                              outputs=[helper.make_tensor_value_info("out",
-                                                                     TensorProto.INT32, list(b_np.shape))])
+    graph = helper.make_graph(
+        [node],
+        "argmax_test",
+        inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='argmax_test')
+    model = helper.make_model(graph, producer_name="argmax_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_forward_arg_min_max():
-    '''Verify argmin and argmax'''
+    """Verify argmin and argmax"""
     verify_argmin([3, 4, 4])
     verify_argmax([3, 4, 4])
     verify_argmin([3, 4, 4], axis=1)
@@ -1350,30 +1380,26 @@ def verify_constantofshape(input_dim, value, dtype):
     out = np.empty(shape=input_dim, dtype=dtype)
     out.fill(value)
 
-    fill_node = helper.make_node("ConstantOfShape", ["input"], ["output"],
-                                 value=helper.make_tensor(
-                                     'value',
-                                     mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
-                                     (1, ), (value, )))
+    fill_node = helper.make_node(
+        "ConstantOfShape",
+        ["input"],
+        ["output"],
+        value=helper.make_tensor(
+            "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,)
+        ),
+    )
 
-    inputs = [
-        helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)
-    ]
+    inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)]
 
     graph = helper.make_graph(
         [fill_node],
         "fill_test",
         inputs,
-        outputs=[
-            helper.make_tensor_value_info("output", TensorProto.FLOAT,
-                                          list(out.shape))
-        ],
-        initializer=[
-            helper.make_tensor("input", TensorProto.INT32, (len(input_dim), ),
-                               input_dim)
-        ])
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(out.shape))],
+        initializer=[helper.make_tensor("input", TensorProto.INT32, (len(input_dim),), input_dim)],
+    )
 
-    model = helper.make_model(graph, producer_name='fill_test')
+    model = helper.make_model(graph, producer_name="fill_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
@@ -1383,130 +1409,111 @@ def verify_constantofshape(input_dim, value, dtype):
 
 @tvm.testing.uses_gpu
 def test_constantofshape():
-    verify_constantofshape((2, 3, 4, 5), 10, 'float32')
-    verify_constantofshape((3, 3), 0, 'int32')
-    verify_constantofshape((1, 2, 3), -1, 'float32')
+    verify_constantofshape((2, 3, 4, 5), 10, "float32")
+    verify_constantofshape((3, 3), 0, "int32")
+    verify_constantofshape((1, 2, 3), -1, "float32")
 
 
-def verify_pad(indata, pads, mode='constant', value=0.0):
+def verify_pad(indata, pads, mode="constant", value=0.0):
     indata = np.array(indata).astype(np.float32)
     #  numpy expect result
     len_dim = len(pads) // 2
-    np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
+    np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)]
     #  onnx graph
-    if mode in ['edge', 'reflect']:
+    if mode in ["edge", "reflect"]:
         outdata = np.pad(indata, pad_width=np_pads, mode=mode)
         node = helper.make_node(
-            'Pad',
-            inputs=['input'],
-            outputs=['output'],
+            "Pad",
+            inputs=["input"],
+            outputs=["output"],
             mode=mode,
             pads=pads,
         )
     else:
-        outdata = np.pad(indata, pad_width=np_pads,
-                         mode='constant', constant_values=value)
+        outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value)
         node = helper.make_node(
-            'Pad',
-            inputs=['input'],
-            outputs=['output'],
-            mode='constant',
-            pads=pads,
-            value=value
+            "Pad", inputs=["input"], outputs=["output"], mode="constant", pads=pads, value=value
         )
-    graph = helper.make_graph([node],
-                              'pad_test',
-                              inputs=[helper.make_tensor_value_info("input",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("output",
-                                                                     TensorProto.FLOAT, list(outdata.shape))])
-    model = helper.make_model(graph, producer_name='pad_test')
+    graph = helper.make_graph(
+        [node],
+        "pad_test",
+        inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))],
+    )
+    model = helper.make_model(graph, producer_name="pad_test")
     #  tvm result
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, outdata.shape, 'float32', opset=2)
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=2)
     tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
 
-def verify_pad_v11(indata, pads, mode='constant', value=0.0):
+def verify_pad_v11(indata, pads, mode="constant", value=0.0):
     indata = np.array(indata).astype(np.float32)
     #  numpy expect result
     len_dim = len(pads) // 2
-    np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
+    np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)]
     pads = np.array(pads)
     #  onnx graph
-    if mode in ['edge', 'reflect']:
+    if mode in ["edge", "reflect"]:
         inputs = [indata, pads]
         outdata = np.pad(indata, pad_width=np_pads, mode=mode)
-        node = helper.make_node(
-            'Pad',
-            inputs=['input', 'pads'],
-            outputs=['output'],
-            mode=mode
+        node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode)
+        graph = helper.make_graph(
+            [node],
+            "pad_test",
+            inputs=[
+                helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)),
+                helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)),
+            ],
+            initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)],
+            outputs=[
+                helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))
+            ],
         )
-        graph = helper.make_graph([node],
-                                  'pad_test',
-                                  inputs=[helper.make_tensor_value_info("input",
-                                                                        TensorProto.FLOAT, list(indata.shape)),
-                                          helper.make_tensor_value_info("pads",
-                                                                        TensorProto.INT64,(len(pads),))],
-                                  initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)],
-                                  outputs=[helper.make_tensor_value_info("output",
-                                                                         TensorProto.FLOAT, list(outdata.shape))])
     else:
         inputs = [indata, pads, np.array([value])]
-        outdata = np.pad(indata, pad_width=np_pads,
-                         mode='constant', constant_values=value)
+        outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value)
         node = helper.make_node(
-            'Pad',
-            inputs=['input', 'pads', 'constant_value'],
-            outputs=['output'],
-            mode='constant'
+            "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant"
+        )
+        graph = helper.make_graph(
+            [node],
+            "pad_test",
+            inputs=[
+                helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)),
+                helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)),
+                helper.make_tensor_value_info("constant_value", TensorProto.INT64, (1,)),
+            ],
+            initializer=[
+                helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads),
+                helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value]),
+            ],
+            outputs=[
+                helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))
+            ],
         )
-        graph = helper.make_graph([node],
-                                  'pad_test',
-                                  inputs=[helper.make_tensor_value_info("input",
-                                                                        TensorProto.FLOAT, list(indata.shape)),
-                                          helper.make_tensor_value_info("pads",
-                                                                        TensorProto.INT64,(len(pads),)),
-                                          helper.make_tensor_value_info("constant_value",
-                                                                        TensorProto.INT64,(1,)),
-                                          ],
-                                  initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads),
-                                               helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])],
-                                  outputs=[helper.make_tensor_value_info("output",
-                                                                         TensorProto.FLOAT, list(outdata.shape))])
-    model = helper.make_model(graph, producer_name='pad_test')
+    model = helper.make_model(graph, producer_name="pad_test")
     #  tvm result
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, inputs, target, ctx, outdata.shape, 'float32', opset=11)
+        tvm_out = get_tvm_output(model, inputs, target, ctx, outdata.shape, "float32", opset=11)
     tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_pad():
-    verify_pad(np.random.randn(2, 2).astype(
-        np.float32), [0, 1, 0, 0], 'constant', 0.0)
-    verify_pad(np.random.randn(2, 3).astype(
-        np.float32), [1, 0, 0, 1], 'constant', 0.0)
-    verify_pad(np.random.randn(3, 2).astype(
-        np.float32), [0, 0, 1, 0], 'constant', 5.0)
-    verify_pad(np.random.randn(1, 3, 4, 5).astype(
-        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
-    verify_pad(np.random.randn(1, 3, 4, 5).astype(
-        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
-
-    verify_pad_v11(np.random.randn(2, 2).astype(
-        np.float32), [0, 1, 0, 0], 'constant', 0.0)
-    verify_pad_v11(np.random.randn(2, 3).astype(
-        np.float32), [1, 0, 0, 1], 'constant', 0.0)
-    verify_pad_v11(np.random.randn(3, 2).astype(
-        np.float32), [0, 0, 1, 0], 'constant', 5.0)
-    verify_pad_v11(np.random.randn(1, 3, 4, 5).astype(
-        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
-    verify_pad_v11(np.random.randn(1, 3, 4, 5).astype(
-        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
+    verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0)
+    verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0)
+    verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], "constant", 5.0)
+    verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "edge")
+    verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "reflect")
+
+    verify_pad_v11(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0)
+    verify_pad_v11(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0)
+    verify_pad_v11(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], "constant", 5.0)
+    verify_pad_v11(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "edge")
+    verify_pad_v11(
+        np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "reflect"
+    )
 
 
 def verify_reduce_func(func, data, axis, keepdims):
@@ -1514,67 +1521,67 @@ def verify_reduce_func(func, data, axis, keepdims):
     outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape
 
     if axis:
-        node = onnx.helper.make_node(func,
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     axes=axis,
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node(
+            func, inputs=["x"], outputs=["y"], axes=axis, keepdims=keepdims
+        )
     else:
-        node = onnx.helper.make_node(func,
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     keepdims=keepdims)
+        node = onnx.helper.make_node(func, inputs=["x"], outputs=["y"], keepdims=keepdims)
 
-    graph = helper.make_graph([node],
-                              "reduce_test",
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
+    graph = helper.make_graph(
+        [node],
+        "reduce_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))],
+    )
 
-    model = helper.make_model(graph, producer_name='reduce_test')
+    model = helper.make_model(graph, producer_name="reduce_test")
 
-    onnx_out = get_onnxruntime_output(model, data, 'float32')
+    onnx_out = get_onnxruntime_output(model, data, "float32")
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32')
+        tvm_out = get_tvm_output(model, data, target, ctx, outshape, "float32")
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_all_reduce_funcs():
-    funcs = ["ReduceMax",
-             "ReduceMean",
-             "ReduceMin",
-             "ReduceProd",
-             "ReduceSum",
-             'ReduceSumSquare',
-             "ReduceLogSum",
-             "ReduceLogSumExp",
-             "ReduceL1",
-             "ReduceL2"]
+    funcs = [
+        "ReduceMax",
+        "ReduceMean",
+        "ReduceMin",
+        "ReduceProd",
+        "ReduceSum",
+        "ReduceSumSquare",
+        "ReduceLogSum",
+        "ReduceLogSumExp",
+        "ReduceL1",
+        "ReduceL2",
+    ]
 
     for func in funcs:
         for keepdims in [True, False]:
-            verify_reduce_func(func,
-                               np.random.randn(3, 2, 2).astype(np.float32),
-                               axis=None, keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(3, 2, 2).astype(np.float32), axis=None, keepdims=keepdims
+            )
 
-            verify_reduce_func(func,
-                               np.random.randn(3, 2, 3).astype(np.float32),
-                               axis=None, keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(3, 2, 3).astype(np.float32), axis=None, keepdims=keepdims
+            )
 
-            verify_reduce_func(func,
-                               np.random.randn(3, 3, 3).astype(np.float32),
-                               axis=(1,), keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=keepdims
+            )
 
-            verify_reduce_func(func,
-                               np.random.randn(3, 3, 3, 1).astype(np.float32),
-                               axis=(1, 2), keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1, 2), keepdims=keepdims
+            )
 
-            verify_reduce_func(func,
-                               np.random.randn(3, 3, 3, 1).astype(np.float32),
-                               axis=(1,), keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1,), keepdims=keepdims
+            )
 
-            verify_reduce_func(func,
-                               np.random.randn(1, 3, 4, 1).astype(np.float32),
-                               axis=(1,), keepdims=keepdims)
+            verify_reduce_func(
+                func, np.random.randn(1, 3, 4, 1).astype(np.float32), axis=(1,), keepdims=keepdims
+            )
 
 
 def verify_split(indata, outdatas, split, axis=0):
@@ -1585,27 +1592,29 @@ def verify_split(indata, outdatas, split, axis=0):
     else:
         split_index = range(len(outdatas))
     node = helper.make_node(
-        'Split',
-        inputs=['input'],
-        outputs=['output_{}'.format(i) for i in range(len(split_index))],
+        "Split",
+        inputs=["input"],
+        outputs=["output_{}".format(i) for i in range(len(split_index))],
         axis=axis,
-        split=split
-    )
-    graph = helper.make_graph([node],
-                              'split_test',
-                              inputs=[helper.make_tensor_value_info("input",
-                                                                    TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("output_{}".format(i),
-                                                                     TensorProto.FLOAT, list(outdatas[i].shape))
-                                       for i in range(len(split_index))
-                                       ])
-    model = helper.make_model(graph, producer_name='split_test')
+        split=split,
+    )
+    graph = helper.make_graph(
+        [node],
+        "split_test",
+        inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[
+            helper.make_tensor_value_info(
+                "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape)
+            )
+            for i in range(len(split_index))
+        ],
+    )
+    model = helper.make_model(graph, producer_name="split_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         output_shape = [o.shape for o in outdatas]
-        output_type = ['float32', 'float32', 'float32']
-        tvm_out = get_tvm_output(
-            model, indata, target, ctx, output_shape, output_type)
+        output_type = ["float32", "float32", "float32"]
+        tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
     for o, t in zip(outdatas, tvm_out):
         tvm.testing.assert_allclose(o, t)
 
@@ -1613,13 +1622,15 @@ def verify_split(indata, outdatas, split, axis=0):
 @tvm.testing.uses_gpu
 def test_split():
     # 1D
-    verify_split([1., 2., 3., 4., 5., 6.], [
-                 [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
-    verify_split([1., 2., 3., 4., 5., 6.], [
-                 [1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
+    verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0)
+    verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0)
     # 2D
-    verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
-                 [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
+    verify_split(
+        [[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]],
+        [[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]],
+        [2, 2],
+        1,
+    )
     # Split evenly (unstack)
     verify_split([1, 2, 3], [[1], [2], [3]], False)
 
@@ -1630,20 +1641,21 @@ def test_binary_ops():
     dtype = "float32"
     out_shape = in_shape
 
-    def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None):
+    def verify_binary_ops(op, x, y, out_np, x_name="in1", y_name="in2", broadcast=None):
         if broadcast is None:
-            z = helper.make_node(op, [x_name, y_name], ['out'])
+            z = helper.make_node(op, [x_name, y_name], ["out"])
         else:
-            z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1)
-        graph = helper.make_graph([z],
-                                  '_test',
-                                  inputs=[helper.make_tensor_value_info(x_name,
-                                                                        TensorProto.FLOAT, list(in_shape)),
-                                          helper.make_tensor_value_info(y_name,
-                                                                        TensorProto.FLOAT, list(in_shape))],
-                                  outputs=[helper.make_tensor_value_info("out",
-                                                                         TensorProto.FLOAT, list(out_shape))])
-        model = helper.make_model(graph, producer_name='_test')
+            z = helper.make_node(op, [x_name, y_name], ["out"], broadcast=1)
+        graph = helper.make_graph(
+            [z],
+            "_test",
+            inputs=[
+                helper.make_tensor_value_info(x_name, TensorProto.FLOAT, list(in_shape)),
+                helper.make_tensor_value_info(y_name, TensorProto.FLOAT, list(in_shape)),
+            ],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+        )
+        model = helper.make_model(graph, producer_name="_test")
         for target, ctx in tvm.testing.enabled_targets():
             tvm_out = get_tvm_output(model, [x, y], target, ctx)
             tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
@@ -1652,12 +1664,12 @@ def test_binary_ops():
     y = np.random.uniform(size=in_shape).astype(dtype)
     z = np.random.uniform(size=(3,)).astype(dtype)
     verify_binary_ops("Add", x, y, x + y, broadcast=None)
-    verify_binary_ops("Add", x, z,  x + z, broadcast=True)
+    verify_binary_ops("Add", x, z, x + z, broadcast=True)
     verify_binary_ops("Sub", x, y, x - y, broadcast=None)
     verify_binary_ops("Sub", x, z, x - z, broadcast=True)
     verify_binary_ops("Mul", x, y, x * y, broadcast=None)
-    verify_binary_ops("Mul", x, z,  x * z, broadcast=True)
-    verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None)
+    verify_binary_ops("Mul", x, z, x * z, broadcast=True)
+    verify_binary_ops("Mul", x, x, x * x, x_name="in1", y_name="in1", broadcast=None)
     verify_binary_ops("Div", x, y, x / y, broadcast=None)
     verify_binary_ops("Div", x, z, x / z, broadcast=True)
     verify_binary_ops("Sum", x, y, x + y, broadcast=None)
@@ -1673,14 +1685,16 @@ def test_single_ops():
     out_shape = in_shape
 
     def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
-        z = helper.make_node(op, ['in1'], ['out'])
-        graph = helper.make_graph([z],
-                                  '_test',
-                                  inputs=[helper.make_tensor_value_info("in1",
-                                                                        TensorProto.FLOAT, list(in_shape)), ],
-                                  outputs=[helper.make_tensor_value_info("out",
-                                                                         TensorProto.FLOAT, list(out_shape))])
-        model = helper.make_model(graph, producer_name='_test')
+        z = helper.make_node(op, ["in1"], ["out"])
+        graph = helper.make_graph(
+            [z],
+            "_test",
+            inputs=[
+                helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)),
+            ],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+        )
+        model = helper.make_model(graph, producer_name="_test")
         for target, ctx in tvm.testing.enabled_targets():
             tvm_out = get_tvm_output(model, [x], target, ctx)
             tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
@@ -1688,7 +1702,7 @@ def test_single_ops():
     x = np.random.uniform(size=in_shape).astype(dtype)
     verify_single_ops("Neg", x, -x)
     verify_single_ops("Abs", x, np.abs(x))
-    verify_single_ops("Reciprocal", x, 1/x)
+    verify_single_ops("Reciprocal", x, 1 / x)
     verify_single_ops("Sqrt", x, np.sqrt(x))
     verify_single_ops("Relu", x, np.maximum(x, 0))
     verify_single_ops("Exp", x, np.exp(x))
@@ -1715,65 +1729,67 @@ def test_single_ops():
 def test_leaky_relu():
     def leaky_relu_x(x, alpha):
         return np.where(x >= 0, x, x * alpha)
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              leaky_relu_x,
-                              {'alpha': 0.25},
-                              'float32',
-                              'LeakyRelu',
-                              {'alpha': 0.25})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6), leaky_relu_x, {"alpha": 0.25}, "float32", "LeakyRelu", {"alpha": 0.25}
+    )
 
 
 @tvm.testing.uses_gpu
 def test_elu():
     def elu_x(x, alpha):
         return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              elu_x,
-                              {'alpha': 0.25},
-                              'float32',
-                              'Elu',
-                              {'alpha': 0.25})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6), elu_x, {"alpha": 0.25}, "float32", "Elu", {"alpha": 0.25}
+    )
 
 
 @tvm.testing.uses_gpu
 def test_selu():
     def selu_x(x, alpha, gamma):
         return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              selu_x,
-                              {'alpha': 0.25, 'gamma': 0.3},
-                              'float32',
-                              'Selu',
-                              {'alpha': 0.25, 'gamma': 0.3})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6),
+        selu_x,
+        {"alpha": 0.25, "gamma": 0.3},
+        "float32",
+        "Selu",
+        {"alpha": 0.25, "gamma": 0.3},
+    )
 
 
 @tvm.testing.uses_gpu
 def test_prelu():
     def verify_prelu(x_shape, a_shape):
-        node = helper.make_node('PRelu',
-                                inputs=['X', 'slope'],
-                                outputs=['Y'])
-
-        graph = helper.make_graph([node],
-                                  "prelu_test",
-                                  inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)),
-                                          helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape))],
-                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))])
+        node = helper.make_node("PRelu", inputs=["X", "slope"], outputs=["Y"])
+
+        graph = helper.make_graph(
+            [node],
+            "prelu_test",
+            inputs=[
+                helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)),
+                helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape)),
+            ],
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))],
+        )
 
-        model = helper.make_model(graph, producer_name='prelu_test')
+        model = helper.make_model(graph, producer_name="prelu_test")
 
         indata = np.random.uniform(-10, 10, x_shape).astype(np.float32)
         slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32)
         onnx_out = get_onnxruntime_output(model, [indata, slopedata])
 
-        for target, ctx in [('llvm', tvm.cpu())]:
-            tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape),
-                    output_dtype='float32')
+        for target, ctx in [("llvm", tvm.cpu())]:
+            tvm_out = get_tvm_output(
+                model, [indata, slopedata], target, ctx, list(x_shape), output_dtype="float32"
+            )
             tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
 
-    verify_prelu([3,4,5,6], [1, 4, 1, 1])
-    verify_prelu([1,8,5,6], [1, 8, 1, 1])
-    verify_prelu([2,12,16,16], [1, 12, 1, 1])
+    verify_prelu([3, 4, 5, 6], [1, 4, 1, 1])
+    verify_prelu([1, 8, 5, 6], [1, 8, 1, 1])
+    verify_prelu([2, 12, 16, 16], [1, 12, 1, 1])
 
 
 @tvm.testing.uses_gpu
@@ -1782,69 +1798,72 @@ def test_ThresholdedRelu():
         out_np = np.clip(x, alpha, np.inf)
         out_np[out_np == alpha] = 0
         return out_np
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              ThresholdedRelu_x,
-                              {'alpha': 0.25},
-                              'float32',
-                              'ThresholdedRelu',
-                              {'alpha': 0.25})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6),
+        ThresholdedRelu_x,
+        {"alpha": 0.25},
+        "float32",
+        "ThresholdedRelu",
+        {"alpha": 0.25},
+    )
 
 
 @tvm.testing.uses_gpu
 def test_ScaledTanh():
     def ScaledTanh_x(x, alpha, beta):
         return alpha * np.tanh(beta * x)
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              ScaledTanh_x,
-                              {'alpha': 0.25, 'beta': 0.3},
-                              'float32',
-                              'ScaledTanh',
-                              {'alpha': 0.25, 'beta': 0.3})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6),
+        ScaledTanh_x,
+        {"alpha": 0.25, "beta": 0.3},
+        "float32",
+        "ScaledTanh",
+        {"alpha": 0.25, "beta": 0.3},
+    )
 
 
 @tvm.testing.uses_gpu
 def test_ParametricSoftplus():
     def ParametricSoftplus_x(x, alpha, beta):
         return alpha * np.log(np.exp(beta * x) + 1)
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              ParametricSoftplus_x,
-                              {'alpha': 0.25, 'beta': 0.3},
-                              'float32',
-                              'ParametricSoftplus',
-                              {'alpha': 0.25, 'beta': 0.3})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6),
+        ParametricSoftplus_x,
+        {"alpha": 0.25, "beta": 0.3},
+        "float32",
+        "ParametricSoftplus",
+        {"alpha": 0.25, "beta": 0.3},
+    )
 
 
 @tvm.testing.uses_gpu
 def test_Scale():
     def Scale_x(x, scale):
         return scale * x
-    _test_onnx_op_elementwise((2, 4, 5, 6),
-                              Scale_x,
-                              {'scale': 0.25},
-                              'float32',
-                              'Scale',
-                              {'scale': 0.25})
+
+    _test_onnx_op_elementwise(
+        (2, 4, 5, 6), Scale_x, {"scale": 0.25}, "float32", "Scale", {"scale": 0.25}
+    )
 
 
 @tvm.testing.uses_gpu
 def test_LogSoftmax():
-    _test_onnx_op_elementwise((1, 4),
-                              tvm.topi.testing.log_softmax_python,
-                              {},
-                              'float32',
-                              'LogSoftmax',
-                              {'axis': 1})
+    _test_onnx_op_elementwise(
+        (1, 4), tvm.topi.testing.log_softmax_python, {}, "float32", "LogSoftmax", {"axis": 1}
+    )
 
 
 def check_torch_conversion(model, input_size):
     dummy_input = torch.randn(*input_size)
-    file_name = '{}.onnx'.format(model.__name__)
+    file_name = "{}.onnx".format(model.__name__)
     # Set verbose=True for more output
-    torch.onnx.export(model(), dummy_input, file_name,
-                      export_params=True, verbose=False)
+    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
     onnx_model = onnx.load(file_name)
     for target, ctx in tvm.testing.enabled_targets():
-        input_data = np.random.uniform(size=input_size).astype('int32')
+        input_data = np.random.uniform(size=input_size).astype("int32")
         c2_out = get_onnxruntime_output(onnx_model, input_data)
         tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
         tvm.testing.assert_allclose(c2_out, tvm_out)
@@ -1855,6 +1874,7 @@ def test_resnet():
     check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224))
     # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
 
+
 # def test_alexnet():
 # Torch's ONNX export does not support the adaptive pooling used by AlexNet?
 # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
@@ -1878,6 +1898,7 @@ def test_densenet():
 def test_inception():
     check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224))
 
+
 # TODO(@jroesch): Update Torch + ONNX to support this import.
 # def test_googlenet():
 #     check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
@@ -1891,27 +1912,28 @@ def test_inception():
 def test_sign():
     def Sign_x(x):
         return np.sign(x)
-    _test_onnx_op_elementwise((3, 4, 5, 6),
-                              Sign_x,
-                              {},
-                              'float32',
-                              'Sign',
-                              {})
+
+    _test_onnx_op_elementwise((3, 4, 5, 6), Sign_x, {}, "float32", "Sign", {})
 
 
 def verify_not(indata, dtype):
     x = indata.astype(dtype)
     outdata = np.logical_not(x)
 
-    node = helper.make_node('Not', inputs=['in'], outputs=['out'],)
+    node = helper.make_node(
+        "Not",
+        inputs=["in"],
+        outputs=["out"],
+    )
 
-    graph = helper.make_graph([node],
-                              'not_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.BOOL, list(x.shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
+    graph = helper.make_graph(
+        [node],
+        "not_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='not_test')
+    model = helper.make_model(graph, producer_name="not_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape)
@@ -1933,15 +1955,23 @@ def verify_and(indata, dtype):
     y = indata[1].astype(dtype)
     outdata = np.logical_and(x, y)
 
-    node = helper.make_node('And', inputs=['in1', 'in2'], outputs=['out'], )
+    node = helper.make_node(
+        "And",
+        inputs=["in1", "in2"],
+        outputs=["out"],
+    )
 
-    graph = helper.make_graph([node],
-                              'and_test',
-                              inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
-                                      helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
+    graph = helper.make_graph(
+        [node],
+        "and_test",
+        inputs=[
+            helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
+            helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='and_test')
+    model = helper.make_model(graph, producer_name="and_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
@@ -1951,98 +1981,87 @@ def verify_and(indata, dtype):
 @tvm.testing.uses_gpu
 def test_and():
     # 2d
-    x = (np.random.randn(3, 4) > 0)
-    y = (np.random.randn(3, 4) > 0)
+    x = np.random.randn(3, 4) > 0
+    y = np.random.randn(3, 4) > 0
     verify_and(indata=[x, y], dtype=bool)
 
     # 3d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(3, 4, 5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(3, 4, 5) > 0
     verify_and(indata=[x, y], dtype=bool)
 
     # 4d
-    x = (np.random.randn(3, 4, 5, 6) > 0)
-    y = (np.random.randn(3, 4, 5, 6) > 0)
+    x = np.random.randn(3, 4, 5, 6) > 0
+    y = np.random.randn(3, 4, 5, 6) > 0
     verify_and(indata=[x, y], dtype=bool)
 
     # 3d vs 1d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(5) > 0
     verify_and(indata=[x, y], dtype=bool)
 
     # 3d vs 2d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(4, 5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(4, 5) > 0
     verify_and(indata=[x, y], dtype=bool)
 
 
 def verify_tile_v1(indata, outdata, **kwargs):
-    node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
-    graph = helper.make_graph([node],
-                              'tile_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  "in", TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
+    node = helper.make_node("Tile", inputs=["in"], outputs=["out"], **kwargs)
+    graph = helper.make_graph(
+        [node],
+        "tile_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='tile_test')
+    model = helper.make_model(graph, producer_name="tile_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [indata], target, ctx, outdata.shape, opset=1)
+        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, opset=1)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 def verify_tile_v6(indata, repeats, outdata):
-    node = helper.make_node('Tile',
-                            inputs=['input', 'repeats'],
-                            outputs=['out'])
+    node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"])
     graph = helper.make_graph(
         [node],
-        'tile_test',
+        "tile_test",
         inputs=[
-            helper.make_tensor_value_info("input", TensorProto.FLOAT,
-                                          list(indata.shape)),
-            helper.make_tensor_value_info("repeats", TensorProto.INT64,
-                                          list(repeats.shape))
-        ],
-        outputs=[
-            helper.make_tensor_value_info("out", TensorProto.FLOAT,
-                                          list(outdata.shape))
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)),
+            helper.make_tensor_value_info("repeats", TensorProto.INT64, list(repeats.shape)),
         ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
         initializer=[
-            helper.make_tensor("repeats", TensorProto.INT64,
-                               list(repeats.shape), repeats)
-        ])
+            helper.make_tensor("repeats", TensorProto.INT64, list(repeats.shape), repeats)
+        ],
+    )
 
-    model = helper.make_model(graph, producer_name='tile_test')
+    model = helper.make_model(graph, producer_name="tile_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [indata],
-                                 target,
-                                 ctx,
-                                 outdata.shape,
-                                 opset=6)
+        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, opset=6)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
 def test_tile():
     x = np.random.rand(2, 3, 4, 5).astype(np.float32)
-    repeats = np.random.randint(
-        low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
+    repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
     z = np.tile(x, repeats)
     verify_tile_v1(x, z, repeats=repeats)
     verify_tile_v6(x, repeats, z)
 
 
 def verify_erf(indata, outdata):
-    node = helper.make_node('Erf', inputs=['in'], outputs=['out'])
-    graph = helper.make_graph([node],
-                              'erf_test',
-                              inputs=[helper.make_tensor_value_info(
-                                  'in', TensorProto.FLOAT, list(indata.shape))],
-                              outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))])
-    model = helper.make_model(graph, producer_name='erf_test')
+    node = helper.make_node("Erf", inputs=["in"], outputs=["out"])
+    graph = helper.make_graph(
+        [node],
+        "erf_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))],
+    )
+    model = helper.make_model(graph, producer_name="erf_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
@@ -2057,14 +2076,18 @@ def test_erf():
 
 
 def verify_where(condition, x, y, dtype, outdata):
-    node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
-    graph = helper.make_graph([node],
-                              'where_test',
-                              inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)),
-                                      helper.make_tensor_value_info('x', dtype, list(x.shape)),
-                                      helper.make_tensor_value_info('y', dtype, list(y.shape))],
-                              outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))])
-    model = helper.make_model(graph, producer_name='where_test')
+    node = helper.make_node("Where", inputs=["condition", "x", "y"], outputs=["out"])
+    graph = helper.make_graph(
+        [node],
+        "where_test",
+        inputs=[
+            helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)),
+            helper.make_tensor_value_info("x", dtype, list(x.shape)),
+            helper.make_tensor_value_info("y", dtype, list(y.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))],
+    )
+    model = helper.make_model(graph, producer_name="where_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
@@ -2111,15 +2134,23 @@ def verify_or(indata, dtype):
     y = indata[1].astype(dtype)
     outdata = np.logical_or(x, y)
 
-    node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], )
+    node = helper.make_node(
+        "Or",
+        inputs=["in1", "in2"],
+        outputs=["out"],
+    )
 
-    graph = helper.make_graph([node],
-                              'or_test',
-                              inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
-                                      helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
-                              outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
+    graph = helper.make_graph(
+        [node],
+        "or_test",
+        inputs=[
+            helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
+            helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='or_test')
+    model = helper.make_model(graph, producer_name="or_test")
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
@@ -2129,64 +2160,63 @@ def verify_or(indata, dtype):
 @tvm.testing.uses_gpu
 def test_or():
     # 2d
-    x = (np.random.randn(3, 4) > 0)
-    y = (np.random.randn(3, 4) > 0)
+    x = np.random.randn(3, 4) > 0
+    y = np.random.randn(3, 4) > 0
     verify_or(indata=[x, y], dtype=bool)
 
     # 3d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(3, 4, 5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(3, 4, 5) > 0
     verify_or(indata=[x, y], dtype=bool)
 
     # 4d
-    x = (np.random.randn(3, 4, 5, 6) > 0)
-    y = (np.random.randn(3, 4, 5, 6) > 0)
+    x = np.random.randn(3, 4, 5, 6) > 0
+    y = np.random.randn(3, 4, 5, 6) > 0
     verify_or(indata=[x, y], dtype=bool)
 
     # 3d vs 1d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(5) > 0
     verify_or(indata=[x, y], dtype=bool)
 
     # 3d vs 2d
-    x = (np.random.randn(3, 4, 5) > 0)
-    y = (np.random.randn(4, 5) > 0)
+    x = np.random.randn(3, 4, 5) > 0
+    y = np.random.randn(4, 5) > 0
     verify_or(indata=[x, y], dtype=bool)
 
 
 @tvm.testing.uses_gpu
 def test_batch_norm():
     def verify_batch_norm(in_shape):
-        batchnorm = onnx.helper.make_node('BatchNormalization',
-                                          inputs=["x", "scale", "B", "mean", "var"],
-                                          outputs=['Y'])
-
-        graph = helper.make_graph([batchnorm],
-                                  "batchnorm_test",
-                                  inputs=[helper.make_tensor_value_info("x",
-                                                                        TensorProto.FLOAT, list(in_shape)),
-                                          helper.make_tensor_value_info("scale",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("B",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("mean",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("var",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                         ],
-                                  outputs=[helper.make_tensor_value_info("Y",
-                                                                         TensorProto.FLOAT, list(in_shape))])
-
-        model = helper.make_model(graph, producer_name='batchnorm_test')
+        batchnorm = onnx.helper.make_node(
+            "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"]
+        )
+
+        graph = helper.make_graph(
+            [batchnorm],
+            "batchnorm_test",
+            inputs=[
+                helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)),
+                helper.make_tensor_value_info("scale", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("B", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("mean", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("var", TensorProto.FLOAT, [in_shape[1]]),
+            ],
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(in_shape))],
+        )
+
+        model = helper.make_model(graph, producer_name="batchnorm_test")
 
         for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=in_shape).astype('float32')
-            scale = np.random.uniform(size=in_shape[1]).astype('float32')
-            b = np.random.uniform(size=in_shape[1]).astype('float32')
-            mean = np.random.uniform(size=in_shape[1]).astype('float32')
-            var = np.random.uniform(size=in_shape[1]).astype('float32')
-            onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], 'float32')[0]
-            tvm_out = get_tvm_output(model, [x, scale, b, mean, var], target, ctx, in_shape, 'float32')
+            x = np.random.uniform(size=in_shape).astype("float32")
+            scale = np.random.uniform(size=in_shape[1]).astype("float32")
+            b = np.random.uniform(size=in_shape[1]).astype("float32")
+            mean = np.random.uniform(size=in_shape[1]).astype("float32")
+            var = np.random.uniform(size=in_shape[1]).astype("float32")
+            onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], "float32")[0]
+            tvm_out = get_tvm_output(
+                model, [x, scale, b, mean, var], target, ctx, in_shape, "float32"
+            )
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
     verify_batch_norm([1, 3, 224, 224])
@@ -2199,91 +2229,108 @@ def test_batch_norm():
 @tvm.testing.uses_gpu
 def test_batch_norm_dynamic_subgraph():
     def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):
-        batchnorm = onnx.helper.make_node('BatchNormalization',
-                                          inputs=["x", "scale", "B", "mean", "var"],
-                                          outputs=['Y'])
+        batchnorm = onnx.helper.make_node(
+            "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"]
+        )
 
-        shape_node = helper.make_node("Shape", ['Y'], ['shape'])
+        shape_node = helper.make_node("Shape", ["Y"], ["shape"])
         reshape_node = helper.make_node("Reshape", ["in", "shape"], ["out"])
-        graph = helper.make_graph([batchnorm, shape_node, reshape_node],
-                                  "batchnorm_test",
-                                  inputs=[helper.make_tensor_value_info("x",
-                                                                        TensorProto.FLOAT, list(in_shape)),
-                                          helper.make_tensor_value_info("in",
-                                                                        TensorProto.FLOAT, list(o_shape)),
-                                          helper.make_tensor_value_info("scale",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("B",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("mean",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                          helper.make_tensor_value_info("var",
-                                                                        TensorProto.FLOAT, [in_shape[1]]),
-                                         ],
-                                  outputs=[helper.make_tensor_value_info("out",
-                                                                         TensorProto.FLOAT, list(in_shape))])
-
-        model = helper.make_model(graph, producer_name='batchnorm_test')
+        graph = helper.make_graph(
+            [batchnorm, shape_node, reshape_node],
+            "batchnorm_test",
+            inputs=[
+                helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)),
+                helper.make_tensor_value_info("in", TensorProto.FLOAT, list(o_shape)),
+                helper.make_tensor_value_info("scale", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("B", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("mean", TensorProto.FLOAT, [in_shape[1]]),
+                helper.make_tensor_value_info("var", TensorProto.FLOAT, [in_shape[1]]),
+            ],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(in_shape))],
+        )
+
+        model = helper.make_model(graph, producer_name="batchnorm_test")
 
         for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=in_shape).astype('float32')
-            inp = np.random.uniform(size=o_shape).astype('float32')
-            scale = np.random.uniform(size=in_shape[1]).astype('float32')
-            b = np.random.uniform(size=in_shape[1]).astype('float32')
-            mean = np.random.uniform(size=in_shape[1]).astype('float32')
-            var = np.random.uniform(size=in_shape[1]).astype('float32')
-            onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0]
-            tvm_out = get_tvm_output(model, [x, inp, scale, b, mean, var], target, ctx, in_shape, 'float32')
+            x = np.random.uniform(size=in_shape).astype("float32")
+            inp = np.random.uniform(size=o_shape).astype("float32")
+            scale = np.random.uniform(size=in_shape[1]).astype("float32")
+            b = np.random.uniform(size=in_shape[1]).astype("float32")
+            mean = np.random.uniform(size=in_shape[1]).astype("float32")
+            var = np.random.uniform(size=in_shape[1]).astype("float32")
+            onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], "float32")[0]
+            tvm_out = get_tvm_output(
+                model, [x, inp, scale, b, mean, var], target, ctx, in_shape, "float32"
+            )
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
     verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])
 
 
-def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False):
+def verify_conv(
+    x_shape,
+    w_shape,
+    y_shape,
+    padding,
+    kernel_shape,
+    strides,
+    dilations,
+    auto_pad="NOTSET",
+    unset_pad=False,
+):
     if unset_pad:
-        node = helper.make_node('Conv',
-                                inputs=['x', 'W'],
-                                outputs=['y'],
-                                kernel_shape=kernel_shape,
-                                # Default values for other attributes:
-                                strides=strides,
-                                dilations=dilations,
-                                # groups=1
-                                )
+        node = helper.make_node(
+            "Conv",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            # groups=1
+        )
     elif padding is None:
-        node = helper.make_node('Conv',
-                                inputs=['x', 'W'],
-                                outputs=['y'],
-                                kernel_shape=kernel_shape,
-                                # Default values for other attributes:
-                                strides=strides,
-                                dilations=dilations,
-                                # groups=1
-                                auto_pad=auto_pad)
+        node = helper.make_node(
+            "Conv",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            # groups=1
+            auto_pad=auto_pad,
+        )
     else:
-        node = helper.make_node('Conv',
-                                inputs=['x', 'W'],
-                                outputs=['y'],
-                                kernel_shape=kernel_shape,
-                                # Default values for other attributes:
-                                strides=strides,
-                                dilations=dilations,
-                                # groups=1
-                                pads=padding)
-
-    graph = helper.make_graph([node],
-                              'conv_test',
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
-                                      helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))])
-
-    model = helper.make_model(graph, producer_name='conv_test')
+        node = helper.make_node(
+            "Conv",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            # groups=1
+            pads=padding,
+        )
+
+    graph = helper.make_graph(
+        [node],
+        "conv_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
+            helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="conv_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=x_shape).astype('float32')
-        W = np.random.uniform(size=w_shape).astype('float32')
+        x = np.random.uniform(size=x_shape).astype("float32")
+        W = np.random.uniform(size=w_shape).astype("float32")
         tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
-        onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0]
+        onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0]
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -2291,90 +2338,112 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat
 def test_conv():
     def repeat(N, D):
         return tuple([N for _ in range(D)])
+
     for D in [1, 2, 3]:
         # Convolution with padding
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(5, D),
-                    2 * repeat(1, D),
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(1, D))
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(5, D),
+            2 * repeat(1, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+        )
         # Convolution without padding
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(3, D),
-                    2 * repeat(0, D),
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(1, D))
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(3, D),
+            2 * repeat(0, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+        )
         # Convolution with autopadding
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(5, D),
-                    None,
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(1, D),
-                    auto_pad="SAME_UPPER")
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(5, D),
+            None,
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            auto_pad="SAME_UPPER",
+        )
         # Convolution with valid autopadding
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(3, D),
-                    None,
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(1, D),
-                    auto_pad="VALID")
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(3, D),
+            None,
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            auto_pad="VALID",
+        )
         # Convolution with unset padding
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(3, D),
-                    2 * repeat(0, D),
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(1, D),
-                    True)
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(3, D),
+            2 * repeat(0, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            True,
+        )
         # Convolution with non uniform stride
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(3, D),
-                    None,
-                    repeat(3, D),
-                    repeat(2, D),
-                    repeat(1, D),
-                    auto_pad="SAME_UPPER")
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(3, D),
+            None,
+            repeat(3, D),
+            repeat(2, D),
+            repeat(1, D),
+            auto_pad="SAME_UPPER",
+        )
         # Convolution with dilation
-        verify_conv((1, 1) + repeat(5, D),
-                    (1, 1) + repeat(3, D),
-                    (1, 1) + repeat(5, D),
-                    2 * repeat(2, D),
-                    repeat(3, D),
-                    repeat(1, D),
-                    repeat(2, D))
+        verify_conv(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(5, D),
+            2 * repeat(2, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(2, D),
+        )
+
 
 def verify_convtranspose(x_shape, w_shape, y_shape, p):
-    node = onnx.helper.make_node("ConvTranspose",
-                                 inputs=["x", "W"],
-                                 outputs=['y'],
-                                 strides=[3, 2],
-                                 group=1,
-                                 kernel_shape=[3, 3],
-                                 pads=p)
-
-    graph = helper.make_graph([node],
-                              'verify_convtranspose_test',
-                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
-                                      helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))],
-                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))])
-
-    model = helper.make_model(graph, producer_name='convtranspose_trest')
+    node = onnx.helper.make_node(
+        "ConvTranspose",
+        inputs=["x", "W"],
+        outputs=["y"],
+        strides=[3, 2],
+        group=1,
+        kernel_shape=[3, 3],
+        pads=p,
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "verify_convtranspose_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
+            helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="convtranspose_trest")
 
     for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=x_shape).astype('float32')
-        W = np.random.uniform(size=w_shape).astype('float32')
+        x = np.random.uniform(size=x_shape).astype("float32")
+        W = np.random.uniform(size=w_shape).astype("float32")
         tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
-        onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0]
+        onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0]
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -2391,11 +2460,13 @@ def test_convtranspose():
 @tvm.testing.uses_gpu
 def test_unsqueeze_constant():
     from torch.nn import Linear, Sequential, Module
+
     class Flatten(Module):
         def forward(self, input):
             return input.view(input.size(0), -1)
 
     import tempfile
+
     with tempfile.NamedTemporaryFile() as fp:
         file_name = fp.name
         input_size = (1, 16, 32, 32)
@@ -2404,156 +2475,185 @@ def test_unsqueeze_constant():
         torch.onnx.export(layer, dummy_input, file_name, export_params=True)
 
         onnx_model = onnx.load(file_name)
-        relay.frontend.from_onnx(onnx_model, {'0': input_size})
+        relay.frontend.from_onnx(onnx_model, {"0": input_size})
 
 
 def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"):
-    x_np = np.random.uniform(size=x_shape).astype('float32')
+    x_np = np.random.uniform(size=x_shape).astype("float32")
 
-    if mode == 'max':
+    if mode == "max":
         node_type = "MaxPool"
-    elif mode == 'average':
+    elif mode == "average":
         node_type = "AveragePool"
     else:
         raise ValueError("Pool method {} is not supported.".format(mode))
 
     pool_node = helper.make_node(
-        node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides)
+        node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides
+    )
 
     if pads is None:
-        pad_attr = helper.make_attribute('auto_pad', auto_pad)
+        pad_attr = helper.make_attribute("auto_pad", auto_pad)
     else:
-        pad_attr = helper.make_attribute('pads', pads)
+        pad_attr = helper.make_attribute("pads", pads)
     pool_node.attribute.append(pad_attr)
 
-    if mode == 'max':
-        storage_attr = helper.make_attribute('storage_order', 0)
+    if mode == "max":
+        storage_attr = helper.make_attribute("storage_order", 0)
         pool_node.attribute.append(storage_attr)
 
-    graph = helper.make_graph([pool_node],
-                              "pooling_test",
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    TensorProto.FLOAT, list(x_shape))],
-                              outputs=[helper.make_tensor_value_info("y",
-                                                                     TensorProto.FLOAT, list(out_shape))])
+    graph = helper.make_graph(
+        [pool_node],
+        "pooling_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))],
+    )
 
-    model = helper.make_model(graph, producer_name='pooling_test')
+    model = helper.make_model(graph, producer_name="pooling_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, x_np, 'float32')
-        tvm_out = get_tvm_output(
-            model, [x_np], target, ctx, out_shape)
+        onnx_out = get_onnxruntime_output(model, x_np, "float32")
+        tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape)
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_pooling():
-    for mode in ['max', 'average']:
+    for mode in ["max", "average"]:
         # Pool1D
-        verify_pooling(x_shape=[1, 1, 32],
-                       kernel_shape=[3],
-                       strides=[1],
-                       pads=[1, 1],
-                       out_shape=[1, 1, 32],
-                       mode=mode)
+        verify_pooling(
+            x_shape=[1, 1, 32],
+            kernel_shape=[3],
+            strides=[1],
+            pads=[1, 1],
+            out_shape=[1, 1, 32],
+            mode=mode,
+        )
         # Pool2D
-        verify_pooling(x_shape=[1, 1, 32, 32],
-                       kernel_shape=[3, 3],
-                       strides=[1, 1],
-                       pads=[1, 1, 1, 1],
-                       out_shape=[1, 1, 32, 32],
-                       mode=mode)
+        verify_pooling(
+            x_shape=[1, 1, 32, 32],
+            kernel_shape=[3, 3],
+            strides=[1, 1],
+            pads=[1, 1, 1, 1],
+            out_shape=[1, 1, 32, 32],
+            mode=mode,
+        )
 
         # Pool1D with stride
-        verify_pooling(x_shape=[1, 1, 32],
-                       kernel_shape=[3],
-                       strides=[2],
-                       pads=[1, 1],
-                       out_shape=[1, 1, 16],
-                       mode=mode)
+        verify_pooling(
+            x_shape=[1, 1, 32],
+            kernel_shape=[3],
+            strides=[2],
+            pads=[1, 1],
+            out_shape=[1, 1, 16],
+            mode=mode,
+        )
         # Pool2D with stride
-        verify_pooling(x_shape=[1, 1, 32, 32],
-                       kernel_shape=[3, 3],
-                       strides=[2, 2],
-                       pads=[1, 1, 1, 1],
-                       out_shape=[1, 1, 16, 16],
-                       mode=mode)
+        verify_pooling(
+            x_shape=[1, 1, 32, 32],
+            kernel_shape=[3, 3],
+            strides=[2, 2],
+            pads=[1, 1, 1, 1],
+            out_shape=[1, 1, 16, 16],
+            mode=mode,
+        )
 
         # Pool1D with stride and autopadding
-        verify_pooling(x_shape=[1, 1, 32],
-                       kernel_shape=[3],
-                       strides=[2],
-                       pads=None,
-                       out_shape=[1, 1, 16],
-                       mode=mode,
-                       auto_pad='SAME_UPPER')
+        verify_pooling(
+            x_shape=[1, 1, 32],
+            kernel_shape=[3],
+            strides=[2],
+            pads=None,
+            out_shape=[1, 1, 16],
+            mode=mode,
+            auto_pad="SAME_UPPER",
+        )
         # Pool2D with stride and autopadding
-        verify_pooling(x_shape=[1, 1, 32, 32],
-                       kernel_shape=[3, 3],
-                       strides=[2, 2],
-                       pads=None,
-                       out_shape=[1, 1, 16, 16],
-                       mode=mode,
-                       auto_pad='SAME_UPPER')
+        verify_pooling(
+            x_shape=[1, 1, 32, 32],
+            kernel_shape=[3, 3],
+            strides=[2, 2],
+            pads=None,
+            out_shape=[1, 1, 16, 16],
+            mode=mode,
+            auto_pad="SAME_UPPER",
+        )
 
         # Pool3D with stride
-        verify_pooling(x_shape=[1, 1, 32, 32, 32],
-                       kernel_shape=[3, 3, 3],
-                       strides=[2, 2, 2],
-                       pads=[1, 1, 1, 1, 1, 1],
-                       out_shape=[1, 1, 16, 16, 16],
-                       mode=mode)
+        verify_pooling(
+            x_shape=[1, 1, 32, 32, 32],
+            kernel_shape=[3, 3, 3],
+            strides=[2, 2, 2],
+            pads=[1, 1, 1, 1, 1, 1],
+            out_shape=[1, 1, 16, 16, 16],
+            mode=mode,
+        )
 
         # Pool3D with stride and autopadding
-        verify_pooling(x_shape=[1, 1, 32, 32, 32],
-                       kernel_shape=[3, 3, 3],
-                       strides=[2, 2, 2],
-                       pads=None,
-                       out_shape=[1, 1, 16, 16, 16],
-                       mode=mode,
-                       auto_pad='SAME_UPPER')
+        verify_pooling(
+            x_shape=[1, 1, 32, 32, 32],
+            kernel_shape=[3, 3, 3],
+            strides=[2, 2, 2],
+            pads=None,
+            out_shape=[1, 1, 16, 16, 16],
+            mode=mode,
+            auto_pad="SAME_UPPER",
+        )
 
 
-def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'):
+def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"):
     x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype)
     y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype)
-    y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error
+    y_np = np.where(y_np == 0, 1, y_np)  # remove 0's to avoid division by zero error
 
-    mod_node = helper.make_node("Mod",
-                                inputs=["x", "y"],
-                                outputs=["z"],
-                                fmod=fmod)
+    mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], fmod=fmod)
 
     onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32
-    graph = helper.make_graph([mod_node],
-                              "mod_test",
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    onnx_dtype, list(x_shape)),
-                                      helper.make_tensor_value_info("y",
-                                                                    onnx_dtype, list(y_shape))],
-                              outputs=[helper.make_tensor_value_info("z",
-                                                                    onnx_dtype, list(out_shape))])
-    model = helper.make_model(graph, producer_name='mod_test')
+    graph = helper.make_graph(
+        [mod_node],
+        "mod_test",
+        inputs=[
+            helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)),
+            helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))],
+    )
+    model = helper.make_model(graph, producer_name="mod_test")
 
     onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0]
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x_np, y_np], target, ctx, out_shape)
+        tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape)
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_mod():
     # Mod
-    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32")
-    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=0, out_shape=(1, 32, 32, 32), dtype="int32")
+    verify_mod(
+        x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32"
+    )
+    verify_mod(
+        x_shape=[1, 32, 32, 32],
+        y_shape=[1, 32, 32, 32],
+        fmod=0,
+        out_shape=(1, 32, 32, 32),
+        dtype="int32",
+    )
 
     # fmod
-    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32")
+    verify_mod(
+        x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32"
+    )
     verify_mod(x_shape=[1, 1, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
     verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 1, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
-    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32), dtype="int32")
+    verify_mod(
+        x_shape=[1, 32, 32, 32],
+        y_shape=[1, 32, 32, 32],
+        fmod=1,
+        out_shape=(1, 32, 32, 32),
+        dtype="int32",
+    )
     verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
 
 
@@ -2564,24 +2664,22 @@ def verify_xor(x_shape, y_shape):
     np_out = np.logical_xor(x_np, y_np)
     out_shape = np_out.shape
 
-    xor_node = helper.make_node("Xor",
-                                inputs=["x", "y"],
-                                outputs=["z"])
+    xor_node = helper.make_node("Xor", inputs=["x", "y"], outputs=["z"])
 
     onnx_dtype = TensorProto.BOOL
-    graph = helper.make_graph([xor_node],
-                              "xor_test",
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    onnx_dtype, list(x_shape)),
-                                      helper.make_tensor_value_info("y",
-                                                                    onnx_dtype, list(y_shape))],
-                              outputs=[helper.make_tensor_value_info("z",
-                                                                    onnx_dtype, list(out_shape))])
-    model = helper.make_model(graph, producer_name='xor_test')
+    graph = helper.make_graph(
+        [xor_node],
+        "xor_test",
+        inputs=[
+            helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)),
+            helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))],
+    )
+    model = helper.make_model(graph, producer_name="xor_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x_np, y_np], target, ctx, out_shape)
+        tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape)
         tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -2595,199 +2693,248 @@ def test_xor():
 
 
 def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape):
-    x_np = np.random.uniform(size=x_shape).astype('float32')
-    rois_np = np.random.uniform(size=rois_shape).astype('float32')
+    x_np = np.random.uniform(size=x_shape).astype("float32")
+    rois_np = np.random.uniform(size=rois_shape).astype("float32")
 
     if spatial_scale is None:
-        pool_node = helper.make_node("MaxRoiPool",
-                                     inputs=["x", "rois"],
-                                     outputs=["y"],
-                                     pooled_shape=pooled_shape)
+        pool_node = helper.make_node(
+            "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape
+        )
     else:
-        pool_node = helper.make_node("MaxRoiPool",
-                                     inputs=["x", "rois"],
-                                     outputs=["y"],
-                                     pooled_shape=pooled_shape,
-                                     spatial_scale=spatial_scale)
-
-    graph = helper.make_graph([pool_node],
-                              "pool_test",
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    TensorProto.FLOAT, list(x_shape)),
-                                      helper.make_tensor_value_info("rois",
-                                                                    TensorProto.FLOAT, list(rois_shape))],
-                              outputs=[helper.make_tensor_value_info("y",
-                                                                     TensorProto.FLOAT, list(out_shape))])
-
-    model = helper.make_model(graph, producer_name='pool_test')
-
-    onnx_out = get_onnxruntime_output(model, [x_np, rois_np], 'float32')[0]
+        pool_node = helper.make_node(
+            "MaxRoiPool",
+            inputs=["x", "rois"],
+            outputs=["y"],
+            pooled_shape=pooled_shape,
+            spatial_scale=spatial_scale,
+        )
+
+    graph = helper.make_graph(
+        [pool_node],
+        "pool_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
+            helper.make_tensor_value_info("rois", TensorProto.FLOAT, list(rois_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="pool_test")
+
+    onnx_out = get_onnxruntime_output(model, [x_np, rois_np], "float32")[0]
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(
-            model, [x_np, rois_np], target, ctx, out_shape)
+        tvm_out = get_tvm_output(model, [x_np, rois_np], target, ctx, out_shape)
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_max_roi_pool():
-    verify_max_roi_pool(x_shape=[1, 3, 6, 6],
-                        rois_shape=[3, 5],
-                        pooled_shape=[1, 1],
-                        spatial_scale=None,
-                        out_shape=[3, 3, 1, 1])
+    verify_max_roi_pool(
+        x_shape=[1, 3, 6, 6],
+        rois_shape=[3, 5],
+        pooled_shape=[1, 1],
+        spatial_scale=None,
+        out_shape=[3, 3, 1, 1],
+    )
 
-    verify_max_roi_pool(x_shape=[1, 3, 10, 10],
-                        rois_shape=[4, 5],
-                        pooled_shape=[2, 2],
-                        spatial_scale=2.0,
-                        out_shape=[4, 3, 2, 2])
+    verify_max_roi_pool(
+        x_shape=[1, 3, 10, 10],
+        rois_shape=[4, 5],
+        pooled_shape=[2, 2],
+        spatial_scale=2.0,
+        out_shape=[4, 3, 2, 2],
+    )
 
 
 def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"):
-    x_np = np.random.uniform(size=x_shape).astype('float32')
+    x_np = np.random.uniform(size=x_shape).astype("float32")
 
     if pads is None:
-        pool_node = helper.make_node("LpPool",
-                                    inputs=["x"],
-                                    outputs=["y"],
-                                    kernel_shape=kernel_shape,
-                                    p = p,
-                                    auto_pad=auto_pad,
-                                    strides=strides)
+        pool_node = helper.make_node(
+            "LpPool",
+            inputs=["x"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            p=p,
+            auto_pad=auto_pad,
+            strides=strides,
+        )
     else:
-        pool_node = helper.make_node("LpPool",
-                                    inputs=["x"],
-                                    outputs=["y"],
-                                    kernel_shape=kernel_shape,
-                                    p = p,
-                                    pads=pads,
-                                    strides=strides)
-
-    graph = helper.make_graph([pool_node],
-                              "lppool_test",
-                              inputs=[helper.make_tensor_value_info("x",
-                                                                    TensorProto.FLOAT, list(x_shape))],
-                              outputs=[helper.make_tensor_value_info("y",
-                                                                     TensorProto.FLOAT, list(out_shape))])
-
-    model = helper.make_model(graph, producer_name='lppool_test')
+        pool_node = helper.make_node(
+            "LpPool",
+            inputs=["x"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            p=p,
+            pads=pads,
+            strides=strides,
+        )
+
+    graph = helper.make_graph(
+        [pool_node],
+        "lppool_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="lppool_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, x_np, 'float32')
-        tvm_out = get_tvm_output(
-            model, [x_np], target, ctx, out_shape)
+        onnx_out = get_onnxruntime_output(model, x_np, "float32")
+        tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape)
         tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_lppool():
     # Pool1D
-    verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1],
-                  out_shape=[1, 1, 32])
+    verify_lppool(
+        x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], out_shape=[1, 1, 32]
+    )
 
     # Pool2D
-    verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[1, 1],
-                  pads=[1, 1, 1, 1], out_shape=[1, 1, 32, 32])
+    verify_lppool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        p=2,
+        strides=[1, 1],
+        pads=[1, 1, 1, 1],
+        out_shape=[1, 1, 32, 32],
+    )
 
     # Pool1D with stride
-    verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1],
-                  out_shape=[1, 1, 16])
+    verify_lppool(
+        x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1], out_shape=[1, 1, 16]
+    )
 
     # Pool2D with stride
-    verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2],
-                  pads=[1, 1, 1, 1], out_shape=[1, 1, 16, 16])
+    verify_lppool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        p=2,
+        strides=[2, 2],
+        pads=[1, 1, 1, 1],
+        out_shape=[1, 1, 16, 16],
+    )
 
     # Pool1D with stride and autopadding
-    verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=None,
-                  out_shape=[1, 1, 16], auto_pad='SAME_UPPER')
+    verify_lppool(
+        x_shape=[1, 1, 32],
+        kernel_shape=[3],
+        p=2,
+        strides=[2],
+        pads=None,
+        out_shape=[1, 1, 16],
+        auto_pad="SAME_UPPER",
+    )
 
     # Pool2D with stride and autopadding
-    verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2],
-                  pads=None, out_shape=[1, 1, 16, 16], auto_pad='SAME_UPPER')
+    verify_lppool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        p=2,
+        strides=[2, 2],
+        pads=None,
+        out_shape=[1, 1, 16, 16],
+        auto_pad="SAME_UPPER",
+    )
 
     # Pool3D with stride
-    verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2],
-                  pads=[1, 1, 1, 1, 1, 1], out_shape=[1, 1, 16, 16, 16])
+    verify_lppool(
+        x_shape=[1, 1, 32, 32, 32],
+        kernel_shape=[3, 3, 3],
+        p=2,
+        strides=[2, 2, 2],
+        pads=[1, 1, 1, 1, 1, 1],
+        out_shape=[1, 1, 16, 16, 16],
+    )
 
     # Pool3D with stride and autopadding
-    verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2],
-                  pads=None, out_shape=[1, 1, 16, 16, 16], auto_pad='SAME_UPPER')
-
-
-def verify_rnn(seq_length,
-               batch_size,
-               input_size,
-               hidden_size,
-               rnn_type='LSTM',
-               use_bias=False,
-               activations=None,
-               alphas=None,
-               betas=None,
-               use_initial_state=False,
-               use_peep=False,
-               linear_before_reset=False):
-    if rnn_type == 'LSTM':
+    verify_lppool(
+        x_shape=[1, 1, 32, 32, 32],
+        kernel_shape=[3, 3, 3],
+        p=2,
+        strides=[2, 2, 2],
+        pads=None,
+        out_shape=[1, 1, 16, 16, 16],
+        auto_pad="SAME_UPPER",
+    )
+
+
+def verify_rnn(
+    seq_length,
+    batch_size,
+    input_size,
+    hidden_size,
+    rnn_type="LSTM",
+    use_bias=False,
+    activations=None,
+    alphas=None,
+    betas=None,
+    use_initial_state=False,
+    use_peep=False,
+    linear_before_reset=False,
+):
+    if rnn_type == "LSTM":
         multiplier = 4
-    elif rnn_type == 'GRU':
+    elif rnn_type == "GRU":
         multiplier = 3
     else:
         raise NotImplementedError("%s RNNs not yet supported." % rnn_type)
-    x_np = np.random.uniform(size=(seq_length, batch_size,
-                                   input_size)).astype('float32')
-    w_np = np.random.uniform(size=(1, multiplier * hidden_size,
-                                   input_size)).astype('float32')
-    r_np = np.random.uniform(size=(1, multiplier * hidden_size,
-                                   hidden_size)).astype('float32')
+    x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32")
+    w_np = np.random.uniform(size=(1, multiplier * hidden_size, input_size)).astype("float32")
+    r_np = np.random.uniform(size=(1, multiplier * hidden_size, hidden_size)).astype("float32")
     input_names = ["X", "W", "R"]
     input_tensors = [
         helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)),
         helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)),
-        helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape))
+        helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape)),
     ]
     input_values = [x_np, w_np, r_np]
 
     if use_bias:
-        b_np = np.random.uniform(size=(1, multiplier * 2 *
-                                       hidden_size)).astype('float32')
+        b_np = np.random.uniform(size=(1, multiplier * 2 * hidden_size)).astype("float32")
         input_names.append("B")
         input_tensors.append(
-            helper.make_tensor_value_info("B", TensorProto.FLOAT,
-                                          [1, multiplier * 2 * hidden_size]))
+            helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, multiplier * 2 * hidden_size])
+        )
         input_values.append(b_np)
 
     if use_initial_state:
         assert use_bias == True, "Initial states must have bias specified."
-        sequence_np = np.repeat(seq_length, batch_size).astype('int32')
+        sequence_np = np.repeat(seq_length, batch_size).astype("int32")
         input_names.append("sequence_lens")
         input_tensors.append(
-            helper.make_tensor_value_info("sequence_lens", TensorProto.INT32,
-                                          [batch_size]))
+            helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size])
+        )
         input_values.append(sequence_np)
 
-        initial_h_np = np.random.uniform(size=(1, batch_size,
-                                               hidden_size)).astype('float32')
+        initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32")
         input_names.append("initial_h")
         input_tensors.append(
-            helper.make_tensor_value_info("initial_h", TensorProto.FLOAT,
-                                          [1, batch_size, hidden_size]))
+            helper.make_tensor_value_info(
+                "initial_h", TensorProto.FLOAT, [1, batch_size, hidden_size]
+            )
+        )
         input_values.append(initial_h_np)
 
-        if rnn_type == 'LSTM':
-            initial_c_np = np.random.uniform(
-                size=(1, batch_size, hidden_size)).astype('float32')
+        if rnn_type == "LSTM":
+            initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32")
             input_names.append("initial_c")
             input_tensors.append(
-                helper.make_tensor_value_info("initial_c", TensorProto.FLOAT,
-                                              [1, batch_size, hidden_size]))
+                helper.make_tensor_value_info(
+                    "initial_c", TensorProto.FLOAT, [1, batch_size, hidden_size]
+                )
+            )
             input_values.append(initial_c_np)
 
-    if use_peep and rnn_type == 'LSTM':
+    if use_peep and rnn_type == "LSTM":
         assert use_initial_state == True, "Peepholes require initial state to be specified."
-        p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype('float32')
+        p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype("float32")
         input_names.append("P")
         input_tensors.append(
-            helper.make_tensor_value_info("P", TensorProto.FLOAT,
-                                          [1, 3 * hidden_size]))
+            helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size])
+        )
         input_values.append(p_np)
 
     Y_shape = [seq_length, 1, batch_size, hidden_size]
@@ -2795,49 +2942,48 @@ def verify_rnn(seq_length,
     outputs = ["Y", "Y_h"]
     graph_outputs = [
         helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(Y_shape)),
-        helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape))
+        helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape)),
     ]
     output_shapes = [Y_shape, Y_h_shape]
 
-    if rnn_type == 'LSTM':
+    if rnn_type == "LSTM":
         Y_c_shape = [1, batch_size, hidden_size]
         outputs.append("Y_c")
         graph_outputs.append(
-            helper.make_tensor_value_info("Y_c", TensorProto.FLOAT,
-                                          list(Y_c_shape)))
+            helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, list(Y_c_shape))
+        )
         output_shapes.append(Y_c_shape)
 
     rnn_node = helper.make_node(
-        rnn_type, inputs=input_names, outputs=outputs, hidden_size=hidden_size)
+        rnn_type, inputs=input_names, outputs=outputs, hidden_size=hidden_size
+    )
     if activations is not None:
-        activations_attr = helper.make_attribute('activations', activations)
+        activations_attr = helper.make_attribute("activations", activations)
         rnn_node.attribute.append(activations_attr)
     if alphas is not None:
-        alphas_attr = helper.make_attribute('activation_alpha', alphas)
+        alphas_attr = helper.make_attribute("activation_alpha", alphas)
         rnn_node.attribute.append(alphas_attr)
     if betas is not None:
-        betas_attr = helper.make_attribute('activation_beta', betas)
+        betas_attr = helper.make_attribute("activation_beta", betas)
         rnn_node.attribute.append(betas_attr)
-    if linear_before_reset and rnn_type == 'GRU':
-        lbr_attr = helper.make_attribute('linear_before_reset', 1)
+    if linear_before_reset and rnn_type == "GRU":
+        lbr_attr = helper.make_attribute("linear_before_reset", 1)
         rnn_node.attribute.append(lbr_attr)
 
-    graph = helper.make_graph([rnn_node],
-                              "rnn_test",
-                              inputs=input_tensors,
-                              outputs=graph_outputs)
+    graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, outputs=graph_outputs)
 
-    model = helper.make_model(graph, producer_name='rnn_test')
+    model = helper.make_model(graph, producer_name="rnn_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, input_values, 'float32')
+        onnx_out = get_onnxruntime_output(model, input_values, "float32")
         tvm_out = get_tvm_output(
             model,
             input_values,
             target,
             ctx,
             output_shapes,
-            output_dtype=['float32'] * len(output_shapes))
+            output_dtype=["float32"] * len(output_shapes),
+        )
         for o_out, t_out in zip(onnx_out, tvm_out):
             tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)
 
@@ -2846,52 +2992,28 @@ def verify_rnn(seq_length,
 def test_lstm():
     # No bias.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=16,
-        hidden_size=32,
-        use_bias=False,
-        rnn_type='LSTM')
+        seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="LSTM"
+    )
     # large batch.
     verify_rnn(
-        seq_length=4,
-        batch_size=8,
-        input_size=16,
-        hidden_size=32,
-        use_bias=True,
-        rnn_type='LSTM')
+        seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM"
+    )
     # Non power of two.
     verify_rnn(
-        seq_length=3,
-        batch_size=3,
-        input_size=16,
-        hidden_size=40,
-        use_bias=True,
-        rnn_type='LSTM')
+        seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="LSTM"
+    )
     # Long sequence.
     verify_rnn(
-        seq_length=8,
-        batch_size=1,
-        input_size=16,
-        hidden_size=32,
-        use_bias=True,
-        rnn_type='LSTM')
+        seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM"
+    )
     # Large hidden.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=16,
-        hidden_size=128,
-        use_bias=True,
-        rnn_type='LSTM')
+        seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="LSTM"
+    )
     # Large input.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=64,
-        hidden_size=32,
-        use_bias=True,
-        rnn_type='LSTM')
+        seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="LSTM"
+    )
 
     # Different activation testing.
     # Default value hardsigmoid.
@@ -2901,8 +3023,9 @@ def test_lstm():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'Tanh', 'Tanh'],
-        rnn_type='LSTM')
+        activations=["HardSigmoid", "Tanh", "Tanh"],
+        rnn_type="LSTM",
+    )
     # Multiple parameterized activations.
     verify_rnn(
         seq_length=2,
@@ -2910,10 +3033,11 @@ def test_lstm():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'LeakyRelu', 'Tanh'],
+        activations=["HardSigmoid", "LeakyRelu", "Tanh"],
         alphas=[2.0, 0.5],
-        betas=[.3],
-        rnn_type='LSTM')
+        betas=[0.3],
+        rnn_type="LSTM",
+    )
     # All parameterized with new Affine activation.
     verify_rnn(
         seq_length=2,
@@ -2921,10 +3045,11 @@ def test_lstm():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'LeakyRelu', 'Affine'],
+        activations=["HardSigmoid", "LeakyRelu", "Affine"],
         alphas=[2.0, 0.5, 0.8],
-        betas=[.3, 0.1],
-        rnn_type='LSTM')
+        betas=[0.3, 0.1],
+        rnn_type="LSTM",
+    )
 
     # Testing with initial state and peepholes
     verify_rnn(
@@ -2934,7 +3059,8 @@ def test_lstm():
         hidden_size=32,
         use_bias=True,
         use_initial_state=True,
-        rnn_type='LSTM')
+        rnn_type="LSTM",
+    )
 
     verify_rnn(
         seq_length=2,
@@ -2944,19 +3070,16 @@ def test_lstm():
         use_bias=True,
         use_initial_state=True,
         use_peep=True,
-        rnn_type='LSTM')
+        rnn_type="LSTM",
+    )
 
 
 @tvm.testing.uses_gpu
 def test_gru():
     # No bias.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=16,
-        hidden_size=32,
-        use_bias=False,
-        rnn_type='GRU')
+        seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="GRU"
+    )
     # large batch.
     verify_rnn(
         seq_length=4,
@@ -2964,40 +3087,25 @@ def test_gru():
         input_size=16,
         hidden_size=32,
         use_bias=True,
-        rnn_type='GRU',
-        linear_before_reset=True)
+        rnn_type="GRU",
+        linear_before_reset=True,
+    )
     # Non power of two.
     verify_rnn(
-        seq_length=3,
-        batch_size=3,
-        input_size=16,
-        hidden_size=40,
-        use_bias=True,
-        rnn_type='GRU')
+        seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="GRU"
+    )
     # Long sequence.
     verify_rnn(
-        seq_length=8,
-        batch_size=1,
-        input_size=16,
-        hidden_size=32,
-        use_bias=True,
-        rnn_type='GRU')
+        seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="GRU"
+    )
     # Large hidden.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=16,
-        hidden_size=128,
-        use_bias=True,
-        rnn_type='GRU')
+        seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="GRU"
+    )
     # Large input.
     verify_rnn(
-        seq_length=2,
-        batch_size=1,
-        input_size=64,
-        hidden_size=32,
-        use_bias=True,
-        rnn_type='GRU')
+        seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="GRU"
+    )
 
     # Different activation testing.
     # Default value hardsigmoid.
@@ -3007,8 +3115,9 @@ def test_gru():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'Softsign'],
-        rnn_type='GRU')
+        activations=["HardSigmoid", "Softsign"],
+        rnn_type="GRU",
+    )
     # Multiple parameterized activations.
     verify_rnn(
         seq_length=2,
@@ -3016,10 +3125,11 @@ def test_gru():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'LeakyRelu'],
+        activations=["HardSigmoid", "LeakyRelu"],
         alphas=[2.0, 0.5],
-        betas=[.3],
-        rnn_type='GRU')
+        betas=[0.3],
+        rnn_type="GRU",
+    )
     # All parameterized with new Affine activation.
     verify_rnn(
         seq_length=2,
@@ -3027,10 +3137,11 @@ def test_gru():
         input_size=16,
         hidden_size=32,
         use_bias=False,
-        activations=['HardSigmoid', 'Affine'],
+        activations=["HardSigmoid", "Affine"],
         alphas=[2.0, 0.8],
-        betas=[.3, 0.1],
-        rnn_type='GRU')
+        betas=[0.3, 0.1],
+        rnn_type="GRU",
+    )
 
     # Testing with initial state
     verify_rnn(
@@ -3040,42 +3151,49 @@ def test_gru():
         hidden_size=32,
         use_bias=True,
         use_initial_state=True,
-        rnn_type='GRU')
+        rnn_type="GRU",
+    )
 
 
 @tvm.testing.uses_gpu
 def test_resize():
     def verify(ishape, oshape, scales, mode, coord_trans):
         nodes = [
-            make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []),
-            make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales)
+            make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []),
+            make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales),
         ]
-        input_names = ['X', 'roi', 'scales']
+        input_names = ["X", "roi", "scales"]
         if oshape != []:
-            nodes.append(make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape))
-            input_names.append('sizes')
-        nodes.append(helper.make_node(
-            'Resize',
-            inputs=input_names,
-            outputs=['Y'],
-            mode=mode,
-            coordinate_transformation_mode=coord_trans
-        ))
+            nodes.append(
+                make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape)
+            )
+            input_names.append("sizes")
+        nodes.append(
+            helper.make_node(
+                "Resize",
+                inputs=input_names,
+                outputs=["Y"],
+                mode=mode,
+                coordinate_transformation_mode=coord_trans,
+            )
+        )
 
         if oshape == []:
             oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)]
 
-        graph = helper.make_graph(nodes,
-                                  "resize_test",
-                                  inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)],
-                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)])
+        graph = helper.make_graph(
+            nodes,
+            "resize_test",
+            inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)],
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)],
+        )
 
-        model = helper.make_model(graph, producer_name='resize_test')
+        model = helper.make_model(graph, producer_name="resize_test")
 
         for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=ishape).astype('float32')
-            onnx_out = get_onnxruntime_output(model, x, 'float32')
-            tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11)
+            x = np.random.uniform(size=ishape).astype("float32")
+            onnx_out = get_onnxruntime_output(model, x, "float32")
+            tvm_out = get_tvm_output(model, x, target, ctx, oshape, "float32", opset=11)
 
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
 
@@ -3094,22 +3212,25 @@ def test_resize():
 
 @tvm.testing.uses_gpu
 def test_nonzero():
-
     def verify_nonzero(indata, outdata, dtype):
-        node = helper.make_node('NonZero',
-                                inputs=['X'],
-                                outputs=['Y'],)
+        node = helper.make_node(
+            "NonZero",
+            inputs=["X"],
+            outputs=["Y"],
+        )
 
-        graph = helper.make_graph([node],
-                                  "nonzero_test",
-                                  inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
-                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))])
+        graph = helper.make_graph(
+            [node],
+            "nonzero_test",
+            inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))],
+        )
 
-        model = helper.make_model(graph, producer_name='nonzero_test')
+        model = helper.make_model(graph, producer_name="nonzero_test")
 
         onnx_out = get_onnxruntime_output(model, indata, dtype)
 
-        for target, ctx in [('llvm', tvm.cpu())]:
+        for target, ctx in [("llvm", tvm.cpu())]:
             tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
 
@@ -3121,33 +3242,51 @@ def test_nonzero():
     result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
     verify_nonzero(input_data, result, dtype=np.int64)
 
+
 @tvm.testing.uses_gpu
 def test_topk():
     def verify_topk(input_dims, K, axis=-1):
         output_dims = list(input_dims)
         output_dims[axis] = K
 
-        node = helper.make_node('TopK',
-                                inputs=['X', 'K'],
-                                outputs=['Values', 'Indicies'],
-                                axis=axis)
+        node = helper.make_node(
+            "TopK", inputs=["X", "K"], outputs=["Values", "Indicies"], axis=axis
+        )
 
-        graph = helper.make_graph([node],
-                                  "topk_test",
-                                  inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
-                                          helper.make_tensor_value_info("K", TensorProto.INT64, [1,])],
-                                  initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
-                                  outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
-                                           helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)])
+        graph = helper.make_graph(
+            [node],
+            "topk_test",
+            inputs=[
+                helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
+                helper.make_tensor_value_info(
+                    "K",
+                    TensorProto.INT64,
+                    [
+                        1,
+                    ],
+                ),
+            ],
+            initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
+            outputs=[
+                helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
+                helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims),
+            ],
+        )
 
-        model = helper.make_model(graph, producer_name='topk_test')
+        model = helper.make_model(graph, producer_name="topk_test")
 
         indata = np.random.uniform(-10, 10, input_dims).astype(np.float32)
         onnx_out = get_onnxruntime_output(model, [indata, k])
 
-        for target, ctx in [('llvm', tvm.cpu())]:
-            tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
-                    output_dtype=['float32', 'int64'])
+        for target, ctx in [("llvm", tvm.cpu())]:
+            tvm_out = get_tvm_output(
+                model,
+                indata,
+                target,
+                ctx,
+                [output_dims, output_dims],
+                output_dtype=["float32", "int64"],
+            )
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
 
     for n in [12, 32]:
@@ -3162,68 +3301,70 @@ def test_topk():
 
 @tvm.testing.uses_gpu
 def test_roi_align():
-    def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):
+    def verify_roi_align(
+        input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0
+    ):
         output_dims = [num_roi, input_dims[1], output_height, output_width]
 
-        node = helper.make_node('RoiAlign',
-                                inputs=['X', 'rois', 'batch_indicies'],
-                                outputs=['Y'],
-                                mode="avg",
-                                output_height=output_height,
-                                output_width=output_width,
-                                sampling_ratio=sampling_ratio,
-                                spatial_scale=spatial_scale,
-                                )
-
-        graph = helper.make_graph([node],
-                                  "roialign_test",
-                                  inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
-                                          helper.make_tensor_value_info(
-                                              "rois", TensorProto.FLOAT, [num_roi, 4]),
-                                          helper.make_tensor_value_info(
-                                              "batch_indicies", TensorProto.INT64, [num_roi, ]),
-                                          ],
-                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)])
-
-        model = helper.make_model(graph, producer_name='roialign_test')
+        node = helper.make_node(
+            "RoiAlign",
+            inputs=["X", "rois", "batch_indicies"],
+            outputs=["Y"],
+            mode="avg",
+            output_height=output_height,
+            output_width=output_width,
+            sampling_ratio=sampling_ratio,
+            spatial_scale=spatial_scale,
+        )
+
+        graph = helper.make_graph(
+            [node],
+            "roialign_test",
+            inputs=[
+                helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
+                helper.make_tensor_value_info("rois", TensorProto.FLOAT, [num_roi, 4]),
+                helper.make_tensor_value_info(
+                    "batch_indicies",
+                    TensorProto.INT64,
+                    [
+                        num_roi,
+                    ],
+                ),
+            ],
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)],
+        )
+
+        model = helper.make_model(graph, producer_name="roialign_test")
 
         np_data = np.random.uniform(size=input_dims).astype("float32")
-        np_rois = np.random.uniform(size=[num_roi, 4]).astype(
-            'float32') * input_dims[2]
-        np_batch_indicies = np.random.randint(
-            low=0, high=input_dims[0], size=num_roi)
-
-        onnx_out = get_onnxruntime_output(
-            model, [np_data, np_rois, np_batch_indicies])
-        for target, ctx in [('llvm', tvm.cpu())]:
-            tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims,
-                                     output_dtype='float32')
-            tvm.testing.assert_allclose(
-                onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
-
-    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((4, 4, 16, 32), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((1, 8, 16, 16), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((1, 4, 8, 8), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((1, 4, 16, 16), 16, 5, 7,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((1, 4, 16, 12), 8, 7, 3,
-                     sampling_ratio=0, spatial_scale=1.0)
-    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=0.5)
-    verify_roi_align((3, 4, 12, 16), 32, 7, 7,
-                     sampling_ratio=0, spatial_scale=1.5)
-    verify_roi_align((5, 4, 16, 14), 32, 7, 7,
-                     sampling_ratio=1, spatial_scale=1.0)
-    verify_roi_align((1, 4, 16, 16), 32, 7, 7,
-                     sampling_ratio=2, spatial_scale=1.0)
-
-
-if __name__ == '__main__':
+        np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * input_dims[2]
+        np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi)
+
+        onnx_out = get_onnxruntime_output(model, [np_data, np_rois, np_batch_indicies])
+        for target, ctx in [("llvm", tvm.cpu())]:
+            tvm_out = get_tvm_output(
+                model,
+                [np_data, np_rois, np_batch_indicies],
+                target,
+                ctx,
+                output_dims,
+                output_dtype="float32",
+            )
+            tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
+
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 8, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 8, 8), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 16, 5, 7, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 12), 8, 7, 3, sampling_ratio=0, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=0.5)
+    verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5)
+    verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0)
+    verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0)
+
+
+if __name__ == "__main__":
     test_flatten()
     test_reshape()
     test_shape()
index f6c7280..f34dcef 100644 (file)
@@ -33,6 +33,7 @@ from tvm.contrib.download import download_testdata
 
 def torch_version_check():
     from packaging import version
+
     return version.parse(torch.__version__) > version.parse("1.4.0")
 
 
@@ -56,11 +57,10 @@ def get_qconfig(per_channel):
     from torch.quantization.observer import default_weight_observer
 
     if per_channel:
-        return torch.quantization.get_default_qconfig('fbgemm')
+        return torch.quantization.get_default_qconfig("fbgemm")
     else:
         act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
-        return torch.quantization.QConfig(activation=act,
-                                          weight=default_weight_observer)
+        return torch.quantization.QConfig(activation=act, weight=default_weight_observer)
 
 
 def quantize_model(model, inp, per_channel=False):
@@ -74,8 +74,7 @@ def quantize_model(model, inp, per_channel=False):
 class ConvBn(nn.Module):
     def __init__(self, with_relu=False):
         super().__init__()
-        layers = [nn.Conv2d(3, 32, 3, bias=True),
-                  nn.BatchNorm2d(32)]
+        layers = [nn.Conv2d(3, 32, 3, bias=True), nn.BatchNorm2d(32)]
         if with_relu:
             layers.append(nn.ReLU())
         self.conv = nn.Sequential(*layers)
@@ -135,8 +134,8 @@ class Hsigmoid(nn.Module):
     def forward(self, x):
         if self.add_stub:
             x = self.quant(x)
-        relu6 = self.relu6(self.float_op.add_scalar(x, 3.))
-        mul = self.float_op.mul_scalar(relu6, 1/6.)
+        relu6 = self.relu6(self.float_op.add_scalar(x, 3.0))
+        mul = self.float_op.mul_scalar(relu6, 1 / 6.0)
         if self.add_stub:
             mul = self.dequant(mul)
         return mul
@@ -174,7 +173,7 @@ class SqueezeExcite(nn.Module):
             nn.Linear(channel, channel // reduction, bias=False),
             nn.ReLU(inplace=True),
             nn.Linear(channel // reduction, channel, bias=False),
-            Hsigmoid(add_stub=False)
+            Hsigmoid(add_stub=False),
         )
         self.fmul = nn.quantized.FloatFunctional()
         self.quant = QuantStub()
@@ -199,7 +198,9 @@ class SqueezeExcite(nn.Module):
 
 # test on quantized::mul_scalar with negative scale
 class MulScalarNegative(nn.Module):
-    def __init__(self, ):
+    def __init__(
+        self,
+    ):
         super().__init__()
         self.float_op = nn.quantized.FloatFunctional()
         self.quant = QuantStub()
@@ -222,9 +223,7 @@ class UpsamplingBilinear(nn.Module):
 
     def forward(self, x):
         x = self.quant(x)
-        upsample = nn.functional.interpolate(x, scale_factor=2,
-                                             mode='bilinear',
-                                             align_corners=True)
+        upsample = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
         return self.dequant(upsample)
 
     def fuse_model(self):
@@ -259,9 +258,9 @@ def test_quantized_modules():
     imagenet_ishape = (1, 3, 224, 224)
 
     qmodules = [
-       ("relu", imagenet_ishape, ReLU(), False),
-       ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
-       ("avgpool", imagenet_ishape, AvgPool2d(), False),
+        ("relu", imagenet_ishape, ReLU(), False),
+        ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
+        ("avgpool", imagenet_ishape, AvgPool2d(), False),
     ]
 
     for per_channel in [False, True]:
@@ -271,19 +270,19 @@ def test_quantized_modules():
             postfix = ""
 
         qmodules += [
-           ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
-           ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
-           ("linear" + postfix, (16, 16), Linear(), per_channel),
-           ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel)
+            ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
+            ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
+            ("linear" + postfix, (16, 16), Linear(), per_channel),
+            ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
         ]
 
     if torch_version_check():
         qmodules += [
-           ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
-           ("hswish", imagenet_ishape, Hswish(add_stub=True), False),
-           ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
-           ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
-           ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False)
+            ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
+            ("hswish", imagenet_ishape, Hswish(add_stub=True), False),
+            ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
+            ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
+            ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False),
         ]
     else:
         print("Skipping tests that require torch > 1.4")
@@ -336,20 +335,22 @@ def test_quantized_modules():
 def test_quantized_imagenet():
     def get_transform():
         import torchvision.transforms as transforms
-        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                                         std=[0.229, 0.224, 0.225])
-        return transforms.Compose([
+
+        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        return transforms.Compose(
+            [
                 transforms.Resize(256),
                 transforms.CenterCrop(224),
                 transforms.ToTensor(),
                 normalize,
-            ])
+            ]
+        )
 
     def get_real_image(im_height, im_width):
-        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-        img_name = 'elephant-299.jpg'
+        repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+        img_name = "elephant-299.jpg"
         image_url = os.path.join(repo_base, img_name)
-        img_path = download_testdata(image_url, img_name, module='data')
+        img_path = download_testdata(image_url, img_name, module="data")
         return Image.open(img_path).resize((im_height, im_width))
 
     def get_imagenet_input():
index 2ce669f..6661aad 100644 (file)
@@ -33,24 +33,29 @@ import tvm.testing
 
 sys.setrecursionlimit(10000)
 
+
 def list_ops(expr):
     class OpLister(tvm.relay.ExprVisitor):
         def visit_op(self, expr):
             if expr not in self.node_set:
                 self.node_list.append(expr)
             return super().visit_op(expr)
+
         def list_nodes(self, expr):
             self.node_set = {}
             self.node_list = []
             self.visit(expr)
             return self.node_list
+
     return OpLister().list_nodes(expr)
 
+
 def assert_shapes_match(tru, est):
     if tru.shape != est.shape:
         msg = "Output shapes {} and {} don't match"
         raise AssertionError(msg.format(tru.shape, est.shape))
 
+
 def load_torchvision(model_name):
     """Given a model name, returns a Torchvision model in eval mode as well
     as an example input."""
@@ -76,10 +81,12 @@ def load_torchvision(model_name):
         model = model.float().eval()
         return model, [input_data]
 
+
 def load_pretrainedmodels(model_name):
     """Given a model name, returns a pretrainedmodels.pytorch model in eval
     mode as well as an example input."""
-    import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch
+    import pretrainedmodels  # https://github.com/Cadene/pretrained-models.pytorch
+
     model = getattr(pretrainedmodels, model_name)().float().eval()
     input_shape = [1, *model.input_size]
     input_data = torch.rand(input_shape).float() * 256
@@ -88,12 +95,14 @@ def load_pretrainedmodels(model_name):
         input_data[:, channel] /= model.std[channel]
     return model, [input_data]
 
+
 def load_model(model_name):
     """Given a model name, returns a model as well as an example input."""
     if hasattr(torchvision.models, model_name):
         return load_torchvision(model_name)
     try:
         import pretrainedmodels
+
         if hasattr(pretrainedmodels, model_name):
             return load_pretrainedmodels(model_name)
     except ModuleNotFoundError:
@@ -101,13 +110,14 @@ def load_model(model_name):
     raise RuntimeError("Model not supported")
 
 
-def confidence_interval(mean, stdev, count, alpha=.01):
+def confidence_interval(mean, stdev, count, alpha=0.01):
     """Returns the lower and upper bounds of the confidence interval of a random
     variable. Confidence is 1 - alpha (default confidence is 99%)."""
     stdval = tdistr.ppf(1 - alpha / 2, count - 1)
     lower, upper = mean + np.array([-1, 1]) * stdval * stdev / np.sqrt(count)
     return lower, upper
 
+
 def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
     """Compute the latency of the given model"""
     latencies = []
@@ -150,9 +160,8 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
             if err < thresh:
                 return est
 
-def verify_model(model_name, input_data=[],
-                 custom_convert_map={},
-                 rtol=1e-5, atol=1e-5):
+
+def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
     if isinstance(model_name, str):
@@ -189,12 +198,9 @@ def verify_model(model_name, input_data=[],
             trace = trace.cpu()
 
     input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
-    input_shapes = list(zip(input_names,
-                            [inp.shape for inp in baseline_input]))
-    mod, params = relay.frontend.from_pytorch(trace, input_shapes,
-                                              custom_convert_map)
-    compiled_input = dict(zip(input_names,
-                              [inp.cpu().numpy() for inp in baseline_input]))
+    input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
+    mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)
+    compiled_input = dict(zip(input_names, [inp.cpu().numpy() for inp in baseline_input]))
 
     with tvm.transform.PassContext(opt_level=3):
         for target, ctx in tvm.testing.enabled_targets():
@@ -209,13 +215,13 @@ def verify_model(model_name, input_data=[],
                 compiled_output = relay_model.get_output(i).asnumpy()
 
                 assert_shapes_match(baseline_output, compiled_output)
-                tvm.testing.assert_allclose(baseline_output, compiled_output,
-                                            rtol=rtol, atol=atol)
+                tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
 
     del model_name
     del baseline_model
     torch.cuda.empty_cache()
 
+
 # Single operator tests
 @tvm.testing.uses_gpu
 def test_forward_add():
@@ -250,6 +256,7 @@ def test_forward_add():
     verify_model(Add3().float().eval(), input_data=input_data)
     verify_model(Add4().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_subtract():
     torch.set_grad_enabled(False)
@@ -283,6 +290,7 @@ def test_forward_subtract():
     verify_model(Subtract3().float().eval(), input_data=input_data)
     verify_model(Subtract4().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_multiply():
     torch.set_grad_enabled(False)
@@ -359,6 +367,7 @@ def test_min_max():
 def test_forward_reciprocal():
     torch.set_grad_enabled(False)
     input_shape = [2, 1, 10, 1, 10]
+
     class Reciprocal1(Module):
         def forward(self, *args):
             return args[0].reciprocal()
@@ -366,10 +375,12 @@ def test_forward_reciprocal():
     input_data = torch.rand(input_shape).float()
     verify_model(Reciprocal1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_repeat():
     torch.set_grad_enabled(False)
     input_shape = [1, 3]
+
     class Repeat1(Module):
         def forward(self, *args):
             return args[0].repeat(1, 1)
@@ -387,10 +398,12 @@ def test_forward_repeat():
     verify_model(Repeat2().float().eval(), input_data=input_data)
     verify_model(Repeat3().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_repeat_interleave():
     torch.set_grad_enabled(False)
     input_shape = [2, 2, 3]
+
     class RepeatInterleave1(Module):
         def forward(self, *args):
             return args[0].repeat_interleave(2)
@@ -413,6 +426,7 @@ def test_forward_repeat_interleave():
     verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
     verify_model(RepeatInterleave4().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_unsqueeze():
     torch.set_grad_enabled(False)
@@ -425,6 +439,7 @@ def test_forward_unsqueeze():
     input_data = torch.rand(input_shape).float()
     verify_model(Unsqueeze1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_squeeze():
     torch.set_grad_enabled(False)
@@ -442,6 +457,7 @@ def test_forward_squeeze():
     verify_model(Squeeze1().float().eval(), input_data=input_data)
     verify_model(Squeeze2().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_arange():
     torch.set_grad_enabled(False)
@@ -517,6 +533,7 @@ def test_forward_arange():
     verify_model(Arange11().float().eval())
     verify_model(Arange12().float().eval())
 
+
 @tvm.testing.uses_gpu
 def test_forward_mesh_grid():
     torch.set_grad_enabled(False)
@@ -538,6 +555,7 @@ def test_forward_mesh_grid():
     verify_model(MeshGrid1().float().eval())
     verify_model(MeshGrid2().float().eval())
 
+
 @tvm.testing.uses_gpu
 def test_forward_abs():
     torch.set_grad_enabled(False)
@@ -550,6 +568,7 @@ def test_forward_abs():
     input_data = torch.rand(input_shape).float()
     verify_model(Abs1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_concatenate():
     torch.set_grad_enabled(False)
@@ -570,6 +589,7 @@ def test_forward_concatenate():
     verify_model(Concatenate1().float().eval(), input_data=input_data)
     verify_model(Concatenate2().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_relu():
     torch.set_grad_enabled(False)
@@ -577,6 +597,7 @@ def test_forward_relu():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.ReLU().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_prelu():
     torch.set_grad_enabled(False)
@@ -584,6 +605,7 @@ def test_forward_prelu():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_leakyrelu():
     torch.set_grad_enabled(False)
@@ -592,7 +614,10 @@ def test_forward_leakyrelu():
     verify_model(torch.nn.LeakyReLU().eval(), input_data=input_data)
     verify_model(torch.nn.LeakyReLU(negative_slope=0.05).eval(), input_data=input_data)
     verify_model(torch.nn.LeakyReLU(negative_slope=1.0, inplace=True).eval(), input_data=input_data)
-    verify_model(torch.nn.LeakyReLU(negative_slope=1.25, inplace=True).eval(), input_data=input_data)
+    verify_model(
+        torch.nn.LeakyReLU(negative_slope=1.25, inplace=True).eval(), input_data=input_data
+    )
+
 
 @tvm.testing.uses_gpu
 def test_forward_elu():
@@ -604,6 +629,7 @@ def test_forward_elu():
     verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data)
     verify_model(torch.nn.ELU(alpha=1.3).eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_celu():
     torch.set_grad_enabled(False)
@@ -614,6 +640,7 @@ def test_forward_celu():
     verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data)
     verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_gelu():
     torch.set_grad_enabled(False)
@@ -621,6 +648,7 @@ def test_forward_gelu():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.GELU().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_selu():
     torch.set_grad_enabled(False)
@@ -628,6 +656,7 @@ def test_forward_selu():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.SELU().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_softplus():
     torch.set_grad_enabled(False)
@@ -637,6 +666,7 @@ def test_forward_softplus():
     verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
     verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_softsign():
     torch.set_grad_enabled(False)
@@ -644,6 +674,7 @@ def test_forward_softsign():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.Softsign().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_log_sigmoid():
     torch.set_grad_enabled(False)
@@ -651,6 +682,7 @@ def test_forward_log_sigmoid():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.LogSigmoid().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_adaptiveavgpool():
     torch.set_grad_enabled(False)
@@ -659,20 +691,16 @@ def test_forward_adaptiveavgpool():
     verify_model(torch.nn.AdaptiveAvgPool2d([1, 1]).eval(), input_data=input_data)
     verify_model(torch.nn.AdaptiveAvgPool2d([10, 10]).eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_maxpool2d():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
     input_data = torch.rand(input_shape).float()
 
-    verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4],
-                                    padding=2,
-                                    stride=2).eval(),
-                 input_data)
+    verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), input_data)
+    verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), input_data)
+    verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2).eval(), input_data)
 
     # A functional variant (default strides = None case)
     class MaxPool2D(Module):
@@ -692,20 +720,16 @@ def test_forward_maxpool2d():
 
     verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_maxpool1d():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10]
     input_data = torch.rand(input_shape).float()
 
-    verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool1d(kernel_size=4,
-                                    padding=2,
-                                    stride=2).eval(),
-                 input_data)
+    verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), input_data)
+    verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), input_data)
+    verify_model(torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2).eval(), input_data)
 
     # A functional variant (default strides = None case)
     class MaxPool1D(Module):
@@ -721,14 +745,9 @@ def test_forward_maxpool3d():
     input_shape = [1, 3, 10, 10, 10]
     input_data = torch.rand(input_shape).float()
 
-    verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
-                 input_data)
-    verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
-                                    padding=2,
-                                    stride=2).eval(),
-                 input_data)
+    verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), input_data)
+    verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), input_data)
+    verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), input_data)
 
     # A functional variant (default strides = None case)
     class MaxPool3D(Module):
@@ -753,14 +772,11 @@ def test_forward_split():
             return torch.split(args[0], self.split_size_or_sections, self.dim)
 
     input_data = torch.rand(input_shape).float()
-    verify_model(Split(2, 0).float().eval(),
-                 input_data=input_data)
-    verify_model(Split(3, 1).float().eval(),
-                 input_data=input_data)
-    verify_model(Split(4, 1).float().eval(),
-                 input_data=input_data)
-    verify_model(Split([2, 3, 5], 1).float().eval(),
-                 input_data=input_data)
+    verify_model(Split(2, 0).float().eval(), input_data=input_data)
+    verify_model(Split(3, 1).float().eval(), input_data=input_data)
+    verify_model(Split(4, 1).float().eval(), input_data=input_data)
+    verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data)
+
 
 @tvm.testing.uses_gpu
 def test_forward_avgpool():
@@ -775,6 +791,7 @@ def test_forward_avgpool():
     verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
     verify_model(AvgPool2D2().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_avgpool3d():
     torch.set_grad_enabled(False)
@@ -788,6 +805,7 @@ def test_forward_avgpool3d():
     verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
     verify_model(AvgPool3D1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_hardtanh():
     torch.set_grad_enabled(False)
@@ -795,6 +813,7 @@ def test_forward_hardtanh():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.Hardtanh().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_conv():
     torch.set_grad_enabled(False)
@@ -861,15 +880,17 @@ def test_forward_conv():
     # depth wise conv with channel mult 2
     verify_model(Conv2D3().float().eval(), input_data=conv2d_input_data)
     # group conv
-    verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3),
-                                 stride=(1, 1), groups=2).eval(),
-                 input_data=torch.randn((1, 8, 16, 16)))
+    verify_model(
+        torch.nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), groups=2).eval(),
+        input_data=torch.randn((1, 8, 16, 16)),
+    )
 
     conv1d_input_data = torch.rand(conv1d_input_shape).float()
     verify_model(Conv1D1().float().eval(), input_data=conv1d_input_data)
     verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data)
     verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_conv_transpose():
     torch.set_grad_enabled(False)
@@ -891,6 +912,7 @@ def test_forward_threshold():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.Threshold(0, 0).float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_contiguous():
     torch.set_grad_enabled(False)
@@ -913,8 +935,7 @@ def test_forward_batchnorm():
     inp_2d = torch.rand((1, 16, 10, 10))
     inp_3d = torch.rand((1, 16, 10, 10, 10))
 
-    for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d),
-                    (torch.nn.BatchNorm3d(16), inp_3d)]:
+    for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d), (torch.nn.BatchNorm3d(16), inp_3d)]:
         init_weight(bn.eval())
         verify_model(bn.eval(), input_data=inp)
 
@@ -924,10 +945,13 @@ def test_forward_instancenorm():
     inp_2d = torch.rand((1, 16, 10, 10))
     inp_3d = torch.rand((1, 16, 10, 10, 10))
 
-    for ins_norm, inp in [(torch.nn.InstanceNorm2d(16), inp_2d),
-                          (torch.nn.InstanceNorm3d(16), inp_3d)]:
+    for ins_norm, inp in [
+        (torch.nn.InstanceNorm2d(16), inp_2d),
+        (torch.nn.InstanceNorm3d(16), inp_3d),
+    ]:
         verify_model(ins_norm.eval(), input_data=inp)
 
+
 @tvm.testing.uses_gpu
 def test_forward_layernorm():
     def init_weight(m):
@@ -936,8 +960,7 @@ def test_forward_layernorm():
 
     inp_2d = torch.rand((1, 16, 10, 10))
     inp_3d = torch.rand((1, 16, 10, 10, 10))
-    for ln, inp in [(torch.nn.LayerNorm(10), inp_2d),
-                    (torch.nn.LayerNorm(10), inp_3d)]:
+    for ln, inp in [(torch.nn.LayerNorm(10), inp_2d), (torch.nn.LayerNorm(10), inp_3d)]:
         init_weight(ln.eval())
         verify_model(ln.eval(), input_data=inp)
 
@@ -1019,13 +1042,14 @@ def test_forward_transpose():
 
     class Transpose3(Module):
         def forward(self, *args):
-            return args[0].permute(0,2,3,1)
+            return args[0].permute(0, 2, 3, 1)
 
     input_data = torch.rand(input_shape).float()
     verify_model(Transpose1().float().eval(), input_data=input_data)
     verify_model(Transpose2().float().eval(), input_data=input_data)
     verify_model(Transpose3().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_size():
     torch.set_grad_enabled(False)
@@ -1100,6 +1124,7 @@ def test_forward_view():
     verify_model(View2().float().eval(), input_data=input_data)
     verify_model(View3().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_select():
     torch.set_grad_enabled(False)
@@ -1175,9 +1200,13 @@ def test_forward_gather():
     verify_model(Gather2().float().eval(), input_data=[input_data, index])
 
     input_data = torch.rand((3, 3, 3)).float()
-    index = torch.tensor([[[1, 0, 0], [1, 0, 1], [0, 1, 1]],
-                          [[1, 1, 1], [1, 2, 1], [1, 0, 1]],
-                          [[1, 2, 1], [1, 2, 1], [1, 2, 1]]])
+    index = torch.tensor(
+        [
+            [[1, 0, 0], [1, 0, 1], [0, 1, 1]],
+            [[1, 1, 1], [1, 2, 1], [1, 0, 1]],
+            [[1, 2, 1], [1, 2, 1], [1, 2, 1]],
+        ]
+    )
     verify_model(Gather3().float().eval(), input_data=[input_data, index])
 
 
@@ -1201,23 +1230,23 @@ def test_forward_norm():
 
     class Norm1(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p=float('inf'), dim=None, keepdim=False)
+            return torch.norm(args[0], p=float("inf"), dim=None, keepdim=False)
 
     class Norm2(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=False)
+            return torch.norm(args[0], p=float("-inf"), dim=None, keepdim=False)
 
     class Norm3(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=True)
+            return torch.norm(args[0], p=float("-inf"), dim=None, keepdim=True)
 
     class Norm4(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p=float('inf'), dim=(1, 2), keepdim=False)
+            return torch.norm(args[0], p=float("inf"), dim=(1, 2), keepdim=False)
 
     class Norm5(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p=float('inf'), dim=(1), keepdim=True)
+            return torch.norm(args[0], p=float("inf"), dim=(1), keepdim=True)
 
     class Norm6(Module):
         def forward(self, *args):
@@ -1263,11 +1292,11 @@ def test_forward_frobenius_norm():
 
     class FroNorm2(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p='fro', dim=None, keepdim=True)
+            return torch.norm(args[0], p="fro", dim=None, keepdim=True)
 
     class FroNorm3(Module):
         def forward(self, *args):
-            return torch.norm(args[0], p='fro', dim=(1), keepdim=True)
+            return torch.norm(args[0], p="fro", dim=(1), keepdim=True)
 
     class FroNorm4(Module):
         def forward(self, *args):
@@ -1287,6 +1316,7 @@ def test_forward_sigmoid():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.Sigmoid().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_dense():
     torch.set_grad_enabled(False)
@@ -1296,6 +1326,7 @@ def test_forward_dense():
         def __init__(self):
             super(Dense1, self).__init__()
             self.linear = torch.nn.Linear(10, 7, bias=True)
+
         def forward(self, *args):
             return self.linear(args[0][0, 0])
 
@@ -1303,6 +1334,7 @@ def test_forward_dense():
         def __init__(self):
             super(Dense2, self).__init__()
             self.linear = torch.nn.Linear(10, 7, bias=False)
+
         def forward(self, *args):
             return self.linear(args[0][0, 0])
 
@@ -1313,9 +1345,10 @@ def test_forward_dense():
     trace = torch.jit.trace(Dense1(), [input_data])
     mod, params = relay.frontend.from_pytorch(
         trace,
-        [('input', input_shape)],
+        [("input", input_shape)],
     )
-    assert not any([op.name == "multiply" for op in list_ops(mod['main'])])
+    assert not any([op.name == "multiply" for op in list_ops(mod["main"])])
+
 
 @tvm.testing.uses_gpu
 def test_forward_dropout():
@@ -1327,6 +1360,7 @@ def test_forward_dropout():
     verify_model(torch.nn.Dropout3d(p=0.5).eval(), input_data=input_data)
     verify_model(torch.nn.AlphaDropout(p=0.5).eval(), input_data=input_data[0, 0])
 
+
 @tvm.testing.uses_gpu
 def test_forward_slice():
     torch.set_grad_enabled(False)
@@ -1374,6 +1408,7 @@ def test_forward_mean():
     input_data = torch.rand(input_shape).float()
     verify_model(Mean1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_expand():
     torch.set_grad_enabled(False)
@@ -1407,6 +1442,7 @@ def test_forward_pow():
     input_data = torch.rand(input_shape).float()
     verify_model(Pow1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_chunk():
     torch.set_grad_enabled(False)
@@ -1420,11 +1456,11 @@ def test_forward_chunk():
     input_data = torch.rand(input_shape).float()
     verify_model(Chunk1().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_upsample():
     class Upsample(Module):
-        def __init__(self, size=None, scale=None,
-                     mode="nearest", align_corners=None):
+        def __init__(self, size=None, scale=None, mode="nearest", align_corners=None):
             super().__init__()
             self.size = size
             self.scale = scale
@@ -1432,10 +1468,14 @@ def test_upsample():
             self.align_corners = align_corners
 
         def forward(self, x):
-            return torch.nn.functional.interpolate(x, size=self.size,
-                                                   scale_factor=self.scale,
-                                                   mode=self.mode,
-                                                   align_corners=self.align_corners)
+            return torch.nn.functional.interpolate(
+                x,
+                size=self.size,
+                scale_factor=self.scale,
+                mode=self.mode,
+                align_corners=self.align_corners,
+            )
+
     inp = torch.rand((1, 3, 32, 32))
     verify_model(Upsample(size=(64, 64), mode="nearest"), inp)
     verify_model(Upsample(scale=2, mode="nearest"), inp)
@@ -1444,9 +1484,11 @@ def test_upsample():
     verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
     verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)
 
+
 @tvm.testing.uses_gpu
 def test_to():
     """ test for aten::to(...) """
+
     class ToCPU(Module):
         def forward(self, x):
             return x.to("cpu")
@@ -1478,9 +1520,7 @@ def test_to():
 
 @tvm.testing.uses_gpu
 def test_adaptive_pool3d():
-    for ishape in [(1, 32, 16, 16, 16),
-                   (1, 32, 9, 15, 15),
-                   (1, 32, 13, 7, 7)]:
+    for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), (1, 32, 13, 7, 7)]:
         inp = torch.rand(ishape)
         verify_model(torch.nn.AdaptiveMaxPool3d((1, 1, 1)).eval(), inp)
         verify_model(torch.nn.AdaptiveMaxPool3d((2, 2, 2)).eval(), inp)
@@ -1494,6 +1534,7 @@ def test_adaptive_pool3d():
 def test_forward_functional_pad():
     torch.set_grad_enabled(False)
     pad = (0, 0)
+
     class Pad1(Module):
         def forward(self, *args):
             return torch.nn.functional.pad(args[0], pad, "constant", 0)
@@ -1592,14 +1633,17 @@ def test_forward_replication_pad3d():
 @tvm.testing.uses_gpu
 def test_forward_upsample3d():
     inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2)
-    verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp)
-    verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear').eval(), inp)
-    verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp)
+    verify_model(torch.nn.Upsample(scale_factor=2, mode="nearest").eval(), inp)
+    verify_model(torch.nn.Upsample(scale_factor=2, mode="trilinear").eval(), inp)
+    verify_model(
+        torch.nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True).eval(), inp
+    )
 
 
 def test_forward_nms():
     """dynamic Non-Maximum Suppression"""
     torch.set_grad_enabled(False)
+
     class NonMaxSupression(Module):
         def __init__(self, iou_thres):
             super().__init__()
@@ -1626,49 +1670,42 @@ def test_forward_nms():
 
 @tvm.testing.uses_gpu
 def test_conv3d():
-    for ishape in [(1, 32, 16, 16, 16),
-                   (1, 32, 9, 15, 15),
-                   (1, 32, 13, 7, 7)]:
+    for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), (1, 32, 13, 7, 7)]:
         inp = torch.rand(ishape)
-        verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3),
-                                     padding=(1, 1, 1)).eval(),
-                     inp),
-        verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5),
-                                     padding=(2, 2, 2)).eval(),
-                     inp),
-        verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(),
-                     inp)
+        verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3), padding=(1, 1, 1)).eval(), inp),
+        verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5), padding=(2, 2, 2)).eval(), inp),
+        verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(), inp)
         # downsample
-        verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(),
-                     inp)
+        verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(), inp)
 
 
 @tvm.testing.uses_gpu
 def test_conv3d_transpose():
-    for ishape in [(1, 8, 10, 5, 10),
-                   (1, 8, 5, 8, 8),
-                   (1, 8, 13, 7, 7)]:
+    for ishape in [(1, 8, 10, 5, 10), (1, 8, 5, 8, 8), (1, 8, 13, 7, 7)]:
         inp = torch.rand(ishape)
-        verify_model(torch.nn.ConvTranspose3d(in_channels=8,
-                                              out_channels=33,
-                                              kernel_size=3,
-                                              stride=2).eval(),
-                     inp),
-        verify_model(torch.nn.ConvTranspose3d(in_channels=8,
-                                              out_channels=20,
-                                              kernel_size=(3, 5, 2),
-                                              stride=(2, 1, 1),
-                                              padding=(0, 4, 2)).eval(),
-                     inp),
-        verify_model(torch.nn.ConvTranspose3d(in_channels=8,
-                                               out_channels=20,
-                                               kernel_size=1).eval(),
-                     inp)
-        verify_model(torch.nn.ConvTranspose3d(in_channels=8,
-                                              out_channels=5,
-                                              kernel_size=1,
-                                              stride=2).eval(),
-                     inp)
+        verify_model(
+            torch.nn.ConvTranspose3d(
+                in_channels=8, out_channels=33, kernel_size=3, stride=2
+            ).eval(),
+            inp,
+        ),
+        verify_model(
+            torch.nn.ConvTranspose3d(
+                in_channels=8,
+                out_channels=20,
+                kernel_size=(3, 5, 2),
+                stride=(2, 1, 1),
+                padding=(0, 4, 2),
+            ).eval(),
+            inp,
+        ),
+        verify_model(
+            torch.nn.ConvTranspose3d(in_channels=8, out_channels=20, kernel_size=1).eval(), inp
+        )
+        verify_model(
+            torch.nn.ConvTranspose3d(in_channels=8, out_channels=5, kernel_size=1, stride=2).eval(),
+            inp,
+        )
 
 
 # Model tests
@@ -1677,41 +1714,49 @@ def test_resnet18():
     torch.set_grad_enabled(False)
     verify_model("resnet18", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_squeezenet1_0():
     torch.set_grad_enabled(False)
     verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_squeezenet1_1():
     torch.set_grad_enabled(False)
     verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_densenet121():
     torch.set_grad_enabled(False)
     verify_model("densenet121", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_inception_v3():
     torch.set_grad_enabled(False)
     verify_model("inception_v3", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_googlenet():
     torch.set_grad_enabled(False)
     verify_model("googlenet", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_mnasnet0_5():
     torch.set_grad_enabled(False)
     verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4)
 
+
 @tvm.testing.uses_gpu
 def test_mobilenet_v2():
     torch.set_grad_enabled(False)
     verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4)
 
+
 """
 #TODO: Fix VGG and AlexNet issues (probably due to pooling)
 @tvm.testing.uses_gpu
@@ -1730,18 +1775,23 @@ def test_vgg11_bn():
     verify_model("vgg11_bn")
 """
 
+
 @tvm.testing.uses_gpu
 def test_custom_conversion_map():
     def get_roi_align():
         pool_size = 5
         n_channels = 2 * (pool_size ** 2)
         x = torch.rand(2, n_channels, 10, 10)
-        rois = torch.tensor([[0, 0, 0, 9, 9],  # format is (xyxy)
-                             [0, 0, 5, 4, 9],
-                             [0, 5, 5, 9, 9],
-                             [1, 0, 0, 9, 9]], dtype=torch.float)
-        roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1,
-                                             sampling_ratio=-1)
+        rois = torch.tensor(
+            [
+                [0, 0, 0, 9, 9],  # format is (xyxy)
+                [0, 0, 5, 4, 9],
+                [0, 5, 5, 9, 9],
+                [1, 0, 0, 9, 9],
+            ],
+            dtype=torch.float,
+        )
+        roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1, sampling_ratio=-1)
         return roi_align.eval(), [x, rois]
 
     def convert_roi_align():
@@ -1749,12 +1799,13 @@ def test_custom_conversion_map():
             spatial_scale = inputs[2]
             pooled_size = (inputs[3], inputs[4])
             sampling_ratio = inputs[5]
-            return relay.op.vision.roi_align(inputs[0], inputs[1],
-                                             pooled_size, spatial_scale,
-                                             sampling_ratio)
+            return relay.op.vision.roi_align(
+                inputs[0], inputs[1], pooled_size, spatial_scale, sampling_ratio
+            )
+
         return _impl
 
-    custom_map = {'torchvision::roi_align': convert_roi_align()}
+    custom_map = {"torchvision::roi_align": convert_roi_align()}
     model, inputs = get_roi_align()
 
     verify_model(model, inputs, custom_map)
@@ -1802,12 +1853,10 @@ def verify_trace_model(pt_model, idata, targets):
     verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)
 
 
-def verify_model_vm(input_model, ishapes, idtype=torch.float,
-                    idata=None, targets=["llvm"]):
+def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None, targets=["llvm"]):
     input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
     input_shapes = list(zip(input_names, ishapes))
-    input_data = idata if idata else [torch.randn(shape, dtype=idtype)
-                                      for shape in ishapes]
+    input_data = idata if idata else [torch.randn(shape, dtype=idtype) for shape in ishapes]
     # Compile via VM
     mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
 
@@ -1832,8 +1881,7 @@ def verify_model_vm(input_model, ishapes, idtype=torch.float,
             tvm_res = vm_res.asnumpy().item()
             assert pt_result == tvm_res
         else:
-            tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
-                                        rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(), rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
@@ -1844,7 +1892,7 @@ def test_control_flow():
             self.weight = torch.nn.Parameter(torch.rand(N, M))
 
         def forward(self, inp):
-            if inp.sum() > 0.:
+            if inp.sum() > 0.0:
                 output = self.weight + inp
             else:
                 output = self.weight - inp
@@ -1856,13 +1904,13 @@ def test_control_flow():
             self.weight = torch.nn.Parameter(torch.rand(N, M))
 
         def forward(self, inp):
-            if inp.sum() > 0.:
-                if inp.mean() > 0.:
+            if inp.sum() > 0.0:
+                if inp.mean() > 0.0:
                     output = self.weight + inp
                 else:
                     output = self.weight - inp
             else:
-                if inp.mean() >= 0.:
+                if inp.mean() >= 0.0:
                     output = self.weight * inp
                 else:
                     output = self.weight / inp
@@ -1886,7 +1934,7 @@ def test_control_flow():
         def forward(self, inp):
             a = inp
             for i in range(inp.size(0)):
-                b = a * 2.
+                b = a * 2.0
                 c = a + b
                 a += c
             return a
@@ -1895,7 +1943,7 @@ def test_control_flow():
         def forward(self, inp):
             a = inp
             for i in range(inp.size(0)):
-                b = a * 2.
+                b = a * 2.0
                 b = a + b
                 if b.sum() > 0.0:
                     a += b
@@ -2007,11 +2055,11 @@ def test_forward_reduce_sum():
 
     class ReduceSum4(Module):
         def forward(self, *args):
-            return args[0].sum(dim=(2,3), keepdim=True)
+            return args[0].sum(dim=(2, 3), keepdim=True)
 
     class ReduceSum5(Module):
         def forward(self, *args):
-            return args[0].sum(dim=(2,3), keepdim=False)
+            return args[0].sum(dim=(2, 3), keepdim=False)
 
     input_data = torch.rand(input_shape).float()
     verify_model(ReduceSum1().float().eval(), input_data=input_data)
@@ -2109,11 +2157,11 @@ def test_forward_std():
 
     class Std4(Module):
         def forward(self, *args):
-            return args[0].std(dim=(2,3), keepdim=True, unbiased=False)
+            return args[0].std(dim=(2, 3), keepdim=True, unbiased=False)
 
     class Std5(Module):
         def forward(self, *args):
-            return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
+            return args[0].std(dim=(2, 3), keepdim=False, unbiased=False)
 
     class Std6(Module):
         def forward(self, *args):
@@ -2125,7 +2173,7 @@ def test_forward_std():
 
     class Std8(Module):
         def forward(self, *args):
-            return args[0].std(dim=(2,3), keepdim=True, unbiased=True)
+            return args[0].std(dim=(2, 3), keepdim=True, unbiased=True)
 
     class Std9(Module):
         def forward(self, *args):
@@ -2162,11 +2210,11 @@ def test_forward_variance():
 
     class Variance4(Module):
         def forward(self, *args):
-            return args[0].var(dim=(2,3), keepdim=True, unbiased=False)
+            return args[0].var(dim=(2, 3), keepdim=True, unbiased=False)
 
     class Variance5(Module):
         def forward(self, *args):
-            return args[0].var(dim=(2,3), keepdim=False, unbiased=False)
+            return args[0].var(dim=(2, 3), keepdim=False, unbiased=False)
 
     class Variance6(Module):
         def forward(self, *args):
@@ -2178,7 +2226,7 @@ def test_forward_variance():
 
     class Variance8(Module):
         def forward(self, *args):
-            return args[0].var(dim=(2,3), keepdim=True, unbiased=True)
+            return args[0].var(dim=(2, 3), keepdim=True, unbiased=True)
 
     class Variance9(Module):
         def forward(self, *args):
@@ -2258,7 +2306,7 @@ def test_forward_isfinite():
         def forward(self, *args):
             return torch.isfinite(args[0])
 
-    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    input_data = torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]).float()
     verify_model(IsFinite1().float().eval(), input_data=input_data)
 
 
@@ -2270,7 +2318,7 @@ def test_forward_isnan():
         def forward(self, *args):
             return torch.isnan(args[0])
 
-    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    input_data = torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]).float()
     verify_model(IsNan1().float().eval(), input_data=input_data)
 
 
@@ -2282,7 +2330,7 @@ def test_forward_isinf():
         def forward(self, *args):
             return torch.isinf(args[0])
 
-    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    input_data = torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]).float()
     verify_model(IsInf1().float().eval(), input_data=input_data)
 
 
@@ -2315,7 +2363,7 @@ def test_forward_ones():
 
     class Ones1(Module):
         def forward(self, *args):
-            return torch.ones(2,3)
+            return torch.ones(2, 3)
 
     verify_model(Ones1().float().eval(), input_data=[])
 
@@ -2349,7 +2397,7 @@ def test_forward_zeros():
 
     class Zeros1(Module):
         def forward(self, *args):
-            return torch.zeros(2,3)
+            return torch.zeros(2, 3)
 
     verify_model(Zeros1().float().eval(), input_data=[])
 
@@ -2383,11 +2431,11 @@ def test_forward_full():
 
     class Full1(Module):
         def forward(self, *args):
-            return torch.full((2,3), 3.14)
+            return torch.full((2, 3), 3.14)
 
     class Full2(Module):
         def forward(self, *args):
-            return torch.full((1, 2,3), 1.0, dtype=torch.int32)
+            return torch.full((1, 2, 3), 1.0, dtype=torch.int32)
 
     verify_model(Full1().float().eval(), input_data=[])
     verify_model(Full2().float().eval(), input_data=[])
@@ -2415,6 +2463,7 @@ def test_forward_full_like():
     verify_model(FullLike2().float().eval(), input_data=input_data)
     verify_model(FullLike3().float().eval(), input_data=input_data)
 
+
 @tvm.testing.uses_gpu
 def test_forward_linspace():
     torch.set_grad_enabled(False)
@@ -2422,24 +2471,31 @@ def test_forward_linspace():
     class Linspace1(Module):
         def forward(self, *args):
             return torch.linspace(5, 10)
+
     class Linspace2(Module):
         def forward(self, *args):
             return torch.linspace(-10, 10, steps=5)
+
     class Linspace3(Module):
         def forward(self, *args):
             return torch.linspace(start=-10, end=10, steps=5)
+
     class Linspace4(Module):
         def forward(self, *args):
             return torch.linspace(start=-10, end=10, steps=1)
+
     class Linspace5(Module):
         def forward(self, *args):
             return torch.linspace(1, 2, 1, dtype=torch.int32)
+
     class Linspace6(Module):
         def forward(self, *args):
             return torch.linspace(start=1, end=6, steps=2)
+
     class Linspace7(Module):
         def forward(self, *args):
             return torch.linspace(1, 4, dtype=torch.float32)
+
     class Linspace8(Module):
         def forward(self, *args):
             return torch.linspace(1, 2, 1, dtype=torch.int16)
@@ -2457,9 +2513,10 @@ def test_forward_linspace():
 @tvm.testing.uses_gpu
 def test_forward_take():
     torch.set_grad_enabled(False)
+
     class Take1(Module):
         def forward(self, *args):
-            indices = torch.tensor([[0,0],[1,0]])
+            indices = torch.tensor([[0, 0], [1, 0]])
             if torch.cuda.is_available():
                 indices = indices.cuda()
             return torch.take(args[0], indices)
@@ -2468,15 +2525,16 @@ def test_forward_take():
         def forward(self, *args):
             return torch.take(args[0], args[1])
 
-    input_data = torch.tensor([[1,2],[3,4]])
+    input_data = torch.tensor([[1, 2], [3, 4]])
     verify_model(Take1().float().eval(), input_data=input_data)
-    indices = torch.tensor([[0,0],[1,0]])
+    indices = torch.tensor([[0, 0], [1, 0]])
     verify_model(Take2().float().eval(), input_data=[input_data, indices])
 
 
 @tvm.testing.uses_gpu
 def test_forward_topk():
     torch.set_grad_enabled(False)
+
     class Topk1(Module):
         def forward(self, *args):
             return torch.topk(args[0], k=3)
@@ -2525,10 +2583,10 @@ def test_forward_logical_not():
     input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
     verify_model(LogicalNot1().float().eval(), input_data=input_data)
 
-    input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
+    input_data = torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)
     verify_model(LogicalNot1().float().eval(), input_data=input_data)
 
-    input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
+    input_data = torch.tensor([0.0, 1.0, -10.0], dtype=torch.int32)
     verify_model(LogicalNot1().float().eval(), input_data=input_data)
 
 
@@ -2543,7 +2601,7 @@ def test_forward_bitwise_not():
     input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
     verify_model(BitwiseNot1().float().eval(), input_data=input_data)
 
-    input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
+    input_data = torch.tensor([0.0, 1.0, -10.0], dtype=torch.int32)
     verify_model(BitwiseNot1().float().eval(), input_data=input_data)
 
     input_data = torch.tensor([True, False])
@@ -2795,6 +2853,7 @@ def test_forward_addcmul():
     t2 = torch.rand([1, 3]).float()
     verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
 
+
 @tvm.testing.uses_gpu
 def test_forward_traced_function():
     def fn(t1, t2):
@@ -2804,6 +2863,7 @@ def test_forward_traced_function():
     tensor2 = torch.randn(3, 4)
     verify_model(fn, input_data=[tensor1, tensor2])
 
+
 @tvm.testing.uses_gpu
 def test_forward_dtypes():
     def fn(t1, t2):
@@ -2831,7 +2891,7 @@ def test_forward_dtypes():
 @tvm.testing.uses_gpu
 def test_weight_names():
     tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
-    mod, params = relay.frontend.from_pytorch(tm, [('input', (2, 3))])
+    mod, params = relay.frontend.from_pytorch(tm, [("input", (2, 3))])
     assert set(params.keys()) == set(n for n, p in tm.named_parameters())
 
 
@@ -2980,7 +3040,7 @@ def test_forward_pretrained_bert_base_uncased():
     # -----------------------------------------
 
     # Load pre-trained model tokenizer (vocabulary)
-    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
     # Tokenized input
     text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
@@ -2988,9 +3048,23 @@ def test_forward_pretrained_bert_base_uncased():
 
     # Mask a token that we will try to predict back with `BertForMaskedLM`
     masked_index = 8
-    tokenized_text[masked_index] = '[MASK]'
-    assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet',
-                              '##eer', '[SEP]']
+    tokenized_text[masked_index] = "[MASK]"
+    assert tokenized_text == [
+        "[CLS]",
+        "who",
+        "was",
+        "jim",
+        "henson",
+        "?",
+        "[SEP]",
+        "jim",
+        "[MASK]",
+        "was",
+        "a",
+        "puppet",
+        "##eer",
+        "[SEP]",
+    ]
 
     # Convert token to vocabulary indices
     indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
@@ -3006,7 +3080,7 @@ def test_forward_pretrained_bert_base_uncased():
     # -------------------------------------------------
 
     # Bert Model with a language modeling
-    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
+    model = BertForMaskedLM.from_pretrained("bert-base-uncased")
     model.eval()
 
     ######################################################################
@@ -3027,10 +3101,9 @@ def test_forward_pretrained_bert_base_uncased():
     # -------------------------
     # Convert PyTorch graph to Relay graph. The input name can be arbitrary.
 
-    input_1 = 'input_ids'
-    input_2 = 'input.2'
-    shape_list = [(input_1, list(tokens_tensor.shape)),
-                  (input_2, list(segments_tensors.shape))]
+    input_1 = "input_ids"
+    input_2 = "input.2"
+    shape_list = [(input_1, list(tokens_tensor.shape)), (input_2, list(segments_tensors.shape))]
 
     mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
 
@@ -3038,7 +3111,7 @@ def test_forward_pretrained_bert_base_uncased():
     # Compile the model with relay
     # ----------------------------
 
-    target = 'llvm'
+    target = "llvm"
     with tvm.transform.PassContext(opt_level=3):
         relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
 
@@ -3078,8 +3151,8 @@ def test_forward_pretrained_bert_base_uncased():
     assert torch_pred_token == tvm_pred_token
 
     # Print the outputs
-    print('Torch top-1 id: {}, token: {}'.format(torch_pred_idx, torch_pred_token))
-    print('TVM   top-1 id: {}, token: {}'.format(tvm_pred_idx, tvm_pred_token))
+    print("Torch top-1 id: {}, token: {}".format(torch_pred_idx, torch_pred_token))
+    print("TVM   top-1 id: {}, token: {}".format(tvm_pred_idx, tvm_pred_token))
 
 
 if __name__ == "__main__":
index dcaa5e1..27dbec3 100644 (file)
@@ -100,14 +100,16 @@ class ReverseLSTMLayer(jit.ScriptModule):
 
 
 class BidirLSTMLayer(jit.ScriptModule):
-    __constants__ = ['directions']
+    __constants__ = ["directions"]
 
     def __init__(self, cell, *cell_args):
         super(BidirLSTMLayer, self).__init__()
-        self.directions = nn.ModuleList([
-            LSTMLayer(cell, *cell_args),
-            ReverseLSTMLayer(cell, *cell_args),
-        ])
+        self.directions = nn.ModuleList(
+            [
+                LSTMLayer(cell, *cell_args),
+                ReverseLSTMLayer(cell, *cell_args),
+            ]
+        )
 
     @jit.script_method
     def forward(self, input, states):
@@ -126,18 +128,16 @@ class BidirLSTMLayer(jit.ScriptModule):
 
 
 def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
-    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
-                                           for _ in range(num_layers - 1)]
+    layers = [layer(*first_layer_args)] + [layer(*other_layer_args) for _ in range(num_layers - 1)]
     return nn.ModuleList(layers)
 
 
 class StackedLSTM(jit.ScriptModule):
-    __constants__ = ['layers']  # Necessary for iterating through self.layers
+    __constants__ = ["layers"]  # Necessary for iterating through self.layers
 
     def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
         super().__init__()
-        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
-                                        other_layer_args)
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args)
 
     @jit.script_method
     def forward(self, input, states):
@@ -153,12 +153,11 @@ class StackedLSTM(jit.ScriptModule):
 
 
 class StackedBidirLSTM(jit.ScriptModule):
-    __constants__ = ['layers']  # Necessary for iterating through self.layers
+    __constants__ = ["layers"]  # Necessary for iterating through self.layers
 
     def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
         super(StackedBidirLSTM, self).__init__()
-        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
-                                        other_layer_args)
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args)
 
     @jit.script_method
     def forward(self, input, states):
@@ -179,9 +178,12 @@ def lstm(input_size, hidden_size):
 
 
 def stacked_lstm(input_size, hidden_size, num_layers):
-    return StackedLSTM(num_layers, LSTMLayer,
-                       first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
-                       other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+    return StackedLSTM(
+        num_layers,
+        LSTMLayer,
+        first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+        other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size],
+    )
 
 
 def bidir_lstm(input_size, hidden_size):
@@ -189,9 +191,12 @@ def bidir_lstm(input_size, hidden_size):
 
 
 def stacked_bidir_lstm(input_size, hidden_size, num_layers):
-    return StackedBidirLSTM(num_layers, BidirLSTMLayer,
-                            first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
-                            other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+    return StackedBidirLSTM(
+        num_layers,
+        BidirLSTMLayer,
+        first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+        other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size],
+    )
 
 
 def vmobj_to_list(o, dtype="float32"):
@@ -212,8 +217,9 @@ def assert_equal(tvm_result, torch_result):
         for tvm_res, pt_res in zip(tvm_result, torch_result):
             assert_equal(tvm_res, pt_res)
     elif isinstance(torch_result, torch.Tensor):
-        tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(),
-                                    rtol=1e-4, atol=1e-4)
+        tvm.testing.assert_allclose(
+            tvm_result.asnumpy(), torch_result.numpy(), rtol=1e-4, atol=1e-4
+        )
 
 
 def run_and_compare(mod, params, pt_result, target, ctx):
@@ -272,38 +278,53 @@ def test_custom_lstm():
 
     inp = torch.randn(seq_len, batch, input_size)
 
-    input_shapes = [(input_name, (seq_len, batch, input_size)),
-                    (states_name, (state_tensor_shape, state_tensor_shape))]
+    input_shapes = [
+        (input_name, (seq_len, batch, input_size)),
+        (states_name, (state_tensor_shape, state_tensor_shape)),
+    ]
 
-    input_shapes_stacked = [(input_name, (seq_len, batch, input_size)),
-                            (states_name, [(state_tensor_shape, state_tensor_shape),
-                                           (state_tensor_shape, state_tensor_shape)])]
+    input_shapes_stacked = [
+        (input_name, (seq_len, batch, input_size)),
+        (
+            states_name,
+            [(state_tensor_shape, state_tensor_shape), (state_tensor_shape, state_tensor_shape)],
+        ),
+    ]
 
-    input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)),
-                                  (states_name, [[(state_tensor_shape,
-                                                   state_tensor_shape)
-                                                  for _ in range(2)]
-                                                 for _ in range(num_layers)])]
+    input_shapes_stacked_bidir = [
+        (input_name, (seq_len, batch, input_size)),
+        (
+            states_name,
+            [
+                [(state_tensor_shape, state_tensor_shape) for _ in range(2)]
+                for _ in range(num_layers)
+            ],
+        ),
+    ]
 
-    states = [(torch.randn(state_tensor_shape),
-               torch.randn(state_tensor_shape))
-              for _ in range(num_layers)]
+    states = [
+        (torch.randn(state_tensor_shape), torch.randn(state_tensor_shape))
+        for _ in range(num_layers)
+    ]
 
-    bidir_states = [(torch.randn(state_tensor_shape),
-                     torch.randn(state_tensor_shape))
-                    for _ in range(2)]
+    bidir_states = [
+        (torch.randn(state_tensor_shape), torch.randn(state_tensor_shape)) for _ in range(2)
+    ]
 
-    stacked_bidir_states = [[(torch.randn(state_tensor_shape),
-                              torch.randn(state_tensor_shape))
-                             for _ in range(2)]
-                            for _ in range(num_layers)]
+    stacked_bidir_states = [
+        [(torch.randn(state_tensor_shape), torch.randn(state_tensor_shape)) for _ in range(2)]
+        for _ in range(num_layers)
+    ]
 
     models = [
-      (lstm(input_size, hidden_size).eval(), states[0], input_shapes),
-      (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked),
-      (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked),
-      (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(),
-       stacked_bidir_states, input_shapes_stacked_bidir)
+        (lstm(input_size, hidden_size).eval(), states[0], input_shapes),
+        (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked),
+        (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked),
+        (
+            stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(),
+            stacked_bidir_states,
+            input_shapes_stacked_bidir,
+        ),
     ]
 
     for (raw_model, states, input_shapes) in models:
@@ -320,12 +341,12 @@ def test_custom_lstm():
         elif isinstance(states, list) and isinstance(states[0], torch.Tensor):
             states_np = [st.numpy() for st in states]
         elif isinstance(states, list) and isinstance(states[0], tuple):
-            states_np = [tuple(st.numpy() for st in states[i])
-                         for i in range(len(states))]
+            states_np = [tuple(st.numpy() for st in states[i]) for i in range(len(states))]
         elif isinstance(states, list) and isinstance(states[0], list):
-            states_np = [[tuple(st.numpy() for st in states)
-                         for states in states[layer]]
-                         for layer in range(num_layers)]
+            states_np = [
+                [tuple(st.numpy() for st in states) for states in states[layer]]
+                for layer in range(num_layers)
+            ]
         else:
             assert False
 
index 010899c..d7a00b2 100644 (file)
@@ -22,6 +22,7 @@ in TensorFlow frontend when mean and variance are not given.
 """
 import tvm
 import numpy as np
+
 try:
     import tensorflow.compat.v1 as tf
 except ImportError:
@@ -29,39 +30,52 @@ except ImportError:
 from tvm import relay
 from tensorflow.python.framework import graph_util
 
+
 def verify_fused_batch_norm(shape):
     g = tf.Graph()
     with g.as_default():
-        input_tensor = tf.placeholder(tf.float32, shape=shape, name='input')
-        alpha = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='alpha')
-        beta = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='beta')
-        bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name='bn')
-        out = tf.identity(bn[0], name='output')
+        input_tensor = tf.placeholder(tf.float32, shape=shape, name="input")
+        alpha = tf.constant(
+            np.random.rand(
+                shape[-1],
+            ),
+            dtype=tf.float32,
+            name="alpha",
+        )
+        beta = tf.constant(
+            np.random.rand(
+                shape[-1],
+            ),
+            dtype=tf.float32,
+            name="beta",
+        )
+        bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name="bn")
+        out = tf.identity(bn[0], name="output")
     data = np.random.rand(*shape)
     with tf.Session(graph=out.graph) as sess:
         sess.run([tf.global_variables_initializer()])
-        tf_out = sess.run(out, feed_dict={input_tensor:data})
-        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
+        tf_out = sess.run(out, feed_dict={input_tensor: data})
+        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
 
     for device in ["llvm"]:
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
             print("Skip because %s is not enabled" % device)
             continue
-        mod, params = relay.frontend.from_tensorflow(constant_graph,
-                                                     outputs=['output'])
+        mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"])
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build(mod,
-                                             target=device,
-                                             params=params)
+            graph, lib, params = relay.build(mod, target=device, params=params)
         from tvm.contrib import graph_runtime
+
         m = graph_runtime.create(graph, lib, ctx)
         m.set_input(**params)
-        m.set_input('input', data)
+        m.set_input("input", data)
         m.run()
         tvm_out = m.get_output(0)
-        tvm.testing.assert_allclose(tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype),
-                                    atol=1e-3, rtol=1e-3)
+        tvm.testing.assert_allclose(
+            tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype), atol=1e-3, rtol=1e-3
+        )
+
 
 def test_fused_batch_norm():
     verify_fused_batch_norm(shape=(1, 12, 12, 32))
@@ -71,5 +85,6 @@ def test_fused_batch_norm():
     verify_fused_batch_norm(shape=(16, 12, 12, 32))
     verify_fused_batch_norm(shape=(32, 12, 12, 32))
 
+
 if __name__ == "__main__":
     test_fused_batch_norm()
index 3ec04bf..ebe2ca3 100644 (file)
 # under the License.
 """Unit tests for converting TensorFlow control flow op to Relay."""
 import pytest
+
 try:
     import tensorflow.compat.v1 as tf
+
     tf.disable_v2_behavior()
 except ImportError:
     import tensorflow as tf
@@ -32,7 +34,7 @@ def check_equal(graph, tf_out, input_map=None):
     mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
     if input_map is not None:
         params.update(input_map)
-    ex = relay.create_executor('vm', mod=mod)
+    ex = relay.create_executor("vm", mod=mod)
     relay_out = ex.evaluate()(**params)
     if isinstance(relay_out, nd.NDArray):
         np.testing.assert_allclose(tf_out, relay_out.asnumpy())
@@ -48,9 +50,11 @@ def test_vanilla_loop():
     with graph.as_default():
         i = tf.constant(0, name="while/constant")
 
-        def c(i): return tf.less(i, 10)
+        def c(i):
+            return tf.less(i, 10)
 
-        def b(i): return tf.add(i, 1)
+        def b(i):
+            return tf.add(i, 1)
 
         r = tf.while_loop(c, b, [i])
 
@@ -65,9 +69,11 @@ def test_callnode_loop_vars():
     with graph.as_default():
         i = tf.add(tf.constant(0), 1)
 
-        def c(i): return tf.less(i, 10)
+        def c(i):
+            return tf.less(i, 10)
 
-        def b(i): return tf.add(i, 1)
+        def b(i):
+            return tf.add(i, 1)
 
         r = tf.while_loop(c, b, [i])
 
@@ -83,9 +89,11 @@ def test_loop_2_vars():
         i0 = tf.constant(0)
         j0 = tf.ones([2, 2])
 
-        def c(i, j): return i < 10
+        def c(i, j):
+            return i < 10
 
-        def b(i, j): return [tf.add(i, 1), j]
+        def b(i, j):
+            return [tf.add(i, 1), j]
 
         i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
         i1 += tf.constant(1337)
@@ -103,9 +111,12 @@ def test_loop_3_vars():
         j0 = tf.constant(2)
         k0 = tf.constant(4)
 
-        def c(i, j, k): return i < 10
+        def c(i, j, k):
+            return i < 10
+
+        def b(i, j, k):
+            return [i + 1, j * k, k + i]
 
-        def b(i, j, k): return [i+1, j * k, k + i]
         r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])
 
         with tf.Session() as sess:
@@ -121,12 +132,14 @@ def test_loop_conditions():
         j = tf.constant(1)
         k = tf.constant(5)
 
-        def c(i, j, k): return \
-            tf.equal(tf.not_equal(tf.less(i + j, 10),
-                                  tf.less(j * k, 100)),
-                     tf.greater_equal(k, i + j))
+        def c(i, j, k):
+            return tf.equal(
+                tf.not_equal(tf.less(i + j, 10), tf.less(j * k, 100)), tf.greater_equal(k, i + j)
+            )
+
+        def b(i, j, k):
+            return [i + j, j + k, k + 1]
 
-        def b(i, j, k): return [i+j, j+k, k+1]
         r = tf.while_loop(c, b, loop_vars=[i, j, k])
         with tf.Session() as sess:
             tf_out = sess.run(r)
@@ -138,6 +151,7 @@ def test_loop_conditions():
 def test_loop_bodies():
     graph = tf.Graph()
     with graph.as_default():
+
         def body(x):
             a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32)
             b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
@@ -146,6 +160,7 @@ def test_loop_bodies():
 
         def condition(x):
             return tf.reduce_sum(x) < 100
+
         x = tf.constant(0, shape=[2, 2])
         r = tf.while_loop(condition, body, [x])
         with tf.Session() as sess:
@@ -161,13 +176,17 @@ def test_nested_loop():
         def body(x):
             def nest_body(c):
                 return tf.multiply(c, 2)
-            def cd(c): return tf.less(c, 10)
+
+            def cd(c):
+                return tf.less(c, 10)
+
             c = tf.constant(2)
             res = tf.while_loop(cd, nest_body, loop_vars=[c])
             return tf.nn.relu(x + res)
 
         def condition(x):
             return tf.greater(x, 100)
+
         x = tf.constant(3)
         r = tf.while_loop(condition, body, loop_vars=[x])
 
@@ -188,6 +207,7 @@ def test_vanilla_cond():
 
         def f2():
             return tf.add(4, 23)
+
         r = tf.cond(tf.less(i, j), f1, f2)
 
     with tf.Session(graph=graph) as sess:
@@ -202,8 +222,7 @@ def test_multiple_cond_vars():
         x1 = tf.constant(7)
         x2 = tf.constant(12)
         z = tf.constant(20)
-        r = tf.cond(tf.less(tf.add(x1, x2), 10),
-                    lambda: tf.add(10, 2), lambda: tf.square(5))
+        r = tf.cond(tf.less(tf.add(x1, x2), 10), lambda: tf.add(10, 2), lambda: tf.square(5))
 
         with tf.Session() as sess:
             tf_out = sess.run(r)
@@ -214,6 +233,7 @@ def test_multiple_cond_vars():
 def test_cond_fn_parameters():
     graph = tf.Graph()
     with graph.as_default():
+
         def fn1(x, y):
             return tf.multiply(5, 6)
 
@@ -234,6 +254,7 @@ def test_cond_fn_parameters():
 def test_nested_cond():
     graph = tf.Graph()
     with graph.as_default():
+
         def fn1(a, b):
             def nest_fn1():
                 return tf.add(1, 2)
@@ -262,12 +283,16 @@ def test_nested_cond():
 def test_loop_in_cond():
     graph = tf.Graph()
     with graph.as_default():
+
         def fn1(a, b):
             i = tf.constant(0)
 
-            def cd(i): return tf.less(i, 10)
+            def cd(i):
+                return tf.less(i, 10)
+
+            def bd(i):
+                return tf.add(i, 1)
 
-            def bd(i): return tf.add(i, 1)
             res = tf.while_loop(cd, bd, [i])
             return tf.multiply(tf.add(20, res), 10)
 
@@ -289,14 +314,15 @@ def test_loop_in_cond():
 def test_cond_in_loop():
     graph = tf.Graph()
     with graph.as_default():
+
         def body(x):
             x = tf.constant(7)
             z = tf.constant(20)
-            res = tf.cond(tf.less(x, 10), lambda: tf.add(
-                10, 20), lambda: tf.square(10))
+            res = tf.cond(tf.less(x, 10), lambda: tf.add(10, 20), lambda: tf.square(10))
             return tf.multiply(res, x)
 
         x = tf.constant(21)
+
         def condition(x):
             return tf.less(x, 100)
 
@@ -306,6 +332,7 @@ def test_cond_in_loop():
 
     check_equal(graph, tf_out)
 
+
 def test_vanilla_loop_bound():
     graph = tf.Graph()
     with graph.as_default():
@@ -316,14 +343,15 @@ def test_vanilla_loop_bound():
         data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
         x = tf.slice(data, [1, 4], [1, 4])
         outer = x + 5.0
+
         def body(x, y):
-            res = tf.cond(tf.less(y, 10), lambda: tf.add(
-                10.0, 20.0), lambda: tf.square(10.0))
+            res = tf.cond(tf.less(y, 10), lambda: tf.add(10.0, 20.0), lambda: tf.square(10.0))
             z = tf.constant(7)
             res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
             return tf.multiply(res, x * outer), y + 1
 
         y = tf.constant(0)
+
         def condition(x, y):
             return tf.less(y, 20)
 
@@ -333,6 +361,7 @@ def test_vanilla_loop_bound():
 
     check_equal(graph, tf_out, {dname: np_data})
 
+
 def test_nested_loop_bound():
     graph = tf.Graph()
     with graph.as_default():
@@ -343,13 +372,16 @@ def test_nested_loop_bound():
         data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
         x = tf.slice(data, [1, 4], [1, 4])
         outer = x + 5.0
+
         def body(x, y):
-            res = tf.cond(tf.less(y, 10), lambda: tf.add(
-                10.0, 20.0), lambda: tf.square(10.0))
+            res = tf.cond(tf.less(y, 10), lambda: tf.add(10.0, 20.0), lambda: tf.square(10.0))
+
             def nested_body(nx, ny):
                 return nx + 1, res + 2.0
+
             def nested_cond(nx, ny):
                 return tf.less(nx, 15)
+
             nx = tf.constant(0)
             ny = tf.constant(0.0)
             nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
@@ -359,6 +391,7 @@ def test_nested_loop_bound():
             return tf.multiply(res, x * outer), y + 1
 
         y = tf.constant(0)
+
         def condition(x, y):
             return tf.less(y, 20)
 
@@ -368,13 +401,14 @@ def test_nested_loop_bound():
 
     check_equal(graph, tf_out, {dname: np_data})
 
+
 def test_switch():
     graph = tf.Graph()
 
     with graph.as_default():
-        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
-        dname = 'data'
-        flag_name = 'flag'
+        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype("float32")
+        dname = "data"
+        flag_name = "flag"
         data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
         split = tf.split(data, 2, axis=0)
         flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
@@ -384,12 +418,13 @@ def test_switch():
 
     check_equal(graph, tf_out, {dname: data_np, flag_name: False})
 
+
 def test_loop_tuple_input():
     graph = tf.Graph()
 
     with graph.as_default():
-        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
-        dname = 'data'
+        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype("float32")
+        dname = "data"
         data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
         split = tf.split(data, 2, axis=0)
 
@@ -397,6 +432,7 @@ def test_loop_tuple_input():
             return x + 2, y + 1
 
         start = tf.constant(0)
+
         def condition(x, y):
             return tf.less(y, 20)
 
index a6df6ff..2a5fb60 100644 (file)
@@ -17,6 +17,7 @@
 """Unit tests for converting TensorFlow debugging ops to Relay."""
 try:
     import tensorflow.compat.v1 as tf
+
     tf.disable_v2_behavior()
 except ImportError:
     import tensorflow as tf
@@ -24,13 +25,13 @@ import numpy as np
 from tvm import relay
 from tvm.relay.frontend.tensorflow import from_tensorflow
 
+
 def run_relay(graph, shape_dict=None, *vars):
-    mod, params = from_tensorflow(
-        graph.as_graph_def(add_shapes=True),
-        shape=shape_dict)
-    ex = relay.create_executor('debug', mod=mod)
+    mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict)
+    ex = relay.create_executor("debug", mod=mod)
     return ex.evaluate()(*vars)
 
+
 def test_assert_true():
     g = tf.Graph()
     shape = (1, 2)
@@ -51,7 +52,7 @@ def test_assert_true():
         # do that, it's happening in Relay, and that optimization shouldn't
         # affect the arity of the main function. We should have to pass in
         # x_value here.
-        np.testing.assert_allclose(0, run_relay(g, {'input': shape}).asnumpy())
+        np.testing.assert_allclose(0, run_relay(g, {"input": shape}).asnumpy())
 
 
 def test_assert_true_var_capture():
@@ -71,8 +72,8 @@ def test_assert_true_var_capture():
         # TODO: The frontend converter notes the output of
         # the graph as a boolean, which is not correct - as you can see above,
         # TF believes that the value of this graph is None.
-        np.testing.assert_allclose(True,
-                                   run_relay(g, None, x_value).asnumpy())
+        np.testing.assert_allclose(True, run_relay(g, None, x_value).asnumpy())
+
 
 def test_assert_false():
     g = tf.Graph()
@@ -92,6 +93,7 @@ def test_assert_false():
         # argument is false.
         np.testing.assert_allclose(0, run_relay(g).asnumpy())
 
+
 if __name__ == "__main__":
     test_assert_true()
     test_assert_true_var_capture()
index 37a32be..4940b69 100644 (file)
@@ -24,6 +24,7 @@ from __future__ import print_function
 import threading
 import numpy as np
 import pytest
+
 try:
     import tensorflow.compat.v1 as tf
 except ImportError:
@@ -61,18 +62,20 @@ def convert_to_list(x):
         x = [x]
     return x
 
+
 tf_dtypes = {
-    'float32': tf.float32,
-    'float16': tf.float16,
-    'float64': tf.float64,
-    'int32': tf.int32,
-    'uint8' : tf.uint8,
-    'int8': tf.int8,
-    'int16': tf.int16,
-    'uint16': tf.uint16,
-    'int64': tf.int64,
+    "float32": tf.float32,
+    "float16": tf.float16,
+    "float64": tf.float64,
+    "int32": tf.int32,
+    "uint8": tf.uint8,
+    "int8": tf.int8,
+    "int16": tf.int16,
+    "uint16": tf.uint16,
+    "int64": tf.int64,
 }
 
+
 def vmobj_to_list(o):
     if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy()]
@@ -82,27 +85,37 @@ def vmobj_to_list(o):
             result.extend(vmobj_to_list(f))
         return result
     elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
-        if o.constructor.name_hint == 'Cons':
+        if o.constructor.name_hint == "Cons":
             tl = vmobj_to_list(o.fields[1])
             hd = vmobj_to_list(o.fields[0])
             hd.extend(tl)
             return hd
-        elif o.constructor.name_hint == 'Nil':
+        elif o.constructor.name_hint == "Nil":
             return []
-        elif 'tensor_nil' in o.constructor.name_hint:
+        elif "tensor_nil" in o.constructor.name_hint:
             return [0]
-        elif 'tensor' in o.constructor.name_hint:
+        elif "tensor" in o.constructor.name_hint:
             return [o.fields[0].asnumpy()]
         else:
-            raise RuntimeError("Unknown object type: %s" %
-                               o.constructor.name_hint)
+            raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
 
-def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
-                  target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
-                  cuda_layout="NCHW", layout=None, disabled_pass=None, ignore_in_shape=False):
+def run_tvm_graph(
+    graph_def,
+    input_data,
+    input_node,
+    num_output=1,
+    target="llvm",
+    out_names=None,
+    opt_level=3,
+    mode="graph_runtime",
+    cuda_layout="NCHW",
+    layout=None,
+    disabled_pass=None,
+    ignore_in_shape=False,
+):
     """ Generic function to compile on relay and execute on tvm """
     input_data = convert_to_list(input_data)
     input_node = convert_to_list(input_node)
@@ -112,17 +125,17 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
     if ignore_in_shape:
         shape_dict = None
     else:
-        shape_dict = {e: i.shape if hasattr(i, "shape") else ()
-                      for e, i in zip(input_node, input_data)}
-    mod, params = relay.frontend.from_tensorflow(graph_def,
-                                                 layout=layout,
-                                                 shape=shape_dict,
-                                                 outputs=out_names)
+        shape_dict = {
+            e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)
+        }
+    mod, params = relay.frontend.from_tensorflow(
+        graph_def, layout=layout, shape=shape_dict, outputs=out_names
+    )
     ctx = tvm.context(target, 0)
-    if mode == 'debug':
+    if mode == "debug":
         ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
         inputs = []
-        for param in mod['main'].params:
+        for param in mod["main"].params:
             found = False
             for i, n in enumerate(input_node):
                 if n == param.name_hint:
@@ -134,7 +147,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
                 inputs.append(tvm.nd.array(params[param.name_hint]))
         result = ex.evaluate()(*inputs)
         return vmobj_to_list(result)
-    elif mode == 'vm':
+    elif mode == "vm":
         with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
             vm_exec = relay.vm.compile(mod, target="llvm", params=params)
         vm = VirtualMachine(vm_exec, tvm.cpu())
@@ -147,6 +160,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
         with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
             graph, lib, params = relay.build(mod, target, target_host, params)
         from tvm.contrib import graph_runtime
+
         m = graph_runtime.create(graph, lib, ctx)
         # set inputs
         for e, i in zip(input_node, input_data):
@@ -156,10 +170,10 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
         # execute
         m.run()
         # get outputs
-        assert out_names is None or num_output == len(out_names), (
-            "out_names: {} num_output: {}".format(out_names, num_output))
-        tvm_output_list = [m.get_output(i).asnumpy()
-                           for i in range(num_output)]
+        assert out_names is None or num_output == len(
+            out_names
+        ), "out_names: {} num_output: {}".format(out_names, num_output)
+        tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
         return tvm_output_list
 
 
@@ -169,8 +183,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
     input_node = convert_to_list(input_node)
     output_node = convert_to_list(output_node)
 
-    tensor = [sess.graph.get_tensor_by_name(
-        output_name) for output_name in output_node]
+    tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
 
     input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
 
@@ -178,12 +191,20 @@ def run_tf_graph(sess, input_data, input_node, output_node):
     return output_data
 
 
-def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
-                        no_gpu=False, opt_level=3, mode='graph_runtime',
-                        cuda_layout="NCHW"):
+def compare_tf_with_tvm(
+    in_data,
+    in_name,
+    out_name,
+    init_global_variables=False,
+    no_gpu=False,
+    opt_level=3,
+    mode="graph_runtime",
+    cuda_layout="NCHW",
+):
     """Generic function to generate and compare tensorflow and TVM output"""
+
     def name_without_num(name):
-        return name.split(':')[0] if ":" in name else name
+        return name.split(":")[0] if ":" in name else name
 
     out_name = convert_to_list(out_name)
     out_node = [name_without_num(name) for name in out_name]
@@ -203,34 +224,42 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
             if not tvm.testing.device_enabled(device):
                 print("Skip because %s is not enabled" % device)
                 continue
-            if no_gpu and device == 'cuda':
+            if no_gpu and device == "cuda":
                 continue
 
-            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
-                                       target=device, out_names=out_name,
-                                       num_output=len(out_name), opt_level=opt_level, mode=mode,
-                                       cuda_layout=cuda_layout)
+            tvm_output = run_tvm_graph(
+                final_graph_def,
+                in_data,
+                in_node,
+                target=device,
+                out_names=out_name,
+                num_output=len(out_name),
+                opt_level=opt_level,
+                mode=mode,
+                cuda_layout=cuda_layout,
+            )
             # since the names from tensorflow and relay runs are not exactly same,
             # first len(tf_output) will be compared
             for i in range(len(tf_output)):
                 if not isinstance(tf_output[i], np.ndarray):
                     assert len(tvm_output[i].shape) == 0
-                tvm.testing.assert_allclose(
-                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
         sess.close()
 
 
 def is_gpu_available():
     from tensorflow.python.client import device_lib
+
     local_device_protos = device_lib.list_local_devices()
-    gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
+    gpu_list = [x.name for x in local_device_protos if x.device_type == "GPU"]
     if len(gpu_list) > 0:
         print("Tensorflow GPU:", gpu_list)
         return True
     else:
         return False
 
+
 #######################################################################
 # Pooling
 # -------
@@ -239,19 +268,18 @@ def is_gpu_available():
 def _test_pooling_iteration(input_shape, **kwargs):
     """ One iteration of pool operation with given shapes and attributes """
 
-    x = -np.arange(
-        np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
+    x = -np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
+        in_data = array_ops.placeholder(shape=input_shape, dtype="float32")
         nn_ops.pool(in_data, **kwargs)
 
-        if kwargs['pooling_type'] == 'MAX':
-            out_name = 'max_pool:0'
+        if kwargs["pooling_type"] == "MAX":
+            out_name = "max_pool:0"
         else:
-            out_name = 'avg_pool:0'
+            out_name = "avg_pool:0"
 
-        compare_tf_with_tvm(x, 'Placeholder:0', out_name)
+        compare_tf_with_tvm(x, "Placeholder:0", out_name)
 
 
 def _test_pooling(input_shape, **kwargs):
@@ -260,7 +288,7 @@ def _test_pooling(input_shape, **kwargs):
     if is_gpu_available():
         if len(input_shape) == 4:
             input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
-            kwargs['data_format'] = 'NCHW'
+            kwargs["data_format"] = "NCHW"
             _test_pooling_iteration(input_shape, **kwargs)
 
 
@@ -268,97 +296,127 @@ def _test_pooling(input_shape, **kwargs):
 def test_forward_pooling():
     """ Pooling """
     # TensorFlow only supports NDHWC for max_pool3d on CPU
-    for pool_type in ['AVG', 'MAX']:
+    for pool_type in ["AVG", "MAX"]:
         # NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow
-        _test_pooling(input_shape=[1, 3, 32, 32, 32],
-                      window_shape=[2, 2, 2],
-                      padding='VALID',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1, 1],
-                      strides=[2, 2, 2])
-
-        _test_pooling(input_shape=[1, 3, 32, 32, 32],
-                      window_shape=[1, 1, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1, 1],
-                      strides=[1, 1, 1])
-
-        _test_pooling(input_shape=[1, 3, 32, 32, 32],
-                      window_shape=[2, 2, 2],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1, 1],
-                      strides=[2, 2, 2])
+        _test_pooling(
+            input_shape=[1, 3, 32, 32, 32],
+            window_shape=[2, 2, 2],
+            padding="VALID",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1, 1],
+            strides=[2, 2, 2],
+        )
+
+        _test_pooling(
+            input_shape=[1, 3, 32, 32, 32],
+            window_shape=[1, 1, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1, 1],
+            strides=[1, 1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[1, 3, 32, 32, 32],
+            window_shape=[2, 2, 2],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1, 1],
+            strides=[2, 2, 2],
+        )
 
         # test cases for max_pool3d & avg_pool3d with layout NCDHW
         # TensorFlow pool3d  doesn't support NCDHW on cpu
         if is_gpu_available():
-            _test_pooling(input_shape=[1, 3, 32, 32, 32],
-                          window_shape=[1, 1, 1],
-                          padding='SAME',
-                          pooling_type=pool_type,
-                          dilation_rate=[1, 1, 1],
-                          strides=[1, 1, 1],
-                          data_format='NCDHW')
-
-            _test_pooling(input_shape=[1, 3, 32, 32, 32],
-                          window_shape=[2, 2, 2],
-                          padding='VALID',
-                          pooling_type=pool_type,
-                          dilation_rate=[1, 1, 1],
-                          strides=[2, 2, 2],
-                          data_format='NCDHW')
-
-        _test_pooling(input_shape=[2, 9, 10, 2],
-                      window_shape=[1, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 10, 9, 2],
-                      window_shape=[1, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 9, 10, 2],
-                      window_shape=[2, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 10, 9, 2],
-                      window_shape=[2, 3],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[2, 1])
+            _test_pooling(
+                input_shape=[1, 3, 32, 32, 32],
+                window_shape=[1, 1, 1],
+                padding="SAME",
+                pooling_type=pool_type,
+                dilation_rate=[1, 1, 1],
+                strides=[1, 1, 1],
+                data_format="NCDHW",
+            )
+
+            _test_pooling(
+                input_shape=[1, 3, 32, 32, 32],
+                window_shape=[2, 2, 2],
+                padding="VALID",
+                pooling_type=pool_type,
+                dilation_rate=[1, 1, 1],
+                strides=[2, 2, 2],
+                data_format="NCDHW",
+            )
+
+        _test_pooling(
+            input_shape=[2, 9, 10, 2],
+            window_shape=[1, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 10, 9, 2],
+            window_shape=[1, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 9, 10, 2],
+            window_shape=[2, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 10, 9, 2],
+            window_shape=[2, 3],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[2, 1],
+        )
 
         # Tests involving SpaceToBatchND
-        _test_pooling(input_shape=[1, 1, 2, 1],
-                      window_shape=[1, 1],
-                      padding='VALID',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 2])
+        _test_pooling(
+            input_shape=[1, 1, 2, 1],
+            window_shape=[1, 1],
+            padding="VALID",
+            pooling_type=pool_type,
+            dilation_rate=[1, 2],
+        )
+
+        _test_pooling(
+            input_shape=[1, 2, 1],
+            window_shape=[1],
+            padding="VALID",
+            pooling_type=pool_type,
+            dilation_rate=[2],
+        )
 
-        _test_pooling(input_shape=[1, 2, 1],
-                      window_shape=[1],
-                      padding='VALID',
-                      pooling_type=pool_type,
-                      dilation_rate=[2])
 
 #######################################################################
 # Convolution
 # -----------
 
 
-def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
-                      dilations, strides, padding, data_format,
-                      deconv_output_shape=[]):
+def _test_convolution(
+    opname,
+    tensor_in_sizes,
+    filter_in_sizes,
+    dilations,
+    strides,
+    padding,
+    data_format,
+    deconv_output_shape=[],
+):
     """ One iteration of convolution with given shapes and attributes """
 
     total_size_1 = np.prod(tensor_in_sizes)
@@ -369,126 +427,358 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
     filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_filter = constant_op.constant(
-            filter_array, shape=filter_in_sizes, dtype='float32')
-        if data_format == 'NHWC':
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
+        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
+        if data_format == "NHWC":
             strides = [1] + strides + [1]
             dilations = [1] + dilations + [1]
         else:
             strides = [1, 1] + strides
             dilations = [1, 1] + dilations
 
-        if opname == 'conv':
-            nn_ops.conv2d(in_data,
-                          in_filter,
-                          strides=strides,
-                          dilations=dilations,
-                          padding=padding,
-                          data_format=data_format)
-
-            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                                'Placeholder:0', 'Conv2D:0')
-        elif opname == 'conv_transpose':
-            nn_ops.conv2d_transpose(in_data,
-                                    in_filter,
-                                    output_shape=deconv_output_shape,
-                                    strides=strides,
-                                    padding=padding,
-                                    data_format=data_format)
-
-            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                                'Placeholder:0', 'conv2d_transpose:0')
+        if opname == "conv":
+            nn_ops.conv2d(
+                in_data,
+                in_filter,
+                strides=strides,
+                dilations=dilations,
+                padding=padding,
+                data_format=data_format,
+            )
+
+            compare_tf_with_tvm(
+                np.reshape(data_array, tensor_in_sizes).astype("float32"),
+                "Placeholder:0",
+                "Conv2D:0",
+            )
+        elif opname == "conv_transpose":
+            nn_ops.conv2d_transpose(
+                in_data,
+                in_filter,
+                output_shape=deconv_output_shape,
+                strides=strides,
+                padding=padding,
+                data_format=data_format,
+            )
+
+            compare_tf_with_tvm(
+                np.reshape(data_array, tensor_in_sizes).astype("float32"),
+                "Placeholder:0",
+                "conv2d_transpose:0",
+            )
         else:
-            nn_ops.depthwise_conv2d_native(in_data,
-                                           in_filter,
-                                           strides=strides,
-                                           dilations=dilations,
-                                           padding=padding,
-                                           data_format=data_format)
-
-            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                                'Placeholder:0', 'DepthwiseConv2dNative:0')
+            nn_ops.depthwise_conv2d_native(
+                in_data,
+                in_filter,
+                strides=strides,
+                dilations=dilations,
+                padding=padding,
+                data_format=data_format,
+            )
+
+            compare_tf_with_tvm(
+                np.reshape(data_array, tensor_in_sizes).astype("float32"),
+                "Placeholder:0",
+                "DepthwiseConv2dNative:0",
+            )
 
 
 @tvm.testing.uses_gpu
 def test_forward_convolution():
     if is_gpu_available():
-        _test_convolution('conv', [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
-        _test_convolution('conv', [4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW')
-        _test_convolution('conv', [4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW')
-        _test_convolution('conv', [4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW')
-        _test_convolution('depthwise', [4, 176, 8, 8], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
-        _test_convolution('depthwise', [4, 19, 17, 17], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
-        _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
-        _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
-        _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
-                          'NCHW', [4, 176, 8, 8])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
-                          'NCHW', [4, 176, 8, 8])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
-                          'NCHW', [4, 176, 15, 15])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
-                          'NCHW', [4, 176, 8, 8])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
-                          'NCHW', [4, 176, 15, 15])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
-                          'NCHW', [4, 176, 16, 16])
-        _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
-                          'NCHW', [4, 19, 17, 17])
-        _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
-                          'NCHW', [4, 124, 17, 17])
-        _test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
-                          'NCHW', [4, 124, 17, 17])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
-                          'NCHW', [4, 12, 17, 17])
+        _test_convolution("conv", [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NCHW")
+        _test_convolution("conv", [4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NCHW")
+        _test_convolution("conv", [4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NCHW")
+        _test_convolution("conv", [4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NCHW")
+        _test_convolution(
+            "depthwise", [4, 176, 8, 8], [1, 1, 176, 1], [1, 1], [1, 1], "SAME", "NCHW"
+        )
+        _test_convolution(
+            "depthwise", [4, 19, 17, 17], [3, 3, 19, 1], [1, 1], [2, 2], "VALID", "NCHW"
+        )
+        _test_convolution(
+            "depthwise", [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], "SAME", "NCHW"
+        )
+        _test_convolution(
+            "depthwise", [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], "VALID", "NCHW"
+        )
+        _test_convolution(
+            "depthwise", [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], "VALID", "NCHW"
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [1, 1, 176, 32],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NCHW",
+            [4, 176, 8, 8],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [2, 2, 176, 32],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NCHW",
+            [4, 176, 8, 8],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [2, 2, 176, 32],
+            [1, 1],
+            [2, 2],
+            "SAME",
+            "NCHW",
+            [4, 176, 15, 15],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [3, 3, 176, 32],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NCHW",
+            [4, 176, 8, 8],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [3, 3, 176, 32],
+            [1, 1],
+            [2, 2],
+            "SAME",
+            "NCHW",
+            [4, 176, 15, 15],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [3, 3, 176, 32],
+            [1, 1],
+            [2, 2],
+            "SAME",
+            "NCHW",
+            [4, 176, 16, 16],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 19, 8, 8],
+            [3, 3, 19, 19],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NCHW",
+            [4, 19, 17, 17],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 19, 17, 17],
+            [1, 1, 124, 19],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NCHW",
+            [4, 124, 17, 17],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 19, 17, 17],
+            [3, 3, 124, 19],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NCHW",
+            [4, 124, 17, 17],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [3, 3, 12, 32],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NCHW",
+            [4, 12, 17, 17],
+        )
         # kernel 2x2, strides (2,2)
-        _test_convolution('conv_transpose', [4, 19, 8, 8], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
-                          'NCHW', [4, 19, 16, 16])
-        _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
-                          'NCHW', [4, 12, 16, 16])
+        _test_convolution(
+            "conv_transpose",
+            [4, 19, 8, 8],
+            [2, 2, 19, 19],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NCHW",
+            [4, 19, 16, 16],
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 32, 8, 8],
+            [2, 2, 12, 32],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NCHW",
+            [4, 12, 16, 16],
+        )
         # output channel is 1
-        _test_convolution('conv_transpose', [1, 19, 8, 8], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
-                          'NCHW', [1, 1, 8, 8])
-
-    _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
-    _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
-    _test_convolution('conv', [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
-    _test_convolution('conv', [4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
-    _test_convolution('depthwise', [4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
-    _test_convolution('depthwise', [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
-    _test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
-    _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
-    _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
-                      'NHWC', [4, 8, 8, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
-                      'NHWC', [4, 8, 8, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
-                      'NHWC', [4, 15, 15, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
-                      'NHWC', [4, 8, 8, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
-                      'NHWC', [4, 15, 15, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
-                      'NHWC', [4, 16, 16, 176])
-    _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
-                      'NHWC', [4, 17, 17, 19])
-    _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
-                      'NHWC', [4, 17, 17, 124])
-    _test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
-                      'NHWC', [4, 17, 17, 124])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
-                      'NHWC', [4, 17, 17, 12])
+        _test_convolution(
+            "conv_transpose",
+            [1, 19, 8, 8],
+            [1, 1, 1, 19],
+            [1, 1],
+            [1, 1],
+            "VALID",
+            "NCHW",
+            [1, 1, 8, 8],
+        )
+
+    _test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
+    _test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
+    _test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
+    _test_convolution("conv", [4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC")
+    _test_convolution("depthwise", [4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], "SAME", "NHWC")
+    _test_convolution("depthwise", [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], "VALID", "NHWC")
+    _test_convolution("depthwise", [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], "SAME", "NHWC")
+    _test_convolution("depthwise", [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], "VALID", "NHWC")
+    _test_convolution("depthwise", [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], "VALID", "NHWC")
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [1, 1, 176, 32],
+        [1, 1],
+        [1, 1],
+        "SAME",
+        "NHWC",
+        [4, 8, 8, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [2, 2, 176, 32],
+        [1, 1],
+        [1, 1],
+        "SAME",
+        "NHWC",
+        [4, 8, 8, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [2, 2, 176, 32],
+        [1, 1],
+        [2, 2],
+        "SAME",
+        "NHWC",
+        [4, 15, 15, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [3, 3, 176, 32],
+        [1, 1],
+        [1, 1],
+        "SAME",
+        "NHWC",
+        [4, 8, 8, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [3, 3, 176, 32],
+        [1, 1],
+        [2, 2],
+        "SAME",
+        "NHWC",
+        [4, 15, 15, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [3, 3, 176, 32],
+        [1, 1],
+        [2, 2],
+        "SAME",
+        "NHWC",
+        [4, 16, 16, 176],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 19],
+        [3, 3, 19, 19],
+        [1, 1],
+        [2, 2],
+        "VALID",
+        "NHWC",
+        [4, 17, 17, 19],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 17, 17, 19],
+        [1, 1, 124, 19],
+        [1, 1],
+        [1, 1],
+        "SAME",
+        "NHWC",
+        [4, 17, 17, 124],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 17, 17, 19],
+        [3, 3, 124, 19],
+        [1, 1],
+        [1, 1],
+        "SAME",
+        "NHWC",
+        [4, 17, 17, 124],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [3, 3, 12, 32],
+        [1, 1],
+        [2, 2],
+        "VALID",
+        "NHWC",
+        [4, 17, 17, 12],
+    )
     # kernel 2x2, strides (2,2)
-    _test_convolution('conv_transpose', [4, 8, 8, 19], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
-                      'NHWC', [4, 16, 16, 19])
-    _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
-                      'NHWC', [4, 16, 16, 12])
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 19],
+        [2, 2, 19, 19],
+        [1, 1],
+        [2, 2],
+        "VALID",
+        "NHWC",
+        [4, 16, 16, 19],
+    )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 32],
+        [2, 2, 12, 32],
+        [1, 1],
+        [2, 2],
+        "VALID",
+        "NHWC",
+        [4, 16, 16, 12],
+    )
     # output channel is 1
-    _test_convolution('conv_transpose', [1, 8, 8, 19], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
-                      'NHWC', [1, 8, 8, 1])
+    _test_convolution(
+        "conv_transpose",
+        [1, 8, 8, 19],
+        [1, 1, 1, 19],
+        [1, 1],
+        [1, 1],
+        "VALID",
+        "NHWC",
+        [1, 8, 8, 1],
+    )
 
 
 #######################################################################
@@ -496,9 +786,16 @@ def test_forward_convolution():
 # -------------
 
 
-def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes,
-                        dilations, strides, padding, data_format,
-                        deconv_output_shape=[]):
+def _test_convolution3d(
+    opname,
+    tensor_in_sizes,
+    filter_in_sizes,
+    dilations,
+    strides,
+    padding,
+    data_format,
+    deconv_output_shape=[],
+):
     """ One iteration of 3D convolution with given shapes and attributes """
 
     total_size_1 = np.prod(tensor_in_sizes)
@@ -509,125 +806,166 @@ def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes,
     filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_filter = constant_op.constant(
-            filter_array, shape=filter_in_sizes, dtype='float32')
-        if data_format == 'NDHWC':
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
+        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
+        if data_format == "NDHWC":
             strides = [1] + strides + [1]
             dilations = [1] + dilations + [1]
         else:
             strides = [1, 1] + strides
             dilations = [1, 1] + dilations
 
-        if opname == 'conv':
-            nn_ops.conv3d(in_data,
-                          in_filter,
-                          strides=strides,
-                          dilations=dilations,
-                          padding=padding,
-                          data_format=data_format)
+        if opname == "conv":
+            nn_ops.conv3d(
+                in_data,
+                in_filter,
+                strides=strides,
+                dilations=dilations,
+                padding=padding,
+                data_format=data_format,
+            )
+
+            compare_tf_with_tvm(
+                np.reshape(data_array, tensor_in_sizes).astype("float32"),
+                "Placeholder:0",
+                "Conv3D:0",
+                cuda_layout="NCDHW",
+            )
 
-            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                                'Placeholder:0', 'Conv3D:0', cuda_layout="NCDHW")
 
 @tvm.testing.uses_gpu
 def test_forward_convolution3d():
     if is_gpu_available():
-        _test_convolution3d('conv', [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
-        _test_convolution3d('conv', [4, 19, 17, 17, 17], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
-        _test_convolution3d('conv', [4, 124, 17, 17, 17], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
-        _test_convolution3d('conv', [4, 12, 17, 17, 17], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
-    _test_convolution3d('conv', [4, 8, 8, 8, 176], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
-    _test_convolution3d('conv', [4, 17, 17, 17, 19], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
-    _test_convolution3d('conv', [4, 17, 17, 17, 124], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
-    _test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
+        _test_convolution3d(
+            "conv", [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], "SAME", "NCDHW"
+        )
+        _test_convolution3d(
+            "conv", [4, 19, 17, 17, 17], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], "VALID", "NCDHW"
+        )
+        _test_convolution3d(
+            "conv", [4, 124, 17, 17, 17], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], "SAME", "NCDHW"
+        )
+        _test_convolution3d(
+            "conv", [4, 12, 17, 17, 17], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], "VALID", "NCDHW"
+        )
+    _test_convolution3d(
+        "conv", [4, 8, 8, 8, 176], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], "SAME", "NDHWC"
+    )
+    _test_convolution3d(
+        "conv", [4, 17, 17, 17, 19], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], "VALID", "NDHWC"
+    )
+    _test_convolution3d(
+        "conv", [4, 17, 17, 17, 124], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], "SAME", "NDHWC"
+    )
+    _test_convolution3d(
+        "conv", [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], "VALID", "NDHWC"
+    )
 
 
 #######################################################################
 # Convolution3D Transpose
 # -----------------------
 
-def _test_convolution3d_transpose(data_shape, filter_shape, strides,
-                                  padding, output_shape, data_format='NCDHW'):
+
+def _test_convolution3d_transpose(
+    data_shape, filter_shape, strides, padding, output_shape, data_format="NCDHW"
+):
     """ One iteration of 3D convolution transpose with given shapes and attributes """
 
-    dtype = 'float32'
+    dtype = "float32"
     data_array = np.random.uniform(size=data_shape).astype(dtype)
     filter_array = np.random.uniform(size=filter_shape).astype(dtype)
-    if data_format == 'NDHWC':
+    if data_format == "NDHWC":
         strides = [1] + strides + [1]
     else:
         strides = [1, 1] + strides
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data_shape, dtype=dtype)
-        in_filter = constant_op.constant(
-            filter_array, shape=filter_shape, dtype=dtype)
-
-        nn_ops.conv3d_transpose(in_data,
-                                in_filter,
-                                output_shape=output_shape,
-                                strides=strides,
-                                padding=padding,
-                                data_format=data_format)
+        in_filter = constant_op.constant(filter_array, shape=filter_shape, dtype=dtype)
+
+        nn_ops.conv3d_transpose(
+            in_data,
+            in_filter,
+            output_shape=output_shape,
+            strides=strides,
+            padding=padding,
+            data_format=data_format,
+        )
 
-        compare_tf_with_tvm(data_array, 'Placeholder:0', 'conv3d_transpose:0', cuda_layout="NDHWC")
+        compare_tf_with_tvm(data_array, "Placeholder:0", "conv3d_transpose:0", cuda_layout="NDHWC")
 
 
 @tvm.testing.uses_gpu
 def test_forward_convolution3d_transpose():
     if is_gpu_available():
-        _test_convolution3d_transpose(data_shape=[1, 10, 8, 8, 8],
-                                      filter_shape=[1, 1, 1, 6, 10],
-                                      strides=[1, 1, 1],
-                                      padding='VALID',
-                                      output_shape=[1, 6, 8, 8, 8])
-
-        _test_convolution3d_transpose(data_shape=[4, 9, 8, 8, 8],
-                                      filter_shape=[1, 1, 1, 6, 9],
-                                      strides=[1, 1, 1],
-                                      padding='VALID',
-                                      output_shape=[4, 6, 8, 8, 8])
-
-        _test_convolution3d_transpose(data_shape=[1, 3, 8, 8, 8],
-                                      filter_shape=[1, 1, 1, 6, 3],
-                                      strides=[2, 2, 2],
-                                      padding='SAME',
-                                      output_shape=[1, 6, 15, 15, 15])
-
-        _test_convolution3d_transpose(data_shape=[1, 16, 8, 8, 8],
-                                      filter_shape=[3, 3, 3, 6, 16],
-                                      strides=[3, 3, 3],
-                                      padding='VALID',
-                                      output_shape=[1, 6, 24, 24, 24])
-
-    _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 10],
-                                  filter_shape=[1, 1, 1, 6, 10],
-                                  strides=[1, 1, 1],
-                                  padding='VALID',
-                                  output_shape=[1, 8, 8, 8, 6],
-                                  data_format='NDHWC')
-
-    _test_convolution3d_transpose(data_shape=[4, 8, 8, 8, 9],
-                                  filter_shape=[1, 1, 1, 6, 9],
-                                  strides=[1, 1, 1],
-                                  padding='VALID',
-                                  output_shape=[4, 8, 8, 8, 6],
-                                  data_format='NDHWC')
-
-    _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 3],
-                                  filter_shape=[1, 1, 1, 6, 3],
-                                  strides=[2, 2, 2],
-                                  padding='SAME',
-                                  output_shape=[1, 15, 15, 15, 6],
-                                  data_format='NDHWC')
-
-    _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 16],
-                                  filter_shape=[3, 3, 3, 6, 16],
-                                  strides=[3, 3, 3],
-                                  padding='VALID',
-                                  output_shape=[1, 24, 24, 24, 6],
-                                  data_format='NDHWC')
+        _test_convolution3d_transpose(
+            data_shape=[1, 10, 8, 8, 8],
+            filter_shape=[1, 1, 1, 6, 10],
+            strides=[1, 1, 1],
+            padding="VALID",
+            output_shape=[1, 6, 8, 8, 8],
+        )
+
+        _test_convolution3d_transpose(
+            data_shape=[4, 9, 8, 8, 8],
+            filter_shape=[1, 1, 1, 6, 9],
+            strides=[1, 1, 1],
+            padding="VALID",
+            output_shape=[4, 6, 8, 8, 8],
+        )
+
+        _test_convolution3d_transpose(
+            data_shape=[1, 3, 8, 8, 8],
+            filter_shape=[1, 1, 1, 6, 3],
+            strides=[2, 2, 2],
+            padding="SAME",
+            output_shape=[1, 6, 15, 15, 15],
+        )
+
+        _test_convolution3d_transpose(
+            data_shape=[1, 16, 8, 8, 8],
+            filter_shape=[3, 3, 3, 6, 16],
+            strides=[3, 3, 3],
+            padding="VALID",
+            output_shape=[1, 6, 24, 24, 24],
+        )
+
+    _test_convolution3d_transpose(
+        data_shape=[1, 8, 8, 8, 10],
+        filter_shape=[1, 1, 1, 6, 10],
+        strides=[1, 1, 1],
+        padding="VALID",
+        output_shape=[1, 8, 8, 8, 6],
+        data_format="NDHWC",
+    )
+
+    _test_convolution3d_transpose(
+        data_shape=[4, 8, 8, 8, 9],
+        filter_shape=[1, 1, 1, 6, 9],
+        strides=[1, 1, 1],
+        padding="VALID",
+        output_shape=[4, 8, 8, 8, 6],
+        data_format="NDHWC",
+    )
+
+    _test_convolution3d_transpose(
+        data_shape=[1, 8, 8, 8, 3],
+        filter_shape=[1, 1, 1, 6, 3],
+        strides=[2, 2, 2],
+        padding="SAME",
+        output_shape=[1, 15, 15, 15, 6],
+        data_format="NDHWC",
+    )
+
+    _test_convolution3d_transpose(
+        data_shape=[1, 8, 8, 8, 16],
+        filter_shape=[3, 3, 3, 6, 16],
+        strides=[3, 3, 3],
+        padding="VALID",
+        output_shape=[1, 24, 24, 24, 6],
+        data_format="NDHWC",
+    )
 
 
 #######################################################################
@@ -641,8 +979,7 @@ def _test_biasadd(tensor_in_sizes, data_format):
     total_size_1 = 1
     for s in tensor_in_sizes:
         total_size_1 *= s
-    tensor_bias_sizes = [tensor_in_sizes[1]
-                         ] if data_format == 'NCHW' else [tensor_in_sizes[3]]
+    tensor_bias_sizes = [tensor_in_sizes[1]] if data_format == "NCHW" else [tensor_in_sizes[3]]
     total_size_2 = tensor_bias_sizes[0]
     # Initializes the input tensor with array containing incrementing
     # numbers from 1.
@@ -650,39 +987,38 @@ def _test_biasadd(tensor_in_sizes, data_format):
     bias_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_bias = constant_op.constant(
-            bias_array, shape=tensor_bias_sizes, dtype='float32')
-        nn_ops.bias_add(in_data,
-                        in_bias,
-                        data_format=data_format)
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
+        in_bias = constant_op.constant(bias_array, shape=tensor_bias_sizes, dtype="float32")
+        nn_ops.bias_add(in_data, in_bias, data_format=data_format)
 
-        compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                            'Placeholder:0', 'BiasAdd:0')
+        compare_tf_with_tvm(
+            np.reshape(data_array, tensor_in_sizes).astype("float32"), "Placeholder:0", "BiasAdd:0"
+        )
 
 
 @tvm.testing.uses_gpu
 def test_forward_biasadd():
     if is_gpu_available():
-        _test_biasadd([4, 176, 8, 8], 'NCHW')
-        _test_biasadd([1, 100, 1, 1], 'NCHW')
-        _test_biasadd([4, 19, 17, 17], 'NCHW')
-        _test_biasadd([4, 124, 3, 3], 'NCHW')
+        _test_biasadd([4, 176, 8, 8], "NCHW")
+        _test_biasadd([1, 100, 1, 1], "NCHW")
+        _test_biasadd([4, 19, 17, 17], "NCHW")
+        _test_biasadd([4, 124, 3, 3], "NCHW")
 
-    _test_biasadd([4, 8, 8, 176], 'NHWC')
-    _test_biasadd([1, 1, 1, 100], 'NHWC')
-    _test_biasadd([4, 17, 17, 19], 'NHWC')
-    _test_biasadd([4, 3, 3, 124], 'NHWC')
+    _test_biasadd([4, 8, 8, 176], "NHWC")
+    _test_biasadd([1, 1, 1, 100], "NHWC")
+    _test_biasadd([4, 17, 17, 19], "NHWC")
+    _test_biasadd([4, 3, 3, 124], "NHWC")
 
 
 def _test_forward_where(input_shape):
     with tf.Graph().as_default():
         dtype = tf.float32
-        t = tf.constant(np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2],
-                                         size=input_shape).astype(dtype.name))
+        t = tf.constant(
+            np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2], size=input_shape).astype(dtype.name)
+        )
         out = tf.where(t)
-        compare_tf_with_tvm([], [], out.name, mode='debug')
-        compare_tf_with_tvm([], [], out.name, mode='vm')
+        compare_tf_with_tvm([], [], out.name, mode="debug")
+        compare_tf_with_tvm([], [], out.name, mode="vm")
 
 
 def test_forward_argwhere():
@@ -692,12 +1028,13 @@ def test_forward_argwhere():
     _test_forward_where((5, 5, 5, 5))
     _test_forward_where((5, 5, 5, 5, 5))
 
+
 #######################################################################
 # SpaceToBatchND
 # --------------
 
 
-def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
+def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype="int32"):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
     with tf.Graph().as_default():
@@ -706,7 +1043,8 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
 
         compare_tf_with_tvm(data, in_data.name, out.name)
 
-def _test_space_to_batch_nd_infer_paddings(input_shape, block_shape, dtype='int32'):
+
+def _test_space_to_batch_nd_infer_paddings(input_shape, block_shape, dtype="int32"):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
     padding_np = np.array([0, 1]).astype(np.int32).reshape((1, 2))
     with tf.Graph().as_default():
@@ -718,60 +1056,36 @@ def _test_space_to_batch_nd_infer_paddings(input_shape, block_shape, dtype='int3
         out = tf.space_to_batch_nd(in_data, block_shape, paddings)
         compare_tf_with_tvm(data, in_data.name, out.name)
 
+
 def test_forward_space_to_batch_nd():
     # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
-    _test_space_to_batch_nd(
-        input_shape=[1, 2, 2, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 2, 2, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
 
-    _test_space_to_batch_nd(
-        input_shape=[1, 2, 2, 3],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 2, 2, 3], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
 
-    _test_space_to_batch_nd(
-        input_shape=[1, 4, 4, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 4, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
 
     _test_space_to_batch_nd(
-        input_shape=[2, 2, 4, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [2, 0]],
-        dtype='int64'
+        input_shape=[2, 2, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [2, 0]], dtype="int64"
     )
 
     # pylint: disable=line-too-long
     # https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/spacetobatch_op_test.py
-    _test_space_to_batch_nd(
-        input_shape=[2, 3],
-        block_shape=[2],
-        paddings=[[1, 0]],
-        dtype='float32'
-    )
+    _test_space_to_batch_nd(input_shape=[2, 3], block_shape=[2], paddings=[[1, 0]], dtype="float32")
 
     _test_space_to_batch_nd(
-        input_shape=[2, 3, 2],
-        block_shape=[2],
-        paddings=[[1, 0]],
-        dtype='float64'
+        input_shape=[2, 3, 2], block_shape=[2], paddings=[[1, 0]], dtype="float64"
     )
 
-    _test_space_to_batch_nd_infer_paddings(
-        input_shape=[2, 3, 2],
-        block_shape=[2]
-    )
+    _test_space_to_batch_nd_infer_paddings(input_shape=[2, 3, 2], block_shape=[2])
+
 
 #######################################################################
 # BatchToSpaceND
 # --------------
 
 
-def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
+def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype="int32"):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
     with tf.Graph().as_default():
@@ -783,47 +1097,27 @@ def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
 
 def test_forward_batch_to_space_nd():
     # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
-    _test_batch_to_space_nd(
-        input_shape=[4, 1, 1, 1],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
+    _test_batch_to_space_nd(input_shape=[4, 1, 1, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
 
-    _test_batch_to_space_nd(
-        input_shape=[4, 1, 1, 3],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
+    _test_batch_to_space_nd(input_shape=[4, 1, 1, 3], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
 
-    _test_batch_to_space_nd(
-        input_shape=[4, 2, 2, 1],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
+    _test_batch_to_space_nd(input_shape=[4, 2, 2, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
 
     _test_batch_to_space_nd(
-        input_shape=[8, 1, 3, 1],
-        block_shape=[2, 2],
-        crops=[[0, 0], [2, 0]],
-        dtype='int64'
+        input_shape=[8, 1, 3, 1], block_shape=[2, 2], crops=[[0, 0], [2, 0]], dtype="int64"
     )
 
     # pylint: disable=line-too-long
     # https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/batchtospace_op_test.py
     _test_batch_to_space_nd(
-        input_shape=[18, 2, 1, 2],
-        block_shape=[2, 3],
-        crops=[[1, 1], [0, 0]],
-        dtype='float32'
+        input_shape=[18, 2, 1, 2], block_shape=[2, 3], crops=[[1, 1], [0, 0]], dtype="float32"
     )
 
     _test_batch_to_space_nd(
-        input_shape=[20, 5, 8, 7],
-        block_shape=[2, 2],
-        crops=[[1, 1], [1, 1]],
-        dtype='float64'
+        input_shape=[20, 5, 8, 7], block_shape=[2, 2], crops=[[1, 1], [1, 1]], dtype="float64"
     )
 
+
 #######################################################################
 # Reshape
 # -------
@@ -836,7 +1130,8 @@ def _test_reshape(data, out_shape):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         array_ops.reshape(in_data, out_shape)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "Reshape:0")
+
 
 def _test_reshape_with_call():
     """ relay.expr.Call as shape """
@@ -847,7 +1142,8 @@ def _test_reshape_with_call():
         out_shape = tf.multiply(out_shape, 2)
         array_ops.reshape(in_data, out_shape)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "Reshape:0")
+
 
 def _test_reshape_like(data, shape_like):
     """ A special case for reshape. """
@@ -858,7 +1154,8 @@ def _test_reshape_like(data, shape_like):
         out_shape = array_ops.shape(in_shape_like)
         array_ops.reshape(in_data, out_shape)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "Reshape:0")
+
 
 def _test_reshape_symbolic(data, a_data, b_data):
     with tf.Graph().as_default():
@@ -869,7 +1166,10 @@ def _test_reshape_symbolic(data, a_data, b_data):
         out = array_ops.reshape(in_data, newshape)
 
         for mode in ["debug", "vm"]:
-            compare_tf_with_tvm([data, a_data, b_data], [in_data.name, a.name, b.name], out.name, mode=mode)
+            compare_tf_with_tvm(
+                [data, a_data, b_data], [in_data.name, a.name, b.name], out.name, mode=mode
+            )
+
 
 def test_forward_reshape():
     _test_reshape(np.arange(6.0), [2, 3])
@@ -883,6 +1183,7 @@ def test_forward_reshape():
     _test_reshape_symbolic(np.arange(6), np.array([3, 0]), np.array([3, -1]))
     _test_reshape_symbolic(np.arange(6), np.array([0]), np.array([-1]))
 
+
 #######################################################################
 # DepthToSpace
 # ------------
@@ -895,13 +1196,14 @@ def _test_depthtospace(data, block_size):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         array_ops.depth_to_space(in_data, block_size)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "DepthToSpace:0")
 
 
 def test_forward_depthtospace():
     _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
     _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
 
+
 #######################################################################
 # SpaceToDepth
 # ------------
@@ -914,13 +1216,14 @@ def _test_spacetodepth(data, block_size):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         array_ops.space_to_depth(in_data, block_size)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'SpaceToDepth:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "SpaceToDepth:0")
 
 
 def test_forward_spacetodepth():
     _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]), 2)
     _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]), 4)
 
+
 #######################################################################
 # Squeeze
 # -------
@@ -940,7 +1243,7 @@ def _test_squeeze(data, squeeze_dims=None):
         else:
             array_ops.squeeze(in_data)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "Squeeze:0")
 
 
 def test_forward_squeeze():
@@ -978,13 +1281,14 @@ def test_tensor_array_write_read():
             in_data = [np_data, np_data]
             t1 = tf.constant(np_data, dtype=dtype)
             t2 = tf.constant(np_data, dtype=dtype)
-            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
-                                 element_shape=element_shape)
+            ta1 = tf.TensorArray(
+                dtype=dtype, size=2, infer_shape=infer_shape, element_shape=element_shape
+            )
             ta2 = ta1.write(0, t1)
             ta3 = ta2.write(1, t2)
             out = ta3.read(0)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
+            compare_tf_with_tvm([], [], "TensorArrayReadV3:0", mode="vm")
 
     for dtype in ["float32", "int8"]:
         run(dtype, False, None)
@@ -995,24 +1299,25 @@ def test_tensor_array_write_read():
 def test_tensor_array_scatter():
     def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
-            dtype =  tf_dtypes[dtype_str]
+            dtype = tf_dtypes[dtype_str]
             if infer_shape:
                 element_shape = tf.TensorShape([tf.Dimension(None)])
             else:
                 element_shape = None
             t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
             indices = tf.constant([2, 1, 0])
-            ta1 = tf.TensorArray(dtype=dtype, size=3,
-                                 infer_shape=infer_shape,
-                                 element_shape=element_shape)
+            ta1 = tf.TensorArray(
+                dtype=dtype, size=3, infer_shape=infer_shape, element_shape=element_shape
+            )
             ta2 = ta1.scatter(indices, t)
             out0 = ta2.read(0)
             out1 = ta2.read(1)
             out2 = ta2.read(2)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="vm")
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="vm")
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="vm")
+
     for dtype in ["float32", "int8"]:
         run(dtype, False)
         run(dtype, True)
@@ -1021,7 +1326,7 @@ def test_tensor_array_scatter():
 def test_tensor_array_gather():
     def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
-            dtype =  tf_dtypes[dtype_str]
+            dtype = tf_dtypes[dtype_str]
             t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
             scatter_indices = tf.constant([2, 1, 0])
             gather_indices = tf.constant([1, 2])
@@ -1029,7 +1334,8 @@ def test_tensor_array_gather():
             ta2 = ta1.scatter(scatter_indices, t)
             t1 = ta2.gather(gather_indices)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
+            compare_tf_with_tvm([], [], ["TensorArrayGatherV3:0"], mode="vm")
+
     for dtype in ["float32", "int8"]:
         run(dtype, True)
 
@@ -1037,8 +1343,13 @@ def test_tensor_array_gather():
 def test_tensor_array_split():
     def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
-            dtype =  tf_dtypes[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            dtype = tf_dtypes[dtype_str]
+            t = tf.constant(
+                np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(
+                    dtype_str
+                ),
+                dtype=dtype,
+            )
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
             ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
             ta2 = ta1.split(t, split_length)
@@ -1047,10 +1358,11 @@ def test_tensor_array_split():
             out2 = ta2.read(2)
             out3 = ta2.read(3)
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="debug")
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="debug")
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="debug")
+            compare_tf_with_tvm([], [], ["TensorArrayReadV3_3:0"], mode="debug")
+
     for dtype in ["float32", "int8"]:
         run(dtype, False)
         run(dtype, True)
@@ -1060,26 +1372,31 @@ def test_tensor_array_concat():
     def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
             dtype = tf_dtypes[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            t = tf.constant(
+                np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(
+                    dtype_str
+                ),
+                dtype=dtype,
+            )
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
-            ta1 = tf.TensorArray(dtype=dtype, size=4,
-                                 infer_shape=infer_shape)
+            ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
             ta2 = ta1.split(t, split_length)
             t = ta2.concat()
             out = tf.identity(t)
-            compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
+            compare_tf_with_tvm([], [], ["Identity:0"], mode="debug")
+
     for dtype in ["float32", "int8"]:
         run(dtype, False)
         run(dtype, True)
 
 
 def test_tensor_array_size():
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
-            pytest.skip("Needs fixing for tflite >= 1.15.0")
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
+        pytest.skip("Needs fixing for tflite >= 1.15.0")
 
     def run(dtype_str, infer_shape):
         with tf.Graph().as_default():
-            dtype =  tf_dtypes[dtype_str]
+            dtype = tf_dtypes[dtype_str]
             np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
             in_data = [np_data, np_data]
             t1 = tf.constant(np_data, dtype=dtype)
@@ -1089,7 +1406,8 @@ def test_tensor_array_size():
             ta3 = ta2.write(1, t2)
             out = ta3.size()
             g = tf.get_default_graph()
-            compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
+            compare_tf_with_tvm([], [], "TensorArraySizeV3:0", mode="debug")
+
     for dtype in ["float32", "int8"]:
         run(dtype, False)
         run(dtype, True)
@@ -1097,11 +1415,11 @@ def test_tensor_array_size():
 
 def test_tensor_array_stack():
     def run(dtype_str, infer_shape):
-        if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+        if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
             pytest.skip("Needs fixing for tflite >= 1.15.0")
 
         with tf.Graph().as_default():
-            dtype =  tf_dtypes[dtype_str]
+            dtype = tf_dtypes[dtype_str]
             t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
             scatter_indices = tf.constant([2, 1, 0])
             ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
@@ -1110,26 +1428,27 @@ def test_tensor_array_stack():
             print(t1)
             g = tf.get_default_graph()
 
-            compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
+            compare_tf_with_tvm([], [], ["TensorArrayStack/TensorArrayGatherV3:0"], mode="vm")
+
     for dtype in ["float32", "int8"]:
         run(dtype, True)
 
 
 def test_tensor_array_unstack():
     def run(dtype_str, input_shape, infer_shape):
-        if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+        if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
             pytest.skip("Needs fixing for tflite >= 1.15.0")
 
         with tf.Graph().as_default():
             dtype = tf_dtypes[dtype_str]
-            t = tf.constant(np.random.choice([0, 1, 2, 3],
-                                             size=input_shape).astype(dtype.name))
+            t = tf.constant(np.random.choice([0, 1, 2, 3], size=input_shape).astype(dtype.name))
             ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
             ta2 = ta1.unstack(t)
             out0 = ta2.size()
             out1 = ta2.read(0)
-            compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
-            compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
+            compare_tf_with_tvm([], [], "TensorArraySizeV3:0", mode="debug")
+            compare_tf_with_tvm([], [], "TensorArrayReadV3:0", mode="debug")
+
     for dtype in ["float32", "int8"]:
         run(dtype, (5,), False)
         run(dtype, (5, 5), True)
@@ -1146,20 +1465,19 @@ def _test_concat_v2(shape1, shape2, dim):
     """ One iteration of ConcatV2 """
 
     with tf.Graph().as_default():
-        dtype = 'float32'
-        in1 = tf.placeholder(shape=shape1, dtype=dtype, name='in1')
-        in2 = tf.placeholder(shape=shape2, dtype=dtype, name='in2')
+        dtype = "float32"
+        in1 = tf.placeholder(shape=shape1, dtype=dtype, name="in1")
+        in2 = tf.placeholder(shape=shape2, dtype=dtype, name="in2")
         array_ops.concat_v2([in1, in2], dim)
 
         np_data1 = np.random.uniform(size=shape1).astype(dtype)
         np_data2 = np.random.uniform(size=shape2).astype(dtype)
 
-        compare_tf_with_tvm([np_data1, np_data2], [
-                            'in1:0', 'in2:0'], 'ConcatV2:0')
+        compare_tf_with_tvm([np_data1, np_data2], ["in1:0", "in2:0"], "ConcatV2:0")
 
 
 def test_forward_concat_v2():
-    if tf.__version__ < LooseVersion('1.4.1'):
+    if tf.__version__ < LooseVersion("1.4.1"):
         return
 
     _test_concat_v2([2, 3], [2, 3], 0)
@@ -1168,6 +1486,7 @@ def test_forward_concat_v2():
     _test_concat_v2([5, 8], [5, 4], 1)
     _test_concat_v2([2, 8, 5], [2, 8, 6], -1)
 
+
 #######################################################################
 # Sigmoid
 # -------
@@ -1180,13 +1499,14 @@ def _test_sigmoid(data):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         sigmoid_out = math_ops.sigmoid(in_data)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "Sigmoid:0")
 
 
 def test_forward_sigmoid():
     """ Sigmoid """
 
-    _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
+    _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype("float32"))
+
 
 #######################################################################
 # Argmin/Argmax
@@ -1196,16 +1516,15 @@ def test_forward_sigmoid():
 def _test_argx(func, data, **kwargs):
 
     with tf.Graph().as_default():
-        inp = array_ops.placeholder(
-            shape=data.shape, dtype=data.dtype, name="c0")
+        inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
         func(inp, name="argx0", output_type=tf.int32, **kwargs)
 
-        compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
+        compare_tf_with_tvm(data, "c0:0", "argx0:0")
 
 
 def test_forward_argminmax():
     for axis in [None, 0, 1, 2]:
-        data = np.random.uniform(size=(8, 4, 9)).astype('float32')
+        data = np.random.uniform(size=(8, 4, 9)).astype("float32")
         _test_argx(tf.argmax, data=data, axis=axis)
         _test_argx(tf.argmin, data=data, axis=axis)
 
@@ -1214,6 +1533,7 @@ def test_forward_argminmax():
 # Variable
 # --------
 
+
 def _test_variable(data):
     """ One iteration of a variable """
 
@@ -1224,17 +1544,15 @@ def _test_variable(data):
 
         size = input_tensor.shape.dims[1]
         with variable_scope.variable_scope("linear", reuse=None):
-            w = variable_scope.get_variable(
-                "w", shape=[size, size], dtype=input_tensor.dtype)
+            w = variable_scope.get_variable("w", shape=[size, size], dtype=input_tensor.dtype)
         math_ops.matmul(input_tensor, w)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0',
-                            init_global_variables=True)
+        compare_tf_with_tvm(data, "Placeholder:0", "MatMul:0", init_global_variables=True)
 
 
 def test_forward_variable():
     """Variable type op test"""
-    _test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
+    _test_variable(np.random.uniform(size=(32, 100)).astype("float32"))
 
 
 @tvm.testing.parametrize_targets("llvm", "cuda")
@@ -1242,18 +1560,18 @@ def test_read_variable_op(target, ctx):
     """ Read Variable op test """
 
     tf.reset_default_graph()
-    data = np.random.uniform(size=(32, 100)).astype('float32')
+    data = np.random.uniform(size=(32, 100)).astype("float32")
     input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
 
     size = input_tensor.shape.dims[1]
     var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
-    input_var = tf.Variable(var_data, name='var1', use_resource=True)
+    input_var = tf.Variable(var_data, name="var1", use_resource=True)
     math_ops.matmul(input_tensor, input_var)
 
-    out_name = ['MatMul:0']
-    out_node = ['MatMul']
-    in_name = ['Placeholder:0']
-    in_node = ['Placeholder']
+    out_name = ["MatMul:0"]
+    out_node = ["MatMul"]
+    in_name = ["Placeholder:0"]
+    in_node = ["Placeholder"]
     in_data = [data]
 
     with tf.Session() as sess:
@@ -1264,10 +1582,9 @@ def test_read_variable_op(target, ctx):
 
         shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
         with pytest.raises(Exception) as execinfo:
-            mod, params = relay.frontend.from_tensorflow(final_graph_def,
-                                                         layout=None,
-                                                         shape=shape_dict,
-                                                         outputs=None)
+            mod, params = relay.frontend.from_tensorflow(
+                final_graph_def, layout=None, shape=shape_dict, outputs=None
+            )
 
         assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph")
 
@@ -1278,12 +1595,16 @@ def test_read_variable_op(target, ctx):
             out_node,
         )
 
-        tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
-                                   target=target, out_names=out_name,
-                                   num_output=len(out_name))
+        tvm_output = run_tvm_graph(
+            final_graph_def,
+            in_data,
+            in_node,
+            target=target,
+            out_names=out_name,
+            num_output=len(out_name),
+        )
         for i in range(len(tf_output)):
-            tvm.testing.assert_allclose(
-                tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
+            tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
 
         sess.close()
 
@@ -1292,6 +1613,7 @@ def test_read_variable_op(target, ctx):
 # MatMul, BatchMatMul, BatchMatMulV2
 # ----------------------------------
 
+
 def _test_matmul(i, j, k, dtype, outer=None):
     """ One iteration of matmul """
 
@@ -1301,36 +1623,31 @@ def _test_matmul(i, j, k, dtype, outer=None):
     for transpose_a in [False, True]:
         for transpose_b in [False, True]:
             outer = outer or []
-            A_shape = outer + \
-                (A_shape_init[::-1] if transpose_a else A_shape_init)
-            B_shape = outer + \
-                (B_shape_init[::-1] if transpose_b else B_shape_init)
+            A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
+            B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
 
             with tf.Graph().as_default():
-                A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
-                B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
-                result = tf.matmul(
-                    A, B, transpose_a=transpose_a, transpose_b=transpose_b)
+                A = tf.placeholder(shape=A_shape, dtype=dtype, name="A")
+                B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+                result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
 
                 A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
                 B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
-                compare_tf_with_tvm(
-                    [A_np, B_np], [A.name, B.name], result.name)
+                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
 
 
 def test_forward_matmul():
     """ MatMul op test"""
-    _test_matmul(1, 3, 6, 'int32')
-    _test_matmul(5, 3, 1, 'float64')
+    _test_matmul(1, 3, 6, "int32")
+    _test_matmul(5, 3, 1, "float64")
 
 
 def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
 
     with tf.Graph().as_default():
-        A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
-        B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
-        result = tf.matmul(A, B, adjoint_a=adjoint_a,
-                           adjoint_b=adjoint_b, name='batchmatmul')
+        A = tf.placeholder(shape=A_shape, dtype=dtype, name="A")
+        B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+        result = tf.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")
 
         A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
         B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
@@ -1339,91 +1656,147 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
 
 def test_forward_batch_matmul():
     """ TF op BatchMatMul, BatchMatMulV2 test"""
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'int32')
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
-    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32')
-    _test_batch_matmul((1, 2, 3, 4, 5, 6),
-                       (1, 2, 3, 4, 6, 5), 'float32', True, True)
-    _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False)
-    _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6),
-                       (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), "int32")
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), "int32", True, False)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
+    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "int32")
+    _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True)
+    _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False)
+    _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)
 
 
 #######################################################################
 # StridedSlice
 # ------------
 
-def _test_stridedslice(ip_shape, begin, end, stride, dtype,
-                       begin_mask=0, end_mask=0, new_axis_mask=0,
-                       shrink_axis_mask=0, ellipsis_mask=0):
+
+def _test_stridedslice(
+    ip_shape,
+    begin,
+    end,
+    stride,
+    dtype,
+    begin_mask=0,
+    end_mask=0,
+    new_axis_mask=0,
+    shrink_axis_mask=0,
+    ellipsis_mask=0,
+):
     """ One iteration of a Stridedslice """
 
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
-        tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
-                         end_mask=end_mask, new_axis_mask=new_axis_mask,
-                         shrink_axis_mask=shrink_axis_mask,
-                         ellipsis_mask=ellipsis_mask, name="strided_slice")
+        tf.strided_slice(
+            in_data,
+            begin,
+            end,
+            stride,
+            begin_mask=begin_mask,
+            end_mask=end_mask,
+            new_axis_mask=new_axis_mask,
+            shrink_axis_mask=shrink_axis_mask,
+            ellipsis_mask=ellipsis_mask,
+            name="strided_slice",
+        )
         np_data = np.random.uniform(size=ip_shape).astype(dtype)
 
-        compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
+        compare_tf_with_tvm(np_data, "in_data:0", "strided_slice:0")
 
 
 def test_forward_stridedslice():
-    '''test StridedSlice'''
-
-    _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
-    _test_stridedslice((2, 1), [0], [1], [1], 'float32', shrink_axis_mask=1)
-    _test_stridedslice((2, 3, 4), [0], [1], [1], 'float32', shrink_axis_mask=8)
-    _test_stridedslice((3, 4, 3), [1, -1, 0],
-                       [4, -5, 3], [2, -1, 1], 'float32')
-    _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
-                       2, 1], 'float32', ellipsis_mask=8)
-    _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [
-                       2, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [
-                       2, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [
-                       2, 1, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [
-                       2, 1, 1], 'float32', new_axis_mask=5)
-    _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
-                       new_axis_mask=4)
-    _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2,
-                       new_axis_mask=5)
-    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4,
-                       new_axis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
-                       new_axis_mask=3)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
-                       new_axis_mask=3)
-    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
-                       new_axis_mask=2)
-    _test_stridedslice((3, 4), [1, 0], [4, 4], [
-                       1, 1], 'float32', shrink_axis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
-                       new_axis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1,
-                       new_axis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
-                       new_axis_mask=1)
-    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5,
-                       new_axis_mask=1)
-    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
-                       'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2,
-                       begin_mask=8, end_mask=8)
-    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
-                       'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2,
-                       begin_mask=5, end_mask=5)
-    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
-                       'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2,
-                       begin_mask=5, end_mask=5)
-    _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
-                       'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2,
-                       begin_mask=5, end_mask=8)
+    """test StridedSlice"""
+
+    _test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1)
+    _test_stridedslice((2, 1), [0], [1], [1], "float32", shrink_axis_mask=1)
+    _test_stridedslice((2, 3, 4), [0], [1], [1], "float32", shrink_axis_mask=8)
+    _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32")
+    _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8)
+    _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
+    _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2)
+    _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], "float32", ellipsis_mask=2)
+    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], "float32", new_axis_mask=5)
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=4
+    )
+    _test_stridedslice(
+        (6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=5
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=4, new_axis_mask=2
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=2
+    )
+    _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2)
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=2
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=1, new_axis_mask=2
+    )
+    _test_stridedslice(
+        (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=1
+    )
+    _test_stridedslice(
+        (3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], "float32", shrink_axis_mask=5, new_axis_mask=1
+    )
+    _test_stridedslice(
+        (3, 4, 5, 4, 5, 6),
+        [0, 0, 1, 2, 1],
+        [2, 3, 4, 5, 3],
+        [1, 1, 2, 2, 1],
+        "float32",
+        shrink_axis_mask=5,
+        new_axis_mask=1,
+        ellipsis_mask=2,
+        begin_mask=8,
+        end_mask=8,
+    )
+    _test_stridedslice(
+        (3, 4, 5, 4, 5, 6),
+        [0, 0, 1, 2, 1],
+        [2, 3, 4, 5, 3],
+        [1, 1, 2, 2, 1],
+        "float32",
+        shrink_axis_mask=8,
+        new_axis_mask=1,
+        ellipsis_mask=2,
+        begin_mask=5,
+        end_mask=5,
+    )
+    _test_stridedslice(
+        (3, 4, 5, 4, 5, 6),
+        [0, 0, 1, 2, 1],
+        [2, 3, 4, 5, 3],
+        [1, 1, 2, 2, 1],
+        "float32",
+        shrink_axis_mask=16,
+        new_axis_mask=1,
+        ellipsis_mask=2,
+        begin_mask=5,
+        end_mask=5,
+    )
+    _test_stridedslice(
+        (3, 4, 5, 4, 5, 6),
+        [1, 2, 0, -3],
+        [4, 5, 3, 3],
+        [2, 2, 1, 1],
+        "float32",
+        shrink_axis_mask=8,
+        new_axis_mask=1,
+        ellipsis_mask=2,
+        begin_mask=5,
+        end_mask=8,
+    )
+
 
 #######################################################################
 # FloorDiv, RealDiv
@@ -1435,9 +1808,8 @@ def _test_forward_divide(ip_shape, dtype):
     with tf.Graph().as_default():
         numerator = tf.placeholder(dtype, ip_shape, name="numer")
         denominator = tf.placeholder(dtype, ip_shape, name="denomin")
-        tf.math.divide(numerator, denominator, name='RealDiv')
-        compare_tf_with_tvm([np_numer, np_denomin], [
-                            'numer:0', 'denomin:0'], 'RealDiv:0')
+        tf.math.divide(numerator, denominator, name="RealDiv")
+        compare_tf_with_tvm([np_numer, np_denomin], ["numer:0", "denomin:0"], "RealDiv:0")
 
 
 def _test_forward_floordiv(ip_shape, dtype):
@@ -1445,16 +1817,17 @@ def _test_forward_floordiv(ip_shape, dtype):
     tf.reset_default_graph()
     with tf.Graph().as_default():
         numerator = tf.placeholder(dtype, ip_shape, name="numer")
-        tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
-        compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
+        tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name="FloorDiv")
+        compare_tf_with_tvm([np_numer], ["numer:0"], "FloorDiv:0")
 
 
 def test_forward_divide():
-    '''test FloorDiv, RealDiv'''
-    _test_forward_divide((4,), 'int32')
-    _test_forward_divide((4, 3, 7), 'float32')
-    _test_forward_floordiv((4, 3, 7), 'float32')
-    _test_forward_floordiv((4, 3, 7), 'int32')
+    """test FloorDiv, RealDiv"""
+    _test_forward_divide((4,), "int32")
+    _test_forward_divide((4, 3, 7), "float32")
+    _test_forward_floordiv((4, 3, 7), "float32")
+    _test_forward_floordiv((4, 3, 7), "int32")
+
 
 #######################################################################
 # FloorMod
@@ -1466,15 +1839,16 @@ def _test_forward_floormod(in_shape, if_shape, dtype):
     with tf.Graph().as_default():
         numerator = tf.placeholder(dtype, in_shape, name="numer")
         factor = tf.placeholder(dtype, if_shape, name="factor")
-        tf.floormod(numerator, factor, name='FloorMod')
-        compare_tf_with_tvm([np_numer, np_factor], ['numer:0', 'factor:0'], 'FloorMod:0')
+        tf.floormod(numerator, factor, name="FloorMod")
+        compare_tf_with_tvm([np_numer, np_factor], ["numer:0", "factor:0"], "FloorMod:0")
+
 
 def test_forward_floormod():
-    '''test FloorMod'''
-    _test_forward_floormod((10,), (10,), 'float32')
-    _test_forward_floormod((8, 2), (1,), 'float32')
-    _test_forward_floormod((4, 3, 7), (4, 3, 7), 'float32')
-    _test_forward_floormod((4, 3, 7), (4, 3, 7), 'int32')
+    """test FloorMod"""
+    _test_forward_floormod((10,), (10,), "float32")
+    _test_forward_floormod((8, 2), (1,), "float32")
+    _test_forward_floormod((4, 3, 7), (4, 3, 7), "float32")
+    _test_forward_floormod((4, 3, 7), (4, 3, 7), "int32")
 
 
 #######################################################################
@@ -1487,20 +1861,20 @@ def _test_forward_truncatemod(ip_shape, dtype):
     with tf.Graph().as_default():
         in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
         in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
-        tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
-        compare_tf_with_tvm([np_data_1, np_data_2], [
-                            'in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
+        tf.truncatemod(in_data_1, in_data_2, name="truncatemod")
+        compare_tf_with_tvm([np_data_1, np_data_2], ["in_data_1:0", "in_data_2:0"], "truncatemod:0")
 
 
 def test_forward_truncatemod():
-    '''test TruncateMod'''
-    _test_forward_truncatemod((4, 3, 7), 'int32')
+    """test TruncateMod"""
+    _test_forward_truncatemod((4, 3, 7), "int32")
 
 
 #######################################################################
 # Gather, GatherV2
 # --------------------------
 
+
 def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
     """ One iteration of a GatherV2 """
 
@@ -1514,33 +1888,35 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
         def _fill_indices(indice_value):
             indices = np.array(ip_shape, dtype=dtype)
             if isinstance(indice_value, int):
-                indices = np.array([indice_value], dtype='int32')
+                indices = np.array([indice_value], dtype="int32")
             else:
-                indices = np.asarray(indice_value, dtype='int32')
+                indices = np.asarray(indice_value, dtype="int32")
             return indices
+
         np_indices = _fill_indices(indice_value)
 
-        compare_tf_with_tvm([np_data, np_indices], [
-                            'in_data:0', 'indices:0'], out.name)
+        compare_tf_with_tvm([np_data, np_indices], ["in_data:0", "indices:0"], out.name)
 
 
 def test_forward_gather():
-    '''test Gather/GatherV2 layer'''
-    _test_gather((4,), (1,), 1, 0, 'int32')
-    _test_gather((4,), (1,), 1, 0, 'float32')
-    _test_gather((1, 4), (1,), [0], 0, 'int32')
-    _test_gather((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
-    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
-    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
-    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
-    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32')
-    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
-    _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
+    """test Gather/GatherV2 layer"""
+    _test_gather((4,), (1,), 1, 0, "int32")
+    _test_gather((4,), (1,), 1, 0, "float32")
+    _test_gather((1, 4), (1,), [0], 0, "int32")
+    _test_gather((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "float32")
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "int32")
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 1, "int32")
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, "float32")
+    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, "int32")
+    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, "int32")
+    _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, "float32")
+
 
 #######################################################################
 # GatherND
 # --------------------------
 
+
 def _test_gather_nd(ip_shape, indice_value, dtype):
     """test operator GatherNd"""
     np_data = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
@@ -1548,27 +1924,30 @@ def _test_gather_nd(ip_shape, indice_value, dtype):
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
         tf.gather_nd(in_data, indices=indice_value, name="gather_nd")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "gather_nd:0")
+
 
 def test_forward_gather_nd():
     """test operator GatherNd"""
-    _test_gather_nd((2, 2), [[0, 0], [1, 1]], 'float32')
-    _test_gather_nd((2, 2, 2), [[1, 0, 0], [0, 0, 0]], 'float32')
-    _test_gather_nd((4,), [1], 'float32')
-    _test_gather_nd((4,), [1], 'int32')
-    _test_gather_nd((1, 4), [0, 3], 'int32')
-    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'int32')
-    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'float32')
-    _test_gather_nd((3, 3, 3),  [[[1, 0]]], 'int32')
-    _test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32')
-    _test_gather_nd((4, 3, 5, 6),  [[2, 1, 0, 0]], 'float32')
-    _test_gather_nd((3, 3, 3), [[[2, 1]]], 'int32')
+    _test_gather_nd((2, 2), [[0, 0], [1, 1]], "float32")
+    _test_gather_nd((2, 2, 2), [[1, 0, 0], [0, 0, 0]], "float32")
+    _test_gather_nd((4,), [1], "float32")
+    _test_gather_nd((4,), [1], "int32")
+    _test_gather_nd((1, 4), [0, 3], "int32")
+    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], "int32")
+    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], "float32")
+    _test_gather_nd((3, 3, 3), [[[1, 0]]], "int32")
+    _test_gather_nd((3, 3, 3), [[[1, 0]]], "int32")
+    _test_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32")
+    _test_gather_nd((3, 3, 3), [[[2, 1]]], "int32")
+
 
 #######################################################################
 # BiasAdd
 # -------
 def test_forward_bias_add():
     """test Op BiasAdd"""
+
     def check_bias_add(lh_shpae, rh_shape, dtype):
         tf.reset_default_graph()
         lh_data = np.random.uniform(size=lh_shpae).astype(dtype)
@@ -1577,8 +1956,7 @@ def test_forward_bias_add():
             lft_data = tf.placeholder(dtype, name="lft_data")
             rgt_data = tf.placeholder(dtype, name="rgt_data")
             tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
-            compare_tf_with_tvm([lh_data, rh_data], [
-                                'lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
+            compare_tf_with_tvm([lh_data, rh_data], ["lft_data:0", "rgt_data:0"], "BiasAdd:0")
 
     check_bias_add((10, 8, 16, 32), (32,), dtype="int32")
     check_bias_add((10, 20), (20,), dtype="float32")
@@ -1588,6 +1966,7 @@ def test_forward_bias_add():
 # Split
 # -----
 
+
 def _test_split(in_shape, axis, num_or_size_splits, dtype):
     np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
 
@@ -1595,12 +1974,13 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
-        num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
-            else num_or_size_splits
+        num_split = (
+            len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
+        )
         split = tf.split(in_data, num_or_size_splits, axis=axis)
         relu = [tf.nn.relu(i) for i in split]
 
-        compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in relu])
+        compare_tf_with_tvm([np_data], ["in_data:0"], [n.name for n in relu])
 
     # and now test together with concat
     tf.reset_default_graph()
@@ -1608,48 +1988,49 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
         splitted = tf.split(in_data, num_or_size_splits, axis=axis)
         concat = tf.concat(splitted, axis)
-        compare_tf_with_tvm([np_data], 'in_data:0', concat.name)
+        compare_tf_with_tvm([np_data], "in_data:0", concat.name)
 
 
 def test_forward_split():
-    '''test split layer'''
+    """test split layer"""
     # rank 1
-    _test_split((3,), 0, 1, 'float32')
-    _test_split((3,), 0, 3, 'float32')
-    _test_split((6,), 0, 3, 'float32')
+    _test_split((3,), 0, 1, "float32")
+    _test_split((3,), 0, 3, "float32")
+    _test_split((6,), 0, 3, "float32")
     # rank 2
-    _test_split((6, 2), 0, 3, 'float32')
-    _test_split((2, 6), 1, 6, 'float32')
+    _test_split((6, 2), 0, 3, "float32")
+    _test_split((2, 6), 1, 6, "float32")
     # rank 3
-    _test_split((6, 2, 4), 0, 2, 'int32')
-    _test_split((2, 6, 4), 1, 3, 'float32')
-    _test_split((2, 4, 6), 2, 1, 'float32')
+    _test_split((6, 2, 4), 0, 2, "int32")
+    _test_split((2, 6, 4), 1, 3, "float32")
+    _test_split((2, 4, 6), 2, 1, "float32")
     # rank 4
-    _test_split((6, 1, 3, 5), 0, 3, 'float32')
-    _test_split((1, 6, 3, 5), 1, 3, 'float32')
-    _test_split((1, 3, 6, 5), 2, 3, 'float32')
-    _test_split((1, 3, 5, 6), 3, 3, 'float32')
+    _test_split((6, 1, 3, 5), 0, 3, "float32")
+    _test_split((1, 6, 3, 5), 1, 3, "float32")
+    _test_split((1, 3, 6, 5), 2, 3, "float32")
+    _test_split((1, 3, 5, 6), 3, 3, "float32")
     # split along negative axis
-    _test_split((6, 1, 3, 5), -4, 3, 'float32')
-    _test_split((1, 6, 3, 5), -3, 3, 'float32')
-    _test_split((1, 3, 6, 5), -2, 3, 'float32')
-    _test_split((1, 3, 5, 6), -1, 3, 'float32')
+    _test_split((6, 1, 3, 5), -4, 3, "float32")
+    _test_split((1, 6, 3, 5), -3, 3, "float32")
+    _test_split((1, 3, 6, 5), -2, 3, "float32")
+    _test_split((1, 3, 5, 6), -1, 3, "float32")
     # size_splits list
-    _test_split((6,), 0, [1, 2, 3], 'int32')
-    _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
+    _test_split((6,), 0, [1, 2, 3], "int32")
+    _test_split((3, 6, 4), -2, [1, 4, 1], "float32")
 
 
 ######################################################################
 # TopKV2
 # ------
 
+
 def _test_forward_top_k_v2(in_shape, k):
     np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder("float32", in_shape, name="in_data")
-        tf.math.top_k(in_data, k, name='TopK')
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
+        tf.math.top_k(in_data, k, name="TopK")
+        compare_tf_with_tvm([np_data], ["in_data:0"], "TopK:0")
 
 
 def test_forward_top_k_v2():
@@ -1663,6 +2044,7 @@ def test_forward_top_k_v2():
 # Unstack
 # -------
 
+
 def _test_unstack(ip_shape, axis, dtype):
     np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
 
@@ -1671,41 +2053,42 @@ def _test_unstack(ip_shape, axis, dtype):
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
         unstack = tf.unstack(in_data, axis=axis)
 
-        compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in unstack])
+        compare_tf_with_tvm([np_data], ["in_data:0"], [n.name for n in unstack])
 
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
         tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
 
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "stack:0")
 
 
 def test_forward_unstack():
-    '''test unstack layer'''
-    _test_unstack((6,), 0, 'int32')
-    _test_unstack((2, 6), 1, 'float64')
+    """test unstack layer"""
+    _test_unstack((6,), 0, "int32")
+    _test_unstack((2, 6), 1, "float64")
     # negative axis
-    _test_unstack((1, 4), -1, 'int32')
-    _test_unstack((3, 6, 4), -2, 'float32')
+    _test_unstack((1, 4), -1, "int32")
+    _test_unstack((3, 6, 4), -2, "float32")
 
 
 #######################################################################
 # Tile
 # ----
 
+
 def _test_tile(in_shape, multiples, dtype):
     np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
         tf.tile(in_data, multiples=multiples, name="tile")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "tile:0")
 
 
 def test_forward_tile():
-    '''test Tile'''
-    _test_tile((2, ), (3, ), "int32")
+    """test Tile"""
+    _test_tile((2,), (3,), "int32")
     _test_tile((2, 2), (2, 3), "float32")
     _test_tile((2, 4, 6), (6, 7, 8), "float64")
 
@@ -1714,21 +2097,22 @@ def test_forward_tile():
 # ClipByValue
 # -----------
 
+
 def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
-        tf.clip_by_value(in_data, clip_value_min,
-                         clip_value_max, name="ClipByValue")
+        tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue")
         np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "ClipByValue:0")
 
 
 def test_forward_clip_by_value():
-    '''test ClipByValue op'''
-    if tf.__version__ < LooseVersion('1.9'):
-        _test_forward_clip_by_value((4,), .1, 5., 'float32')
-        _test_forward_clip_by_value((4, 4), 1, 5, 'int32')
+    """test ClipByValue op"""
+    if tf.__version__ < LooseVersion("1.9"):
+        _test_forward_clip_by_value((4,), 0.1, 5.0, "float32")
+        _test_forward_clip_by_value((4, 4), 1, 5, "int32")
+
 
 #######################################################################
 # Multi Input to graph
@@ -1737,18 +2121,20 @@ def test_forward_clip_by_value():
 
 def test_forward_multi_input():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
-        in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
-        in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
-        in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
-
-        out1 = tf.add(in1, in2, name='out1')
-        out2 = tf.subtract(in3, in4, name='out2')
-        out = tf.multiply(out1, out2, name='out')
-        in_data = np.arange(9, dtype='int32').reshape([3, 3])
+        in1 = tf.placeholder(tf.int32, shape=[3, 3], name="in1")
+        in2 = tf.placeholder(tf.int32, shape=[3, 3], name="in2")
+        in3 = tf.placeholder(tf.int32, shape=[3, 3], name="in3")
+        in4 = tf.placeholder(tf.int32, shape=[3, 3], name="in4")
+
+        out1 = tf.add(in1, in2, name="out1")
+        out2 = tf.subtract(in3, in4, name="out2")
+        out = tf.multiply(out1, out2, name="out")
+        in_data = np.arange(9, dtype="int32").reshape([3, 3])
+
+        compare_tf_with_tvm(
+            [in_data, in_data, in_data, in_data], ["in1:0", "in2:0", "in3:0", "in4:0"], "out:0"
+        )
 
-        compare_tf_with_tvm([in_data, in_data, in_data, in_data],
-                            ['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
 
 #######################################################################
 # Multi Output to Graph
@@ -1757,29 +2143,33 @@ def test_forward_multi_input():
 
 def test_forward_multi_output():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
-        in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
-        in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
-        in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
-
-        out1 = tf.add(in1, in2, name='out1')
-        out2 = tf.subtract(in3, in4, name='out2')
-        in_data = np.arange(9, dtype='int32').reshape([3, 3])
+        in1 = tf.placeholder(tf.int32, shape=[3, 3], name="in1")
+        in2 = tf.placeholder(tf.int32, shape=[3, 3], name="in2")
+        in3 = tf.placeholder(tf.int32, shape=[3, 3], name="in3")
+        in4 = tf.placeholder(tf.int32, shape=[3, 3], name="in4")
+
+        out1 = tf.add(in1, in2, name="out1")
+        out2 = tf.subtract(in3, in4, name="out2")
+        in_data = np.arange(9, dtype="int32").reshape([3, 3])
         in_data = [in_data] * 4
-        in_name = ['in1:0', 'in2:0', 'in3:0', 'in4:0']
-        out_name = ['out1:0', 'out2:0']
-        out_node = [out.strip(':0') for out in out_name]
-        in_node = [inp.strip(':0') for inp in in_name]
+        in_name = ["in1:0", "in2:0", "in3:0", "in4:0"]
+        out_name = ["out1:0", "out2:0"]
+        out_node = [out.strip(":0") for out in out_name]
+        in_node = [inp.strip(":0") for inp in in_name]
 
         with tf.Session() as sess:
             final_graph_def = tf.graph_util.convert_variables_to_constants(
-                sess, sess.graph.as_graph_def(add_shapes=True), out_node,)
+                sess,
+                sess.graph.as_graph_def(add_shapes=True),
+                out_node,
+            )
             tf_output = run_tf_graph(sess, in_data, in_name, out_name)
-            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
-                                       out_names=out_node, num_output=2)
+            tvm_output = run_tvm_graph(
+                final_graph_def, in_data, in_node, target="llvm", out_names=out_node, num_output=2
+            )
             for i in range(len(tf_output)):
-                tvm.testing.assert_allclose(
-                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+
 
 #######################################################################
 # Resize Bilinear, Nearest_Neighbor
@@ -1789,63 +2179,62 @@ def test_forward_multi_output():
 def _test_resize_bilinear(in_shape, to_shape, align_corners):
     """ One iteration of resize bilinear """
 
-    data = np.random.uniform(size=in_shape).astype('float32')
-    shape_data = np.array(to_shape).astype('int32')
+    data = np.random.uniform(size=in_shape).astype("float32")
+    shape_data = np.array(to_shape).astype("int32")
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         shape_data = constant_op.constant(
-            shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
-        tf.image.resize_bilinear(
-            in_data, shape_data, align_corners=align_corners)
+            shape_data, shape=shape_data.shape, dtype=shape_data.dtype
+        )
+        tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "ResizeBilinear:0")
 
 
 def _test_resize_bilinear_from_tensor(in_shape, align_corners):
-    """ One iteration of resize bilinear with non-constant output shape, requires
-        value inference to get proper output shape."""
+    """One iteration of resize bilinear with non-constant output shape, requires
+    value inference to get proper output shape."""
 
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(
-            shape=[in_shape[0], None, None, in_shape[3]], dtype=data.dtype)
+            shape=[in_shape[0], None, None, in_shape[3]], dtype=data.dtype
+        )
         to_shape = tf.shape(in_data)[1:3]
-        tf.image.resize_bilinear(
-            in_data, to_shape, align_corners=align_corners)
+        tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "ResizeBilinear:0")
 
 
 def _test_resize_nearest_neighbor(in_shape, to_shape):
     """ One iteration of resize nearest neighbor """
 
-    data = np.random.uniform(size=in_shape).astype('float32')
-    shape_data = np.array(to_shape).astype('int32')
+    data = np.random.uniform(size=in_shape).astype("float32")
+    shape_data = np.array(to_shape).astype("int32")
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         shape_data = constant_op.constant(
-            shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
-        tf.image.resize_nearest_neighbor(
-            in_data, shape_data, name='resize_nearest_neighbor')
+            shape_data, shape=shape_data.shape, dtype=shape_data.dtype
+        )
+        tf.image.resize_nearest_neighbor(in_data, shape_data, name="resize_nearest_neighbor")
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "resize_nearest_neighbor:0")
 
 
 def _test_resize_nearest_neighbor_dynamic_shape(in_shape, scale):
     """ One iteration of resize nearest neighbor for graph with dynamic input shape """
 
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=None, dtype=data.dtype)
         # multiply input shape by scale factor
         new_shape = tf.shape(in_data)[1:3] * tf.constant(scale, dtype=tf.int32)
-        tf.image.resize_nearest_neighbor(
-            in_data, new_shape, name='resize_nearest_neighbor')
+        tf.image.resize_nearest_neighbor(in_data, new_shape, name="resize_nearest_neighbor")
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "resize_nearest_neighbor:0")
 
 
 def test_forward_resize():
@@ -1863,35 +2252,35 @@ def test_forward_resize():
 # BroadcastTo
 # -----------
 
+
 def _test_broadcast_to(in_shape, to_shape):
     """ One iteration of broadcast_to"""
 
-    data = np.random.uniform(size=in_shape).astype('float32')
-    shape_data = np.array(to_shape).astype('int32')
+    data = np.random.uniform(size=in_shape).astype("float32")
+    shape_data = np.array(to_shape).astype("int32")
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         shape_data = constant_op.constant(
-            shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
+            shape_data, shape=shape_data.shape, dtype=shape_data.dtype
+        )
         tf.broadcast_to(in_data, shape_data)
 
-        compare_tf_with_tvm(data, 'Placeholder:0',
-                            'BroadcastTo:0', opt_level=0)
+        compare_tf_with_tvm(data, "Placeholder:0", "BroadcastTo:0", opt_level=0)
 
 
 def _test_broadcast_to_from_tensor(in_shape):
     """ One iteration of broadcast_to with unknown shape at graph build"""
 
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(
-            shape=[None], dtype=data.dtype)
+        in_data = array_ops.placeholder(shape=[None], dtype=data.dtype)
 
         shape_data = tf.multiply(tf.shape(in_data), 32)
         tf.broadcast_to(in_data, shape_data)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "BroadcastTo:0")
 
 
 def test_forward_broadcast_to():
@@ -1906,28 +2295,30 @@ def test_forward_broadcast_to():
 # Fill
 # ----
 
+
 def _test_fill(in_shape):
     """ Use the fill op to create a tensor of ones with non-constant shape."""
 
     with tf.Graph().as_default():
-        tf.ones(shape=in_shape, dtype='float32')
-        compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)
+        tf.ones(shape=in_shape, dtype="float32")
+        compare_tf_with_tvm(in_shape, [], "ones:0", opt_level=1)
 
 
 def _test_fill_from_tensor(in_shape):
-    """ Use the fill op to create a tensor of ones with non-constant shape.
-        Some extra ops need to be added here to prevent the graph from
-        being fully constant and folded away."""
+    """Use the fill op to create a tensor of ones with non-constant shape.
+    Some extra ops need to be added here to prevent the graph from
+    being fully constant and folded away."""
 
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(
-            shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
+            shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype
+        )
 
-        x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype)
-        y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
-        compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')
+        x = tf.ones(shape=2 * tf.shape(in_data), dtype=data.dtype)
+        y = tf.math.add(in_data, tf.reduce_mean(x), name="out1")
+        compare_tf_with_tvm(data, "Placeholder:0", "out1:0")
 
 
 def _test_fill_symbolic_inputs(in_shape_data, in_value_data, dtype):
@@ -1935,8 +2326,10 @@ def _test_fill_symbolic_inputs(in_shape_data, in_value_data, dtype):
         in_shape = tf.placeholder(shape=[in_shape_data.shape[0]], dtype=in_shape_data.dtype)
         in_value = tf.placeholder(shape=(), dtype=dtype)
         out = tf.fill(in_shape, in_value)
-        for mode in ['debug', 'vm']:
-            compare_tf_with_tvm([in_shape_data, in_value_data], [in_shape.name, in_value.name], out.name, mode=mode)
+        for mode in ["debug", "vm"]:
+            compare_tf_with_tvm(
+                [in_shape_data, in_value_data], [in_shape.name, in_value.name], out.name, mode=mode
+            )
 
 
 def test_forward_fill():
@@ -1949,6 +2342,7 @@ def test_forward_fill():
     _test_fill_symbolic_inputs(np.array((2, 3)), 9, tf.int64)
     _test_fill_symbolic_inputs(np.array((2, 3, 4)), np.float32(9.0), tf.float32)
 
+
 #######################################################################
 # Crop to bounding box
 # --------------------
@@ -1956,12 +2350,11 @@ def test_forward_fill():
 
 def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
     """ Crop to bounding box """
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
-        compare_tf_with_tvm(data, 'Placeholder:0',
-                            'crop_to_bounding_box/Slice:0')
+        compare_tf_with_tvm(data, "Placeholder:0", "crop_to_bounding_box/Slice:0")
 
 
 def test_forward_crop():
@@ -1973,54 +2366,72 @@ def test_forward_crop():
 # CropAndResize
 # -------------
 
-def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size,
-                                  extrapolation_value=0.0, method='bilinear', dtype="float32"):
+
+def _test_forward_crop_and_resize(
+    img_shape,
+    boxes,
+    box_idx,
+    crop_size,
+    extrapolation_value=0.0,
+    method="bilinear",
+    dtype="float32",
+):
     image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(dtype, image.shape, name="in_data")
-        tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx,
-                                 crop_size=crop_size, method=method,
-                                 extrapolation_value=extrapolation_value,
-                                 name="crop_and_resize")
-        compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
+        tf.image.crop_and_resize(
+            in_data,
+            boxes=boxes,
+            box_ind=box_idx,
+            crop_size=crop_size,
+            method=method,
+            extrapolation_value=extrapolation_value,
+            name="crop_and_resize",
+        )
+        compare_tf_with_tvm([image], ["in_data:0"], "crop_and_resize:0")
 
 
 def test_forward_crop_and_resize():
     """ CropAndResize """
     _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3])
     _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2)
-    _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, 'nearest')
-    _test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3,  1,  1]], [0], [21, 21])
-    _test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [21, 11])
-    _test_forward_crop_and_resize([1, 100, 100, 3], [[ 0,  0, .9, .9]], [0], [30, 30])
-    _test_forward_crop_and_resize([1, 224, 224, 3], [[.1, .2,  1,  1]], [0], [9, 9])
-    _test_forward_crop_and_resize([1, 249, 249, 3], [[ 0,  0,  1,  1]], [0], [9, 9])
-    _test_forward_crop_and_resize([1, 201, 301, 3], [[.2, .3, .7, .8]], [0], [51, 51])
-    _test_forward_crop_and_resize(img_shape=[10, 11, 11, 3],
-                                  boxes=[[ 0,  0, .9, .9],
-                                         [.2, .2, .8, .8]],
-                                  box_idx=[0, 1], crop_size=[5, 5])
-    _test_forward_crop_and_resize(img_shape=[20, 576, 576, 3],
-                                  boxes=[[ 0,  0,  1,  1],
-                                         [ 0,  0, .8, .8],
-                                         [.1, .2, .9,  1],
-                                         [.2,  0,  1,  1]],
-                                  box_idx=[1, 0, 2, 3], crop_size=[24, 24],
-                                  extrapolation_value=0.3)
-    _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3],
-                                  boxes=[[ 0,  0, .9, .9],
-                                         [.3, .3,  1,  1],
-                                         [.2, .1, .7, .8],
-                                         [ 0,  0,  1,  1]],
-                                  box_idx=[3, 0, 2, 1], crop_size=[58, 58],
-                                  extrapolation_value=0.2, method='nearest')
+    _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, "nearest")
+    _test_forward_crop_and_resize([1, 11, 11, 3], [[0.3, 0.3, 1, 1]], [0], [21, 21])
+    _test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [21, 11])
+    _test_forward_crop_and_resize([1, 100, 100, 3], [[0, 0, 0.9, 0.9]], [0], [30, 30])
+    _test_forward_crop_and_resize([1, 224, 224, 3], [[0.1, 0.2, 1, 1]], [0], [9, 9])
+    _test_forward_crop_and_resize([1, 249, 249, 3], [[0, 0, 1, 1]], [0], [9, 9])
+    _test_forward_crop_and_resize([1, 201, 301, 3], [[0.2, 0.3, 0.7, 0.8]], [0], [51, 51])
+    _test_forward_crop_and_resize(
+        img_shape=[10, 11, 11, 3],
+        boxes=[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
+        box_idx=[0, 1],
+        crop_size=[5, 5],
+    )
+    _test_forward_crop_and_resize(
+        img_shape=[20, 576, 576, 3],
+        boxes=[[0, 0, 1, 1], [0, 0, 0.8, 0.8], [0.1, 0.2, 0.9, 1], [0.2, 0, 1, 1]],
+        box_idx=[1, 0, 2, 3],
+        crop_size=[24, 24],
+        extrapolation_value=0.3,
+    )
+    _test_forward_crop_and_resize(
+        img_shape=[20, 229, 229, 3],
+        boxes=[[0, 0, 0.9, 0.9], [0.3, 0.3, 1, 1], [0.2, 0.1, 0.7, 0.8], [0, 0, 1, 1]],
+        box_idx=[3, 0, 2, 1],
+        crop_size=[58, 58],
+        extrapolation_value=0.2,
+        method="nearest",
+    )
 
 
 #######################################################################
 # Non Max Suppression
 # -------------------
-def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"):
+def _test_forward_nms_v3(
+    bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"
+):
     boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
     scores = np.random.uniform(size=score_shape).astype(dtype)
     max_output_size = np.int32(out_size)
@@ -2028,14 +2439,31 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold,
     in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
     in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
     in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
-    tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
-                                 iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms")
-    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
-                        'nms/NonMaxSuppressionV3:0', mode='vm')
-    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
-                        'nms/NonMaxSuppressionV3:0', mode='debug')
-
-def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"):
+    tf.image.non_max_suppression(
+        boxes=in_data_1,
+        scores=in_data_2,
+        max_output_size=in_data_3,
+        iou_threshold=iou_threshold,
+        score_threshold=score_threshold,
+        name="nms",
+    )
+    compare_tf_with_tvm(
+        [boxes, scores, max_output_size],
+        ["in_data_1:0", "in_data_2:0", "in_data_3:0"],
+        "nms/NonMaxSuppressionV3:0",
+        mode="vm",
+    )
+    compare_tf_with_tvm(
+        [boxes, scores, max_output_size],
+        ["in_data_1:0", "in_data_2:0", "in_data_3:0"],
+        "nms/NonMaxSuppressionV3:0",
+        mode="debug",
+    )
+
+
+def _test_forward_nms_v4(
+    bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"
+):
     boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
     scores = np.random.uniform(size=score_shape).astype(dtype)
     max_output_size = np.int32(out_size)
@@ -2043,15 +2471,31 @@ def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold,
     in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
     in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
     in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
-    indices_padded, num_valid = tf.image.non_max_suppression_padded(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
-                                 iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms", pad_to_max_output_size=True)
-    num_valid = tf.reshape(num_valid,shape=(-1,))
+    indices_padded, num_valid = tf.image.non_max_suppression_padded(
+        boxes=in_data_1,
+        scores=in_data_2,
+        max_output_size=in_data_3,
+        iou_threshold=iou_threshold,
+        score_threshold=score_threshold,
+        name="nms",
+        pad_to_max_output_size=True,
+    )
+    num_valid = tf.reshape(num_valid, shape=(-1,))
     indices_padded = tf.reshape(indices_padded, shape=(-1,))
     tf.slice(indices_padded, tf.constant([0]), num_valid, name="SlicedIndices")
-    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
-                        ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='vm')
-    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
-                        ['nms/NonMaxSuppressionV4:1',  "SlicedIndices:0"], mode='debug')
+    compare_tf_with_tvm(
+        [boxes, scores, max_output_size],
+        ["in_data_1:0", "in_data_2:0", "in_data_3:0"],
+        ["nms/NonMaxSuppressionV4:1", "SlicedIndices:0"],
+        mode="vm",
+    )
+    compare_tf_with_tvm(
+        [boxes, scores, max_output_size],
+        ["in_data_1:0", "in_data_2:0", "in_data_3:0"],
+        ["nms/NonMaxSuppressionV4:1", "SlicedIndices:0"],
+        mode="debug",
+    )
+
 
 def test_forward_nms():
     """ NonMaxSuppressionV3,4 """
@@ -2066,44 +2510,50 @@ def test_forward_nms():
 # LSTM
 # ----
 
+
 def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
     """ One iteration of a LSTM cell """
 
     tf.reset_default_graph()
     input_size = num_hidden
-    input_data = np.full((batch_size, input_size), 1., dtype=dtype)
-    in_state_c = np.full(
-        (batch_size, num_hidden), 0.1, dtype=dtype)
-    in_state_h = np.full(
-        (batch_size, num_hidden), 0.1, dtype=dtype)
+    input_data = np.full((batch_size, input_size), 1.0, dtype=dtype)
+    in_state_c = np.full((batch_size, num_hidden), 0.1, dtype=dtype)
+    in_state_h = np.full((batch_size, num_hidden), 0.1, dtype=dtype)
 
     def _get_tensorflow_output():
         with tf.Session() as sess:
             with variable_scope.variable_scope(
-                    "root", initializer=init_ops.constant_initializer(0.5)):
+                "root", initializer=init_ops.constant_initializer(0.5)
+            ):
                 m0 = tf.placeholder(dtype, [batch_size, num_hidden], name="m0")
                 m1 = tf.placeholder(dtype, [batch_size, num_hidden], name="m1")
                 x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype, name="input")
-                g, ((out_m0, out_m1)) = \
-                    tensorflow.contrib.rnn.LSTMBlockCell(num_hidden,
-                                                         forget_bias=forget_bias)(x, (m0, m1))
+                g, ((out_m0, out_m1)) = tensorflow.contrib.rnn.LSTMBlockCell(
+                    num_hidden, forget_bias=forget_bias
+                )(x, (m0, m1))
                 sess.run([variables.global_variables_initializer()])
-                res = sess.run([g, out_m0, out_m1], {
-                    x.name: np.array([[1., 1.]]),
-                    m0.name: in_state_c,
-                    m1.name: in_state_h,
-                })
+                res = sess.run(
+                    [g, out_m0, out_m1],
+                    {
+                        x.name: np.array([[1.0, 1.0]]),
+                        m0.name: in_state_c,
+                        m1.name: in_state_h,
+                    },
+                )
             graph_def = sess.graph.as_graph_def(add_shapes=True)
             final_graph_def = graph_util.convert_variables_to_constants(
-                sess,
-                graph_def,
-                ['root/lstm_cell/LSTMBlockCell'])
+                sess, graph_def, ["root/lstm_cell/LSTMBlockCell"]
+            )
 
             return final_graph_def, res
 
     graph_def, tf_out = _get_tensorflow_output()
-    tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
-                               ['root/input', "root/m0", "root/m1"], num_output=7)
+    tvm_output = run_tvm_graph(
+        graph_def,
+        [input_data, in_state_c, in_state_h],
+        ["root/input", "root/m0", "root/m1"],
+        num_output=7,
+    )
     assert isinstance(tvm_output, list)
 
     tvm.testing.assert_allclose(tf_out[0], tvm_output[6], rtol=1e-3, atol=1e-3)
@@ -2111,10 +2561,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
 
 
 def test_forward_lstm():
-    '''test LSTM block cell'''
-    if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
-        #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
-        _test_lstm_cell(1, 2, 1, 0.5, 'float32')
+    """test LSTM block cell"""
+    if package_version.parse(tf.VERSION) < package_version.parse("2.0.0"):
+        # in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
+        _test_lstm_cell(1, 2, 1, 0.5, "float32")
 
 
 #######################################################################
@@ -2126,12 +2576,12 @@ def _test_pack(axis, shape, **kwargs):
     b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
 
     with tf.Graph().as_default():
-        tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a')
-        tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b')
+        tf_a = array_ops.placeholder(shape=shape, dtype="float32", name="pl_a")
+        tf_b = array_ops.placeholder(shape=shape, dtype="float32", name="pl_b")
         tf_c = tf.stack([tf_a, tf_b], axis=axis, **kwargs)
-        assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation"
+        assert tf_c.op.op_def.name == "Pack", "tf.stack() is expected to produce 'Pack' operation"
 
-        compare_tf_with_tvm([a, b], ['pl_a:0', 'pl_b:0'], 'stack:0')
+        compare_tf_with_tvm([a, b], ["pl_a:0", "pl_b:0"], "stack:0")
 
 
 def test_forward_pack():
@@ -2152,13 +2602,14 @@ def _test_forward_unpack(in_shape, axis, dtype):
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
         tf.unstack(in_data, axis=axis, name="Unpack")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "Unpack:0")
 
 
 def test_forward_unpack():
-    _test_forward_unpack((3,), 0, 'int32')
-    _test_forward_unpack((3,), -1, 'int16')
-    _test_forward_unpack((21, 23, 3), 2, 'float32')
+    _test_forward_unpack((3,), 0, "int32")
+    _test_forward_unpack((3,), -1, "int16")
+    _test_forward_unpack((21, 23, 3), 2, "float32")
+
 
 #######################################################################
 # Range
@@ -2170,13 +2621,14 @@ def test_forward_range():
     tf.reset_default_graph()
     with tf.Graph().as_default():
         tf.range(1, 18, 3, name="range")
-        compare_tf_with_tvm([], [], 'range:0')
+        compare_tf_with_tvm([], [], "range:0")
 
     """test type assignment for operator Range"""
     tf.reset_default_graph()
     with tf.Graph().as_default():
         tf.range(1, 256 + 1, 1, dtype=tf.float32)
-        compare_tf_with_tvm([], [], 'range:0')
+        compare_tf_with_tvm([], [], "range:0")
+
 
 #######################################################################
 # Pad
@@ -2189,19 +2641,19 @@ def _test_pad(input_shape, paddings, mode, **kwargs):
     x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
+        in_data = array_ops.placeholder(shape=input_shape, dtype="float32")
         pad_values = constant_op.constant(paddings)
         pad = tf.pad(in_data, paddings=pad_values, mode=mode, **kwargs)
 
-        if mode == 'CONSTANT':
-            if 'constant_values' in kwargs:
-                out_name = 'PadV2:0'
+        if mode == "CONSTANT":
+            if "constant_values" in kwargs:
+                out_name = "PadV2:0"
             else:
-                out_name = 'Pad:0'
+                out_name = "Pad:0"
         else:
-            out_name = 'MirrorPad:0'
+            out_name = "MirrorPad:0"
 
-        compare_tf_with_tvm(x, 'Placeholder:0', out_name)
+        compare_tf_with_tvm(x, "Placeholder:0", out_name)
 
 
 def test_forward_pad():
@@ -2211,6 +2663,7 @@ def test_forward_pad():
     _test_pad((2, 3), [[1, 1], [2, 2]], mode="SYMMETRIC")
     _test_pad((2, 3), [[1, 1], [2, 2]], mode="REFLECT")
 
+
 #######################################################################
 # Logical operators
 # --------------------
@@ -2218,47 +2671,40 @@ def test_forward_pad():
 
 def test_logical_and():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
-        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
-        out = tf.logical_and(in1, in2, name='out')
-        in_data1 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
+        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1")
+        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2")
+        out = tf.logical_and(in1, in2, name="out")
+        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        compare_tf_with_tvm([in_data1, in_data2], ["in1:0", "in2:0"], "out:0")
 
 
 def test_logical_or():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
-        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
-        out = tf.logical_or(in1, in2, name='out')
-        in_data1 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
+        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1")
+        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2")
+        out = tf.logical_or(in1, in2, name="out")
+        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        compare_tf_with_tvm([in_data1, in_data2], ["in1:0", "in2:0"], "out:0")
 
 
 def test_logical_xor():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
-        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
-        out = tf.logical_xor(in1, in2, name='out')
-        in_data1 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
+        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1")
+        in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2")
+        out = tf.logical_xor(in1, in2, name="out")
+        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        compare_tf_with_tvm([in_data1, in_data2], ["in1:0", "in2:0"], "out:0")
 
 
 def test_logical_not():
     with tf.Graph().as_default():
-        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
-        out = tf.logical_not(in1, name='out')
-        in_data1 = np.random.choice(
-            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
+        in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1")
+        out = tf.logical_not(in1, name="out")
+        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool")
+        compare_tf_with_tvm(in_data1, "in1:0", "out:0")
 
 
 def test_forward_logical():
@@ -2272,42 +2718,37 @@ def test_forward_logical():
 # Where, Select
 # -------------
 def test_forward_where():
-    ''' Where: return elements depending on conditions'''
+    """ Where: return elements depending on conditions"""
     with tf.Graph().as_default():
         with tf.Session() as sess:
-            input1 = tf.placeholder(
-                tf.int32, shape=[1, 4, 4, 3], name='input1')
-            input2 = tf.placeholder(
-                tf.int32, shape=[1, 4, 4, 3], name='input2')
+            input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input1")
+            input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input2")
             mask = input1 > input2
             tf.where(mask, input1 + 1, input2 * 2)
-            in_data1 = np.random.uniform(
-                0, 10, size=(1, 4, 4, 3)).astype("uint32")
-            in_data2 = np.random.uniform(
-                0, 10, size=(1, 4, 4, 3)).astype("uint32")
-            compare_tf_with_tvm([in_data1, in_data2], [
-                                'input1:0', 'input2:0'], 'Select:0')
+            in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            compare_tf_with_tvm([in_data1, in_data2], ["input1:0", "input2:0"], "Select:0")
 
 
 #######################################################################
 # Inception V3
 # ------------
 def test_forward_inception_v3():
-    '''test inception V3 model'''
+    """test inception V3 model"""
     with tf.Graph().as_default():
         graph_def = tf_testing.get_workload(
-            'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
+            "InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb"
+        )
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
-        data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
+        data = np.random.uniform(size=(1, 299, 299, 3)).astype("float32")
 
         with tf.Session() as sess:
-            tf_output = run_tf_graph(
-                sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
-            tvm_output = run_tvm_graph(graph_def, data, 'input')
-            tvm.testing.assert_allclose(
-                tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+            tf_output = run_tf_graph(sess, data, "input:0", "InceptionV3/Predictions/Reshape_1:0")
+            tvm_output = run_tvm_graph(graph_def, data, "input")
+            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+
 
 #######################################################################
 # Inception V1
@@ -2315,10 +2756,9 @@ def test_forward_inception_v3():
 
 
 def test_forward_inception_v1():
-    '''test inception V1 model'''
+    """test inception V1 model"""
     with tf.Graph().as_default():
-        graph_def = tf_testing.get_workload(
-            "InceptionV1/classify_image_graph_def-with_shapes.pb")
+        graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
@@ -2327,31 +2767,28 @@ def test_forward_inception_v1():
         from tvm.contrib import util
 
         img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
-        img = Image.frombuffer(
-            'RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
+        img = Image.frombuffer("RGB", (600, 600), img_array.tostring(), "raw", "RGB", 0, 1)
         temp = util.tempdir()
         img_path = temp.relpath("tf-test.jpg")
         img.save(img_path)
 
         import os.path
+
         if not tf.gfile.Exists(os.path.join(img_path)):
-            tf.logging.fatal('File does not exist %s', img_path)
-        data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read()
+            tf.logging.fatal("File does not exist %s", img_path)
+        data = tf.gfile.FastGFile(os.path.join(img_path), "rb").read()
 
         temp.remove()
 
         # Extract tensorflow decoded image frame for tvm input
         with tf.Session() as sess:
-            tvm_data = run_tf_graph(
-                sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')
+            tvm_data = run_tf_graph(sess, data, "DecodeJpeg/contents:0", "DecodeJpeg:0")
 
         with tf.Session() as sess:
-            tf_output = run_tf_graph(
-                sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
-            tvm_output = run_tvm_graph(
-                graph_def, tvm_data, 'DecodeJpeg/contents')
-            tvm.testing.assert_allclose(
-                tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+            tf_output = run_tf_graph(sess, data, "DecodeJpeg/contents:0", "softmax:0")
+            tvm_output = run_tvm_graph(graph_def, tvm_data, "DecodeJpeg/contents")
+            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+
 
 #######################################################################
 # Mobilenet
@@ -2359,25 +2796,28 @@ def test_forward_inception_v1():
 
 
 def test_forward_mobilenet():
-    '''test mobilenet model'''
+    """test mobilenet model"""
     # MobilenetV2
     with tf.Graph().as_default():
         graph_def = tf_testing.get_workload(
             "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
-            "mobilenet_v2_1.4_224_frozen.pb")
+            "mobilenet_v2_1.4_224_frozen.pb",
+        )
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
-        data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
-        out_node = 'MobilenetV2/Predictions/Reshape_1'
+        data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+        out_node = "MobilenetV2/Predictions/Reshape_1"
 
         with tf.Session() as sess:
             # Add shapes to the graph.
             graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
-            tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
-            tvm_output = run_tvm_graph(graph_def, data, 'input')
-            tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
-                                        rtol=1e-5, atol=1e-5)
+            tf_output = run_tf_graph(sess, data, "input:0", out_node + ":0")
+            tvm_output = run_tvm_graph(graph_def, data, "input")
+            tvm.testing.assert_allclose(
+                np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5
+            )
+
 
 #######################################################################
 # ResnetV2
@@ -2386,29 +2826,32 @@ def test_forward_mobilenet():
 
 @tvm.testing.requires_gpu
 def test_forward_resnetv2():
-    '''test resnet model'''
+    """test resnet model"""
     if is_gpu_available():
         with tf.Graph().as_default():
             graph_def = tf_testing.get_workload(
-                "ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
+                "ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb"
+            )
             # Call the utility to import the graph definition into default graph.
             graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
-            data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
-            out_node = 'ArgMax'
+            data = np.random.uniform(size=(128, 224, 224, 3)).astype("float32")
+            out_node = "ArgMax"
 
             with tf.Session() as sess:
-                tf_output = run_tf_graph(
-                    sess, data, 'input_tensor:0', out_node + ':0')
+                tf_output = run_tf_graph(sess, data, "input_tensor:0", out_node + ":0")
                 for device in ["llvm", "cuda"]:
                     ctx = tvm.context(device, 0)
                     if not tvm.testing.device_enabled(device):
                         print("Skip because %s is not enabled" % device)
                         continue
-                    tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output),
-                                               target=device)
-                    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
-                                                rtol=1e-5, atol=1e-5)
+                    tvm_output = run_tvm_graph(
+                        graph_def, data, "input_tensor", len(tf_output), target=device
+                    )
+                    tvm.testing.assert_allclose(
+                        np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5
+                    )
+
 
 #######################################################################
 # SSD
@@ -2416,33 +2859,43 @@ def test_forward_resnetv2():
 
 
 def _test_ssd_impl():
-    '''Test SSD with backbone MobileNet V1'''
+    """Test SSD with backbone MobileNet V1"""
     with tf.Graph().as_default():
         graph_def = tf_testing.get_workload(
             "object_detection/ssd_mobilenet_v1_ppn_shared_"
-            "box_predictor_300x300_coco14_sync_2018_07_03.pb")
+            "box_predictor_300x300_coco14_sync_2018_07_03.pb"
+        )
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
-        data = np.random.uniform(0.0, 255.0, size=(1, 512, 512, 3)).astype('uint8')
+        data = np.random.uniform(0.0, 255.0, size=(1, 512, 512, 3)).astype("uint8")
         in_node = "image_tensor"
-        out_node = ['detection_boxes', "detection_scores", "detection_classes"]
+        out_node = ["detection_boxes", "detection_scores", "detection_classes"]
 
         with tf.Session() as sess:
             tf_output = run_tf_graph(
-                sess, data, '{}:0'.format(in_node), ["{}:0".format(oname) for oname in out_node])
+                sess, data, "{}:0".format(in_node), ["{}:0".format(oname) for oname in out_node]
+            )
             # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready.
             for device in ["llvm"]:
                 ctx = tvm.context(device, 0)
                 if not tvm.testing.device_enabled(device):
                     print("Skip because %s is not enabled" % device)
                     continue
-                tvm_output = run_tvm_graph(graph_def, data, in_node, len(out_node),
-                                           target=device, layout="NCHW", out_names=out_node,
-                                           mode="vm", disabled_pass=["FoldScaleAxis"])
+                tvm_output = run_tvm_graph(
+                    graph_def,
+                    data,
+                    in_node,
+                    len(out_node),
+                    target=device,
+                    layout="NCHW",
+                    out_names=out_node,
+                    mode="vm",
+                    disabled_pass=["FoldScaleAxis"],
+                )
                 for i in range(len(out_node)):
-                    tvm.testing.assert_allclose(tvm_output[i], tf_output[i],
-                                                rtol=1e-3, atol=1e-3)
+                    tvm.testing.assert_allclose(tvm_output[i], tf_output[i], rtol=1e-3, atol=1e-3)
+
 
 def test_forward_ssd():
     run_thread = threading.Thread(target=_test_ssd_impl, args=())
@@ -2458,36 +2911,37 @@ def test_forward_ssd():
 
 
 def test_forward_placeholder():
-    '''test a simple pb with Placeholder node in the end of GraphDef'''
+    """test a simple pb with Placeholder node in the end of GraphDef"""
     with tf.Graph().as_default():
         graph_def = tf_testing.get_workload("Custom/placeholder.pb")
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
-        data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
-        out_node = 'mul'
+        data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+        out_node = "mul"
 
         with tf.Session() as sess:
             # Add shapes to the graph.
             graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
-            tf_output = run_tf_graph(
-                sess, data, 'Placeholder:0', out_node + ':0')
-            tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
-            tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
-                                        rtol=1e-5, atol=1e-5)
+            tf_output = run_tf_graph(sess, data, "Placeholder:0", out_node + ":0")
+            tvm_output = run_tvm_graph(graph_def, data, "Placeholder")
+            tvm.testing.assert_allclose(
+                np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5
+            )
 
 
 #######################################################################
 # PTB
 # ---
 try:
-    #Load contrib for running ptb model in tf version before 2.0
+    # Load contrib for running ptb model in tf version before 2.0
     import tensorflow.contrib
 except:
     pass
 
+
 def test_forward_ptb():
-    '''test ptb model'''
+    """test ptb model"""
     config = tf_testing.get_config()
     num_steps = config.num_steps
     num_hidden = config.hidden_size
@@ -2502,29 +2956,31 @@ def test_forward_ptb():
 
     def _pretty_print(items, is_char_model, id2word):
         if not is_char_model:
-            return ' '.join([id2word[x] for x in items])
+            return " ".join([id2word[x] for x in items])
         else:
-            return ''.join([id2word[x] for x in items]).replace('_', ' ')
+            return "".join([id2word[x] for x in items]).replace("_", " ")
 
     def _get_tvm_graph_module(graph_def):
         # Cell inputs 'c and 'h' consist of all layers values
-        shape_dict = {'Model/Placeholder': (batch_size, num_steps)}
+        shape_dict = {"Model/Placeholder": (batch_size, num_steps)}
 
         mod, params = relay.frontend.from_tensorflow(
-            graph_def, shape=shape_dict,
-            outputs=['Model/Softmax:0',
-                     'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
-                     'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
-                     'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
-                     'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6',
-                    ])
-
-        target = 'llvm'
+            graph_def,
+            shape=shape_dict,
+            outputs=[
+                "Model/Softmax:0",
+                "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1",
+                "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6",
+                "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1",
+                "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6",
+            ],
+        )
+
+        target = "llvm"
         with tvm.transform.PassContext(opt_level=0):
-            graph, lib, params = relay.build(mod,
-                                             target,
-                                             params=params)
+            graph, lib, params = relay.build(mod, target, params=params)
         from tvm.contrib import graph_runtime
+
         ctx = tvm.cpu(0)
         return params, graph_runtime.create(graph, lib, ctx)
 
@@ -2537,25 +2993,32 @@ def test_forward_ptb():
         def _get_sample(data, state):
             input_data = np.full((batch_size, num_steps), data, dtype="int32")
 
-            model.set_input('Model/Placeholder',
-                            tvm.nd.array(input_data.astype("int32")))
-            model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros',
-                            tvm.nd.array(state[0].astype("float32")))
-            model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1',
-                            tvm.nd.array(state[1].astype("float32")))
-            model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros',
-                            tvm.nd.array(state[2].astype("float32")))
-            model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1',
-                            tvm.nd.array(state[3].astype("float32")))
+            model.set_input("Model/Placeholder", tvm.nd.array(input_data.astype("int32")))
+            model.set_input(
+                "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros",
+                tvm.nd.array(state[0].astype("float32")),
+            )
+            model.set_input(
+                "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1",
+                tvm.nd.array(state[1].astype("float32")),
+            )
+            model.set_input(
+                "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros",
+                tvm.nd.array(state[2].astype("float32")),
+            )
+            model.set_input(
+                "Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1",
+                tvm.nd.array(state[3].astype("float32")),
+            )
             model.set_input(**params)
             model.run()
-            tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
-                                                          "float32")).asnumpy()
+            tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape, "float32")).asnumpy()
 
             state_output = []
             for i in range(4):
-                state_output.append(model.get_output(i+1, tvm.nd.empty(out_state_shape,
-                                                            "float32")).asnumpy())
+                state_output.append(
+                    model.get_output(i + 1, tvm.nd.empty(out_state_shape, "float32")).asnumpy()
+                )
             sample = tf_testing.pick_from_weight(tvm_output[0])
 
             return sample, state_output
@@ -2591,20 +3054,19 @@ def test_forward_ptb():
         cnt_stm += 1
         in_state = [np.full((batch_size, num_hidden), 0, dtype="float32")] * 2 * num_layers
         seed_for_sample = inpt.split()
-        tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word]
-                                                    for word in seed_for_sample],
-                                                in_state, params, cnt_sample)
+        tvm_samples, tvm_state = _do_tvm_sample(
+            m, [word_to_id[word] for word in seed_for_sample], in_state, params, cnt_sample
+        )
         tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
         tf_samples, tf_state = tf_testing.do_tf_sample(
-            sess,
-            [word_to_id[word] for word in seed_for_sample],
-            in_state, cnt_sample)
+            sess, [word_to_id[word] for word in seed_for_sample], in_state, cnt_sample
+        )
         tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
         inpt = tvm_sample_str
-        tvm.testing.assert_allclose(
-            tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
         assert tvm_sample_str == tf_sample_str
 
+
 #######################################################################
 # LRN (Local Response Normalization)
 # ----------------------------------
@@ -2617,21 +3079,18 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
     inp_array = np.random.uniform(size=ishape).astype(np.float32)
 
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=inp_array.shape,
-                             dtype=inp_array.dtype, name="lrn0_data")
-        nn_ops.local_response_normalization(in1,
-                                            name="lrn",
-                                            depth_radius=lrn_depth_radius,
-                                            bias=bias,
-                                            alpha=alpha,
-                                            beta=beta)
+        in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
+        nn_ops.local_response_normalization(
+            in1, name="lrn", depth_radius=lrn_depth_radius, bias=bias, alpha=alpha, beta=beta
+        )
 
-        compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0')
+        compare_tf_with_tvm(inp_array, "lrn0_data:0", "lrn:0")
 
 
 def test_forward_lrn():
     _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
 
+
 #######################################################################
 # l2_normalize
 # ------------
@@ -2644,18 +3103,15 @@ def _test_l2_normalize(ishape, eps, axis):
 
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
-        nn.l2_normalize(in1,
-                        axis=axis,
-                        epsilon=eps,
-                        name=None,
-                        dim=None)
+        nn.l2_normalize(in1, axis=axis, epsilon=eps, name=None, dim=None)
 
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "l2_normalize:0")
 
 
 def test_forward_l2_normalize():
     _test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
 
+
 #######################################################################
 # transpose
 # ---------
@@ -2665,23 +3121,22 @@ def _test_forward_transpose(ishape, axes=None):
     data = np.random.uniform(size=ishape).astype(np.float32)
 
     with tf.Graph().as_default():
-        in1 = tf.placeholder(
-            shape=data.shape, dtype=data.dtype, name="transpose_data")
+        in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")
 
         if axes is None:
             tf.transpose(in1)
         else:
             tf.transpose(in1, perm=axes)
 
-        compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
+        compare_tf_with_tvm(data, "transpose_data:0", "transpose:0")
+
 
 def _test_forward_tranapose_axes_input(ishape, axes):
     data = np.random.uniform(size=ishape).astype(np.float32)
     axes_np = np.array(axes).astype(np.int32)
 
     with tf.Graph().as_default():
-        in1 = tf.placeholder(
-            shape=data.shape, dtype=data.dtype, name="transpose_data")
+        in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")
 
         const1 = tf.constant(axes_np, dtype=tf.int32)
 
@@ -2690,7 +3145,8 @@ def _test_forward_tranapose_axes_input(ishape, axes):
         axes = tf.reverse(const1, axis=[-1])
         tf.transpose(in1, axes)
 
-        compare_tf_with_tvm([data], ['transpose_data:0'], 'transpose:0')
+        compare_tf_with_tvm([data], ["transpose_data:0"], "transpose:0")
+
 
 def test_forward_transpose():
     _test_forward_transpose((2, 3, 4), (1, 2, 0))
@@ -2706,17 +3162,18 @@ def test_forward_transpose():
 def _test_forward_slice_operation_input(input_value, begin_value, size_value):
     input_data = np.array(input_value, dtype=np.float32)
     with tf.Graph().as_default():
-        input_tensor = tf.placeholder(
-            shape=input_data.shape, dtype=input_data.dtype, name="input")
-        tf.slice(input_tensor, begin_value, size_value, name='slice_output')
-        compare_tf_with_tvm([input_data], ['input:0'], 'slice_output:0')
+        input_tensor = tf.placeholder(shape=input_data.shape, dtype=input_data.dtype, name="input")
+        tf.slice(input_tensor, begin_value, size_value, name="slice_output")
+        compare_tf_with_tvm([input_data], ["input:0"], "slice_output:0")
 
 
 def test_forward_slice():
     _test_forward_slice_operation_input([1, 1], [0], [2])
     _test_forward_slice_operation_input([0, 1, 2, 3], [3], [-1])
-    _test_forward_slice_operation_input([[0, 1, 2, 3], [4, 5, 6, 7]],
-                                        begin_value=[0, 1], size_value=[-1, -1])
+    _test_forward_slice_operation_input(
+        [[0, 1, 2, 3], [4, 5, 6, 7]], begin_value=[0, 1], size_value=[-1, -1]
+    )
+
 
 def test_forward_ceil():
     ishape = (1, 3, 10, 10)
@@ -2724,7 +3181,7 @@ def test_forward_ceil():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.ceil(in1)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "Ceil:0")
 
 
 def test_forward_floor():
@@ -2733,26 +3190,27 @@ def test_forward_floor():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.floor(in1)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "Floor:0")
 
 
 def test_forward_relu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
-    for mode in ['graph_runtime', 'vm']:
+    for mode in ["graph_runtime", "vm"]:
         with tf.Graph().as_default():
             in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
             tf.nn.relu(in1)
-            compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0', mode=mode)
+            compare_tf_with_tvm(inp_array, "Placeholder:0", "Relu:0", mode=mode)
+
 
 def test_forward_leaky_relu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
-    for mode in ['graph_runtime', 'vm']:
+    for mode in ["graph_runtime", "vm"]:
         with tf.Graph().as_default():
             in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
             tf.nn.leaky_relu(in1, alpha=0.4)
-            compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0', mode=mode)
+            compare_tf_with_tvm(inp_array, "Placeholder:0", "LeakyRelu:0", mode=mode)
 
 
 def test_forward_elu():
@@ -2761,7 +3219,7 @@ def test_forward_elu():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.nn.elu(in1)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "Elu:0")
 
 
 def test_forward_selu():
@@ -2770,7 +3228,7 @@ def test_forward_selu():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.nn.selu(in1)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "Selu:0")
 
 
 def test_forward_tanh():
@@ -2779,7 +3237,7 @@ def test_forward_tanh():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.nn.tanh(in1)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0')
+        compare_tf_with_tvm(inp_array, "Placeholder:0", "Tanh:0")
 
 
 #######################################################################
@@ -2787,13 +3245,15 @@ def test_forward_tanh():
 # -------
 def test_forward_softmax():
     """test operator Softmax """
+
     def check_softmax(in_shape, axis, dtype):
         np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)
         tf.reset_default_graph()
         with tf.Graph().as_default():
             in_data = tf.placeholder(dtype, in_shape, name="in_data")
             tf.nn.softmax(in_data, axis=axis, name="Softmax")
-            compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0')
+            compare_tf_with_tvm([np_data], ["in_data:0"], "Softmax:0")
+
     check_softmax((2, 3, 5), 2, "float32")
     check_softmax((2, 3, 5), -1, "float32")
 
@@ -2802,6 +3262,7 @@ def test_forward_softmax():
 # Tensor
 # ------
 
+
 def test_forward_round():
     """test Round"""
     np_data = np.random.uniform(-10, 10, size=(5, 7)).astype(np.float32)
@@ -2809,7 +3270,7 @@ def test_forward_round():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (5, 7), name="in_data")
         tf.round(in_data, name="round")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "round:0")
 
 
 def test_forward_abs():
@@ -2819,7 +3280,7 @@ def test_forward_abs():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
         tf.math.abs(in_data, name="abs")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "abs:0")
 
 
 def _test_forward_zeros_like(in_shape, dtype):
@@ -2828,11 +3289,11 @@ def _test_forward_zeros_like(in_shape, dtype):
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
         tf.zeros_like(in_data, name="zeros_like")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "zeros_like:0")
 
 
 def test_forward_zeros_like():
-    if tf.__version__ < LooseVersion('1.2'):
+    if tf.__version__ < LooseVersion("1.2"):
         _test_forward_zeros_like((2, 3), "int32")
         _test_forward_zeros_like((2, 3, 5), "int8")
         _test_forward_zeros_like((2, 3, 5, 7), "uint16")
@@ -2845,13 +3306,10 @@ def test_forward_squared_difference():
     inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
     inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=inp_array_a.shape,
-                             dtype=inp_array_a.dtype, name="in1")
-        in2 = tf.placeholder(shape=inp_array_b.shape,
-                             dtype=inp_array_b.dtype, name="in2")
+        in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1")
+        in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2")
         out = tf.math.squared_difference(in1, in2)
-        compare_tf_with_tvm([inp_array_a, inp_array_b], [
-                            in1.name, in2.name], out.name)
+        compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name)
 
 
 def _test_forward_reverse_v2(in_shape, axis, dtype):
@@ -2860,7 +3318,7 @@ def _test_forward_reverse_v2(in_shape, axis, dtype):
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, in_shape, name="in_data")
         tf.reverse(in_data, axis=[axis], name="reverse")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "reverse:0")
 
 
 def test_forward_reverse_v2():
@@ -2879,7 +3337,7 @@ def test_forward_sign():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
         tf.sign(in_data, name="sign")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "sign:0")
 
 
 def test_forward_square():
@@ -2889,7 +3347,7 @@ def test_forward_square():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
         tf.square(in_data, name="square")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "square:0")
 
 
 def test_forward_pow_exp():
@@ -2901,9 +3359,9 @@ def test_forward_pow_exp():
         in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1")
         in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2")
         out1 = tf.pow(in1, in2, name="pow")
-        out = tf.exp(in1, name='exp')
-        compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
-        compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
+        out = tf.exp(in1, name="exp")
+        compare_tf_with_tvm([np_in1, np_in2], ["in1:0", "in2:0"], "pow:0")
+        compare_tf_with_tvm([np_in1], ["in1:0"], "exp:0")
 
 
 def test_forward_unary():
@@ -2914,7 +3372,7 @@ def test_forward_unary():
         with tf.Graph().as_default():
             in_data = tf.placeholder(dtype, (2, 3, 5), name="in_data")
             out = op(in_data)
-            compare_tf_with_tvm([np_data], ['in_data:0'], out.name)
+            compare_tf_with_tvm([np_data], ["in_data:0"], out.name)
 
     _test_forward_unary(tf.acos, -1, 1)
     _test_forward_unary(tf.asin, -1, 1)
@@ -2942,18 +3400,17 @@ def test_forward_atan2():
     in_data_1 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_1")
     in_data_2 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_2")
     tf.atan2(in_data_1, in_data_2, name="atan2")
-    compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'atan2:0')
+    compare_tf_with_tvm([np_data_1, np_data_2], ["in_data_1:0", "in_data_2:0"], "atan2:0")
 
 
 def test_forward_negative():
     """test tf operator Neg """
-    np_data = np.random.uniform(-100, 255,
-                                size=(224, 224, 3)).astype(np.float32)
+    np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
     tf.reset_default_graph()
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
         tf.negative(in_data, name="negative")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "negative:0")
 
 
 def test_forward_log_softmax():
@@ -2963,7 +3420,7 @@ def test_forward_log_softmax():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
         tf.math.log_softmax(in_data, name="LogSoftmax")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "LogSoftmax:0")
 
 
 def test_forward_softplus():
@@ -2973,7 +3430,7 @@ def test_forward_softplus():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
         tf.nn.softplus(in_data, name="softplus")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "softplus:0")
 
 
 def test_forward_rsqrt():
@@ -2983,7 +3440,7 @@ def test_forward_rsqrt():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
         tf.rsqrt(in_data, name="rsqrt")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "rsqrt:0")
 
 
 def test_forward_sqrt():
@@ -2993,7 +3450,7 @@ def test_forward_sqrt():
     with tf.Graph().as_default():
         in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
         tf.sqrt(in_data, name="sqrt")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
+        compare_tf_with_tvm([np_data], ["in_data:0"], "sqrt:0")
 
 
 def _test_forward_right_shift(in_shape, dtype):
@@ -3005,13 +3462,12 @@ def _test_forward_right_shift(in_shape, dtype):
         lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
         rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
         tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
-        compare_tf_with_tvm([lh_data, rh_data], [
-                            'lft_data:0', 'rgt_data:0'], 'RightShift:0')
+        compare_tf_with_tvm([lh_data, rh_data], ["lft_data:0", "rgt_data:0"], "RightShift:0")
 
 
 def test_forward_right_shift():
-    _test_forward_right_shift((7,), 'int32')
-    _test_forward_right_shift((3, 11), 'int16')
+    _test_forward_right_shift((7,), "int32")
+    _test_forward_right_shift((3, 11), "int16")
 
 
 def _test_forward_left_shift(in_shape, dtype):
@@ -3023,13 +3479,13 @@ def _test_forward_left_shift(in_shape, dtype):
         lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
         rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
         tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
-        compare_tf_with_tvm([lh_data, rh_data], [
-                            'lft_data:0', 'rgt_data:0'], 'LeftShift:0')
+        compare_tf_with_tvm([lh_data, rh_data], ["lft_data:0", "rgt_data:0"], "LeftShift:0")
 
 
 def test_forward_left_shift():
-    _test_forward_left_shift((10,), 'int32')
-    _test_forward_left_shift((224, 224, 3), 'int16')
+    _test_forward_left_shift((10,), "int32")
+    _test_forward_left_shift((224, 224, 3), "int16")
+
 
 #######################################################################
 # Mean
@@ -3042,13 +3498,13 @@ def test_forward_mean():
         with tf.Graph().as_default():
             in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
             tf.keras.backend.mean(in1, **kwargs)
-            compare_tf_with_tvm(inp_array, 'Placeholder:0',
-                                'Mean:0', no_gpu=True)
+            compare_tf_with_tvm(inp_array, "Placeholder:0", "Mean:0", no_gpu=True)
 
     check_mean((10, 8, 16, 32))
     check_mean((10, 8, 16, 32), axis=(2, 3))
     check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True)
 
+
 #######################################################################
 # Size
 # ----
@@ -3063,18 +3519,19 @@ def test_forward_size():
         tf_input_shape[0] = None
 
         with tf.Graph().as_default():
-            input = tf.placeholder(shape=tf_input_shape,
-                                   dtype=np_input.dtype, name='input')
-            tf.size(input, name='size')
-            compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
+            input = tf.placeholder(shape=tf_input_shape, dtype=np_input.dtype, name="input")
+            tf.size(input, name="size")
+            compare_tf_with_tvm([np_input], ["input:0"], "size:0")
 
     check_size((10, 8, 16, 32))
     check_size((10,))
 
+
 #######################################################################
 # All, Any, Max, Min, Prod, variance, std, logsumexp, euclidean_norm
 # ------------------------------------------------------------------
 
+
 def test_forward_reduce():
     def _check_op(tf_op, ishape, axis, keepdims, dtype="float32"):
         tf.reset_default_graph()
@@ -3087,9 +3544,8 @@ def test_forward_reduce():
             np_data = np_data.reshape(1, -1)
         with tf.Graph().as_default():
             in_data = tf.placeholder(dtype, name="in_data")
-            reduce_op = tf_op(in_data, axis=axis,
-                               keepdims=keepdims, name="reduce_std")
-            compare_tf_with_tvm([np_data], ['in_data:0'], reduce_op.name)
+            reduce_op = tf_op(in_data, axis=axis, keepdims=keepdims, name="reduce_std")
+            compare_tf_with_tvm([np_data], ["in_data:0"], reduce_op.name)
 
     def _test_math_op(op, dtypes=["int32", "float32"]):
         for dtype in dtypes:
@@ -3106,9 +3562,10 @@ def test_forward_reduce():
     _test_math_op(tf.math.reduce_variance)
     _test_math_op(tf.math.reduce_std, dtypes=["float32"])
     _test_math_op(tf.math.reduce_logsumexp, dtypes=["float32"])
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
         _test_math_op(tf.math.reduce_euclidean_norm)
 
+
 #######################################################################
 # Relational operators
 # --------------------
@@ -3116,13 +3573,11 @@ def test_forward_reduce():
 
 def _test_forward_rel_op(data, func):
     with tf.Graph().as_default():
-        in1 = tf.placeholder(
-            shape=data[0].shape, dtype=data[0].dtype, name='in1')
-        in2 = tf.placeholder(
-            shape=data[1].shape, dtype=data[1].dtype, name='in2')
-        op = func(in1, in2, name='op')
-        out = tf.cast(op, tf.int32, name='out1')
-        compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0')
+        in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in1")
+        in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name="in2")
+        op = func(in1, in2, name="op")
+        out = tf.cast(op, tf.int32, name="out1")
+        compare_tf_with_tvm([data[0], data[1]], ["in1:0", "in2:0"], "out1:0")
 
 
 def test_forward_rel_ops():
@@ -3135,6 +3590,7 @@ def test_forward_rel_ops():
     _test_forward_rel_op([t1, t2], math_ops.equal)
     _test_forward_rel_op([t1, t2], math_ops.not_equal)
 
+
 #######################################################################
 # ExpandDims
 # ----------
@@ -3142,7 +3598,7 @@ def test_forward_rel_ops():
 
 def _test_forward_expand_dims(data, axis):
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
+        in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="in1")
         out = tf.expand_dims(in1, axis)
         compare_tf_with_tvm([data], [in1.name], out.name)
 
@@ -3161,6 +3617,7 @@ def test_forward_expand_dims():
 # ----------------
 def test_forward_maximum():
     """test Op Maximum"""
+
     def check_maximum(lh_shape, rh_shape, dtype):
         tf.reset_default_graph()
         lh_data = np.random.uniform(size=lh_shape).astype(dtype)
@@ -3169,8 +3626,7 @@ def test_forward_maximum():
             lft_data = tf.placeholder(dtype, name="lft_data")
             rgt_data = tf.placeholder(dtype, name="rgt_data")
             tf.math.maximum(lft_data, rgt_data, name="maximum")
-            compare_tf_with_tvm([lh_data, rh_data], [
-                                'lft_data:0', 'rgt_data:0'], 'maximum:0')
+            compare_tf_with_tvm([lh_data, rh_data], ["lft_data:0", "rgt_data:0"], "maximum:0")
 
     check_maximum((10, 8, 16, 32), (1,), dtype="int32")
     check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
@@ -3178,6 +3634,7 @@ def test_forward_maximum():
 
 def test_forward_minimum():
     """test Op Minimum"""
+
     def check_minimum(lh_shape, rh_shape, dtype):
         tf.reset_default_graph()
         lh_data = np.random.uniform(size=lh_shape).astype(dtype)
@@ -3186,8 +3643,7 @@ def test_forward_minimum():
             lft_data = tf.placeholder(dtype, name="lft_data")
             rgt_data = tf.placeholder(dtype, name="rgt_data")
             tf.math.minimum(lft_data, rgt_data, name="minimum")
-            compare_tf_with_tvm([lh_data, rh_data], [
-                                'lft_data:0', 'rgt_data:0'], 'minimum:0')
+            compare_tf_with_tvm([lh_data, rh_data], ["lft_data:0", "rgt_data:0"], "minimum:0")
 
     check_minimum((10, 8, 16, 32), (1,), dtype="int32")
     check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
@@ -3199,18 +3655,19 @@ def test_forward_minimum():
 def test_placeholder():
     with tf.Graph().as_default():
         in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
-        var1 = tf.Variable(in_data1, name='in1')
-        var2 = array_ops.placeholder_with_default(var1, None, name='place1')
+        var1 = tf.Variable(in_data1, name="in1")
+        var2 = array_ops.placeholder_with_default(var1, None, name="place1")
 
         in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
-        place1 = array_ops.placeholder(
-            shape=in_data1.shape, dtype=in_data1.dtype, name='in2')
+        place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name="in2")
 
-        out1 = tf.math.add(var1, var2, name='out1')
-        out2 = tf.math.add(out1, place1, name='out2')
+        out1 = tf.math.add(var1, var2, name="out1")
+        out2 = tf.math.add(out1, place1, name="out2")
+
+        compare_tf_with_tvm(
+            [in_data1, in_data2], ["place1:0", "in2:0"], "out2:0", init_global_variables=True
+        )
 
-        compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
-                            init_global_variables=True)
 
 #######################################################################
 # OneHot
@@ -3221,8 +3678,7 @@ def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_d
     inp_array1 = np.random.randint(0, 5, size=indices_shape)
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
-        out = tf.one_hot(in1, depth, on_value, off_value,
-                         axis, dtype=out_dtype)
+        out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
         compare_tf_with_tvm(inp_array1, in1.name, out.name)
 
 
@@ -3234,6 +3690,7 @@ def test_forward_one_hot():
     _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+
 #######################################################################
 # AddN
 # ----------------------
@@ -3246,8 +3703,7 @@ def _test_forward_add_n(inputs):
         for each in inputs:
             temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
         output = tf.add_n(temp)
-        compare_tf_with_tvm([each for each in inputs], [
-                            each.name for each in temp], output.name)
+        compare_tf_with_tvm([each for each in inputs], [each.name for each in temp], output.name)
 
 
 def test_forward_add_n():
@@ -3268,6 +3724,7 @@ def test_forward_add_n():
     _test_forward_add_n(in4)
     _test_forward_add_n(in5)
 
+
 #######################################################################
 # Sharing params case
 # ----------------------
@@ -3275,14 +3732,15 @@ def test_forward_add_n():
 
 def test_sharing_node():
     """Test the sharing params case."""
-    np_data = np.random.uniform(size=(2,2,2)).astype('float32')
+    np_data = np.random.uniform(size=(2, 2, 2)).astype("float32")
     with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.float32, shape=(2, 2, 2), name='in_data')
-        axis = tf.constant([-1], dtype=tf.int32, name='axis')
-        mean0 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name='mean0')
-        mean1 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name='mean1')
-        out = tf.add(mean0, mean1, name='out')
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'out:0')
+        in_data = tf.placeholder(tf.float32, shape=(2, 2, 2), name="in_data")
+        axis = tf.constant([-1], dtype=tf.int32, name="axis")
+        mean0 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name="mean0")
+        mean1 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name="mean1")
+        out = tf.add(mean0, mean1, name="out")
+        compare_tf_with_tvm([np_data], ["in_data:0"], "out:0")
+
 
 #######################################################################
 # Unravel Index
@@ -3294,8 +3752,7 @@ def _test_forward_unravel_index(inputs):
         for each in inputs:
             temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
         output = tf.unravel_index(temp[0], temp[1])
-        compare_tf_with_tvm([each for each in inputs], [
-            each.name for each in temp], output.name)
+        compare_tf_with_tvm([each for each in inputs], [each.name for each in temp], output.name)
 
 
 def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
@@ -3335,8 +3792,7 @@ def test_forward_unravel_index():
 #######################################################################
 # Dilation2d
 # ----------------------
-def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
-                     strides, dilations, padding):
+def _test_dilation2d(tensor_in_sizes, filter_in_sizes, strides, dilations, padding):
     """ One iteration of dilation2d with given shapes and attributes """
 
     total_size_1 = np.prod(tensor_in_sizes)
@@ -3347,18 +3803,17 @@ def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
     filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_filter = constant_op.constant(
-            filter_array, shape=filter_in_sizes, dtype='float32')
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
+        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
 
-        nn_ops.dilation2d(in_data,
-                          in_filter,
-                          strides=strides,
-                          rates=dilations,
-                          padding=padding)
+        nn_ops.dilation2d(in_data, in_filter, strides=strides, rates=dilations, padding=padding)
 
-        compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
-                            'Placeholder:0', 'Dilation2D:0', no_gpu=True)
+        compare_tf_with_tvm(
+            np.reshape(data_array, tensor_in_sizes).astype("float32"),
+            "Placeholder:0",
+            "Dilation2D:0",
+            no_gpu=True,
+        )
 
 
 def test_forward_dilation():
@@ -3398,7 +3853,7 @@ def _verify_infiniteness_ops(tf_op, name):
         tf.reset_default_graph()
         in_data = tf.placeholder(tf_dtype, shape, name="in_data")
         tf_op(in_data, name=name)
-        compare_tf_with_tvm([data], ['in_data:0'], '{}:0'.format(name))
+        compare_tf_with_tvm([data], ["in_data:0"], "{}:0".format(name))
 
 
 def test_forward_isinf():
@@ -3412,37 +3867,47 @@ def test_forward_isfinite():
 def _test_spop_placeholder_without_shape_info():
     with tf.Graph().as_default():
 
-        @function.Defun(*[tf.int32]*2)
-        def Forward(x,y):
+        @function.Defun(*[tf.int32] * 2)
+        def Forward(x, y):
             print(x.name)
             print(y.name)
             b = tf.add(x, y)
             return b
-        pl1 = tf.placeholder(tf.int32,name="pl1")
-        pl2 = tf.placeholder(tf.int32,name="pl2")
+
+        pl1 = tf.placeholder(tf.int32, name="pl1")
+        pl2 = tf.placeholder(tf.int32, name="pl2")
         pl3 = tf.placeholder(tf.int32, name="pl3")
         data = np.array([[-1, 1], [2, -2]], dtype=np.int32)
         data2 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
         data3 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
-        z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1,pl2], Tout=[tf.int32],f=Forward)
+        z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1, pl2], Tout=[tf.int32], f=Forward)
         z2 = z1 + pl3
-        compare_tf_with_tvm([data, data2, data3], ['pl1:0', 'pl2:0', 'pl3:0'],
-                            ['StatefulPartitionedCall:0',z2.name],  mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [data, data2, data3],
+            ["pl1:0", "pl2:0", "pl3:0"],
+            ["StatefulPartitionedCall:0", z2.name],
+            mode="vm",
+            init_global_variables=True,
+        )
 
 
 def _test_spop_placeholder_with_shape_and_default_value():
     with tf.Graph().as_default():
         data = np.ones([1], dtype=int).astype(np.int32)
         dataVar = tf.Variable(data, shape=data.shape)
-        pl1 = array_ops.placeholder_with_default(dataVar,shape=data.shape,name="pl1")
+        pl1 = array_ops.placeholder_with_default(dataVar, shape=data.shape, name="pl1")
         tpl = tf.convert_to_tensor(pl1, dtype=tf.int32)
 
         @function.Defun(*[tf.int32])
         def pl_with_default(pl):
             return tf.expand_dims(tf.multiply(pl, pl), 0)
 
-        z = gen_functional_ops.StatefulPartitionedCall(args=[tpl], Tout=[tf.int32], f=pl_with_default)
-        compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+        z = gen_functional_ops.StatefulPartitionedCall(
+            args=[tpl], Tout=[tf.int32], f=pl_with_default
+        )
+        compare_tf_with_tvm(
+            data, ["pl1:0"], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_placeholder_numpy_arange_feed():
@@ -3457,7 +3922,9 @@ def _test_spop_placeholder_numpy_arange_feed():
             return tf.add(x, y, "add_t1_t2")
 
         t3 = add(t1, t2)
-        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [t1_data, t2_data], ["t1:0", "t2:0"], [t3.name], mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_placeholder_numpy_array_feed():
@@ -3472,28 +3939,30 @@ def _test_spop_placeholder_numpy_array_feed():
             return tf.add(x, y, "add_t1_t2")
 
         t3 = add(t1, t2)
-        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [t1_data, t2_data], ["t1:0", "t2:0"], [t3.name], mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_function_invocation_basic():
     with tf.Graph().as_default():
 
         def fun1(a):
-            return tf.multiply(a,a)
+            return tf.multiply(a, a)
 
         def fun2(b):
-            return tf.multiply(b,10)
+            return tf.multiply(b, 10)
 
         @tf.function
-        def fun3(x,y):
+        def fun3(x, y):
             x = fun2(x)
             y = fun1(y)
-            z = tf.add(x,y)
+            z = tf.add(x, y)
             return z
 
         t3 = fun3(tf.constant(10.5), tf.constant(20.4))
 
-        compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)
+        compare_tf_with_tvm([], [], [t3.name], mode="vm", init_global_variables=True)
 
 
 def _test_spop_function_invocation_nested():
@@ -3511,13 +3980,15 @@ def _test_spop_function_invocation_nested():
         def myfunc2(x, y):
             z = myfunc(x, y)
             l = myfunc(z, y)
-            m = myfunc(l,z)
+            m = myfunc(l, z)
             return tf.add(l, m, "myfunc2")
 
         res1 = myfunc(t1, t2)
         res2 = myfunc2(res1, t1)
 
-        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [t1_data, t2_data], ["t1:0", "t2:0"], [res2.name], mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_function_invocation_no_autograph():
@@ -3525,58 +3996,67 @@ def _test_spop_function_invocation_no_autograph():
 
         @tf.function(autograph=False)
         def fun1(a):
-            return tf.multiply(a,a)
+            return tf.multiply(a, a)
 
         @tf.function(autograph=False)
         def fun2(b):
-            return tf.multiply(b,10)
+            return tf.multiply(b, 10)
 
         @tf.function
-        def fun3(x,y):
+        def fun3(x, y):
             x = fun2(x)
             y = fun1(y)
-            z = tf.add(x,y)
+            z = tf.add(x, y)
             return z
 
         t3 = fun3(tf.constant(10.5), tf.constant(20.4))
 
-        compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True)
+        compare_tf_with_tvm([], [], [t3.name], mode="vm", init_global_variables=True)
 
 
 def _test_spop_function_invocation_defun():
     with tf.Graph().as_default():
 
         def fun1(a):
-            return tf.multiply(a,a)
+            return tf.multiply(a, a)
 
         def fun2(b):
-            return tf.multiply(b,b)
+            return tf.multiply(b, b)
 
         @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
-        def fun3(x,y):
+        def fun3(x, y):
             x = fun2(x)
             y = fun1(y)
-            z = tf.add(x,y)
+            z = tf.add(x, y)
             return z
 
-        op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
-                                                        Tout=[dtypes.float32], f=fun3, name="SpopFnInvocation")
-        compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', init_global_variables=True)
+        op = gen_functional_ops.StatefulPartitionedCall(
+            args=[tf.constant(10.5), tf.constant(20.4)],
+            Tout=[dtypes.float32],
+            f=fun3,
+            name="SpopFnInvocation",
+        )
+        compare_tf_with_tvm([], [], "SpopFnInvocation:0", mode="vm", init_global_variables=True)
 
 
 def _test_spop_arithmetic():
     with tf.Graph().as_default():
-        @function.Defun(*[dtypes.int32]*3)
-        def arithmetic(m,x,c):
+
+        @function.Defun(*[dtypes.int32] * 3)
+        def arithmetic(m, x, c):
             z = tf.add(tf.multiply(m, x), c)
             return z
 
         m = tf.constant(10)
         x = tf.constant(20)
         c = tf.constant(2)
-        spopFn = gen_functional_ops.StatefulPartitionedCall(args=[m,x,c],Tout=[tf.int32], f=arithmetic)
+        spopFn = gen_functional_ops.StatefulPartitionedCall(
+            args=[m, x, c], Tout=[tf.int32], f=arithmetic
+        )
 
-        compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [], [], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_control_flow():
@@ -3587,17 +4067,21 @@ def _test_spop_control_flow():
             with ops.device("/job:localhost/replica:0/task:0/device:CPU:0"):
                 z = math_ops.multiply(x, y)
                 i = 0
-                while i<10 :
-                    i +=1
+                while i < 10:
+                    i += 1
                     if i == 5:
                         continue
-                    z = math_ops.multiply(x, y*i)
+                    z = math_ops.multiply(x, y * i)
             return z
 
         op = gen_functional_ops.StatefulPartitionedCall(
-            args=[constant_op.constant(32.), constant_op.constant(100.)],
-            Tout=[dtypes.float32], f=Body1)
-        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+            args=[constant_op.constant(32.0), constant_op.constant(100.0)],
+            Tout=[dtypes.float32],
+            f=Body1,
+        )
+        compare_tf_with_tvm(
+            [], [], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_variables():
@@ -3607,27 +4091,36 @@ def _test_spop_variables():
         var1 = tf.Variable(const1, dtype=tf.int32)
         var2 = tf.Variable(const2, dtype=tf.int32)
 
-        @function.Defun(tf.int32,tf.int32)
-        def Forward(x,y):
-            return tf.multiply(x,y)
+        @function.Defun(tf.int32, tf.int32)
+        def Forward(x, y):
+            return tf.multiply(x, y)
 
-        z = gen_functional_ops.StatefulPartitionedCall(args=[var1,var2],Tout=[tf.int32], f=Forward)
-        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', init_global_variables=True, mode="vm")
+        z = gen_functional_ops.StatefulPartitionedCall(
+            args=[var1, var2], Tout=[tf.int32], f=Forward
+        )
+        compare_tf_with_tvm(
+            [], [], "StatefulPartitionedCall:0", init_global_variables=True, mode="vm"
+        )
 
 
 def _test_spop_constants():
     with tf.Graph().as_default():
+
         @function.Defun(*[dtypes.int32] * 2)
         def constantsFn(x, y):
             vv = tf.constant([2, 3, 4], name="vv")
             z = tf.add(vv + x, y)
             return z
 
-        a = tf.constant(20000, name = "a")
-        b = tf.constant(40000, name = "b")
-        spopFn = gen_functional_ops.StatefulPartitionedCall(args=[a, b], Tout=[tf.int32], f=constantsFn)
+        a = tf.constant(20000, name="a")
+        b = tf.constant(40000, name="b")
+        spopFn = gen_functional_ops.StatefulPartitionedCall(
+            args=[a, b], Tout=[tf.int32], f=constantsFn
+        )
 
-        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True)
+        compare_tf_with_tvm(
+            [], [], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+        )
 
 
 def _test_spop_stateful():
@@ -3646,17 +4139,16 @@ def _test_spop_stateful():
 
         @tf.function
         def FunctionWithStatefulOp(m, n):
-            a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed = 10)
-            x = tf.multiply(a,m)
+            a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed=10)
+            x = tf.multiply(a, m)
             y = FunctionWithStatefulOp_One(n)
-            z = tf.multiply(x,y)
+            z = tf.multiply(x, y)
             return z
 
-        op = FunctionWithStatefulOp(constant_op.constant(1.), constant_op.constant(2.))
+        op = FunctionWithStatefulOp(constant_op.constant(1.0), constant_op.constant(2.0))
         with pytest.raises(Exception) as execinfo:
             compare_tf_with_tvm([], [], [op.name], init_global_variables=True, mode="vm")
-        assert execinfo.value.args[0].startswith(
-            "The following operators are not implemented")
+        assert execinfo.value.args[0].startswith("The following operators are not implemented")
 
 
 def _test_spop_device_assignment():
@@ -3669,27 +4161,29 @@ def _test_spop_device_assignment():
 
         def fun1(a):
             with ops.device("/GPU:0"):
-                return tf.multiply(a,a)
+                return tf.multiply(a, a)
 
         def fun2(b):
             with ops.device("/job:localhost/replica:0/task:0/device:CPU:1"):
-                return tf.multiply(b,b)
+                return tf.multiply(b, b)
 
         @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
-        def fun3(x,y):
+        def fun3(x, y):
             with ops.device("/CPU:0"):
                 x = fun2(x)
             with ops.device("/job:localhost/replica:0/task:0/device:CPU:2"):
                 y = fun1(y)
             with ops.device("/job:localhost/replica:0/task:0/device:CPU:3"):
-                z = tf.add(x,y)
+                z = tf.add(x, y)
                 return z
 
-        op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
-                                                        Tout=[dtypes.float32], f=fun3)
+        op = gen_functional_ops.StatefulPartitionedCall(
+            args=[tf.constant(10.5), tf.constant(20.4)], Tout=[dtypes.float32], f=fun3
+        )
         with pytest.raises(Exception) as execinfo:
-            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
-                                mode='vm', init_global_variables=True)
+            compare_tf_with_tvm(
+                [], [], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+            )
         assert execinfo.value.args[0].startswith("Found inconsistent Device assignment")
 
 
@@ -3709,36 +4203,38 @@ def _test_spop_resource_variables():
         def resourceVariablesTest(x, y):
             return tf.multiply(x, y)
 
-        op = resourceVariablesTest(var1,var2)
+        op = resourceVariablesTest(var1, var2)
         with pytest.raises(Exception) as execinfo:
-            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
-                                mode='vm', init_global_variables=True)
-        assert execinfo.value.args[0].startswith("Graph is not frozen."
-                                                 " Provide a frozen graph")
+            compare_tf_with_tvm(
+                [], [], "StatefulPartitionedCall:0", mode="vm", init_global_variables=True
+            )
+        assert execinfo.value.args[0].startswith("Graph is not frozen." " Provide a frozen graph")
+
 
 def test_forward_spop():
     _test_spop_stateful()
     _test_spop_device_assignment()
     _test_spop_resource_variables()
 
-    #Placeholder test cases
+    # Placeholder test cases
     _test_spop_placeholder_without_shape_info()
     _test_spop_placeholder_with_shape_and_default_value()
     _test_spop_placeholder_numpy_arange_feed()
     _test_spop_placeholder_numpy_array_feed()
 
-    #Function Invocation test cases
+    # Function Invocation test cases
     _test_spop_function_invocation_basic()
     _test_spop_function_invocation_nested()
     _test_spop_function_invocation_no_autograph()
     _test_spop_function_invocation_defun()
 
-    #Test cases for various other TF constructs
+    # Test cases for various other TF constructs
     _test_spop_arithmetic()
     _test_spop_control_flow()
     _test_spop_variables()
     _test_spop_constants()
 
+
 #######################################################################
 # Dynamic input shape
 # -------------------
@@ -3746,28 +4242,36 @@ def test_forward_dynamic_input_shape():
     tf.reset_default_graph()
 
     with tf.Graph().as_default():
-        data = tf.placeholder(tf.float32, name='data', shape=(None,))
+        data = tf.placeholder(tf.float32, name="data", shape=(None,))
         out = data + 1
         np_data = np.random.uniform(size=(2,)).astype("float32")
         out_name = "add"
 
         with tf.Session() as sess:
             graph_def = tf_testing.AddShapesToGraphDef(sess, out_name)
-            tf_output = run_tf_graph(sess, np_data, 'data:0', ['{}:0'.format(out_name)])
+            tf_output = run_tf_graph(sess, np_data, "data:0", ["{}:0".format(out_name)])
             # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready.
             for device in ["llvm"]:
                 ctx = tvm.context(device, 0)
                 if not tvm.testing.device_enabled(device):
                     print("Skip because %s is not enabled" % device)
                     continue
-                tvm_output = run_tvm_graph(graph_def, np_data, ["data"], 1,
-                                           target=device, layout="NCHW", out_names=[out_name],
-                                           mode="vm", ignore_in_shape=True)
-                tvm.testing.assert_allclose(tvm_output[0], tf_output[0],
-                                            rtol=1e-5, atol=1e-5)
+                tvm_output = run_tvm_graph(
+                    graph_def,
+                    np_data,
+                    ["data"],
+                    1,
+                    target=device,
+                    layout="NCHW",
+                    out_names=[out_name],
+                    mode="vm",
+                    ignore_in_shape=True,
+                )
+                tvm.testing.assert_allclose(tvm_output[0], tf_output[0], rtol=1e-5, atol=1e-5)
+
 
 def test_forward_dynmaic_rnn_lstmblockcell():
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.0.0"):
         return
 
     total_series_length = 50000
@@ -3793,16 +4297,24 @@ def test_forward_dynmaic_rnn_lstmblockcell():
 
     state_per_layer_list = tf.unstack(init_state, axis=0)
     rnn_tuple_state = tuple(
-        [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
-         for idx in range(num_layers)]
+        [
+            tf.nn.rnn_cell.LSTMStateTuple(
+                state_per_layer_list[idx][0], state_per_layer_list[idx][1]
+            )
+            for idx in range(num_layers)
+        ]
     )
 
     # Forward passes
     def lstm_cell():
         return tensorflow.contrib.rnn.LSTMBlockCell(state_size)
-    cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)], state_is_tuple=True)
-    states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1),
-                                                     initial_state=rnn_tuple_state)
+
+    cell = tf.nn.rnn_cell.MultiRNNCell(
+        [lstm_cell() for _ in range(num_layers)], state_is_tuple=True
+    )
+    states_series, current_state = tf.nn.dynamic_rnn(
+        cell, tf.expand_dims(batchX_placeholder, -1), initial_state=rnn_tuple_state
+    )
 
     with tf.Session() as sess:
         sess.run(tf.global_variables_initializer())
@@ -3819,10 +4331,8 @@ def test_forward_dynmaic_rnn_lstmblockcell():
 
         _current_state, _states_series = sess.run(
             [current_state, states_series],
-            feed_dict={
-               batchX_placeholder: batchX,
-               init_state: _current_state
-            })
+            feed_dict={batchX_placeholder: batchX, init_state: _current_state},
+        )
 
         # Organize results and corresponding names
         tf_output = [_states_series]
@@ -3831,29 +4341,30 @@ def test_forward_dynmaic_rnn_lstmblockcell():
             tf_output.append(c.c)
             tf_output.append(c.h)
 
-        name = [states_series.name.split(':')[0]]
+        name = [states_series.name.split(":")[0]]
 
         for t in current_state:
-            name.append(t.c.name.split(':')[0])
-            name.append(t.h.name.split(':')[0])
+            name.append(t.c.name.split(":")[0])
+            name.append(t.h.name.split(":")[0])
 
         graph_def = sess.graph.as_graph_def(add_shapes=True)
 
-        final_graph_def = graph_util.convert_variables_to_constants(
-            sess,
-            graph_def,
-            name)
+        final_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, name)
 
-        tvm_output = run_tvm_graph(final_graph_def,
-                      [batchX.astype('float32'), current_state_tvm.astype('float32')],
-                      ["Placeholder", "Placeholder_1"], out_names=name,
-                      num_output=len(name), mode='vm', disabled_pass=["FoldScaleAxis"])
+        tvm_output = run_tvm_graph(
+            final_graph_def,
+            [batchX.astype("float32"), current_state_tvm.astype("float32")],
+            ["Placeholder", "Placeholder_1"],
+            out_names=name,
+            num_output=len(name),
+            mode="vm",
+            disabled_pass=["FoldScaleAxis"],
+        )
 
         # Compare result
         for i in range(len(tf_output)):
-            tvm.testing.assert_allclose(
-                tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+            tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     pytest.main([__file__])
index 3178863..a84e254 100644 (file)
@@ -23,11 +23,13 @@ import numpy as np
 from tvm import relay
 from tvm.relay.frontend.tensorflow import from_tensorflow
 
+
 def run_relay(graph):
     mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
-    ex = relay.create_executor('debug', mod=mod)
+    ex = relay.create_executor("debug", mod=mod)
     return ex.evaluate()(**params)
 
+
 def test_no_op():
     g = tf.Graph()
     with g.as_default():
@@ -43,4 +45,3 @@ def test_no_op():
 
 if __name__ == "__main__":
     test_no_op()
-
index 8ce858e..e706f2a 100644 (file)
@@ -27,6 +27,6 @@ def test_key_is_not_present():
     assert not attrs.has_attr("b")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_key_is_present()
     test_key_is_present()
index 89296a6..e8b225c 100644 (file)
@@ -27,8 +27,10 @@ import numpy as np
 import tvm
 from tvm import te
 from tvm import relay
+
 try:
     import tensorflow.compat.v1 as tf
+
     # tensorflow.python.framework.ops module itself is not part of
     # TensorFlow's public API: the precise contents of that module
     # may vary from one version to the next
@@ -44,6 +46,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import nn_impl
 from tensorflow.python.ops import variables
+
 try:
     from tensorflow import lite as interpreter_wrapper
 except ImportError:
@@ -69,44 +72,44 @@ def convert_to_list(x):
 # Get a real image for e2e testing
 # --------------------------------
 def get_real_image(im_height, im_width):
-    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-    img_name = 'elephant-299.jpg'
+    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+    img_name = "elephant-299.jpg"
     image_url = os.path.join(repo_base, img_name)
-    img_path = download_testdata(image_url, img_name, module='data')
+    img_path = download_testdata(image_url, img_name, module="data")
     image = Image.open(img_path).resize((im_height, im_width))
-    x = np.array(image).astype('uint8')
+    x = np.array(image).astype("uint8")
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
 
 def pre_processed_image(height, width):
-    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-    img_name = 'elephant-299.jpg'
+    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+    img_name = "elephant-299.jpg"
     image_url = os.path.join(repo_base, img_name)
-    img_path = download_testdata(image_url, img_name, module='data')
+    img_path = download_testdata(image_url, img_name, module="data")
     image = tf.io.read_file(img_path)
     image = tf.image.decode_jpeg(image, channels=3)
-    with tf.name_scope('eval_image'):
+    with tf.name_scope("eval_image"):
         if image.dtype != tf.float32:
             image = tf.image.convert_image_dtype(image, dtype=tf.float32)
         image = tf.image.central_crop(image, central_fraction=0.875)
     # Resize the image to the specified height and width.
-    image = tf.image.resize(image, [height, width],
-                            align_corners=False)
+    image = tf.image.resize(image, [height, width], align_corners=False)
     image = tf.expand_dims(image, axis=0)
     return image
 
 
 def get_real_image_object_detection(im_height, im_width):
-    repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/'
-    img_name = 'street_small.jpg'
+    repo_base = "https://github.com/dmlc/web-data/raw/master/gluoncv/detection/"
+    img_name = "street_small.jpg"
     image_url = os.path.join(repo_base, img_name)
-    img_path = download_testdata(image_url, img_name, module='data')
+    img_path = download_testdata(image_url, img_name, module="data")
     image = Image.open(img_path).resize((im_height, im_width))
-    x = np.array(image).astype('uint8')
+    x = np.array(image).astype("uint8")
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+
 def vmobj_to_list(o):
     if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy().tolist()]
@@ -116,20 +119,19 @@ def vmobj_to_list(o):
             result.extend(vmobj_to_list(f))
         return result
     elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
-        if o.constructor.name_hint == 'Cons':
+        if o.constructor.name_hint == "Cons":
             tl = vmobj_to_list(o.fields[1])
             hd = vmobj_to_list(o.fields[0])
             hd.extend(tl)
             return hd
-        elif o.constructor.name_hint == 'Nil':
+        elif o.constructor.name_hint == "Nil":
             return []
-        elif 'tensor_nil' in o.constructor.name_hint:
+        elif "tensor_nil" in o.constructor.name_hint:
             return [0]
-        elif 'tensor' in o.constructor.name_hint:
+        elif "tensor" in o.constructor.name_hint:
             return [o.fields[0].asnumpy()]
         else:
-            raise RuntimeError("Unknown object type: %s" %
-                               o.constructor.name_hint)
+            raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
@@ -145,15 +147,24 @@ def _quantize_keras_model(keras_model, representative_data_gen):
     return converter.convert()
 
 
-def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
-                  out_names=None, mode='graph_runtime'):
+def run_tvm_graph(
+    tflite_model_buf,
+    input_data,
+    input_node,
+    num_output=1,
+    target="llvm",
+    out_names=None,
+    mode="graph_runtime",
+):
     """ Generic function to compile on relay and execute on tvm """
     # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
     try:
         import tflite.Model
+
         tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
     except AttributeError:
         import tflite
+
         tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
     except ImportError:
         raise ImportError("The tflite package must be installed")
@@ -167,14 +178,14 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
         shape_dict[e] = input_data[i].shape
         dtype_dict[e] = input_data[i].dtype.name
 
-    mod, params = relay.frontend.from_tflite(tflite_model,
-                                             shape_dict=shape_dict,
-                                             dtype_dict=dtype_dict)
+    mod, params = relay.frontend.from_tflite(
+        tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+    )
 
-    if mode in ['debug', 'vm']:
+    if mode in ["debug", "vm"]:
         ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
         inputs = []
-        for param in mod['main'].params:
+        for param in mod["main"].params:
             found = False
             for i, n in enumerate(input_node):
                 if n == param.name_hint:
@@ -192,6 +203,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
 
         ctx = tvm.context(target, 0)
         from tvm.contrib import graph_runtime
+
         m = graph_runtime.create(graph, lib, ctx)
         # set inputs
         for i, e in enumerate(input_node):
@@ -201,8 +213,9 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
         # execute
         m.run()
         # get outputs
-        assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
-            out_names, num_output)
+        assert out_names is None or num_output == len(
+            out_names
+        ), "out_names: {} num_output: {}".format(out_names, num_output)
         tvm_output_list = []
         for i in range(0, num_output):
             tvm_output = m.get_output(i)
@@ -219,13 +232,13 @@ def run_tflite_graph(tflite_model_buf, input_data):
     output_details = interpreter.get_output_details()
 
     for i in range(len(input_details)):
-        interpreter.resize_tensor_input(input_details[i]['index'], input_data[i].shape)
+        interpreter.resize_tensor_input(input_details[i]["index"], input_data[i].shape)
     interpreter.allocate_tensors()
 
     # set input
     assert len(input_data) == len(input_details)
     for i in range(len(input_details)):
-        interpreter.set_tensor(input_details[i]['index'], input_data[i])
+        interpreter.set_tensor(input_details[i]["index"], input_data[i])
 
     # Run
     interpreter.invoke()
@@ -233,30 +246,37 @@ def run_tflite_graph(tflite_model_buf, input_data):
     # get output
     tflite_output = list()
     for i in range(len(output_details)):
-        tflite_output.append(interpreter.get_tensor(output_details[i]['index']))
+        tflite_output.append(interpreter.get_tensor(output_details[i]["index"]))
 
     return tflite_output
 
 
-def compare_tflite_with_tvm(in_data, in_name, input_tensors,
-                            output_tensors, init_global_variables=False,
-                            out_names=None, quantized=False, input_range=None,
-                            mode='graph_runtime', experimental_new_converter=False):
+def compare_tflite_with_tvm(
+    in_data,
+    in_name,
+    input_tensors,
+    output_tensors,
+    init_global_variables=False,
+    out_names=None,
+    quantized=False,
+    input_range=None,
+    mode="graph_runtime",
+    experimental_new_converter=False,
+):
     """Generic function to generate and compare TFLite and TVM output"""
     in_data = convert_to_list(in_data)
     in_name = convert_to_list(in_name)
     out_names = convert_to_list(out_names)
     in_node = [0] * len(in_name)
     for i in range(len(in_name)):
-        in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
+        in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i]
 
     with tf.Session() as sess:
         if init_global_variables:
             sess.run(variables.global_variables_initializer())
         # convert to tflite model
-        converter = tf.lite.TFLiteConverter.from_session(
-            sess, input_tensors, output_tensors)
-        converter.experimental_new_converter=experimental_new_converter
+        converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
+        converter.experimental_new_converter = experimental_new_converter
         if quantized:
             converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
             input_arrays = converter.get_input_arrays()
@@ -268,8 +288,10 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                 try:
                     quant_scale = 255 / (input_range[i][1] - input_range[i][0])
                 except ZeroDivisionError:
-                    raise ZeroDivisionError('Min and max of the input range for tensor ' + i + ' can\'t be equal')
-                mean = - input_range[i][0] * quant_scale
+                    raise ZeroDivisionError(
+                        "Min and max of the input range for tensor " + i + " can't be equal"
+                    )
+                mean = -input_range[i][0] * quant_scale
                 input_stats[i] = (mean, quant_scale)
             converter.quantized_input_stats = input_stats
 
@@ -282,8 +304,15 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                 print("Skip because %s is not enabled" % device)
                 continue
 
-            tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
-                                       num_output=len(out_names), out_names=out_names, mode=mode)
+            tvm_output = run_tvm_graph(
+                tflite_model_buffer,
+                in_data,
+                in_node,
+                target=device,
+                num_output=len(out_names),
+                out_names=out_names,
+                mode=mode,
+            )
 
             # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
             # range for the specific operator. While adding test ensure that we aren't getting only clipped values
@@ -294,7 +323,9 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                     tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1, rtol=1e-5)
             else:
                 for i in range(len(tflite_output)):
-                    tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                    tvm.testing.assert_allclose(
+                        tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5
+                    )
 
 
 def with_fused_activation_function(input_tensor, fn_name):
@@ -313,77 +344,81 @@ def with_fused_activation_function(input_tensor, fn_name):
 
 def _test_split(in_shape, axis, num_splits, dtype):
     """internal split tester taking as parameters in_shape, number of tensors to split into
-       and dtype (data type)"""
+    and dtype (data type)"""
 
     np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data")
         out = array_ops.split(in_data, num_splits, axis=axis)
-        num_splits = len(num_splits) if isinstance(num_splits, list) \
-            else num_splits
-        out_names = ['out_' + str(n) + ':0' for n in range(num_splits)]
-        compare_tflite_with_tvm([np_data], ['in_data'],  [in_data], out,
-                                out_names=out_names)
+        num_splits = len(num_splits) if isinstance(num_splits, list) else num_splits
+        out_names = ["out_" + str(n) + ":0" for n in range(num_splits)]
+        compare_tflite_with_tvm([np_data], ["in_data"], [in_data], out, out_names=out_names)
+
 
 def test_forward_split():
-    '''test split layer'''
+    """test split layer"""
     # rank 1
-    _test_split((3,), 0, 1, 'float32')
-    _test_split((3,), 0, 3, 'float32')
-    _test_split((6,), 0, 3, 'float32')
+    _test_split((3,), 0, 1, "float32")
+    _test_split((3,), 0, 3, "float32")
+    _test_split((6,), 0, 3, "float32")
     # rank 2
-    _test_split((6, 2), 0, 3, 'float32')
-    _test_split((2, 6), 1, 6, 'float32')
+    _test_split((6, 2), 0, 3, "float32")
+    _test_split((2, 6), 1, 6, "float32")
     # rank 3
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
-        _test_split((6, 2, 4), 0, 2, 'int32')
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
+        _test_split((6, 2, 4), 0, 2, "int32")
 
-    _test_split((2, 6, 4), 1, 3, 'float32')
-    _test_split((2, 4, 6), 2, 1, 'float32')
+    _test_split((2, 6, 4), 1, 3, "float32")
+    _test_split((2, 4, 6), 2, 1, "float32")
     # rank 4
-    _test_split((6, 1, 3, 5), 0, 3, 'float32')
-    _test_split((1, 6, 3, 5), 1, 3, 'float32')
-    _test_split((1, 3, 6, 5), 2, 3, 'float32')
-    _test_split((1, 3, 5, 6), 3, 3, 'float32')
+    _test_split((6, 1, 3, 5), 0, 3, "float32")
+    _test_split((1, 6, 3, 5), 1, 3, "float32")
+    _test_split((1, 3, 6, 5), 2, 3, "float32")
+    _test_split((1, 3, 5, 6), 3, 3, "float32")
     # split along negative axis
-    _test_split((6, 1, 3, 5), -4, 3, 'float32')
-    _test_split((1, 6, 3, 5), -3, 3, 'float32')
-    _test_split((1, 3, 6, 5), -2, 3, 'float32')
-    _test_split((1, 3, 5, 6), -1, 3, 'float32')
+    _test_split((6, 1, 3, 5), -4, 3, "float32")
+    _test_split((1, 6, 3, 5), -3, 3, "float32")
+    _test_split((1, 3, 6, 5), -2, 3, "float32")
+    _test_split((1, 3, 5, 6), -1, 3, "float32")
     # size_splits split
-    _test_split((6,), 0, [1, 2, 3], 'float32')
-    _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
+    _test_split((6,), 0, [1, 2, 3], "float32")
+    _test_split((3, 6, 4), -2, [1, 4, 1], "float32")
+
 
 #######################################################################
 # slice
 # -----
 
+
 def _test_slice(data, begin, size):
     """ One iteration of SLICE """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = array_ops.slice(in_data, begin, size)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_slice():
     """ SLICE """
-    _test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2])
+    _test_slice(np.arange(4, dtype=np.float32).reshape((4,)), begin=[0], size=[2])
     _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
     # tflite 1.13 outputs nonsense values if size[i] == -1
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
-        _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
+        _test_slice(np.arange(5, dtype=np.int32).reshape((5,)), begin=[4], size=[-1])
+
 
 #######################################################################
 # Topk
 # ----
 def _test_topk(in_shape, k=1):
     """ One iteration of TOPK """
-    data = np.random.uniform(size=in_shape).astype('float32')
+    data = np.random.uniform(size=in_shape).astype("float32")
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        out = nn_ops.top_k(in_data, k, name='TopK')
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])
+        out = nn_ops.top_k(in_data, k, name="TopK")
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out[0]])
+
 
 def test_forward_topk():
     """ TOPK """
@@ -392,19 +427,23 @@ def test_forward_topk():
     _test_topk((3, 5, 7), 3)
     _test_topk((3, 5, 7), 3)
 
+
 #######################################################################
 # Gather
 # ------
 
+
 def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False):
     """ One iteration of Gather """
-    indices = np.asarray(indices).astype('int32')
+    indices = np.asarray(indices).astype("int32")
     data = np.random.uniform(1, 10, size=dshape)
     data = data.astype(np.uint8) if quantized else data.astype(dtype)
     with tf.Graph().as_default():
         if wrap_idx:
             in_name = "in_indices"
-            indices_expr = array_ops.placeholder(shape=indices.shape, dtype=indices.dtype, name=in_name)
+            indices_expr = array_ops.placeholder(
+                shape=indices.shape, dtype=indices.dtype, name=in_name
+            )
             in_tensor_name = [in_name + ":0"]
             in_indices = [indices_expr]
         else:
@@ -417,108 +456,140 @@ def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_
         if axis:
             out = array_ops.gather(in_data, indices_expr, axis=axis)
         else:
-            out = array_ops.gather(in_data, indices_expr) #tflite conversion fails for None axis
-        input_range = {'in_data': (-100, 100)} if quantized else None
+            out = array_ops.gather(in_data, indices_expr)  # tflite conversion fails for None axis
+        input_range = {"in_data": (-100, 100)} if quantized else None
         try:
-            compare_tflite_with_tvm([data] + indices, ['in_data:0'] + in_tensor_name, [in_data] + in_indices, [out],
-                                      quantized=quantized, input_range=input_range)
+            compare_tflite_with_tvm(
+                [data] + indices,
+                ["in_data:0"] + in_tensor_name,
+                [in_data] + in_indices,
+                [out],
+                quantized=quantized,
+                input_range=input_range,
+            )
         except ValueError as e:
             if not oob:
                 raise e
         except Exception as e:
             raise e
 
+
 def test_forward_gather():
     """ GATHER """
     for quantized in [False, True]:
         for wrap_idx in [False, True]:
-            _test_gather((4,), [1], 0, 'float32', quantized, wrap_idx)
-            _test_gather((4,), [1], None, 'int32', quantized, wrap_idx)
-            _test_gather((1, 4), [0], 0, 'int32', quantized, wrap_idx)
-            _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized, wrap_idx)
-            _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized, wrap_idx)
-            _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized, wrap_idx)
-            _test_gather((3, 3, 3),  [[[1, 0]]], 0, 'int32', quantized, wrap_idx)
-            _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized, wrap_idx)
-            _test_gather((4, 3, 5, 6),  [[2, 1, 0, 0]], 0, 'float32', quantized, wrap_idx)
-            _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized, wrap_idx)
+            _test_gather((4,), [1], 0, "float32", quantized, wrap_idx)
+            _test_gather((4,), [1], None, "int32", quantized, wrap_idx)
+            _test_gather((1, 4), [0], 0, "int32", quantized, wrap_idx)
+            _test_gather((4,), [[[1, 0], [0, 1]]], 0, "float32", quantized, wrap_idx)
+            _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, "int32", quantized, wrap_idx)
+            _test_gather((2, 2), [[[1, 0], [0, 1]]], None, "float32", quantized, wrap_idx)
+            _test_gather((3, 3, 3), [[[1, 0]]], 0, "int32", quantized, wrap_idx)
+            _test_gather((3, 3, 3), [[[1, 0]]], 2, "int32", quantized, wrap_idx)
+            _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32", quantized, wrap_idx)
+            _test_gather((3, 3, 3), [[[2, 1]]], -1, "int32", quantized, wrap_idx)
         # Out of boundary error cannot be tested with wrapped index
-        _test_gather((4,), [16], 0, 'float32', quantized, oob=True)
-        _test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
-        _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
-        _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
+        _test_gather((4,), [16], 0, "float32", quantized, oob=True)
+        _test_gather((1, 3, 3), [12], 0, "int32", quantized, oob=True)
+        _test_gather((1, 3, 3), [20], 1, "float32", quantized, oob=True)
+        _test_gather((1, 3, 3), [20, 20], 2, "float32", quantized, oob=True)
+
 
 #######################################################################
 # Gather_ND
 # ---------
 
+
 def _test_gather_nd(data, indices):
     """ One iteration of GATHER_ND """
     with tf.Graph().as_default():
         in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
-        indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
-                                        name="indices")
+        indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, name="indices")
         out = tf.gather_nd(in_data, indices_data)
 
-        compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
-                                  [in_data, indices_data], [out])
+        compare_tflite_with_tvm(
+            [data, indices], ["data:0", "indices:0"], [in_data, indices_data], [out]
+        )
+
 
 def test_forward_gather_nd():
     """ GATHER_ND """
     _test_gather_nd(
-        np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'),
-        np.asarray([[0, 1], [1, 0]]).astype('int32')
+        np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype("float32"),
+        np.asarray([[0, 1], [1, 0]]).astype("int32"),
     )
     _test_gather_nd(
-        np.reshape(np.arange(30), [5, 6]).astype('int32'),
-        np.asarray([[1, 2]]).astype('int32')
+        np.reshape(np.arange(30), [5, 6]).astype("int32"), np.asarray([[1, 2]]).astype("int32")
     )
     _test_gather_nd(
-        np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
-        np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
+        np.reshape(np.arange(12), [2, 3, 2]).astype("int32"),
+        np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype("int32"),
     )
     _test_gather_nd(
-        np.reshape(np.arange(4), [4]).astype('float32'),
-        np.asarray([1]).astype('int32')
+        np.reshape(np.arange(4), [4]).astype("float32"), np.asarray([1]).astype("int32")
     )
     _test_gather_nd(
-        np.reshape(np.arange(4), [1, 4]).astype('float32'),
-        np.asarray([0]).astype('int32')
+        np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0]).astype("int32")
     )
     _test_gather_nd(
-        np.reshape(np.arange(4), [1, 4]).astype('float32'),
-        np.asarray([0, 3]).astype('int32')
+        np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0, 3]).astype("int32")
     )
 
+
 #######################################################################
 # StridedSlice
 # ------------
 
-def _test_stridedslice(ip_shape, begin, end, stride, dtype,
-                       begin_mask=0, end_mask=0, new_axis_mask=0,
-                       shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
+
+def _test_stridedslice(
+    ip_shape,
+    begin,
+    end,
+    stride,
+    dtype,
+    begin_mask=0,
+    end_mask=0,
+    new_axis_mask=0,
+    shrink_axis_mask=0,
+    ellipsis_mask=0,
+    quantized=False,
+):
     """ One iteration of a Stridedslice """
     data = np.random.uniform(size=ip_shape).astype(dtype)
     data = data.astype(np.uint8) if quantized else data.astype(dtype)
     with tf.Graph().as_default():
         in_data = tf.placeholder(dtype, ip_shape, name="in_data")
-        out = array_ops.strided_slice(in_data, begin, end, stride,
-                                      begin_mask=begin_mask,
-                                      end_mask=end_mask,
-                                      new_axis_mask=new_axis_mask,
-                                      shrink_axis_mask=shrink_axis_mask,
-                                      ellipsis_mask=ellipsis_mask)
-        input_range = {'in_data': (-100, 100)} if quantized else None
-        compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
-                                  input_range=input_range)
+        out = array_ops.strided_slice(
+            in_data,
+            begin,
+            end,
+            stride,
+            begin_mask=begin_mask,
+            end_mask=end_mask,
+            new_axis_mask=new_axis_mask,
+            shrink_axis_mask=shrink_axis_mask,
+            ellipsis_mask=ellipsis_mask,
+        )
+        input_range = {"in_data": (-100, 100)} if quantized else None
+        compare_tflite_with_tvm(
+            [data], ["in_data:0"], [in_data], [out], quantized=quantized, input_range=input_range
+        )
+
 
 def test_forward_stridedslice():
-    '''test StridedSlice'''
+    """test StridedSlice"""
     for quantized in [False, True]:
-        _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
-        _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
-        _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
-        _test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)
+        _test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1, quantized=quantized)
+        _test_stridedslice(
+            (3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32", quantized=quantized
+        )
+        _test_stridedslice(
+            (3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=0, quantized=quantized
+        )
+        _test_stridedslice(
+            (4, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2, quantized=quantized
+        )
+
 
 #######################################################################
 # transpose
@@ -536,7 +607,7 @@ def _test_forward_transpose(ishape, axes=()):
         else:
             out = array_ops.transpose(in_data, axes)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_transpose():
@@ -548,16 +619,18 @@ def test_forward_transpose():
     _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
     _test_forward_transpose((2, 3, 4, 5), ())
 
+
 #######################################################################
 # Cast
 # ----
 
+
 def _test_cast(data, cast_dtype):
     """ One iteration of CAST """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = math_ops.cast(in_data, cast_dtype)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_cast():
@@ -566,15 +639,15 @@ def test_forward_cast():
     _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
     _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
 
+
 #######################################################################
 # Batch Mat Mul
 # ----
 def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
     with tf.Graph().as_default():
-        A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A')
-        B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B')
-        result = math_ops.matmul(A, B, adjoint_a=adjoint_a,
-                           adjoint_b=adjoint_b, name='batchmatmul')
+        A = array_ops.placeholder(shape=A_shape, dtype=dtype, name="A")
+        B = array_ops.placeholder(shape=B_shape, dtype=dtype, name="B")
+        result = math_ops.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")
 
         A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
         B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
@@ -583,11 +656,12 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
 
 def test_forward_batch_matmul():
     """ BATCH_MAT_MUL """
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32')
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
-    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32')
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
+    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
+
 
 #######################################################################
 # Tile
@@ -602,19 +676,20 @@ def _test_forward_tile(in_shape, reps, dtype):
 
         out = array_ops.tile(in_data, reps)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_tile():
-    _test_forward_tile((2, ), (3, ), "int32")
+    _test_forward_tile((2,), (3,), "int32")
     _test_forward_tile((2, 2), (2, 3), "float32")
 
+
 ######################################################################
 # BatchToSpaceND
 # --------------
 
 
-def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
+def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype="int32"):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
     with tf.Graph().as_default():
@@ -622,35 +697,24 @@ def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
 
         out = array_ops.batch_to_space_nd(in_data, block_shape, crops)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_batch_to_space_nd():
     # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
-    _test_batch_to_space_nd(
-        input_shape=[4, 1, 1, 1],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
+    _test_batch_to_space_nd(input_shape=[4, 1, 1, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
 
-    _test_batch_to_space_nd(
-        input_shape=[4, 1, 1, 3],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
+    _test_batch_to_space_nd(input_shape=[4, 1, 1, 3], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
+
+    _test_batch_to_space_nd(input_shape=[4, 2, 2, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
 
-    _test_batch_to_space_nd(
-        input_shape=[4, 2, 2, 1],
-        block_shape=[2, 2],
-        crops=[[0, 0], [0, 0]]
-    )
 
 ######################################################################
 # SpaceToBatchND
 # --------------
 
 
-def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
+def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype="int32"):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
     with tf.Graph().as_default():
@@ -658,34 +722,19 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
 
         out = array_ops.space_to_batch_nd(in_data, block_shape, paddings)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_space_to_batch_nd():
     # test cases: https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
-    _test_space_to_batch_nd(
-        input_shape=[1, 2, 2, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 2, 2, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
 
-    _test_space_to_batch_nd(
-        input_shape=[1, 2, 2, 3],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 2, 2, 3], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
 
-    _test_space_to_batch_nd(
-        input_shape=[1, 4, 4, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [0, 0]]
-    )
+    _test_space_to_batch_nd(input_shape=[1, 4, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
+
+    _test_space_to_batch_nd(input_shape=[2, 2, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [2, 0]])
 
-    _test_space_to_batch_nd(
-        input_shape=[2, 2, 4, 1],
-        block_shape=[2, 2],
-        paddings=[[0, 0], [2, 0]]
-    )
 
 #######################################################################
 # Pooling
@@ -693,14 +742,13 @@ def test_forward_space_to_batch_nd():
 def _test_pooling_iteration(input_shape, **kwargs):
     """ One iteration of pool operation with given shapes and attributes """
 
-    x = -np.arange(
-        np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
+    x = -np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
+        in_data = array_ops.placeholder(shape=input_shape, dtype="float32")
         out = nn_ops.pool(in_data, **kwargs)
 
-        compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(x, "Placeholder:0", [in_data], [out])
 
 
 def _test_pooling(input_shape, **kwargs):
@@ -710,59 +758,72 @@ def _test_pooling(input_shape, **kwargs):
 def test_forward_pooling():
     """ Pooling """
 
-    for pool_type in ['AVG', 'MAX']:
-        _test_pooling(input_shape=[2, 9, 10, 2],
-                      window_shape=[1, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 10, 9, 2],
-                      window_shape=[1, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 9, 10, 2],
-                      window_shape=[2, 1],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[1, 1])
-
-        _test_pooling(input_shape=[2, 10, 9, 2],
-                      window_shape=[2, 3],
-                      padding='SAME',
-                      pooling_type=pool_type,
-                      dilation_rate=[1, 1],
-                      strides=[2, 1])
+    for pool_type in ["AVG", "MAX"]:
+        _test_pooling(
+            input_shape=[2, 9, 10, 2],
+            window_shape=[1, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 10, 9, 2],
+            window_shape=[1, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 9, 10, 2],
+            window_shape=[2, 1],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
+
+        _test_pooling(
+            input_shape=[2, 10, 9, 2],
+            window_shape=[2, 3],
+            padding="SAME",
+            pooling_type=pool_type,
+            dilation_rate=[1, 1],
+            strides=[2, 1],
+        )
 
 
 def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None):
     x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
 
     with tf.Graph().as_default():
-        in_data = tf.placeholder(
-            dtype=tf.float32, name="input", shape=input_shape)
-        out = tf.sqrt(tf.nn.avg_pool(
-            tf.square(in_data), ksize=ksize, strides=strides,
-            padding=padding, data_format=data_format))
+        in_data = tf.placeholder(dtype=tf.float32, name="input", shape=input_shape)
+        out = tf.sqrt(
+            tf.nn.avg_pool(
+                tf.square(in_data),
+                ksize=ksize,
+                strides=strides,
+                padding=padding,
+                data_format=data_format,
+            )
+        )
         out = with_fused_activation_function(out, fused_func_name)
 
-        compare_tflite_with_tvm(x, 'input', [in_data], [out])
+        compare_tflite_with_tvm(x, "input", [in_data], [out])
 
 
 def test_forward_l2_pool2d():
-    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'SAME', "NHWC")
-    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC", "RELU")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC")
-    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'VALID', "NHWC", "RELU6")
+    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "SAME", "NHWC")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "SAME", "NHWC")
+    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC", "RELU")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
+    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "VALID", "NHWC", "RELU6")
 
 
 #######################################################################
@@ -770,21 +831,24 @@ def test_forward_l2_pool2d():
 # -----------
 
 
-def _test_tflite2_quantized_convolution(input_shape, kernel_shape,
-        dilations, strides, padding, data_format):
+def _test_tflite2_quantized_convolution(
+    input_shape, kernel_shape, dilations, strides, padding, data_format
+):
     """ One iteration of TFLite2 quantized convolution with given shapes and attributes """
     data_format = "channels_last" if "NHWC" else "channels_first"
-    data = np.random.uniform(0, 1, input_shape).astype('float32')
-    kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+    data = np.random.uniform(0, 1, input_shape).astype("float32")
+    kernel = np.random.uniform(0, 1, kernel_shape).astype("float32")
 
     data_in = tf.keras.layers.Input(shape=data.shape[1:])
-    conv = tf.keras.layers.Conv2D(filters=kernel_shape[3],
-                                  kernel_size=(kernel_shape[0], kernel_shape[1]),
-                                  strides=strides,
-                                  padding=padding,
-                                  data_format=data_format,
-                                  activation='relu',
-                                  use_bias=False)(data_in)
+    conv = tf.keras.layers.Conv2D(
+        filters=kernel_shape[3],
+        kernel_size=(kernel_shape[0], kernel_shape[1]),
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        activation="relu",
+        use_bias=False,
+    )(data_in)
     keras_model = tf.keras.models.Model(data_in, conv)
     keras_model.layers[1].set_weights([kernel])
 
@@ -796,31 +860,34 @@ def _test_tflite2_quantized_convolution(input_shape, kernel_shape,
     tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-2, atol=1e-2)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
+    )
 
 
-def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape,
-        dilations, strides, padding, data_format, depth_multiplier):
+def _test_tflite2_quantized_depthwise_convolution(
+    input_shape, kernel_shape, dilations, strides, padding, data_format, depth_multiplier
+):
     """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes"""
 
     data_format = "channels_last" if "NHWC" else "channels_first"
-    data = np.random.uniform(0, 1, input_shape).astype('float32')
-    kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+    data = np.random.uniform(0, 1, input_shape).astype("float32")
+    kernel = np.random.uniform(0, 1, kernel_shape).astype("float32")
 
     data_in = tf.keras.layers.Input(shape=data.shape[1:])
-    conv = tf.keras.layers.DepthwiseConv2D(kernel_size=(kernel_shape[0], kernel_shape[1]),
-                                           strides=strides,
-                                           padding=padding,
-                                           data_format=data_format,
-                                           activation='relu',
-                                           use_bias=False,
-                                           depth_multiplier=depth_multiplier)(data_in)
+    conv = tf.keras.layers.DepthwiseConv2D(
+        kernel_size=(kernel_shape[0], kernel_shape[1]),
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        activation="relu",
+        use_bias=False,
+        depth_multiplier=depth_multiplier,
+    )(data_in)
     keras_model = tf.keras.models.Model(data_in, conv)
     keras_model.layers[1].set_weights([kernel])
 
-
     # To create quantized values with dynamic range of activations, needs representative dataset
     def representative_data_gen():
         for i in range(1):
@@ -829,14 +896,22 @@ def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape,
     tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-2, atol=1e-2)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
+    )
 
 
-def _test_convolution(tensor_in_sizes, filter_in_sizes,
-                      dilations, strides, padding, data_format,
-                      is_depthwise=False, quantized=False):
+def _test_convolution(
+    tensor_in_sizes,
+    filter_in_sizes,
+    dilations,
+    strides,
+    padding,
+    data_format,
+    is_depthwise=False,
+    quantized=False,
+):
     """ One iteration of convolution with given shapes and attributes """
 
     total_size_1 = 1
@@ -848,91 +923,174 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
     # Initializes the input tensor with array containing incrementing
     # numbers from 1.
     if quantized:
-        data_array = np.random.uniform(0, 255, tensor_in_sizes).astype('uint8')
-        filter_array = np.random.uniform(0, 255, filter_in_sizes).astype('uint8')
+        data_array = np.random.uniform(0, 255, tensor_in_sizes).astype("uint8")
+        filter_array = np.random.uniform(0, 255, filter_in_sizes).astype("uint8")
     else:
         data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
         filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32', name='in_data')
-        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32', name='in_filter')
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data")
+        in_filter = constant_op.constant(
+            filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter"
+        )
         strides = [1] + strides + [1]
         dilations = [1] + dilations + [1]
 
         if is_depthwise:
-            out = nn_ops.depthwise_conv2d_native(in_data,
-                                                 in_filter,
-                                                 strides=strides,
-                                                 padding=padding,
-                                                 data_format=data_format)
+            out = nn_ops.depthwise_conv2d_native(
+                in_data, in_filter, strides=strides, padding=padding, data_format=data_format
+            )
         else:
-            out = nn_ops.conv2d(in_data,
-                                in_filter,
-                                strides=strides,
-                                padding=padding,
-                                data_format=data_format)
+            out = nn_ops.conv2d(
+                in_data, in_filter, strides=strides, padding=padding, data_format=data_format
+            )
 
         if quantized:
             if is_depthwise:
                 # Quantized the inputs and feed them to the convolution
-                inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
-                inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
-                out = nn_ops.depthwise_conv2d_native(inq_data,
-                                                     inq_filter,
-                                                     strides=strides,
-                                                     padding=padding,
-                                                     data_format=data_format)
-                out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+                inq_data = tf.quantization.fake_quant_with_min_max_args(
+                    in_data, min=-100, max=100, name="inq_data"
+                )
+                inq_filter = tf.quantization.fake_quant_with_min_max_args(
+                    in_filter, min=-100, max=100, name="inq_filter"
+                )
+                out = nn_ops.depthwise_conv2d_native(
+                    inq_data, inq_filter, strides=strides, padding=padding, data_format=data_format
+                )
+                out = tf.quantization.fake_quant_with_min_max_args(
+                    out, min=-200, max=200, name="out"
+                )
 
                 # Set the input quantization range
-                input_range = {'in_data': (-100, 100)} if quantized else None
+                input_range = {"in_data": (-100, 100)} if quantized else None
 
                 # Compare
-                compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+                compare_tflite_with_tvm(
+                    data_array,
+                    "in_data",
+                    [in_data],
+                    [out],
+                    quantized=quantized,
+                    input_range=input_range,
+                )
             else:
                 # Quantized the inputs and feed them to the convolution
-                inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
-                inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
-                out = nn_ops.conv2d(inq_data,
-                                    inq_filter,
-                                    strides=strides,
-                                    padding=padding,
-                                    data_format=data_format)
-                out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+                inq_data = tf.quantization.fake_quant_with_min_max_args(
+                    in_data, min=-100, max=100, name="inq_data"
+                )
+                inq_filter = tf.quantization.fake_quant_with_min_max_args(
+                    in_filter, min=-100, max=100, name="inq_filter"
+                )
+                out = nn_ops.conv2d(
+                    inq_data, inq_filter, strides=strides, padding=padding, data_format=data_format
+                )
+                out = tf.quantization.fake_quant_with_min_max_args(
+                    out, min=-200, max=200, name="out"
+                )
 
                 # Set the input quantization range
-                input_range = {'in_data': (-100, 100)} if quantized else None
+                input_range = {"in_data": (-100, 100)} if quantized else None
 
                 # Compare
-                compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+                compare_tflite_with_tvm(
+                    data_array,
+                    "in_data",
+                    [in_data],
+                    [out],
+                    quantized=quantized,
+                    input_range=input_range,
+                )
         else:
-            data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
-            compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out])
+            data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
+            compare_tflite_with_tvm(data_array, "in_data", [in_data], [out])
 
 
 def test_forward_convolution():
     for quantized in [False, True]:
-        _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
-        _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)
-        _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
-        _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)
+        _test_convolution(
+            [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC", quantized=quantized
+        )
+        _test_convolution(
+            [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC", quantized=quantized
+        )
+        _test_convolution(
+            [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC", quantized=quantized
+        )
+        _test_convolution(
+            [4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC", quantized=quantized
+        )
 
         # depthwise convolution
-        _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
-        _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
-        _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
-        _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
-        _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+        _test_convolution(
+            [4, 8, 8, 176],
+            [1, 1, 176, 1],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NHWC",
+            True,
+            quantized=quantized,
+        )
+        _test_convolution(
+            [4, 17, 17, 19],
+            [3, 3, 19, 1],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NHWC",
+            True,
+            quantized=quantized,
+        )
+        _test_convolution(
+            [4, 17, 17, 124],
+            [1, 1, 124, 1],
+            [1, 1],
+            [1, 1],
+            "SAME",
+            "NHWC",
+            True,
+            quantized=quantized,
+        )
+        _test_convolution(
+            [4, 17, 17, 12],
+            [3, 3, 12, 1],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NHWC",
+            True,
+            quantized=quantized,
+        )
+        _test_convolution(
+            [4, 17, 17, 12],
+            [3, 3, 12, 2],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NHWC",
+            True,
+            quantized=quantized,
+        )
         # depthwise convolution with single input channel
-        _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+        _test_convolution(
+            [1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], "SAME", "NHWC", True, quantized=quantized
+        )
 
     # TFLite2 quantized convolution testing
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
-        _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
-        _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
-        _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
-        _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
+        _test_tflite2_quantized_convolution(
+            [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC"
+        )
+        _test_tflite2_quantized_convolution(
+            [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC"
+        )
+        _test_tflite2_quantized_convolution(
+            [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC"
+        )
+        _test_tflite2_quantized_convolution(
+            [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC"
+        )
 
         # Disable as tests are flaky - https://github.com/apache/incubator-tvm/issues/6064
         # depthwise convolution
@@ -944,11 +1102,11 @@ def test_forward_convolution():
         #                                               'SAME', 'NHWC', 8)
 
 
-
 #######################################################################
 # Transpose Convolution
 # ---------------------
 
+
 def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
     """ One iteration of transpose convolution with given shapes and attributes """
 
@@ -964,76 +1122,78 @@ def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides
     filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
+        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
         strides = [1] + strides + [1]
         # in_filter layout is HWOI
-        out = nn_ops.conv2d_transpose(in_data,
-                                      in_filter,
-                                      output_shape=output_shape,
-                                      strides=strides,
-                                      padding=padding)
-        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
-        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
+        out = nn_ops.conv2d_transpose(
+            in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding
+        )
+        data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
+        compare_tflite_with_tvm(data_array, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_transpose_conv():
     # kernel 3x3, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], 'VALID')
+    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID")
 
     # kernel 3x3, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], 'SAME')
+    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
 
     # kernel 2x2, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], 'VALID')
+    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID")
 
     # kernel 2x2, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], 'SAME')
+    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
 
     # kernel 1x1, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'VALID')
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'VALID')
+    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID")
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID")
 
     # kernel 1x1, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'SAME')
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'SAME')
+    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME")
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME")
 
 
 #######################################################################
 # Reshape
 # -------
 
+
 def _test_reshape(data, out_shape, wrap_shape):
     """ One iteration of reshape operation with given data and out shape """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
 
-        out_shape = out_shape if not wrap_shape\
-            else np.array(out_shape, dtype=np.int32)
+        out_shape = out_shape if not wrap_shape else np.array(out_shape, dtype=np.int32)
 
-        in_shape = out_shape if not wrap_shape\
-            else array_ops.placeholder(shape=out_shape.shape,\
-                                        dtype=out_shape.dtype,\
-                                        name="Newshape")
+        in_shape = (
+            out_shape
+            if not wrap_shape
+            else array_ops.placeholder(
+                shape=out_shape.shape, dtype=out_shape.dtype, name="Newshape"
+            )
+        )
 
         out = array_ops.reshape(in_data, in_shape)
 
         compare_tflite_with_tvm(
-            [data, out_shape]               if wrap_shape else [data],\
-            ['Placeholder:0', 'Newshape:0'] if wrap_shape else ['Placeholder:0'],\
-            [in_data, in_shape]             if wrap_shape else [in_data],\
+            [data, out_shape] if wrap_shape else [data],
+            ["Placeholder:0", "Newshape:0"] if wrap_shape else ["Placeholder:0"],
+            [in_data, in_shape] if wrap_shape else [in_data],
             [out],
-            mode='vm')
+            mode="vm",
+        )
 
 
 def test_forward_reshape():
@@ -1048,6 +1208,7 @@ def test_forward_reshape():
 # Resize
 # ------
 
+
 def _test_resize(tf_resize_op, data, align_corners):
     """ One iteration of Resize """
 
@@ -1055,10 +1216,10 @@ def _test_resize(tf_resize_op, data, align_corners):
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
+        images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in")
         size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
         out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
-        compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor])
+        compare_tflite_with_tvm([data[0]], ["in:0"], [images_tensor], [out_tensor])
 
 
 def test_all_resize():
@@ -1071,21 +1232,24 @@ def test_all_resize():
     # According to topi resize.h
     # Align corners not supported for nearest neighbour
     from tflite.BuiltinOperator import BuiltinOperator
-    if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
+
+    if "RESIZE_NEAREST_NEIGHBOR" in dir(BuiltinOperator()):
         _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
 
+
 #######################################################################
 # Range
 # -----
 def _test_range(start, limit, delta):
     # tflite 1.13 convert method does not accept empty shapes
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         tf.reset_default_graph()
         with tf.Graph().as_default():
-            start_scalar, limit_scalar, delta_scalar = \
-                tf.placeholder(dtype=start.dtype, shape=(), name="start"), \
-                tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \
-                tf.placeholder(dtype=delta.dtype, shape=(), name="delta")
+            start_scalar, limit_scalar, delta_scalar = (
+                tf.placeholder(dtype=start.dtype, shape=(), name="start"),
+                tf.placeholder(dtype=limit.dtype, shape=(), name="limit"),
+                tf.placeholder(dtype=delta.dtype, shape=(), name="delta"),
+            )
 
             out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")
 
@@ -1095,43 +1259,44 @@ def _test_range(start, limit, delta):
                 [start_scalar, limit_scalar, delta_scalar],
                 [out],
                 mode="vm",
-                quantized=False
-        )
+                quantized=False,
+            )
+
 
 def _test_range_default():
     # tflite 1.13 convert method does not accept empty shapes
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         tf.reset_default_graph()
         with tf.Graph().as_default():
             inputs = [
                 tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
-                tf.placeholder(dtype=tf.int32, shape=(), name="p2")
+                tf.placeholder(dtype=tf.int32, shape=(), name="p2"),
             ]
             outputs = [
-                tf.range(start = inputs[0], limit = inputs[1]), # use default delta
-                tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range
+                tf.range(start=inputs[0], limit=inputs[1]),  # use default delta
+                tf.range(
+                    start=inputs[1]
+                ),  # use start as limit with 0 as the first item in the range
             ]
 
             compare_tflite_with_tvm(
-                [np.int32(1), np.int32(18)],
-                ["p1", "p2"],
-                inputs,
-                outputs,
-                mode="vm"
-        )
+                [np.int32(1), np.int32(18)], ["p1", "p2"], inputs, outputs, mode="vm"
+            )
+
 
 def test_forward_range():
-   _test_range(np.int32(1), np.int32(18), np.int32(3))
-   _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float
-   _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float
-   _test_range_default()
+    _test_range(np.int32(1), np.int32(18), np.int32(3))
+    _test_range(np.int32(1), np.int32(18), np.float32(3.1))  # increment is of type float
+    _test_range(np.float32(1.0), np.int32(18), np.int32(3.1))  # start is of type float
+    _test_range_default()
+
 
 #######################################################################
 # Shape
 # -----
 def test_forward_shape():
     # tflite 1.13 convert method does not accept empty shapes
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         tf.reset_default_graph()
         with tf.Graph().as_default():
             data = np.array([1, 18, 3], dtype=np.int32)
@@ -1145,13 +1310,15 @@ def test_forward_shape():
                 ["start", "limit", "delta"],
                 [start, limit, delta],
                 [out],
-                mode="vm"
+                mode="vm",
             )
 
+
 #######################################################################
 # Concatenation
 # -------------
 
+
 def _test_concatenation(data, axis):
     """ One iteration of concatenation """
 
@@ -1160,7 +1327,8 @@ def _test_concatenation(data, axis):
     with tf.Graph().as_default():
         in_data = [
             array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
-            for idx, tensor in enumerate(data)]
+            for idx, tensor in enumerate(data)
+        ]
         out = array_ops.concat(in_data, axis=axis)
         name = ["in_{}:0".format(idx) for idx in range(len(data))]
 
@@ -1169,140 +1337,182 @@ def _test_concatenation(data, axis):
 
 def test_forward_concatenation():
 
-    _test_concatenation(
-        [np.arange(6).reshape((1, 2, 1, 3)),
-        np.arange(6).reshape((1, 2, 1, 3))], 1)
+    _test_concatenation([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
 
-    _test_concatenation(
-        [np.arange(6).reshape((3, 2)),
-         np.arange(6).reshape((3, 2))], 1)
+    _test_concatenation([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
 
     _test_concatenation(
-        [np.arange(6).reshape((2, 1, 1, 3)),
-         np.arange(6).reshape((2, 1, 1, 3)),
-         np.arange(6).reshape((2, 1, 1, 3))], 1)
+        [
+            np.arange(6).reshape((2, 1, 1, 3)),
+            np.arange(6).reshape((2, 1, 1, 3)),
+            np.arange(6).reshape((2, 1, 1, 3)),
+        ],
+        1,
+    )
+
 
 #######################################################################
 # Unary elemwise
 # --------------
 
+
 def _test_unary_elemwise(math_op, data):
     """ One iteration of unary elemwise """
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in')
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in")
         out = math_op(in_data)
-        compare_tflite_with_tvm(data, ['in:0'], [in_data], [out])
+        compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
+
 
 #######################################################################
 # Abs
 # ---
 
+
 def _test_abs(data):
     """ One iteration of abs """
     return _test_unary_elemwise(math_ops.abs, data)
+
+
 #######################################################################
 # Ceil
 # ----
 
+
 def _test_ceil(data):
     """ One iteration of ceil """
     return _test_unary_elemwise(math_ops.ceil, data)
+
+
 #######################################################################
 # Floor
 # -----
 
+
 def _test_floor(data):
     """ One iteration of floor """
     return _test_unary_elemwise(math_ops.floor, data)
 
+
 #######################################################################
 # Round
 # -----
 
+
 def _test_round(data):
     """ One iteration of round """
     return _test_unary_elemwise(math_ops.round, data)
 
+
 #######################################################################
 # Exp
 # ---
 
+
 def _test_exp(data):
     """ One iteration of exp """
     return _test_unary_elemwise(math_ops.exp, data)
+
+
 #######################################################################
 # Log
 # ---
 
+
 def _test_log(data):
     """ One iteration of log """
     return _test_unary_elemwise(math_ops.log, data)
+
+
 #######################################################################
 # Sin
 # ---
 
+
 def _test_sin(data):
     """ One iteration of sin """
     return _test_unary_elemwise(math_ops.sin, data)
+
+
 #######################################################################
 # Cos
 # ---
 
+
 def _test_cos(data):
     """ One iteration of cos """
     return _test_unary_elemwise(math_ops.cos, data)
+
+
 #######################################################################
 # Tan
 # ---
 
+
 def _test_tan(data):
     """ One iteration of tan """
     return _test_unary_elemwise(math_ops.tan, data)
+
+
 #######################################################################
 # Sqrt
 # ----
 
+
 def _test_sqrt(data):
     """ One iteration of sqrt """
     return _test_unary_elemwise(math_ops.sqrt, data)
+
+
 #######################################################################
 # Rsqrt
 # -----
 
+
 def _test_rsqrt(data):
     """ One iteration of rsqrt """
     return _test_unary_elemwise(math_ops.rsqrt, data)
+
+
 #######################################################################
 # Neg
 # ---
 
+
 def _test_neg(data):
     """ One iteration of neg """
     return _test_unary_elemwise(math_ops.neg, data)
+
+
 #######################################################################
 # Square
 # ------
 
+
 def _test_square(data):
     """ One iteration of square """
     return _test_unary_elemwise(math_ops.square, data)
 
+
 #######################################################################
 # Elu
 # ---
 
+
 def _test_elu(data):
     """ One iteration of elu """
     return _test_unary_elemwise(nn_ops.elu, data)
 
+
 def _test_forward_unary_elemwise(test_op):
     # functions that need positive input
-    if test_op.__name__ in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
+    if test_op.__name__ in {"_test_log", "_test_sqrt", "_test_rsqrt"}:
         test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
     else:
         test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))
 
+
 def test_all_unary_elemwise():
     _test_forward_unary_elemwise(_test_abs)
     _test_forward_unary_elemwise(_test_floor)
@@ -1314,7 +1524,7 @@ def test_all_unary_elemwise():
     _test_forward_unary_elemwise(_test_neg)
     _test_forward_unary_elemwise(_test_square)
     # ceil and cos come with TFLite 1.14.0.post1 fbs schema
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_forward_unary_elemwise(_test_ceil)
         _test_forward_unary_elemwise(_test_cos)
         _test_forward_unary_elemwise(_test_round)
@@ -1322,20 +1532,29 @@ def test_all_unary_elemwise():
         # in CI or anywhere else. The failure mode is that we see a backtrace
         # from the converter that we need to provide a custom Tan operator
         # implementation.
-        #_test_forward_unary_elemwise(_test_tan)
+        # _test_forward_unary_elemwise(_test_tan)
         _test_forward_unary_elemwise(_test_elu)
 
+
 #######################################################################
 # Element-wise
 # ------------
 
-def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None, same_qnn_params=False):
+
+def _test_elemwise(
+    math_op,
+    data,
+    fused_activation_function=None,
+    quantized=False,
+    qnn_op=None,
+    same_qnn_params=False,
+):
     """ One iteration of elemwise """
 
     assert len(data) == 2
 
-    def __test_elemwise( in_data ):
-        assert 2 == len( in_data )
+    def __test_elemwise(in_data):
+        assert 2 == len(in_data)
         if quantized:
             # set the fp32 output range with respect to the operation
             out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
@@ -1348,192 +1567,294 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
                 inq1_min, inq1_max = (out_min, out_max)
 
             # fake_quant will keep the tensors in float32 until the conversion in the session
-            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=out_min, max=out_max, name="inq_0")\
-                        if None != in_data[0]\
-                        else tf.quantization.fake_quant_with_min_max_args(data[0], min=out_min, max=out_max, name="const_tensor0"),
-                        tf.quantization.fake_quant_with_min_max_args(in_data[1], min=out_min, max=out_max, name="inq_1")\
-                        if None != in_data[1]\
-                        else tf.quantization.fake_quant_with_min_max_args(data[1], min=out_min, max=out_max, name="const_tensor1")]
+            inq_data = [
+                tf.quantization.fake_quant_with_min_max_args(
+                    in_data[0], min=out_min, max=out_max, name="inq_0"
+                )
+                if None != in_data[0]
+                else tf.quantization.fake_quant_with_min_max_args(
+                    data[0], min=out_min, max=out_max, name="const_tensor0"
+                ),
+                tf.quantization.fake_quant_with_min_max_args(
+                    in_data[1], min=out_min, max=out_max, name="inq_1"
+                )
+                if None != in_data[1]
+                else tf.quantization.fake_quant_with_min_max_args(
+                    data[1], min=out_min, max=out_max, name="const_tensor1"
+                ),
+            ]
 
-            input_range = {x[1][0]:x[1][1] for x in zip(in_data, (('inq_0', (inq0_min, inq0_max)),\
-                                                                  ('inq_1', (inq1_min, inq1_max)))) if None != x[0]}
+            input_range = {
+                x[1][0]: x[1][1]
+                for x in zip(
+                    in_data, (("inq_0", (inq0_min, inq0_max)), ("inq_1", (inq1_min, inq1_max)))
+                )
+                if None != x[0]
+            }
 
             out = math_op(inq_data[0], inq_data[1])
             out = with_fused_activation_function(out, fused_activation_function)
-            out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
+            out = tf.quantization.fake_quant_with_min_max_args(
+                out, min=out_min, max=out_max, name="out"
+            )
 
             # Note same_qnn_params uses experimental_new_converter as toco failed
-            compare_tflite_with_tvm([x[1] for x in zip(in_data, data) if None != x[0]],
+            compare_tflite_with_tvm(
+                [x[1] for x in zip(in_data, data) if None != x[0]],
                 [x + ":0" for x in input_range.keys()],
                 [x[1] for x in zip(in_data, inq_data) if None != x[0]],
                 [out],
                 quantized=True,
                 input_range=input_range,
-                experimental_new_converter=same_qnn_params)
+                experimental_new_converter=same_qnn_params,
+            )
         else:
-            out = math_op(in_data[0] if None != in_data[0] else ops.convert_to_tensor(data[0], dtype=data[0].dtype),
-                          in_data[1] if None != in_data[1] else ops.convert_to_tensor(data[1], dtype=data[1].dtype))
+            out = math_op(
+                in_data[0]
+                if None != in_data[0]
+                else ops.convert_to_tensor(data[0], dtype=data[0].dtype),
+                in_data[1]
+                if None != in_data[1]
+                else ops.convert_to_tensor(data[1], dtype=data[1].dtype),
+            )
             out = with_fused_activation_function(out, fused_activation_function)
-            compare_tflite_with_tvm([x[1] for x in zip( in_data, data ) if None != x[0]],
-                    [x[1] for x in zip( in_data, ('in_0:0', 'in_1:0') ) if None != x[0]],
-                    [x for x in in_data if None != x],
-                    [out])
+            compare_tflite_with_tvm(
+                [x[1] for x in zip(in_data, data) if None != x[0]],
+                [x[1] for x in zip(in_data, ("in_0:0", "in_1:0")) if None != x[0]],
+                [x for x in in_data if None != x],
+                [out],
+            )
 
     # Test with two tensors
     with tf.Graph().as_default():
-        __test_elemwise( in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
-                                    array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')])
+        __test_elemwise(
+            in_data=[
+                array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"),
+                array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"),
+            ]
+        )
     # Test with tensor and constant
     with tf.Graph().as_default():
-        __test_elemwise( in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
-                                    None])
+        __test_elemwise(
+            in_data=[array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"), None]
+        )
     # Test with constant and tensor
     with tf.Graph().as_default():
-        __test_elemwise( in_data = [None,
-                                    array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')])
+        __test_elemwise(
+            in_data=[None, array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1")]
+        )
+
 
 #######################################################################
 # Add
 # ---
 
+
 def _test_add(data, fused_activation_function=None, quantized=False, qnn_op=None):
     """ One iteration of add """
     return _test_elemwise(math_ops.add, data, fused_activation_function, quantized, qnn_op)
 
+
 #######################################################################
 # Subtract
 # --------
 
+
 def _test_sub(data, fused_activation_function=None, quantized=False, qnn_op=None):
     """ One iteration of subtract """
     return _test_elemwise(math_ops.subtract, data, fused_activation_function, quantized, qnn_op)
+
+
 #######################################################################
 # Mul
 # ---
 
+
 def _test_mul(data, fused_activation_function=None, quantized=False, qnn_op=None):
     """ One iteration of mul """
     return _test_elemwise(math_ops.multiply, data, fused_activation_function, quantized, qnn_op)
 
+
 #######################################################################
 # Divide
 # ------
 
+
 def _test_div(data, fused_activation_function=None):
     """ One iteration of divide """
     return _test_elemwise(math_ops.divide, data, fused_activation_function)
+
+
 #######################################################################
 # Power
 # -----
 
+
 def _test_pow(data):
     """ One iteration of power """
     return _test_elemwise(math_ops.pow, data)
+
+
 #######################################################################
 # Maximum
 # -------
 
+
 def _test_maximum(data, fused_activation_function=None, quantized=False, qnn_op=None):
     """ One iteration of maximum """
-    return _test_elemwise(math_ops.maximum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True)
+    return _test_elemwise(
+        math_ops.maximum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True
+    )
+
+
 #######################################################################
 # Minimum
 # -------
 
+
 def _test_minimum(data, fused_activation_function=None, quantized=False, qnn_op=None):
     """ One iteration of minimum """
-    return _test_elemwise(math_ops.minimum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True)
+    return _test_elemwise(
+        math_ops.minimum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True
+    )
+
+
 #######################################################################
 # Greater
 # -------
 
+
 def _test_greater(data):
     """ One iteration of greater """
     return _test_elemwise(math_ops.greater, data)
+
+
 #######################################################################
 # Greater_equal
 # -------------
 
+
 def _test_greater_equal(data):
     """ One iteration of greater_equal """
     return _test_elemwise(math_ops.greater_equal, data)
+
+
 #######################################################################
 # Less
 # ----
 
+
 def _test_less(data):
     """ One iteration of less """
     return _test_elemwise(math_ops.less, data)
+
+
 #######################################################################
 # Less_equal
 # ----------
 
+
 def _test_less_equal(data):
     """ One iteration of less_equal """
     return _test_elemwise(math_ops.less_equal, data)
+
+
 #######################################################################
 # Equal
 # -----
 
+
 def _test_equal(data):
     """ One iteration of equal """
     return _test_elemwise(math_ops.equal, data)
+
+
 #######################################################################
 # Not_equal
 # ---------
 
+
 def _test_not_equal(data):
     """ One iteration of not_equal"""
     return _test_elemwise(math_ops.not_equal, data)
+
+
 #######################################################################
 # Squared_difference
 # ------------------
 
+
 def _test_squared_difference(data):
     """ One iteration of squared difference """
     return _test_elemwise(math_ops.squared_difference, data)
 
+
 #######################################################################
 # Floor_divide
 # ------------
 
+
 def _test_floor_divide(data):
     """ One iteration of floor_div"""
     return _test_elemwise(math_ops.floordiv, data)
 
+
 #######################################################################
 # Floor_mod
 # ---------
 
+
 def _test_floor_mod(data):
     """ One iteration of floor_mod"""
     return _test_elemwise(math_ops.floormod, data)
 
+
 def _test_forward_elemwise(testop):
     """ Elewise"""
-    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-              np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
-    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
-               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
-    testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
-               np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
+    testop(
+        [
+            np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+        ]
+    )
+    testop(
+        [
+            np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
+        ]
+    )
+    testop(
+        [
+            np.arange(3.0, dtype=np.float32).reshape((1, 3)),
+            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
+        ]
+    )
+
 
 def _test_forward_elemwise_quantized(testop):
-    testop([np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
-            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True, qnn_op=testop)
+    testop(
+        [
+            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
+            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
+        ],
+        quantized=True,
+        qnn_op=testop,
+    )
+
 
 def _test_elemwise_qnn_out_range(qnn_op):
     # set the fake_quant output range with respect to the input tensors float32 range
     qnn_out_range = {
         _test_add: (-150, 150),
         _test_sub: (-150, 150),
-        _test_mul: (-5e+3, 5e+3),
+        _test_mul: (-5e3, 5e3),
         _test_maximum: (-112, 111),
-        _test_minimum: (-128, 127)
+        _test_minimum: (-128, 127),
     }
 
     return qnn_out_range[qnn_op]
 
+
 def test_all_elemwise():
     _test_forward_elemwise(_test_add)
     _test_forward_elemwise_quantized(_test_add)
@@ -1564,7 +1885,7 @@ def test_all_elemwise():
     _test_forward_elemwise(_test_less_equal)
     _test_forward_elemwise(_test_equal)
     _test_forward_elemwise(_test_not_equal)
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_forward_elemwise(_test_floor_divide)
         _test_forward_elemwise(_test_floor_mod)
 
@@ -1581,12 +1902,16 @@ def _test_forward_add_n(inputs):
         for each in inputs:
             temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
         output = tf.add_n(temp)
-        compare_tflite_with_tvm([each for each in inputs], [each.name for each in temp],
-                                [each for each in temp], [output])
+        compare_tflite_with_tvm(
+            [each for each in inputs],
+            [each.name for each in temp],
+            [each for each in temp],
+            [output],
+        )
 
 
 def test_forward_add_n():
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
         y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
         z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
@@ -1609,50 +1934,62 @@ def test_forward_add_n():
 # Logical operators
 # -----------------
 
+
 def _test_logical_binary(logical_bin_op, data):
 
     with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
-                   array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
+        in_data = [
+            array_ops.placeholder(shape=data[0].shape, dtype="bool", name="in_0"),
+            array_ops.placeholder(shape=data[1].shape, dtype="bool", name="in_1"),
+        ]
         if logical_bin_op == math_ops.logical_not:
-            out = math_ops.logical_or(in_data[0], in_data[1], name='out1')
-            out = logical_bin_op(out, name='out')
+            out = math_ops.logical_or(in_data[0], in_data[1], name="out1")
+            out = logical_bin_op(out, name="out")
         else:
-            out = logical_bin_op(in_data[0], in_data[1], name='out')
+            out = logical_bin_op(in_data[0], in_data[1], name="out")
+
+        compare_tflite_with_tvm(data, ["in_0:0", "in_1:0"], in_data, [out])
 
-        compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
 
 def _test_forward_logical_and(data):
     """ One iteration of logical and """
     return _test_logical_binary(math_ops.logical_and, data)
 
+
 def _test_forward_logical_or(data):
     """ One iteration of logical or """
     return _test_logical_binary(math_ops.logical_or, data)
 
+
 def _test_forward_logical_not(data):
     """ One iteration of logical not """
     return _test_logical_binary(math_ops.logical_not, data)
 
+
 def test_all_logical():
-    data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
-            np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
+    data = [
+        np.random.choice(a=[False, True], size=(2, 3, 4)).astype("bool"),
+        np.random.choice(a=[False, True], size=(2, 3, 4)).astype("bool"),
+    ]
     # boolean dtype is not supported by older versions than TFLite 1.15.0
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
         _test_forward_logical_and(data)
         _test_forward_logical_or(data)
         _test_forward_logical_not(data)
 
+
 #######################################################################
 # Zeros like
 # ----------
 
+
 def _test_zeros_like(data):
     """ One iteration of ZEROS LIKE """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = gen_array_ops.zeros_like(in_data)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_zeros_like():
     """ ZEROS LIKE """
@@ -1663,21 +2000,22 @@ def test_forward_zeros_like():
 # Fill
 # ----
 
+
 def _test_fill(dims, value_data, value_dtype):
     """ Use the fill op to create a tensor of value_data with constant dims."""
 
     value_data = np.array(value_data, dtype=value_dtype)
     # TF 1.13 TFLite convert method does not accept empty shapes
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         with tf.Graph().as_default():
             value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[])
-            out = tf.fill(dims,  value)
+            out = tf.fill(dims, value)
             compare_tflite_with_tvm([value_data], ["value"], [value], [out])
 
     with tf.Graph().as_default():
         input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims)
         # Fill op gets converted to static tensor during conversion
-        out = tf.fill(dims,  value_data)
+        out = tf.fill(dims, value_data)
         out1 = tf.add(out, input1)
         input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype)
         compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1])
@@ -1688,13 +2026,14 @@ def test_forward_fill():
 
     _test_fill((1, 2, 2, 4), 5, "int32")
     _test_fill((1, 2, 2, 4), 5, "float32")
-    _test_fill((5, ), 5, "int32")
+    _test_fill((5,), 5, "int32")
 
 
 #######################################################################
 # Reduce
 # ------
 
+
 def _test_reduce(math_op, data, keep_dims=None):
     """ One iteration of reduce """
 
@@ -1702,9 +2041,10 @@ def _test_reduce(math_op, data, keep_dims=None):
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
+        in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in")
         out = math_op(in_data, data[1], keep_dims)
-        compare_tflite_with_tvm([data[0]], ['in:0'], [in_data], [out])
+        compare_tflite_with_tvm([data[0]], ["in:0"], [in_data], [out])
+
 
 def _test_reduce_quantize(math_op, data, keep_dims=None):
     """ One iteration of reduce """
@@ -1713,34 +2053,45 @@ def _test_reduce_quantize(math_op, data, keep_dims=None):
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name='in')]
-        inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
-        input_range = {'inq_0': (-100, 100)}
+        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
+        inq_data = [
+            tf.quantization.fake_quant_with_min_max_args(
+                in_data[0], min=-100, max=100, name="inq_0"
+            )
+        ]
+        input_range = {"inq_0": (-100, 100)}
         out = math_op(inq_data, data[1], keep_dims)
         out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
-        compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True, input_range=input_range)
+        compare_tflite_with_tvm(
+            [data[0]], ["inq_0:0"], [inq_data[0]], [out], quantized=True, input_range=input_range
+        )
 
 
 #######################################################################
 # Reduce_min
 # ----------
 
+
 def _test_reduce_min(data, keep_dims=None):
     """ One iteration of reduce_min """
     return _test_reduce(math_ops.reduce_min, data, keep_dims)
 
+
 #######################################################################
 # Reduce_max
 # ----------
 
+
 def _test_reduce_max(data, keep_dims=None):
     """ One iteration of reduce_max """
     return _test_reduce(math_ops.reduce_max, data, keep_dims)
 
+
 #######################################################################
 # Reduce_mean
 # -----------
 
+
 def _test_reduce_mean(data, keep_dims=None, quantized=False):
     """ One iteration of reduce_mean """
     if quantized:
@@ -1748,37 +2099,45 @@ def _test_reduce_mean(data, keep_dims=None, quantized=False):
     else:
         return _test_reduce(math_ops.reduce_mean, data, keep_dims)
 
+
 #######################################################################
 # Reduce_prod
 # -----------
 
+
 def _test_reduce_prod(data, keep_dims=None):
     """ One iteration of reduce_prod """
     return _test_reduce(math_ops.reduce_prod, data, keep_dims)
 
+
 #######################################################################
 # Reduce_sum
 # -----------
 
+
 def _test_reduce_sum(data, keep_dims=None):
     """ One iteration of reduce_sum """
     return _test_reduce(math_ops.reduce_sum, data, keep_dims)
 
+
 #######################################################################
 # Reduce_any
 # ----------
 
+
 def _test_reduce_any(data, keep_dims=None):
     """ One iteration of reduce_any """
     return _test_reduce(math_ops.reduce_any, data, keep_dims)
 
+
 def _test_forward_reduce(testop, dtype="float32"):
     """ Reduce """
-    if dtype == 'bool':
-        data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
-                 None]
-        data1 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
-                 np.array([1, 2], dtype=np.int32)]
+    if dtype == "bool":
+        data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype), None]
+        data1 = [
+            np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
+            np.array([1, 2], dtype=np.int32),
+        ]
     else:
         data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
         data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
@@ -1789,12 +2148,17 @@ def _test_forward_reduce(testop, dtype="float32"):
     testop(data1, keep_dims=False)
     testop(data1, keep_dims=True)
 
+
 def _test_forward_reduce_quantized(testop):
-    data0 = [np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), np.array([1, 2], dtype=np.int32)]
+    data0 = [
+        np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
+        np.array([1, 2], dtype=np.int32),
+    ]
     testop(data0, quantized=True)
     testop(data0, keep_dims=False, quantized=True)
     testop(data0, keep_dims=True, quantized=True)
 
+
 def test_all_reduce():
     _test_forward_reduce(_test_reduce_min)
     _test_forward_reduce(_test_reduce_max)
@@ -1802,30 +2166,37 @@ def test_all_reduce():
     _test_forward_reduce_quantized(_test_reduce_mean)
     _test_forward_reduce(_test_reduce_prod)
     _test_forward_reduce(_test_reduce_sum)
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
         _test_forward_reduce(_test_reduce_any, dtype="bool")
 
+
 #######################################################################
 # Arg_min_max
 # -----------
 
+
 def _test_arg_min_max(math_op, data, axis, quantized=False):
     """ One iteration of arg_min_max"""
 
     with tf.Graph().as_default():
-        t_name="in"
-        in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name )
-        input_range=None
+        t_name = "in"
+        in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name)
+        input_range = None
         qmin, qmax = -100, 102
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=qmin, max=qmax, name= 'q' + t_name )
-            input_range = { inq_data.name.split(':')[0]: (qmin, qmax)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=qmin, max=qmax, name="q" + t_name
+            )
+            input_range = {inq_data.name.split(":")[0]: (qmin, qmax)}
             out = math_op(input=inq_data, axis=axis)
-            compare_tflite_with_tvm([data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                [data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = math_op(input=in_data, axis=axis)
             compare_tflite_with_tvm([data], [in_data.name], [in_data], [out])
 
+
 def test_forward_arg_min_max():
     # test quantized
     for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]:
@@ -1843,27 +2214,26 @@ def test_forward_arg_min_max():
 # Select, Where
 # -------------
 
+
 def test_forward_select():
     with tf.Graph().as_default():
         with tf.Session() as sess:
-            input1 = tf.placeholder(
-                tf.int32, shape=[1, 4, 4, 3], name='input1')
-            input2 = tf.placeholder(
-                tf.int32, shape=[1, 4, 4, 3], name='input2')
+            input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input1")
+            input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input2")
             mask = input1 > input2
             out = tf.where(mask, input1 + 1, input2 * 2)
-            in_data1 = np.random.uniform(
-                0, 10, size=(1, 4, 4, 3)).astype("int32")
-            in_data2 = np.random.uniform(
-                0, 10, size=(1, 4, 4, 3)).astype("int32")
+            in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("int32")
+            in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("int32")
 
-            compare_tflite_with_tvm([in_data1, in_data2], [
-                                'input1:0', 'input2:0'], [input1, input2], [out])
+            compare_tflite_with_tvm(
+                [in_data1, in_data2], ["input1:0", "input2:0"], [input1, input2], [out]
+            )
 
 
 # Squeeze
 # -------
 
+
 def _test_squeeze(data, squeeze_dims=None):
     """ One iteration of squeeze """
 
@@ -1878,7 +2248,7 @@ def _test_squeeze(data, squeeze_dims=None):
         else:
             out = array_ops.squeeze(in_data)
 
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
 
 
 def test_forward_squeeze():
@@ -1891,6 +2261,7 @@ def test_forward_squeeze():
 # Quantize/DeQuantize
 # -------------------
 
+
 def _test_quantize_dequantize(data):
     """ One iteration of quantize and dequantize """
 
@@ -1913,8 +2284,9 @@ def _test_quantize_dequantize(data):
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
     tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-2)
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
+    )
 
 
 def _test_quantize_dequantize_const(data):
@@ -1939,14 +2311,15 @@ def _test_quantize_dequantize_const(data):
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
     tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-2)
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
+    )
 
 
 def test_forward_quantize_dequantize():
     """ Quantize Dequantize """
     data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
         _test_quantize_dequantize(data)
         _test_quantize_dequantize_const(data)
 
@@ -1955,6 +2328,7 @@ def test_forward_quantize_dequantize():
 # Pad
 # ---
 
+
 def _test_pad(data, mode="CONSTANT", quantized=False):
     """ One iteration of PAD """
 
@@ -1962,198 +2336,302 @@ def _test_pad(data, mode="CONSTANT", quantized=False):
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')]
+        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
 
         if quantized:
             # fake_quant will keep the tensors in float32 until the conversion in the session
-            input_range = {'inq_0': (-100, 100)}
-            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
-                                                                     min=-100,
-                                                                     max=100,
-                                                                     name="inq_0")]
-            out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
-            compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True,
-                                    input_range=input_range)
+            input_range = {"inq_0": (-100, 100)}
+            inq_data = [
+                tf.quantization.fake_quant_with_min_max_args(
+                    in_data[0], min=-100, max=100, name="inq_0"
+                )
+            ]
+            out = array_ops.pad(
+                inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
+            )
+            compare_tflite_with_tvm(
+                [data[0]], ["inq_0:0"], inq_data, [out], quantized=True, input_range=input_range
+            )
         else:
-            out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
-            compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
+            out = array_ops.pad(
+                in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
+            )
+            compare_tflite_with_tvm([data[0]], ["in:0"], in_data, [out])
 
 
 def test_forward_pad():
     """ Pad """
-    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)])
-    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
-               np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)])
-    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)])
-    _test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)])
-    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
-    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")
-    _test_pad([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True)
+    _test_pad(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_pad(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
+            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
+        ]
+    )
+    _test_pad(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_pad(
+        [
+            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_pad(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        mode="REFLECT",
+    )
+    _test_pad(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        mode="SYMMETRIC",
+    )
+    _test_pad(
+        [
+            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        quantized=True,
+    )
 
 
 #######################################################################
 # PADV2
 # -----
 
+
 def _test_padv2(data, mode="CONSTANT", quantized=False):
     """ One iteration of PADV2 """
 
-    assert (len(data) == 2 or len(data) == 3)
+    assert len(data) == 2 or len(data) == 3
 
     with_constant_values = len(data) == 3
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')]
+        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
 
         if quantized:
             # fake_quant will keep the tensors in float32 until the conversion in the session
-            input_range = {'inq_0': (-100, 100)}
-            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
-                                                                     min=-100,
-                                                                     max=100,
-                                                                     name="inq_0")]
+            input_range = {"inq_0": (-100, 100)}
+            inq_data = [
+                tf.quantization.fake_quant_with_min_max_args(
+                    in_data[0], min=-100, max=100, name="inq_0"
+                )
+            ]
             if with_constant_values:
-                in_constant_values = constant_op.constant(data[2], shape=data[2].shape, dtype='float32', name='in_constant_values')
-                inq_constant_values = tf.quantization.fake_quant_with_min_max_args(in_constant_values,
-                                                                                     min=-100,
-                                                                                   max=100,
-                                                                                   name='inq_constant_values')
-                out = array_ops.pad_v2(inq_data[0],
-                                          ops.convert_to_tensor(data[1], dtype=data[1].dtype),
-                                       constant_values=inq_constant_values,
-                                       mode=mode)
-                out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
+                in_constant_values = constant_op.constant(
+                    data[2], shape=data[2].shape, dtype="float32", name="in_constant_values"
+                )
+                inq_constant_values = tf.quantization.fake_quant_with_min_max_args(
+                    in_constant_values, min=-100, max=100, name="inq_constant_values"
+                )
+                out = array_ops.pad_v2(
+                    inq_data[0],
+                    ops.convert_to_tensor(data[1], dtype=data[1].dtype),
+                    constant_values=inq_constant_values,
+                    mode=mode,
+                )
+                out = tf.quantization.fake_quant_with_min_max_args(
+                    out, min=-100, max=100, name="out"
+                )
             else:
-                out = array_ops.pad_v2(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
-            compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True, input_range=input_range)
+                out = array_ops.pad_v2(
+                    inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
+                )
+            compare_tflite_with_tvm(
+                [data[0]], ["inq_0:0"], inq_data, [out], quantized=True, input_range=input_range
+            )
         else:
             if with_constant_values:
-                out = array_ops.pad_v2(in_data[0],
-                                       ops.convert_to_tensor(data[1], dtype=data[1].dtype),
-                                       constant_values= ops.convert_to_tensor(data[2], dtype=data[2].dtype),
-                                       mode=mode)
+                out = array_ops.pad_v2(
+                    in_data[0],
+                    ops.convert_to_tensor(data[1], dtype=data[1].dtype),
+                    constant_values=ops.convert_to_tensor(data[2], dtype=data[2].dtype),
+                    mode=mode,
+                )
             else:
-                out = array_ops.pad_v2(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
-            compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
+                out = array_ops.pad_v2(
+                    in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
+                )
+            compare_tflite_with_tvm([data[0]], ["in:0"], in_data, [out])
 
 
 def test_forward_padv2():
     """ PADV2 """
     # Tests without Constant_values
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)])
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
-               np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)])
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)])
-    _test_padv2([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)])
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")
-    _test_padv2([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True)
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
+            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        mode="REFLECT",
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        mode="SYMMETRIC",
+    )
+    _test_padv2(
+        [
+            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+        ],
+        quantized=True,
+    )
 
     # Tests with Constant_values
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
-            np.array([2], dtype=np.float32)])
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
-               np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
-            np.array([1], dtype=np.float32)])
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32),
-            np.array([-1], dtype=np.float32)])
-    _test_padv2([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32),
-            np.array([2], dtype=np.float32)])
-    _test_padv2([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
-               np.array([[1, 1], [2, 2]], dtype=np.int32),
-               np.array([2], dtype=np.uint8)], quantized=True)
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
+            np.array([2], dtype=np.float32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
+            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
+            np.array([1], dtype=np.float32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+            np.array([-1], dtype=np.float32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+            np.array([2], dtype=np.float32),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+            np.array([2], dtype=np.uint8),
+        ],
+        quantized=True,
+    )
 
     # Constant Values input can be scalar
-    _test_padv2([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
-               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
-               np.float32(2)])
-    _test_padv2([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
-                np.array([[1, 1], [2, 2]], dtype=np.int32),
-                np.uint8(10)], quantized=True)
+    _test_padv2(
+        [
+            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
+            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
+            np.float32(2),
+        ]
+    )
+    _test_padv2(
+        [
+            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+            np.array([[1, 1], [2, 2]], dtype=np.int32),
+            np.uint8(10),
+        ],
+        quantized=True,
+    )
 
 
 #######################################################################
 # EXPAND_DIMS
 # -----------
 
+
 def _test_expand_dims(input_shape, input_type, axis, quantized=False):
     """ One iteration of EXPAND_DIMS """
     with tf.Graph().as_default():
-        axis= ops.convert_to_tensor(axis, dtype=axis.dtype)
+        axis = ops.convert_to_tensor(axis, dtype=axis.dtype)
 
         if quantized:
             # ignoring input_type as quantized requires uint8
-            input = np.random.uniform(0, 256, input_shape).astype('uint8')
-            in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")
+            input = np.random.uniform(0, 256, input_shape).astype("uint8")
+            in_input = tf.placeholder(dtype="float32", shape=input.shape, name="input")
 
-            input_range = {'q_input': (-100, 100)}
+            input_range = {"q_input": (-100, 100)}
             inq_input = tf.quantization.fake_quant_with_min_max_args(
-                in_input,
-                min=-100,
-                max=100,
-                name="q_input")
+                in_input, min=-100, max=100, name="q_input"
+            )
 
             out = array_ops.expand_dims(inq_input, axis=axis)
-            out = tf.quantization.fake_quant_with_min_max_args(
-                out,
-                min=-100,
-                max=100,
-                name="out")
+            out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
 
             compare_tflite_with_tvm(
-                [input],
-                ["q_input"],
-                [inq_input],
-                [out],
-                quantized=True,
-                input_range=input_range)
+                [input], ["q_input"], [inq_input], [out], quantized=True, input_range=input_range
+            )
         else:
             input = np.random.uniform(-100, 100, input_shape).astype(input_type)
             in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
 
             out = array_ops.expand_dims(in_input, axis=axis)
 
-            compare_tflite_with_tvm(
-                [input],
-                ["input"],
-                [in_input],
-                [out])
+            compare_tflite_with_tvm([input], ["input"], [in_input], [out])
+
 
 def test_forward_expand_dims():
     """ EXPAND_DIMS """
     for quantized in [False, True]:
-        _test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0), quantized=quantized)
-        _test_expand_dims((1, 2, 3), 'int32', np.int32(-2), quantized=quantized)
-        _test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32), quantized=quantized)
+        _test_expand_dims((6, 2, 7, 5), "float32", np.int32(0), quantized=quantized)
+        _test_expand_dims((1, 2, 3), "int32", np.int32(-2), quantized=quantized)
+        _test_expand_dims((2, 4, 5), "float32", np.array([1], dtype=np.int32), quantized=quantized)
 
 
 #######################################################################
 # ONE_HOT
 # -------
 
-def _test_one_hot(indices, depth, on_value, off_value, axis = None):
+
+def _test_one_hot(indices, depth, on_value, off_value, axis=None):
     """ One iteration of One_Hot """
     with tf.Graph().as_default():
         in_indices = tf.placeholder(dtype=indices.dtype, shape=indices.shape, name="indices")
         in_depth = ops.convert_to_tensor(depth, dtype=depth.dtype)
         in_on_value = tf.placeholder(dtype=on_value.dtype, shape=on_value.shape, name="on_value")
-        in_off_value = tf.placeholder(dtype=off_value.dtype, shape=off_value.shape, name="off_value")
+        in_off_value = tf.placeholder(
+            dtype=off_value.dtype, shape=off_value.shape, name="off_value"
+        )
         if axis is not None:
             out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value, axis=axis)
         else:
@@ -2162,20 +2640,25 @@ def _test_one_hot(indices, depth, on_value, off_value, axis = None):
             [indices, on_value, off_value],
             ["indices", "on_value", "off_value"],
             [in_indices, in_on_value, in_off_value],
-            [out])
+            [out],
+        )
+
 
 def test_forward_one_hot():
     """ One_Hot """
     _test_one_hot(np.int32(2), np.int32(8), np.int32(1), np.int32(0))
     _test_one_hot(np.int32(4), np.int32(8), np.float32(1), np.float32(0))
     _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1))
-    _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0)
+    _test_one_hot(
+        np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0
+    )
 
 
 #######################################################################
 # Pack
 # ----
 
+
 def _test_pack(data, axis):
     """ One iteration of pack """
 
@@ -2184,7 +2667,8 @@ def _test_pack(data, axis):
     with tf.Graph().as_default():
         in_data = [
             array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
-            for idx, tensor in enumerate(data)]
+            for idx, tensor in enumerate(data)
+        ]
         out = array_ops.pack(in_data, axis=axis)
         name = ["in_{}:0".format(idx) for idx in range(len(data))]
 
@@ -2193,58 +2677,68 @@ def _test_pack(data, axis):
 
 def test_forward_pack():
     """ Pack """
-    _test_pack(
-        [np.arange(6).reshape((1, 2, 1, 3)),
-        np.arange(6).reshape((1, 2, 1, 3))], 1)
+    _test_pack([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
 
-    _test_pack(
-        [np.arange(6).reshape((3, 2)),
-         np.arange(6).reshape((3, 2))], 1)
+    _test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
 
     _test_pack(
-        [np.arange(6).reshape((2, 1, 1, 3)),
-         np.arange(6).reshape((2, 1, 1, 3)),
-         np.arange(6).reshape((2, 1, 1, 3))], 1)
+        [
+            np.arange(6).reshape((2, 1, 1, 3)),
+            np.arange(6).reshape((2, 1, 1, 3)),
+            np.arange(6).reshape((2, 1, 1, 3)),
+        ],
+        1,
+    )
 
 
 #######################################################################
 # Unpack
 # ------
 
+
 def _test_unpack(data, axis, num_unpacks):
     """ One iteration of UNPACK """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name='unpack')
-        out_names = ['out_' + str(n) + ':0' for n in range(num_unpacks)]
-        compare_tflite_with_tvm([data], 'Placeholder:0',  [in_data], out, out_names=out_names)
+        out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name="unpack")
+        out_names = ["out_" + str(n) + ":0" for n in range(num_unpacks)]
+        compare_tflite_with_tvm([data], "Placeholder:0", [in_data], out, out_names=out_names)
+
 
 def test_forward_unpack():
     """ UNPACK """
     _test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
     _test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
     # tflite 1.13 doesn't accept negative axis
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
-        _test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
-        _test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
+        _test_unpack(
+            np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3
+        )
+        _test_unpack(
+            np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2
+        )
 
 
 #######################################################################
 # Local response normalization
 # ----------------------------
 
+
 def _test_local_response_normalization(data, depth_radius, bias, alpha, beta):
     """ One iteration of LOCAL_RESPONSE_NORMALIZATION """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
-        out = nn_ops.local_response_normalization(in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
-        compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
+        out = nn_ops.local_response_normalization(
+            in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta
+        )
+        compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
 
 def test_forward_local_response_normalization():
     """ LOCAL_RESPONSE_NORMALIZATION """
-    data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 6, 4, 3)).astype("float32")
     # LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5)
 
 
@@ -2252,78 +2746,96 @@ def test_forward_local_response_normalization():
 # L2 normalization
 # ----------------
 
+
 def _test_l2_normalization(data, axis, fused_activation_function=None):
     """ One iteration of L2_NORMALIZATION """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = nn_impl.l2_normalize(in_data, axis)
         out = with_fused_activation_function(out, fused_activation_function)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_l2_normalization():
     """ L2_NORMALIZATION """
-    data = np.random.uniform(size=(3, 6, 4)).astype('float32')
+    data = np.random.uniform(size=(3, 6, 4)).astype("float32")
     _test_l2_normalization(data, axis=2)
     _test_l2_normalization(data, axis=2, fused_activation_function="RELU")
 
+
 #######################################################################
 # Logistic
 # --------
 
+
 def _test_logistic(data, quantized=False):
     """ One iteration of LOGISTIC """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
 
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-5, max=5, name="inq_0")
-            input_range = {'inq_0': (-5, 5)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-5, max=5, name="inq_0"
+            )
+            input_range = {"inq_0": (-5, 5)}
             out = math_ops.sigmoid(inq_data)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=1, name="out")
-            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = math_ops.sigmoid(in_data)
-            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
 
 def test_forward_logistic():
     """ LOGISTIC """
     _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
     _test_logistic(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
 
+
 #######################################################################
 # Softmax
 # -------
 
+
 def _test_softmax(data):
     """ One iteration of softmax """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = nn_ops.softmax(in_data)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_softmax():
     """ Softmax """
     _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
+
 ######################################################################
 # Log_softmax
 # -----------
 
+
 def _test_log_softmax(data, quantized=False):
     """ One iteration of log_softmax """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
 
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-10, max=10, name="inq_0")
-            input_range = {'inq_0': (-10, 10)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-10, max=10, name="inq_0"
+            )
+            input_range = {"inq_0": (-10, 10)}
             # tflite log_softmax supports only the case when axis is not specified
             out = nn_ops.log_softmax(inq_data)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=-20, max=0, name="out")
-            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = nn_ops.log_softmax(in_data)
-            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
 
 
 def test_forward_log_softmax():
@@ -2331,33 +2843,38 @@ def test_forward_log_softmax():
     _test_log_softmax(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
     _test_log_softmax(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
 
+
 #######################################################################
 # Tanh
 # ----
 
+
 def _test_tanh(data):
     """ One iteration of TANH """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = math_ops.tanh(in_data)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_tanh():
     """ TANH """
     _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
+
 #######################################################################
 # ReLu
 # ----
 
+
 def _test_relu(data, quantized=False):
     """ One iteration of ReLU """
 
     if quantized:
-        if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'):
+        if package_version.parse(tf.VERSION) < package_version.parse("2.1.0"):
             pytest.skip("Testcase requires tflite version >= 2.1.0")
         data_in = tf.keras.layers.Input(shape=data.shape[1:])
-        relu =  tf.keras.layers.ReLU()(data_in)
+        relu = tf.keras.layers.ReLU()(data_in)
         keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
         input_name = data_in.name.split(":")[0]
 
@@ -2370,140 +2887,180 @@ def _test_relu(data, quantized=False):
 
         tflite_output = run_tflite_graph(tflite_model_quant, data)
         tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+        )
     else:
         with tf.Graph().as_default():
             in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
             out = nn_ops.relu(in_data)
-            compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_relu():
     """ ReLU """
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True)
 
+
 #######################################################################
 # ReLU6
 # -----
 
+
 def _test_relu6(data, quantized=False):
     """ One iteration of ReLU6 """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
 
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-10, max=10, name="inq_0")
-            input_range = {'inq_0': (-10, 10)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-10, max=10, name="inq_0"
+            )
+            input_range = {"inq_0": (-10, 10)}
             out = nn_ops.relu6(inq_data)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=6, name="out")
-            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = nn_ops.relu6(in_data)
-            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
 
 def test_forward_relu6():
     """ ReLU6 """
     _test_relu6(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
     _test_relu6(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
 
+
 #######################################################################
 # Leaky_ReLU
 # ----------
 
+
 def _test_leaky_relu(data, alpha, quantized=False):
     """ One iteration of Leaky_ReLU """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
 
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-3, max=2, name="inq_0")
-            input_range = {'inq_0': (-3, 2)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-3, max=2, name="inq_0"
+            )
+            input_range = {"inq_0": (-3, 2)}
             out = nn_ops.leaky_relu(inq_data, alpha)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=-3, max=2, name="out")
-            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = nn_ops.leaky_relu(in_data, alpha)
-            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
 
 def test_forward_leaky_relu():
     """ Leaky_ReLU """
     _test_leaky_relu(np.random.uniform(-5, 5, (1, 6)).astype(np.float32), alpha=0.2)
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
-        _test_leaky_relu(np.random.uniform(0, 255, (2, 3)).astype(np.uint8), alpha=0.3, quantized=True)
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
+        _test_leaky_relu(
+            np.random.uniform(0, 255, (2, 3)).astype(np.uint8), alpha=0.3, quantized=True
+        )
+
 
 #######################################################################
 # ReLU_n1_to_1
 # ------------
 
+
 def _test_relu_n1_to_1(data, quantized=False):
     """ One iteration of ReLU_n1_to_1 """
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
 
         if quantized:
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-3, max=3, name="inq_0")
-            input_range = {'inq_0': (-3, 3)}
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-3, max=3, name="inq_0"
+            )
+            input_range = {"inq_0": (-3, 3)}
             # There is no such tf operation. The specific pattern will be replaced into RELU_N1_TO_1 by tflite
             out = math_ops.maximum(-1.0, math_ops.minimum(inq_data, 1.0))
             out = tf.quantization.fake_quant_with_min_max_args(out, min=-1, max=1, name="out")
-            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
+            compare_tflite_with_tvm(
+                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
+            )
         else:
             out = math_ops.maximum(-1.0, math_ops.minimum(in_data, 1.0))
-            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
 
 def test_forward_relu_n1_to_1():
     """ ReLU_n1_to_1 """
     _test_relu_n1_to_1(np.random.uniform(-3, 3, (1, 6)).astype(np.float32))
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_relu_n1_to_1(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
 
+
 #######################################################################
 # PReLU
 # -----
 
+
 def _test_prelu(data, alpha):
     """ One iteration of PReLU """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         # This specific pattern will be replaced into PRelu by tflite
         out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_prelu():
     """ PReLU """
-    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32"))
-    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))
+    _test_prelu(
+        np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"),
+        np.full((3,), 0.2, dtype="float32"),
+    )
+    _test_prelu(
+        np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"),
+        np.full((1, 1, 3), 0.2, dtype="float32"),
+    )
+
 
 #######################################################################
 # DepthToSpace
 # ------------
 
+
 def _test_depthtospace(data, block_size):
     """ One iteration of depth_to_space operation with given data and block size """
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = array_ops.depth_to_space(in_data, block_size)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_depthtospace():
     # DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
         _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
         _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
 
+
 #######################################################################
 # SpaceToDepth
 # ------------
 
+
 def _test_spacetodepth(data, block_size):
     """ One iteration of space_to_depth operation with given data and block size """
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = array_ops.space_to_depth(in_data, block_size)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+
 
 def test_forward_spacetodepth():
     _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
@@ -2514,20 +3071,22 @@ def test_forward_spacetodepth():
 # ReverseSequence
 # ---------------
 
+
 def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis):
     """ One iteration of reverse_sequence operation with given data and attributes """
 
     data = np.random.uniform(0, 100, size=shape).astype(dtype)
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(dtype=dtype, name="input", shape=shape)
-        out = tf.reverse_sequence(in_data, seq_lengths=seq_lengths, batch_axis=batch_axis,
-                                   seq_axis=seq_axis)
+        out = tf.reverse_sequence(
+            in_data, seq_lengths=seq_lengths, batch_axis=batch_axis, seq_axis=seq_axis
+        )
 
-        compare_tflite_with_tvm(data, 'input', [in_data], [out])
+        compare_tflite_with_tvm(data, "input", [in_data], [out])
 
 
 def test_forward_reverse_sequence():
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0)
         _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1)
         _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1)
@@ -2540,11 +3099,17 @@ def test_forward_reverse_sequence():
 # ---------------
 def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
     # tflite 1.13 convert method does not accept empty shapes
-    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         with tf.Graph().as_default():
-            indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices")
-            values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values")
-            oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype))
+            indices = tf.placeholder(
+                shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices"
+            )
+            values = tf.placeholder(
+                shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values"
+            )
+            oshape = tf.constant(
+                output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)
+            )
 
             if default_value == None:
                 output = tf.sparse_to_dense(indices, oshape, values)
@@ -2552,7 +3117,7 @@ def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_s
                     [sparse_indices, sparse_values],
                     ["indices", "values"],
                     [indices, values],
-                    [output]
+                    [output],
                 )
             else:
                 dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")
@@ -2561,11 +3126,12 @@ def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_s
                     [sparse_indices, sparse_values, default_value],
                     ["indices", "values", "default_value"],
                     [indices, values, dv],
-                    [output]
+                    [output],
                 )
 
+
 def test_forward_sparse_to_dense():
-    '''
+    """
     Works in tvm/topi/tensorflow. But tflite converter breaks this test case
     _test_sparse_to_dense(
         np.int32(1),
@@ -2573,54 +3139,57 @@ def test_forward_sparse_to_dense():
         np.int32(0),
         np.array([5]).astype("int32")
     )
-    '''
+    """
     # vector
     _test_sparse_to_dense(
         np.array([0, 1, 4]).astype("int32"),
         np.array([3, 3, 3]).astype("int32"),
         np.int32(0),
-        np.array([5]).astype("int32")
+        np.array([5]).astype("int32"),
     )
     # vector nXd
     _test_sparse_to_dense(
         np.array([[0, 0], [1, 2]]).astype("int32"),
         np.array([1, 2]).astype("int32"),
         np.int32(0),
-        np.array([3, 4]).astype("int32")
+        np.array([3, 4]).astype("int32"),
     )
     _test_sparse_to_dense(
         np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
         np.array([1, 2]).astype("int32"),
         np.int32(4),
-        np.array([2, 3, 4]).astype("int32")
+        np.array([2, 3, 4]).astype("int32"),
     )
     # floats
     _test_sparse_to_dense(
         np.array([0, 1, 4]).astype("int32"),
         np.array([3.1, 3.1, 3.1]).astype("float32"),
         np.float32(3.5),
-        np.array([5]).astype("int32")
+        np.array([5]).astype("int32"),
     )
     # default value not specified
     _test_sparse_to_dense(
         np.array([0, 1, 4]).astype("int32"),
         np.array([3.1, 3.1, 3.1]).astype("float32"),
         None,
-        np.array([5]).astype("int32")
+        np.array([5]).astype("int32"),
     )
 
+
 #######################################################################
 # Fully Connected
 # ---------------
 
+
 def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in_size=None):
     """ One iteration of fully connected """
 
     total_size_1 = np.prod(tensor_in_sizes)
     total_size_2 = np.prod(filter_in_sizes)
 
-    assert int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0], \
-        "input size and filter size are mismatched"
+    assert (
+        int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0]
+    ), "input size and filter size are mismatched"
 
     # Initializes the input tensor with array containing incrementing
     # numbers from 1.
@@ -2628,10 +3197,12 @@ def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in
     filter_array = np.arange(1, total_size_2 + 1, dtype=np.float32)
 
     with tf.Graph().as_default():
-        in_name="input"
-        in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype=np.float32, name=in_name) \
-            if const_input \
+        in_name = "input"
+        in_data = (
+            constant_op.constant(data_array, shape=tensor_in_sizes, dtype=np.float32, name=in_name)
+            if const_input
             else array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32, name=in_name)
+        )
 
         in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype=np.float32)
 
@@ -2648,10 +3219,7 @@ def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in
             out = nn_ops.bias_add(out, in_bias)
 
         data_array = np.reshape(data_array, tensor_in_sizes).astype(np.float32)
-        compare_tflite_with_tvm(data_array,
-                                [] if const_input else in_data.name,
-                                [in_data],
-                                [out])
+        compare_tflite_with_tvm(data_array, [] if const_input else in_data.name, [in_data], [out])
 
 
 def test_forward_fully_connected():
@@ -2667,6 +3235,7 @@ def test_forward_fully_connected():
 # REVERSE_V2
 # ----------
 
+
 def _test_reverse_v2(input_shape, axis, dtype):
     """ One iteration of REVERSE_V2 """
     with tf.Graph().as_default():
@@ -2676,23 +3245,21 @@ def _test_reverse_v2(input_shape, axis, dtype):
 
         out = array_ops.reverse(in_input, in_axis)
 
-        compare_tflite_with_tvm(
-            [input],
-            ["input"],
-            [in_input],
-            [out])
+        compare_tflite_with_tvm([input], ["input"], [in_input], [out])
+
 
 def test_forward_reverse_v2():
     """ REVERSE_V2 """
-    for dtype in ['float32', 'int32']:
-        _test_reverse_v2((5), np.array([0], dtype='int32'), dtype)
-        _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)
+    for dtype in ["float32", "int32"]:
+        _test_reverse_v2((5), np.array([0], dtype="int32"), dtype)
+        _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype="int32"), dtype)
 
 
 #######################################################################
 # MATRIX_SET_DIAG
 # ---------------
 
+
 def _test_matrix_set_diag(input_shape, input_type, quantized=False):
     """ One iteration of MATRIX_SET_DIAG """
     with tf.Graph().as_default():
@@ -2701,30 +3268,22 @@ def _test_matrix_set_diag(input_shape, input_type, quantized=False):
 
         if quantized:
             # ignoring input_type as quantized requires uint8
-            input = np.random.uniform(0, 256, input_shape).astype('uint8')
-            in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")
+            input = np.random.uniform(0, 256, input_shape).astype("uint8")
+            in_input = tf.placeholder(dtype="float32", shape=input.shape, name="input")
             inq_input = tf.quantization.fake_quant_with_min_max_args(
-                in_input,
-                min=-100,
-                max=100,
-                name="q_input")
+                in_input, min=-100, max=100, name="q_input"
+            )
 
-            diagonal = np.random.uniform(0, 256, diagonal_shape).astype('uint8')
-            in_diagonal = tf.placeholder(dtype='float32', shape=diagonal.shape, name="diagonal")
+            diagonal = np.random.uniform(0, 256, diagonal_shape).astype("uint8")
+            in_diagonal = tf.placeholder(dtype="float32", shape=diagonal.shape, name="diagonal")
             inq_diagonal = tf.quantization.fake_quant_with_min_max_args(
-                in_diagonal,
-                min=-100,
-                max=100,
-                name="q_diagonal")
+                in_diagonal, min=-100, max=100, name="q_diagonal"
+            )
 
-            input_range = {'q_input': (-100, 100), 'q_diagonal': (-100, 100)}
+            input_range = {"q_input": (-100, 100), "q_diagonal": (-100, 100)}
 
             out = array_ops.matrix_set_diag(inq_input, inq_diagonal)
-            out = tf.quantization.fake_quant_with_min_max_args(
-                out,
-                min=-100,
-                max=100,
-                name="out")
+            out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
 
             compare_tflite_with_tvm(
                 [input, diagonal],
@@ -2732,21 +3291,23 @@ def _test_matrix_set_diag(input_shape, input_type, quantized=False):
                 [inq_input, inq_diagonal],
                 [out],
                 quantized=True,
-                input_range=input_range)
+                input_range=input_range,
+            )
         else:
             input = np.random.uniform(0, 100, input_shape).astype(input_type)
             diagonal = np.random.uniform(0, 100, diagonal_shape).astype(input_type)
 
             in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
-            in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal")
+            in_diagonal = tf.placeholder(
+                dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal"
+            )
 
             out = array_ops.matrix_set_diag(in_input, in_diagonal)
 
             compare_tflite_with_tvm(
-                    [input, diagonal],
-                    ["input", "diagonal"],
-                    [in_input, in_diagonal],
-                    [out])
+                [input, diagonal], ["input", "diagonal"], [in_input, in_diagonal], [out]
+            )
+
 
 def test_forward_matrix_set_diag():
     """ MATRIX_SET_DIAG """
@@ -2764,6 +3325,7 @@ def test_forward_matrix_set_diag():
 # MATRIX_DIAG
 # -----------
 
+
 def _test_matrix_diag(diagonal_shape, dtype):
     """ One iteration of MATRIX_DIAG """
     with tf.Graph().as_default():
@@ -2773,11 +3335,9 @@ def _test_matrix_diag(diagonal_shape, dtype):
         out = array_ops.matrix_diag(in_diagonal)
 
         compare_tflite_with_tvm(
-                [diagonal],
-                ["diagonal"],
-                [in_diagonal],
-                [out],
-                experimental_new_converter=True)
+            [diagonal], ["diagonal"], [in_diagonal], [out], experimental_new_converter=True
+        )
+
 
 def test_forward_matrix_diag():
     """ MATRIX_DIAG """
@@ -2791,11 +3351,12 @@ def test_forward_matrix_diag():
 # Custom Operators
 # ----------------
 
+
 def test_detection_postprocess():
     tf_model_file = tf_testing.get_workload_official(
         "http://download.tensorflow.org/models/object_detection/"
         "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
-        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
     )
     converter = tf.lite.TFLiteConverter.from_frozen_graph(
         tf_model_file,
@@ -2804,7 +3365,7 @@ def test_detection_postprocess():
             "TFLite_Detection_PostProcess",
             "TFLite_Detection_PostProcess:1",
             "TFLite_Detection_PostProcess:2",
-            "TFLite_Detection_PostProcess:3"
+            "TFLite_Detection_PostProcess:3",
         ],
         input_shapes={
             "raw_outputs/box_encodings": (1, 1917, 4),
@@ -2815,15 +3376,23 @@ def test_detection_postprocess():
     converter.inference_type = tf.lite.constants.FLOAT
     tflite_model = converter.convert()
     np.random.seed(0)
-    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32')
-    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32')
+    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32")
+    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
-    tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
-                               ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
+    tvm_output = run_tvm_graph(
+        tflite_model,
+        [box_encodings, class_predictions],
+        ["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
+        num_output=4,
+    )
 
     # Check all output shapes are equal
-    assert all([tvm_tensor.shape == tflite_tensor.shape \
-                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+    assert all(
+        [
+            tvm_tensor.shape == tflite_tensor.shape
+            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
+        ]
+    )
 
     # Check valid count is the same
     assert tvm_output[3] == tflite_output[3]
@@ -2833,123 +3402,152 @@ def test_detection_postprocess():
     # tflite and tvm tensors for only valid boxes.
     for i in range(0, valid_count):
         # Check bounding box co-ords
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[0][0][i]),
+            np.squeeze(tflite_output[0][0][i]),
+            rtol=1e-5,
+            atol=1e-5,
+        )
 
         # Check the class
         # Stricter check to ensure class remains same
-        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]),
-                                np.squeeze(tflite_output[1][0][i]))
+        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]))
 
         # Check the score
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[2][0][i]),
+            np.squeeze(tflite_output[2][0][i]),
+            rtol=1e-5,
+            atol=1e-5,
+        )
 
 
 #######################################################################
 # Mobilenet
 # ---------
 
+
 def test_forward_mobilenet_v1():
     """Test the Mobilenet V1 TF Lite model."""
     # MobilenetV1
     tflite_model_file = tf_testing.get_workload_official(
         "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
-        "mobilenet_v1_1.0_224.tflite")
+        "mobilenet_v1_1.0_224.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 def test_forward_mobilenet_v2():
     """Test the Mobilenet V2 TF Lite model."""
     # MobilenetV2
     tflite_model_file = tf_testing.get_workload_official(
         "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
-        "mobilenet_v2_1.0_224.tflite")
+        "mobilenet_v2_1.0_224.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 #######################################################################
 # Mobilenet V3
 # ------------
 
+
 def test_forward_mobilenet_v3():
     """Test the Mobilenet V3 TF Lite model."""
     # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
-    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) < package_version.parse("1.15.0"):
         return
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
-        "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
+        "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 #######################################################################
 # Inception
 # ---------
 
+
 def test_forward_inception_v3_net():
     """Test the Inception V3 TF Lite model."""
     # InceptionV3
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
-        "inception_v3.tflite")
+        "inception_v3.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 299, 299, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 def test_forward_inception_v4_net():
     """Test the Inception V4 TF Lite model."""
     # InceptionV4
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
-        "inception_v4.tflite")
+        "inception_v4.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 299, 299, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 def test_forward_inception_v4_net_batched():
     """Test the Inception V4 TF Lite model."""
     # InceptionV4
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
-        "inception_v4.tflite")
+        "inception_v4.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(4, 299, 299, 3)).astype('float32')
+    data = np.random.uniform(size=(4, 299, 299, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
-    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
 
 def test_forward_qnn_inception_v1_net():
     """Test the Quantized TFLite Inception model."""
     # InceptionV1
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz",
-        "inception_v1_224_quant.tflite")
+        "inception_v1_224_quant.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
@@ -2961,17 +3559,19 @@ def test_forward_qnn_inception_v1_net():
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
     tvm_predictions = np.squeeze(tvm_output)
     tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
+
 def test_forward_qnn_mobilenet_v1_net():
     """Test the Quantized TFLite Mobilenet V1 model."""
     # MobilenetV1
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
-        "mobilenet_v1_1.0_224_quant.tflite")
+        "mobilenet_v1_1.0_224_quant.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
@@ -2983,17 +3583,19 @@ def test_forward_qnn_mobilenet_v1_net():
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
     tvm_predictions = np.squeeze(tvm_output)
     tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
+
 def test_forward_qnn_mobilenet_v2_net():
     """Test the Quantized TFLite Mobilenet V2 model."""
     # MobilenetV2
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz",
-        "mobilenet_v2_1.0_224_quant.tflite")
+        "mobilenet_v2_1.0_224_quant.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
@@ -3005,26 +3607,29 @@ def test_forward_qnn_mobilenet_v2_net():
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
     tvm_predictions = np.squeeze(tvm_output)
     tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
+
 #######################################################################
 # Mobilenet V3 Quantized
 # ----------------------
 
+
 def test_forward_qnn_mobilenet_v3_net():
     """Test the Quantized TFLite Mobilenet V3 model."""
     # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
-    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+    if package_version.parse(tf.VERSION) < package_version.parse("1.15.0"):
         pytest.skip("Unsupported in tflite < 1.15.0")
     else:
         pytest.skip("This segfaults with tensorflow 1.15.2 and above")
 
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
-        "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
+        "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
@@ -3036,7 +3641,7 @@ def test_forward_qnn_mobilenet_v3_net():
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
     tvm_predictions = np.squeeze(tvm_output)
     tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
@@ -3044,10 +3649,11 @@ def test_forward_qnn_mobilenet_v3_net():
 
 def test_forward_tflite2_qnn_resnet50():
     """Test the Quantized TFLite version 2.1.0 Resnet50 model."""
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
         tflite_model_file = download_testdata(
             "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite",
-            "resnet_50_quantized.tflite")
+            "resnet_50_quantized.tflite",
+        )
         with open(tflite_model_file, "rb") as f:
             tflite_model_buf = f.read()
 
@@ -3056,7 +3662,7 @@ def test_forward_tflite2_qnn_resnet50():
         tflite_output = run_tflite_graph(tflite_model_buf, data)
         tflite_predictions = np.squeeze(tflite_output)
         tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
         tvm_predictions = np.squeeze(tvm_output)
         tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
         tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
@@ -3064,10 +3670,11 @@ def test_forward_tflite2_qnn_resnet50():
 
 def test_forward_tflite2_qnn_inception_v1():
     """Test the Quantized TFLite version 2.1.0 Inception V1 model."""
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
         tflite_model_file = download_testdata(
             "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite",
-            "inception_v1_quantized.tflite")
+            "inception_v1_quantized.tflite",
+        )
         with open(tflite_model_file, "rb") as f:
             tflite_model_buf = f.read()
 
@@ -3076,7 +3683,7 @@ def test_forward_tflite2_qnn_inception_v1():
         tflite_output = run_tflite_graph(tflite_model_buf, data)
         tflite_predictions = np.squeeze(tflite_output)
         tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
         tvm_predictions = np.squeeze(tvm_output)
         tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
         tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
@@ -3084,10 +3691,11 @@ def test_forward_tflite2_qnn_inception_v1():
 
 def test_forward_tflite2_qnn_mobilenet_v2():
     """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model."""
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
         tflite_model_file = download_testdata(
             "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite",
-            "mobilenet_v2_quantized.tflite")
+            "mobilenet_v2_quantized.tflite",
+        )
         with open(tflite_model_file, "rb") as f:
             tflite_model_buf = f.read()
 
@@ -3096,7 +3704,7 @@ def test_forward_tflite2_qnn_mobilenet_v2():
         tflite_output = run_tflite_graph(tflite_model_buf, data)
         tflite_predictions = np.squeeze(tflite_output)
         tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
-        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
         tvm_predictions = np.squeeze(tvm_output)
         tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
         tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
@@ -3106,26 +3714,36 @@ def test_forward_tflite2_qnn_mobilenet_v2():
 # Quantized SSD Mobilenet
 # -----------------------
 
+
 def test_forward_qnn_coco_ssd_mobilenet_v1():
     """Test the quantized Coco SSD Mobilenet V1 TF Lite model."""
-    pytest.skip("LLVM bug - getExtendedVectorNumElements - "
-                + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a "
-                + "specific target, for example, llvm -mpcu=core-avx2")
+    pytest.skip(
+        "LLVM bug - getExtendedVectorNumElements - "
+        + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a "
+        + "specific target, for example, llvm -mpcu=core-avx2"
+    )
 
     tflite_model_file = tf_testing.get_workload_official(
         "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip",
-        "detect.tflite")
+        "detect.tflite",
+    )
 
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
     data = get_real_image_object_detection(300, 300)
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4)
+    tvm_output = run_tvm_graph(
+        tflite_model_buf, data, "normalized_input_image_tensor", num_output=4
+    )
 
     # Check all output shapes are equal
-    assert all([tvm_tensor.shape == tflite_tensor.shape \
-                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+    assert all(
+        [
+            tvm_tensor.shape == tflite_tensor.shape
+            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
+        ]
+    )
 
     # Check valid count is the same
     assert tvm_output[3] == tflite_output[3]
@@ -3142,42 +3760,57 @@ def test_forward_qnn_coco_ssd_mobilenet_v1():
         if tvm_output[2][0][i] > 0.6:
             # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2,
             # because of differences between for requantiize operator in TFLite and TVM.
-            tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]),
-                                        np.squeeze(tflite_output[0][0][i]),
-                                        rtol=1e-2, atol=1e-2)
+            tvm.testing.assert_allclose(
+                np.squeeze(tvm_output[0][0][i]),
+                np.squeeze(tflite_output[0][0][i]),
+                rtol=1e-2,
+                atol=1e-2,
+            )
 
             # Check the class
             # Stricter check to ensure class remains same
-            np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]),
-                                    np.squeeze(tflite_output[1][0][i]))
+            np.testing.assert_equal(
+                np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])
+            )
 
             # Check the score
-            tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]),
-                                        np.squeeze(tflite_output[2][0][i]),
-                                        rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(
+                np.squeeze(tvm_output[2][0][i]),
+                np.squeeze(tflite_output[2][0][i]),
+                rtol=1e-5,
+                atol=1e-5,
+            )
 
 
 #######################################################################
 # SSD Mobilenet
 # -------------
 
+
 def test_forward_coco_ssd_mobilenet_v1():
     """Test the FP32 Coco SSD Mobilenet V1 TF Lite model."""
     tflite_model_file = tf_testing.get_workload_official(
         "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz",
-        "ssd_mobilenet_v1_coco_2018_01_28.tflite")
+        "ssd_mobilenet_v1_coco_2018_01_28.tflite",
+    )
 
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
 
     np.random.seed(0)
-    data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 300, 300, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4)
+    tvm_output = run_tvm_graph(
+        tflite_model_buf, data, "normalized_input_image_tensor", num_output=4
+    )
 
     # Check all output shapes are equal
-    assert all([tvm_tensor.shape == tflite_tensor.shape \
-                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+    assert all(
+        [
+            tvm_tensor.shape == tflite_tensor.shape
+            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
+        ]
+    )
 
     # Check valid count is the same
     assert tvm_output[3] == tflite_output[3]
@@ -3187,14 +3820,23 @@ def test_forward_coco_ssd_mobilenet_v1():
     # tflite and tvm tensors for only valid boxes.
     for i in range(0, valid_count):
         # Check bounding box co-ords
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[0][0][i]),
+            np.squeeze(tflite_output[0][0][i]),
+            rtol=1e-5,
+            atol=1e-5,
+        )
         # Check the class
         np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]))
 
         # Check the score
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[2][0][i]),
+            np.squeeze(tflite_output[2][0][i]),
+            rtol=1e-5,
+            atol=1e-5,
+        )
+
 
 #######################################################################
 # MediaPipe
@@ -3204,21 +3846,23 @@ def test_forward_mediapipe_hand_landmark():
     # MediaPipe 2D hand landmark TF
     tflite_model_file = download_testdata(
         "https://github.com/google/mediapipe/raw/v0.7.4/mediapipe/models/hand_landmark.tflite",
-        "hand_landmark.tflite")
+        "hand_landmark.tflite",
+    )
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    data = np.random.uniform(size=(1, 256, 256, 3)).astype('float32')
+    data = np.random.uniform(size=(1, 256, 256, 3)).astype("float32")
     tflite_output = run_tflite_graph(tflite_model_buf, data)
-    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input_1', num_output=2)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input_1", num_output=2)
     for i in range(2):
-        tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), rtol=1e-5, atol=1e-5
+        )
 
 
 #######################################################################
 # Main
 # ----
-if __name__ == '__main__':
+if __name__ == "__main__":
     # BatchToSpaceND
     test_forward_batch_to_space_nd()
 
@@ -3322,8 +3966,8 @@ if __name__ == '__main__':
     test_forward_qnn_inception_v1_net()
     test_forward_qnn_mobilenet_v1_net()
     test_forward_qnn_mobilenet_v2_net()
-    #This also fails with a segmentation fault in my run
-    #with Tflite 1.15.2
+    # This also fails with a segmentation fault in my run
+    # with Tflite 1.15.2
     test_forward_qnn_mobilenet_v3_net()
     test_forward_qnn_coco_ssd_mobilenet_v1()
 
index 0bfe61a..d4364c8 100644 (file)
@@ -24,10 +24,10 @@ import numpy as np
 def test_dot():
     nn = 12
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    k = te.reduce_axis((0, n), 'k')
-    C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    k = te.reduce_axis((0, n), "k")
+    C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name="C")
     s = te.create_schedule(C.op)
 
     def verify(target):
@@ -36,12 +36,12 @@ def test_dot():
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(nn,)).astype(B.dtype), ctx)
-        c  = tvm.nd.array(np.zeros((1,), dtype=C.dtype), ctx)
+        c = tvm.nd.array(np.zeros((1,), dtype=C.dtype), ctx)
         f(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-4)
+        tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-4)
 
     verify("llvm")
 
+
 if __name__ == "__main__":
     test_dot()
index 6195bf6..dda494d 100644 (file)
@@ -21,12 +21,13 @@ import numpy as np
 import time
 import tvm.testing
 
+
 @tvm.testing.requires_gpu
 def test_exp():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: te.exp(A(*i)), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: te.exp(A(*i)), name="B")
     s = te.create_schedule(B.op)
     # create iter var and assign them tags.
     num_thread = 8
@@ -39,30 +40,28 @@ def test_exp():
         if not tvm.testing.device_enabled(host):
             return
         ctx = tvm.context(device, 0)
-        fexp = tvm.build(s, [A, B],
-                         device, host,
-                         name="myexp")
+        fexp = tvm.build(s, [A, B], device, host, name="myexp")
         ctx = tvm.context(device, 0)
         # launch the kernel.
         n = 1024
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         fexp(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
+        tvm.testing.assert_allclose(b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
 
     check_device("opencl -device=intel_graphics")
     check_device("cuda", "llvm")
     check_device("vulkan")
 
+
 @tvm.testing.requires_gpu
 def test_fmod():
     # graph
     def run(dtype):
-        n = te.size_var('n')
-        A = te.placeholder((n,), name='A', dtype=dtype)
-        B = te.placeholder((n,), name='B', dtype=dtype)
-        C = te.compute(A.shape, lambda *i: te.fmod(A(*i), B(*i)), name='C')
+        n = te.size_var("n")
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.placeholder((n,), name="B", dtype=dtype)
+        C = te.compute(A.shape, lambda *i: te.fmod(A(*i), B(*i)), name="C")
         s = te.create_schedule(C.op)
         # create iter var and assign them tags.
         num_thread = 8
@@ -85,7 +84,7 @@ def test_fmod():
             b_np = (np.random.uniform(size=n) * 256).astype(B.dtype)
 
             # "fix" the values in a and b to avoid the result being too small
-            b_np += ((b_np < 2.0) * 2)
+            b_np += (b_np < 2.0) * 2
             a_np[np.abs(np.fmod(a_np, b_np)) < 1] += 1
 
             a = tvm.nd.array(a_np, ctx)
@@ -93,9 +92,8 @@ def test_fmod():
             c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
             ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1)
             tcost = ftimer(a, b, c).mean
-            #fmod(a, b, c)
-            np.testing.assert_allclose(
-                c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5)
+            # fmod(a, b, c)
+            np.testing.assert_allclose(c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5)
 
         check_device("cuda")
         check_device("opencl -device=intel_graphics")
@@ -103,17 +101,15 @@ def test_fmod():
 
     run("float32")
 
+
 @tvm.testing.requires_gpu
 def test_multiple_cache_write():
     # graph
     n = tvm.runtime.convert(1024)
-    A0 = te.placeholder((n,), name='A0', dtype = "float32")
-    A1 = te.placeholder((n,), name='A1', dtype = "float32")
-    B0, B1 = te.compute((n,),
-            lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
-            name='B')
-    C = te.compute((n,), lambda *i: B0(*i) + B1(*i),
-            name='C')
+    A0 = te.placeholder((n,), name="A0", dtype="float32")
+    A1 = te.placeholder((n,), name="A1", dtype="float32")
+    B0, B1 = te.compute((n,), lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)), name="B")
+    C = te.compute((n,), lambda *i: B0(*i) + B1(*i), name="C")
     s = te.create_schedule(C.op)
     # create iter var and assign them tags.
     num_thread = 8
@@ -130,9 +126,7 @@ def test_multiple_cache_write():
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
             return
-        func = tvm.build(s, [A0, A1, C],
-                         device, host,
-                         name="multiple_cache_write")
+        func = tvm.build(s, [A0, A1, C], device, host, name="multiple_cache_write")
         ctx = tvm.context(device, 0)
         # launch the kernel.
         n = 1024
@@ -141,18 +135,19 @@ def test_multiple_cache_write():
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         func(a0, a1, c)
         tvm.testing.assert_allclose(
-            c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
-            rtol=1e-5)
+            c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()), rtol=1e-5
+        )
 
     check_device("cuda", "llvm")
     check_device("vulkan")
     check_device("opencl")
 
+
 def test_log_pow_llvm():
     # graph
-    n = te.size_var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: te.power(te.log(A(*i)), 2.0), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: te.power(te.log(A(*i)), 2.0), name="B")
     s = te.create_schedule(B.op)
     # create iter var and assign them tags.
     bx, tx = s[B].split(B.op.axis[0], factor=32)
@@ -160,8 +155,7 @@ def test_log_pow_llvm():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    flog = tvm.build(s, [A, B],
-                     "llvm", name="mylog")
+    flog = tvm.build(s, [A, B], "llvm", name="mylog")
     ctx = tvm.cpu(0)
     # launch the kernel.
     n = 1028
@@ -170,9 +164,8 @@ def test_log_pow_llvm():
     repeat = 10
     ftimer = flog.time_evaluator(flog.entry_name, ctx, number=1, repeat=repeat)
     res = ftimer(a, b)
-    assert(len(res.results) == repeat)
-    tvm.testing.assert_allclose(
-        b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
+    assert len(res.results) == repeat
+    tvm.testing.assert_allclose(b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
 
 
 @tvm.testing.uses_gpu
@@ -180,8 +173,8 @@ def test_popcount():
     def run(dtype):
         # graph
         n = tvm.runtime.convert(1024)
-        A = te.placeholder((n,), name='A', dtype=dtype)
-        B = te.compute(A.shape, lambda *i: tvm.tir.popcount(A(*i)), name='B')
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.compute(A.shape, lambda *i: tvm.tir.popcount(A(*i)), name="B")
         s = te.create_schedule(B.op)
         # simple schedule
         num_thread = 8
@@ -203,7 +196,8 @@ def test_popcount():
             b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
             func(a, b)
             tvm.testing.assert_allclose(
-                b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
+                b.asnumpy(), list(map(lambda x: bin(x).count("1"), a.asnumpy())), rtol=1e-5
+            )
 
         check_device("llvm")
         check_device("cuda")
@@ -211,25 +205,26 @@ def test_popcount():
         if dtype == "uint32":
             check_device("metal")
             check_device("vulkan")
-    run('uint32')
-    run('uint64')
+
+    run("uint32")
+    run("uint64")
 
 
 @tvm.testing.requires_gpu
 def test_add():
     def run(dtype):
         # graph
-        n = te.size_var('n')
-        A = te.placeholder((n,), name='A', dtype=dtype)
-        B = te.placeholder((n,), name='B', dtype=dtype)
+        n = te.size_var("n")
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.placeholder((n,), name="B", dtype=dtype)
         bias = te.var("bias", dtype=dtype)
         scale = te.var("scale", dtype=dtype)
-        C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+        C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
         # schedule
         s = te.create_schedule(C.op)
         # create iter var and assign them tags.
         num_thread = 16
-        bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
+        bx, x = s[C].split(C.op.axis[0], factor=num_thread * 4)
         tx, x = s[C].split(x, nparts=num_thread)
         _, x = s[C].split(x, factor=4)
         s[C].bind(bx, te.thread_axis("blockIdx.x"))
@@ -242,9 +237,7 @@ def test_add():
             if not tvm.testing.device_enabled(device):
                 print("skip because %s is not enabled.." % device)
                 return
-            fadd = tvm.build(s, [A, B, C],
-                             device,
-                             name="myadd")
+            fadd = tvm.build(s, [A, B, C], device, name="myadd")
 
             # launch the kernel.
             n = 1024
@@ -253,8 +246,7 @@ def test_add():
             c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
             ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1)
             tcost = ftimer(a, b, c).mean
-            tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + b.asnumpy(), rtol=1e-6)
+            tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy(), rtol=1e-6)
 
         check_device("opencl")
         check_device("cuda")
@@ -272,8 +264,8 @@ def test_add():
 def try_warp_memory():
     """skip this in default test because it require higher arch"""
     m = 128
-    A = te.placeholder((m,), name='A')
-    B = te.compute((m,), lambda i: A[i] + 3, name='B')
+    A = te.placeholder((m,), name="A")
+    B = te.compute((m,), lambda i: A[i] + 3, name="B")
     warp_size = 32
     s = te.create_schedule(B.op)
     AA = s.cache_read(A, "warp", [B])
@@ -288,7 +280,7 @@ def try_warp_memory():
 
     @tvm.register_func
     def tvm_callback_cuda_compile(code):
-        ptx =  nvcc.compile_cuda(code, target="ptx")
+        ptx = nvcc.compile_cuda(code, target="ptx")
         return ptx
 
     # one line to build the function.
@@ -301,8 +293,8 @@ def try_warp_memory():
         a = tvm.nd.array((np.random.uniform(size=m) * 256).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
         f(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
+        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
+
     check_device("cuda")
 
 
index abcddc4..ab05f7f 100644 (file)
@@ -22,17 +22,19 @@ import os
 os.environ["XCL_EMULATION_MODE"] = "1"
 os.environ["CL_CONTEXT_EMULATOR_DEVICE_INTELFPGA"] = "1"
 
+
 @tvm.register_func
 def tvm_callback_vhls_postproc(code):
     """Hook to inspect the Vivado HLS code before actually run it"""
     print(code)
     return code
 
+
 def test_exp():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: te.exp(A(*i)), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: te.exp(A(*i)), name="B")
     s = te.create_schedule(B.op)
     # create iter var and assign them tags.
     px, x = s[B].split(B.op.axis[0], nparts=1)
@@ -43,17 +45,14 @@ def test_exp():
         if not tvm.testing.device_enabled(device):
             return
         ctx = tvm.context(device, 0)
-        fexp = tvm.build(s, [A, B],
-                         device, host,
-                         name="myexp")
+        fexp = tvm.build(s, [A, B], device, host, name="myexp")
         ctx = tvm.context(device, 0)
         # launch the kernel.
         n = 1024
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         fexp(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
+        tvm.testing.assert_allclose(b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
 
     check_device("sdaccel")
     if "AWS_PLATFORM" in os.environ:
@@ -61,13 +60,14 @@ def test_exp():
 
     check_device("aocl_sw_emu")
 
+
 def test_multi_kernel():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
-    D = te.compute(A.shape, lambda *i: A(*i) + C(*i), name='D')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
+    D = te.compute(A.shape, lambda *i: A(*i) + C(*i), name="D")
     s = te.create_schedule(D.op)
     # create iter var and assign them tags.
     px, x = s[C].split(C.op.axis[0], nparts=1)
@@ -80,9 +80,7 @@ def test_multi_kernel():
         if not tvm.testing.device_enabled(device):
             return
         ctx = tvm.context(device, 0)
-        fadd = tvm.build(s, [A, B, C, D],
-                         device, host,
-                         name="myadd")
+        fadd = tvm.build(s, [A, B, C, D], device, host, name="myadd")
         ctx = tvm.context(device, 0)
         # launch the kernel.
         n = 1024
@@ -91,8 +89,7 @@ def test_multi_kernel():
         c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), ctx)
         d = tvm.nd.array(np.random.uniform(size=n).astype(D.dtype), ctx)
         fadd(a, b, c, d)
-        tvm.testing.assert_allclose(
-            d.asnumpy(), a.asnumpy() * 2 + b.asnumpy(), rtol=1e-5)
+        tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy() * 2 + b.asnumpy(), rtol=1e-5)
 
     check_device("sdaccel")
     check_device("aocl_sw_emu")
index b2698f3..42612c2 100644 (file)
@@ -28,13 +28,10 @@ def test_gemm():
     n = tvm.runtime.convert(nn)
     m = n
     l = n
-    A = te.placeholder((n, l), name='A')
-    B = te.placeholder((m, l), name='B')
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute(
-        (n, m),
-        lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k),
-        name='CC')
+    A = te.placeholder((n, l), name="A")
+    B = te.placeholder((m, l), name="B")
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), name="CC")
     # schedule
     s = te.create_schedule(C.op)
     xtile, ytile = 32, 32
@@ -62,7 +59,6 @@ def test_gemm():
     yo, xo = CC.op.axis
     s[CC].reorder(k, yo, xo)
 
-
     s[CC].compute_at(s[C], tx)
     s[AA].compute_at(s[CC], k)
     s[BB].compute_at(s[CC], k)
@@ -103,8 +99,7 @@ def test_gemm():
         ftimer = f.time_evaluator(f.entry_name, ctx, number=1)
         tcost = ftimer(a, b, c).mean
         print("%s: exec=%g sec/op" % (ctx, tcost))
-        tvm.testing.assert_allclose(
-            c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
+        tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
 
     check_device("vulkan")
     check_device("nvptx -mcpu=sm_20")
@@ -113,5 +108,6 @@ def test_gemm():
     check_device("opencl")
     check_device("cuda")
 
+
 if __name__ == "__main__":
     test_gemm()
index 35980ed..b02b798 100644 (file)
@@ -24,12 +24,12 @@ import tvm.testing
 def test_reduce_prims():
     def test_prim(reducer, np_reducer):
         # graph
-        n = tvm.te.size_var('n')
-        m = tvm.te.size_var('m')
-        A = te.placeholder((n, m), name='A')
-        R = te.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R')
+        n = tvm.te.size_var("n")
+        m = tvm.te.size_var("m")
+        A = te.placeholder((n, m), name="A")
+        R = te.compute((n,), lambda i: tvm.tir.Select((i > 1), 1, 0), name="R")
         k = te.reduce_axis((0, m))
-        B = te.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
+        B = te.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i] == 1)), name="B")
         # schedule
         s = te.create_schedule(B.op)
         # create iter var and assign them tags.
@@ -45,10 +45,7 @@ def test_reduce_prims():
             if not tvm.testing.device_enabled(device):
                 print("skip because %s is not enabled.." % device)
                 return
-            freduce = tvm.build(s,
-                             args=[A, B],
-                             target=device, target_host=host,
-                             name="myreduce")
+            freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce")
             # launch the kernel.
             n = 1028
             m = 129
@@ -66,15 +63,17 @@ def test_reduce_prims():
         check_device("cuda")
         check_device("opencl")
         check_device("rocm")
+
     test_prim(te.sum, np.sum)
     test_prim(tvm.te.min, np.amin)
     test_prim(tvm.te.max, np.amax)
 
+
 def test_init_imm():
     n = tvm.runtime.convert(1027)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     k = te.reduce_axis((0, n))
-    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name='B')
+    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name="B")
     # schedule
     s = te.create_schedule(B.op)
     # one line to build the function.
@@ -83,27 +82,25 @@ def test_init_imm():
             return
         ctx = tvm.cpu(0)
         fapi = tvm.lower(s, args=[A, B])
-        fsum = tvm.build(fapi,
-                         target=target,
-                         name="mysum")
+        fsum = tvm.build(fapi, target=target, name="mysum")
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
-        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+        b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
         fsum(a, b)
         res = 10.0 + np.sum(a.asnumpy(), axis=0)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target()
 
+
 def test_init():
     n = tvm.runtime.convert(1027)
-    A = te.placeholder((n,n), name='A')
-    C = te.placeholder((n,n), name='C')
-    I = te.placeholder((n,n), name='I')
+    A = te.placeholder((n, n), name="A")
+    C = te.placeholder((n, n), name="C")
+    I = te.placeholder((n, n), name="I")
     k = te.reduce_axis((0, n))
-    B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+    B = te.compute((n, n), lambda i, j: te.sum(A[i, k] * C[k, j], axis=k, init=I[i, j]), name="B")
 
     # schedule
     s = te.create_schedule(B.op)
@@ -117,22 +114,22 @@ def test_init():
         mmult = tvm.build(fapi, target=target, name="mmult")
         # launch the kernel.
         n = 1027
-        a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
-        c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
-        ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
-        b  = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=(n, n)).astype(A.dtype), ctx)
+        c = tvm.nd.array(np.random.uniform(size=(n, n)).astype(C.dtype), ctx)
+        ii = tvm.nd.array(np.random.uniform(size=(n, n)).astype(B.dtype), ctx)
+        b = tvm.nd.array(np.zeros((n, n), dtype=B.dtype), ctx)
         mmult(a, c, ii, b)
-        res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        res = ii.asnumpy() + np.matmul(a.asnumpy(), c.asnumpy())
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target()
 
+
 def test_rfactor():
     n = tvm.runtime.convert(1027)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     k = te.reduce_axis((0, n))
-    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name='B')
+    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
     # schedule
     s = te.create_schedule(B.op)
     kf, ki = s[B].split(k, nparts=4)
@@ -144,27 +141,25 @@ def test_rfactor():
             return
         ctx = tvm.cpu(0)
         fapi = tvm.lower(s, args=[A, B])
-        fsum = tvm.build(fapi,
-                         target=target,
-                         name="mysum")
+        fsum = tvm.build(fapi, target=target, name="mysum")
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
-        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+        b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
         fsum(a, b)
         res = np.sum(a.asnumpy(), axis=0)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target()
 
+
 def test_rfactor_init():
     n = tvm.runtime.convert(1027)
-    A = te.placeholder((n,n), name='A')
-    C = te.placeholder((n,n), name='C')
-    I = te.placeholder((n,n), name='I')
+    A = te.placeholder((n, n), name="A")
+    C = te.placeholder((n, n), name="C")
+    I = te.placeholder((n, n), name="I")
     k = te.reduce_axis((0, n))
-    B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+    B = te.compute((n, n), lambda i, j: te.sum(A[i, k] * C[k, j], axis=k, init=I[i, j]), name="B")
 
     # schedule
     s = te.create_schedule(B.op)
@@ -181,22 +176,22 @@ def test_rfactor_init():
         mmult = tvm.build(fapi, target=target, name="mmult")
         # launch the kernel.
         n = 1027
-        a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
-        c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
-        ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
-        b  = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=(n, n)).astype(A.dtype), ctx)
+        c = tvm.nd.array(np.random.uniform(size=(n, n)).astype(C.dtype), ctx)
+        ii = tvm.nd.array(np.random.uniform(size=(n, n)).astype(B.dtype), ctx)
+        b = tvm.nd.array(np.zeros((n, n), dtype=B.dtype), ctx)
         mmult(a, c, ii, b)
-        res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        res = ii.asnumpy() + np.matmul(a.asnumpy(), c.asnumpy())
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target()
 
+
 def test_rfactor_factor_axis():
     n = tvm.runtime.convert(1027)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     k = te.reduce_axis((0, n))
-    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name='B')
+    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
     # schedule
     s = te.create_schedule(B.op)
     kf, ki = s[B].split(k, nparts=4)
@@ -208,17 +203,14 @@ def test_rfactor_factor_axis():
             return
         ctx = tvm.cpu(0)
         fapi = tvm.lower(s, args=[A, B])
-        fsum = tvm.build(fapi,
-                         target=target,
-                         name="mysum")
+        fsum = tvm.build(fapi, target=target, name="mysum")
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
-        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+        b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
         fsum(a, b)
         res = np.sum(a.asnumpy(), axis=0)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target()
 
@@ -229,10 +221,10 @@ def test_rfactor_threads():
     mm = 10
     n = tvm.runtime.convert(nn)
     m = tvm.runtime.convert(mm)
-    A = te.placeholder((m, n), name='A')
+    A = te.placeholder((m, n), name="A")
     k = te.reduce_axis((0, n))
     nthread = 16
-    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k, where=(i>1)), name='B')
+    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k, where=(i > 1)), name="B")
     # schedule
     s = te.create_schedule(B.op)
     ko, kf = s[B].split(k, factor=nthread)
@@ -254,19 +246,16 @@ def test_rfactor_threads():
             return
 
         fapi = tvm.lower(s, args=[A, B])
-        fsum = tvm.build(fapi,
-                         target=device,
-                         name="mysum")
+        fsum = tvm.build(fapi, target=device, name="mysum")
         # launch the kernel.
         n = nn
         m = mm
         a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
-        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
+        b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
         fsum(a, b)
         res = np.sum(a.asnumpy(), axis=1)
         res[:2] = 0
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target("vulkan")
     check_target("cuda")
@@ -274,16 +263,17 @@ def test_rfactor_threads():
     check_target("opencl")
     check_target("rocm")
 
+
 @tvm.testing.requires_gpu
 def test_rfactor_elemwise_threads():
     n = 1025
     m = 10
-    A = te.placeholder((m, n), name='A')
+    A = te.placeholder((m, n), name="A")
     k = te.reduce_axis((0, n))
     nthread = 16
-    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
-    BB = te.compute((m,), lambda i: B[i] + 1, name='BB')
-    C = te.compute((m,), lambda i: BB[i] + 1, name='C')
+    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name="B")
+    BB = te.compute((m,), lambda i: B[i] + 1, name="BB")
+    C = te.compute((m,), lambda i: BB[i] + 1, name="C")
     # schedule
     s = te.create_schedule(C.op)
     s[BB].compute_inline()
@@ -309,16 +299,13 @@ def test_rfactor_elemwise_threads():
             print("skip because %s is not enabled.." % device)
             return
         fapi = tvm.lower(s, args=[A, C])
-        fsum = tvm.build(fapi,
-                         target=device,
-                         name="mysum")
+        fsum = tvm.build(fapi, target=device, name="mysum")
         # launch the kernel.
         a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
-        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
+        b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
         fsum(a, b)
         res = np.sum(a.asnumpy(), axis=1) + 2
-        tvm.testing.assert_allclose(
-            b.asnumpy(), res, rtol=1e-4)
+        tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
 
     check_target("vulkan")
     check_target("cuda")
@@ -326,6 +313,7 @@ def test_rfactor_elemwise_threads():
     check_target("opencl")
     check_target("rocm")
 
+
 def test_argmax():
     def fcombine(x, y):
         lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
@@ -335,38 +323,34 @@ def test_argmax():
     def fidentity(t0, t1):
         return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
 
-    argmax = te.comm_reducer(fcombine,
-                              fidentity,
-                              name='argmax')
-    m = te.size_var('m')
-    n = te.size_var('n')
-    idx = te.placeholder((m, n), name='idx', dtype='int32')
-    val = te.placeholder((m, n), name='val', dtype='float32')
-    k = te.reduce_axis((0, n), 'k')
-    T0, T1 = te.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
+    argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
+    m = te.size_var("m")
+    n = te.size_var("n")
+    idx = te.placeholder((m, n), name="idx", dtype="int32")
+    val = te.placeholder((m, n), name="val", dtype="float32")
+    k = te.reduce_axis((0, n), "k")
+    T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
     s = te.create_schedule(T0.op)
 
     def check_target():
-        device = 'cpu'
+        device = "cpu"
         if not tvm.testing.device_enabled(device):
             print("skip because %s is not enabled.." % device)
             return
         ctx = tvm.context(device, 0)
         fapi = tvm.lower(s, args=[idx, val, T0, T1])
-        fargmax = tvm.build(fapi,
-                            target='llvm',
-                            name="argmax")
+        fargmax = tvm.build(fapi, target="llvm", name="argmax")
 
         mm = 12
         nn = 16
-        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
-        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
+        np_idx = np.repeat(np.arange(nn, dtype="int32").reshape(1, nn), mm, axis=0)
+        np_val = np.random.uniform(size=(mm, nn)).astype("float32")
         np_res = np.argmax(np_val, axis=1)
 
-        nd_idx  = tvm.nd.array(np_idx, ctx)
-        nd_val  = tvm.nd.array(np_val, ctx)
-        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
-        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
+        nd_idx = tvm.nd.array(np_idx, ctx)
+        nd_val = tvm.nd.array(np_val, ctx)
+        nd_res0 = tvm.nd.array(np.zeros(mm, dtype="int32"), ctx)
+        nd_res1 = tvm.nd.array(np.zeros(mm, dtype="float32"), ctx)
         fargmax(nd_idx, nd_val, nd_res0, nd_res1)
         tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
 
@@ -383,18 +367,16 @@ def test_rfactor_argmax():
     def fidentity(t0, t1):
         return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
 
-    argmax = te.comm_reducer(fcombine,
-                              fidentity,
-                              name='argmax')
+    argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
 
     nn = 1027
     mm = 10
     n = tvm.runtime.convert(nn)
     m = tvm.runtime.convert(mm)
-    A0 = te.placeholder((m, n), name='A0', dtype='int32')
-    A1 = te.placeholder((m, n), name='A1', dtype='float32')
+    A0 = te.placeholder((m, n), name="A0", dtype="int32")
+    A1 = te.placeholder((m, n), name="A1", dtype="float32")
     k = te.reduce_axis((0, n))
-    B0, B1 = te.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
+    B0, B1 = te.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name="B")
 
     # schedule
     s = te.create_schedule(B0.op)
@@ -416,18 +398,16 @@ def test_rfactor_argmax():
             print("skip because %s is not enabled.." % device)
             return
         fapi = tvm.lower(s, args=[A0, A1, B0, B1])
-        fargmax = tvm.build(fapi,
-                            target=device,
-                            name="argmax")
+        fargmax = tvm.build(fapi, target=device, name="argmax")
 
-        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
-        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
+        np_idx = np.repeat(np.arange(nn, dtype="int32").reshape(1, nn), mm, axis=0)
+        np_val = np.random.uniform(size=(mm, nn)).astype("float32")
         np_res = np.argmax(np_val, axis=1)
 
-        nd_idx  = tvm.nd.array(np_idx, ctx)
-        nd_val  = tvm.nd.array(np_val, ctx)
-        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
-        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
+        nd_idx = tvm.nd.array(np_idx, ctx)
+        nd_val = tvm.nd.array(np_val, ctx)
+        nd_res0 = tvm.nd.array(np.zeros(mm, dtype="int32"), ctx)
+        nd_res1 = tvm.nd.array(np.zeros(mm, dtype="float32"), ctx)
         fargmax(nd_idx, nd_val, nd_res0, nd_res1)
         tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
 
@@ -435,6 +415,7 @@ def test_rfactor_argmax():
     check_target("vulkan")
     check_target("rocm")
 
+
 @tvm.testing.requires_gpu
 def test_warp_reduction1():
     nthx = 32
@@ -450,9 +431,9 @@ def test_warp_reduction1():
             return
 
         # compute
-        A = te.placeholder((m, n), name='A')
+        A = te.placeholder((m, n), name="A")
         k = te.reduce_axis((0, n))
-        B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B')
+        B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name="B")
         s = te.create_schedule(B.op)
 
         # schedule
@@ -467,7 +448,7 @@ def test_warp_reduction1():
 
         # validation
         func = tvm.build(s, [A, B], device, name="warp_reduction")
-        a_np = np.random.uniform(size=(m,n)).astype(A.dtype)
+        a_np = np.random.uniform(size=(m, n)).astype(A.dtype)
         b_np = np.zeros((m,), dtype=A.dtype)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(b_np, ctx)
@@ -482,6 +463,7 @@ def test_warp_reduction1():
     # This is a bug in normal reduction.
     # check_target("cuda", m=10, n=37)
 
+
 @tvm.testing.requires_gpu
 def test_warp_reduction2():
     def fcombine(x, y):
@@ -490,16 +472,15 @@ def test_warp_reduction2():
     def fidentity(t0, t1):
         return tvm.tir.const(0, t0), tvm.tir.const(1, t1)
 
-    add_mul_reducer = te.comm_reducer(fcombine, fidentity, name='add_mul_reducer')
+    add_mul_reducer = te.comm_reducer(fcombine, fidentity, name="add_mul_reducer")
 
     # compute
     m = 16
     n = 256
-    A0 = te.placeholder((m, n), name='A0', dtype='float32')
-    A1 = te.placeholder((m, n), name='Al', dtype='float32')
-    k = te.reduce_axis((0, n), 'k')
-    T0, T1 = te.compute((m, ), lambda i: \
-        add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T')
+    A0 = te.placeholder((m, n), name="A0", dtype="float32")
+    A1 = te.placeholder((m, n), name="Al", dtype="float32")
+    k = te.reduce_axis((0, n), "k")
+    T0, T1 = te.compute((m,), lambda i: add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name="T")
 
     nthdx, nthdy = 32, 2
     block_x = te.thread_axis("blockIdx.x")
@@ -522,8 +503,8 @@ def test_warp_reduction2():
 
         # validation
         ctx = tvm.context(device, 0)
-        a0_np = np.random.uniform(size=(m,n)).astype(A0.dtype)
-        a1_np = np.random.uniform(size=(m,n)).astype(A1.dtype)
+        a0_np = np.random.uniform(size=(m, n)).astype(A0.dtype)
+        a1_np = np.random.uniform(size=(m, n)).astype(A1.dtype)
         t0_np = np.zeros((m,), dtype=A0.dtype)
         t1_np = np.zeros((m,), dtype=A1.dtype)
         a0 = tvm.nd.array(a0_np, ctx)
@@ -540,6 +521,7 @@ def test_warp_reduction2():
     check_target("cuda")
     check_target("rocm")
 
+
 if __name__ == "__main__":
     test_rfactor_elemwise_threads()
     test_rfactor_threads()
index 9a61e60..73be68c 100644 (file)
@@ -19,6 +19,7 @@ from tvm import te
 import numpy as np
 import tvm.testing
 
+
 @tvm.testing.requires_gpu
 def test_scan():
     m = te.size_var("m")
@@ -26,7 +27,7 @@ def test_scan():
     X = te.placeholder((m, n), name="X")
     s_state = te.placeholder((m, n))
     s_init = te.compute((1, n), lambda _, i: X[0, i])
-    s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+    s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
     scan = tvm.te.scan(s_init, s_update, s_state)
     # test scan + compute case
     res = te.compute((m, n), lambda i, j: scan[i, j])
@@ -52,9 +53,7 @@ def test_scan():
         if not tvm.testing.device_enabled(device):
             print("skip because %s is not enabled.." % device)
             return
-        fscan = tvm.build(s, [X, res],
-                          device,
-                          name="myscan")
+        fscan = tvm.build(s, [X, res], device, name="myscan")
         # launch the kernel.
         n = 1024
         m = 10
@@ -62,8 +61,7 @@ def test_scan():
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
         fscan(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), np.cumsum(a_np, axis=0))
+        tvm.testing.assert_allclose(b.asnumpy(), np.cumsum(a_np, axis=0))
 
     check_device("vulkan")
     check_device("cuda")
index 5f45119..64b2c16 100644 (file)
@@ -28,34 +28,37 @@ from tvm.autotvm.tuner import RandomTuner
 
 import tvm.testing
 
+
 @autotvm.template("testing/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
     """An example template for testing"""
     assert N == 1, "Only consider batch_size = 1 in this template"
 
-    data = te.placeholder((N, CI, H, W), name='data')
-    kernel = te.placeholder((CO, CI, KH, KW), name='kernel')
+    data = te.placeholder((N, CI, H, W), name="data")
+    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
 
-    rc = te.reduce_axis((0, CI), name='rc')
-    ry = te.reduce_axis((0, KH), name='ry')
-    rx = te.reduce_axis((0, KW), name='rx')
+    rc = te.reduce_axis((0, CI), name="rc")
+    ry = te.reduce_axis((0, KH), name="ry")
+    rx = te.reduce_axis((0, KW), name="rx")
 
     conv = te.compute(
         (N, CO, H - KH + 1, W - KW + 1),
         lambda nn, ff, yy, xx: te.sum(
-            data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx],
-            axis=[rc, ry, rx]), tag="conv2d_nchw")
+            data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx]
+        ),
+        tag="conv2d_nchw",
+    )
 
     s = te.create_schedule([conv.op])
 
     output = conv
-    OL = s.cache_write(conv, 'local')
+    OL = s.cache_write(conv, "local")
 
     # create cache stage
-    AA = s.cache_read(data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
-    AL = s.cache_read(AA, 'local', [OL])
-    WL = s.cache_read(WW, 'local', [OL])
+    AA = s.cache_read(data, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
+    AL = s.cache_read(AA, "local", [OL])
+    WL = s.cache_read(WW, "local", [OL])
 
     # tile and bind spatial axes
     n, f, y, x = s[output].op.axis
@@ -86,9 +89,9 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
     cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
     cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
     cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
-    rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
-    rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
+    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+    ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
+    rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
     s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
 
     s[AA].compute_at(s[OL], rxo)
@@ -110,31 +113,35 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
     # tune unroll
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
     cfg.define_knob("unroll_explicit", [0, 1])
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     return s, [data, kernel, conv]
 
+
 def get_sample_task(target=tvm.target.cuda(), target_host=None):
     """return a sample task for testing"""
-    task = autotvm.task.create("testing/conv2d_no_batching",
-                               args=(1, 7, 7, 512, 512, 3, 3),
-                               target=target, target_host=target_host)
+    task = autotvm.task.create(
+        "testing/conv2d_no_batching",
+        args=(1, 7, 7, 512, 512, 3, 3),
+        target=target,
+        target_host=target_host,
+    )
     return task, target
 
+
 @tvm.testing.parametrize_targets("cuda", "opencl")
 def test_tuning(target, ctx):
     # init task
     task, target = get_sample_task(target, None)
     logging.info("%s", task.config_space)
 
-    measure_option = autotvm.measure_option(
-        autotvm.LocalBuilder(),
-        autotvm.LocalRunner())
+    measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
 
     tuner = RandomTuner(task)
     tuner.tune(n_trial=20, measure_option=measure_option)
 
+
 if __name__ == "__main__":
     # only print log when invoked from main
     logging.basicConfig(level=logging.DEBUG)
index d85e529..f29d1fb 100644 (file)
@@ -28,15 +28,29 @@ from pytest import skip
 import tvm.testing
 
 
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
-        devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+def verify_conv2d_nchw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    devices=["cuda", "llvm -device=arm_cpu", "opencl -device=mali"],
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -65,7 +79,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
             print("Skipping %s becuase it is not enabled" % device)
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NCHW', out_dtype=dtype)
+            C = topi.nn.conv2d(A, W, stride, padding, dilation, layout="NCHW", out_dtype=dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -77,14 +91,25 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
 
-
     for device in devices:
         check_device(device)
 
@@ -95,18 +120,21 @@ class WinogradFallback(autotvm.FallbackContext):
         if key in self.memory:
             return self.memory[key]
         cfg = FallbackConfigEntity()
-        cfg.template_key = 'winograd_nnpack_fp32'
+        cfg.template_key = "winograd_nnpack_fp32"
         self.memory[key] = cfg
         return cfg
 
+
 def test_conv2d_nchw():
-    if not tvm.get_global_func("tvm.contrib.nnpack.convolution_inference_without_weight_transform", True):
+    if not tvm.get_global_func(
+        "tvm.contrib.nnpack.convolution_inference_without_weight_transform", True
+    ):
         skip("extern function is not available")
 
     if not nnpack.is_available():
         skip("nnpack is not available")
 
-    devices = ['llvm -device=arm_cpu']
+    devices = ["llvm -device=arm_cpu"]
     autotvm.GLOBAL_SCOPE.silent = True
     with WinogradFallback():
         # resnet 18 workloads
@@ -143,4 +171,5 @@ def test_conv2d_nchw():
 
 if __name__ == "__main__":
     import pytest
+
     pytest.main([__file__])
index ada9a96..55971de 100644 (file)
@@ -27,58 +27,67 @@ import tvm.testing
 
 logging.basicConfig(level=logging.INFO)
 
-Config = namedtuple('Config', ['model', 'nbit_input',  'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc'])
-
-
-def get_val_data(model_name,
-                 rec_val,
-                 batch_size,
-                 num_workers=4):
+Config = namedtuple(
+    "Config",
+    [
+        "model",
+        "nbit_input",
+        "dtype_input",
+        "nbit_output",
+        "dtype_output",
+        "global_scale",
+        "expected_acc",
+    ],
+)
+
+
+def get_val_data(model_name, rec_val, batch_size, num_workers=4):
     rec_val = os.path.expanduser(rec_val)
     mean_rgb = [123.68, 116.779, 103.939]
     std_rgb = [58.393, 57.12, 57.375]
+
     def batch_fn(batch, ctx):
         data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
         label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
         return data, label
 
-    img_size = 299 if model_name == 'inceptionv3' else 224
+    img_size = 299 if model_name == "inceptionv3" else 224
     val_data = mx.io.ImageRecordIter(
-        path_imgrec         = rec_val,
-        preprocess_threads  = num_workers,
-        shuffle             = False,
-        batch_size          = batch_size,
-        resize              = 256,
-        data_shape          = (3, img_size, img_size),
-        mean_r              = mean_rgb[0],
-        mean_g              = mean_rgb[1],
-        mean_b              = mean_rgb[2],
-        std_r               = std_rgb[0],
-        std_g               = std_rgb[1],
-        std_b               = std_rgb[2],
+        path_imgrec=rec_val,
+        preprocess_threads=num_workers,
+        shuffle=False,
+        batch_size=batch_size,
+        resize=256,
+        data_shape=(3, img_size, img_size),
+        mean_r=mean_rgb[0],
+        mean_g=mean_rgb[1],
+        mean_b=mean_rgb[2],
+        std_r=std_rgb[0],
+        std_g=std_rgb[1],
+        std_b=std_rgb[2],
     )
     return val_data, batch_fn
 
 
 def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
     gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
-    img_size = 299 if model_name == 'inceptionv3' else 224
+    img_size = 299 if model_name == "inceptionv3" else 224
     data_shape = (batch_size, 3, img_size, img_size)
     mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
-    net = mod['main']
+    net = mod["main"]
 
     with tvm.transform.PassContext(opt_level=3):
         qfunc = relay.quantize.prerequisite_optimize(net, params=params)
-    logging.debug('original')
+    logging.debug("original")
     logging.debug(qfunc.astext(show_meta_data=False))
     if original:
         return qfunc
 
     with qconfig:
-        logging.debug('current quantize config')
+        logging.debug("current quantize config")
         logging.debug(qtz.current_qconfig())
         qfunc = qtz.quantize(qfunc)
-        logging.debug('after quantize')
+        logging.debug("after quantize")
         logging.debug(qfunc.astext(show_meta_data=False))
     return qfunc
 
@@ -109,20 +118,23 @@ def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(),
             _, top1 = acc_top1.get()
             _, top5 = acc_top5.get()
             nsamples = (i + 1) * batch_size
-            logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
-    logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
+            logging.info("[%d samples] validation: acc-top1=%f acc-top5=%f", nsamples, top1, top5)
+    logging.info("[final] validation: acc-top1=%f acc-top5=%f", top1, top5)
     return top1
 
+
 @tvm.testing.requires_gpu
 def test_quantize_acc(cfg, rec_val):
-    qconfig = qtz.qconfig(skip_conv_layers=[0],
-                          nbit_input=cfg.nbit_input,
-                          nbit_weight=cfg.nbit_input,
-                          global_scale=cfg.global_scale,
-                          dtype_input=cfg.dtype_input,
-                          dtype_weight=cfg.dtype_input,
-                          dtype_activation=cfg.dtype_output,
-                          debug_enabled_ops=None)
+    qconfig = qtz.qconfig(
+        skip_conv_layers=[0],
+        nbit_input=cfg.nbit_input,
+        nbit_weight=cfg.nbit_input,
+        global_scale=cfg.global_scale,
+        dtype_input=cfg.dtype_input,
+        dtype_weight=cfg.dtype_input,
+        dtype_activation=cfg.dtype_output,
+        debug_enabled_ops=None,
+    )
 
     model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
     val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
@@ -133,18 +145,65 @@ def test_quantize_acc(cfg, rec_val):
 
 
 if __name__ == "__main__":
-    #TODO(for user): replace the line with the path to imagenet validation dataset
+    # TODO(for user): replace the line with the path to imagenet validation dataset
     rec_val = "/scratch/tqchen/imagenet/val.rec"
 
     results = []
     configs = [
-        Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666),
-
-        Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692),
-        Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692),
-        Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733),
-        Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747),
-        Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756),
+        Config(
+            "mobilenetv2_1.0",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=32,
+            dtype_output="int32",
+            global_scale=4.0,
+            expected_acc=0.666,
+        ),
+        Config(
+            "resnet18_v1",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=16,
+            dtype_output="int16",
+            global_scale=8.0,
+            expected_acc=0.692,
+        ),
+        Config(
+            "resnet18_v1",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=32,
+            dtype_output="int32",
+            global_scale=8.0,
+            expected_acc=0.692,
+        ),
+        Config(
+            "resnet34_v1",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=32,
+            dtype_output="int32",
+            global_scale=8.0,
+            expected_acc=0.733,
+        ),
+        Config(
+            "resnet50_v1",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=32,
+            dtype_output="int32",
+            global_scale=8.0,
+            expected_acc=0.747,
+        ),
+        Config(
+            "resnet101_v1",
+            nbit_input=8,
+            dtype_input="int8",
+            nbit_output=32,
+            dtype_output="int32",
+            global_scale=8.0,
+            expected_acc=0.756,
+        ),
         # TODO: need to fix accuracy
         # Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
     ]
index 4fcf39d..c91a554 100644 (file)
@@ -27,15 +27,18 @@ from tvm.relay import testing
 from tvm.relay import vm
 
 
-def benchmark_execution(mod,
-                        params,
-                        measure=True,
-                        data_shape=(1, 3, 224, 224),
-                        out_shape=(1, 1000),
-                        dtype='float32',
-                        model="unknown"):
-    def get_graph_runtime_output(mod, data, params, target, ctx,
-                                 dtype='float32', number=2, repeat=20):
+def benchmark_execution(
+    mod,
+    params,
+    measure=True,
+    data_shape=(1, 3, 224, 224),
+    out_shape=(1, 1000),
+    dtype="float32",
+    model="unknown",
+):
+    def get_graph_runtime_output(
+        mod, data, params, target, ctx, dtype="float32", number=2, repeat=20
+    ):
         with tvm.transform.PassContext(opt_level=3):
             graph, lib, params = relay.build(mod, target, params=params)
 
@@ -47,32 +50,32 @@ def benchmark_execution(mod,
         out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
 
         if measure:
-            print("Evaluate graph runtime inference cost of {} on "
-                  "{}".format(model, repr(ctx)))
+            print("Evaluate graph runtime inference cost of {} on " "{}".format(model, repr(ctx)))
             ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
             # Measure in millisecond.
             prof_res = np.array(ftimer().results) * 1000
-            print("Mean graph runtime inference time (std dev): %.2f ms (%.2f ms)" %
-                  (np.mean(prof_res), np.std(prof_res)))
+            print(
+                "Mean graph runtime inference time (std dev): %.2f ms (%.2f ms)"
+                % (np.mean(prof_res), np.std(prof_res))
+            )
 
         return out.asnumpy()
 
-    def get_vm_output(mod, data, params, target, ctx, dtype='float32',
-                      number=2, repeat=20):
+    def get_vm_output(mod, data, params, target, ctx, dtype="float32", number=2, repeat=20):
         with tvm.transform.PassContext(opt_level=3):
             exe = vm.compile(mod, target, params=params)
             rly_vm = vm_rt.VirtualMachine(exe, ctx)
             result = rly_vm.run(data)
 
         if measure:
-            print("Evaluate vm inference cost of {} on {}".format(model,
-                                                                  repr(ctx)))
-            ftimer = rly_vm.module.time_evaluator("invoke", ctx, number=number,
-                                                  repeat=repeat)
+            print("Evaluate vm inference cost of {} on {}".format(model, repr(ctx)))
+            ftimer = rly_vm.module.time_evaluator("invoke", ctx, number=number, repeat=repeat)
             # Measure in millisecond.
             prof_res = np.array(ftimer("main", data).results) * 1000
-            print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" %
-                  (np.mean(prof_res), np.std(prof_res)))
+            print(
+                "Mean vm inference time (std dev): %.2f ms (%.2f ms)"
+                % (np.mean(prof_res), np.std(prof_res))
+            )
 
         return result.asnumpy().astype(dtype)
 
@@ -80,18 +83,17 @@ def benchmark_execution(mod,
     data = np.random.uniform(size=data_shape).astype(dtype)
 
     for target, ctx in testing.enabled_targets():
-        tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)),
-                                           params, target, ctx, dtype)
-        vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
-                               target, ctx, dtype)
+        tvm_out = get_graph_runtime_output(
+            mod, tvm.nd.array(data.astype(dtype)), params, target, ctx, dtype
+        )
+        vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, target, ctx, dtype)
         tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def test_mlp():
     image_shape = (1, 1, 28, 28)
     mod, params = testing.mlp.get_workload(1)
-    benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 10),
-                       model="mlp")
+    benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 10), model="mlp")
 
 
 def test_vgg():
@@ -109,7 +111,7 @@ def test_resnet():
 
 
 def test_squeezenet():
-    for version in ['1.0', '1.1']:
+    for version in ["1.0", "1.1"]:
         mod, params = testing.squeezenet.get_workload(version=version)
         model = "squeezenet" + version
         benchmark_execution(mod, params, model=model)
@@ -118,14 +120,12 @@ def test_squeezenet():
 def test_inception_v3():
     image_shape = (3, 299, 299)
     mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
-    benchmark_execution(mod, params, data_shape=(1, 3, 299, 299),
-                        model="inception_v3")
+    benchmark_execution(mod, params, data_shape=(1, 3, 299, 299), model="inception_v3")
 
 
 def test_dqn():
     image_shape = (1, 4, 84, 84)
-    mod, params = testing.dqn.get_workload(
-        batch_size=1, image_shape=image_shape)
+    mod, params = testing.dqn.get_workload(batch_size=1, image_shape=image_shape)
     benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 18))
 
 
@@ -139,20 +139,22 @@ def test_mobilenet():
     mod, params = testing.mobilenet.get_workload(batch_size=1)
     benchmark_execution(mod, params, model="mobilenet")
 
+
 # TODO: enable when the low building performance (several minutes) fixed.
 def test_mobilenet_nhwc():
     image_shape = (1, 224, 224, 3)
-    mod, params = testing.mobilenet.get_workload(batch_size=1,
-                                                 image_shape=image_shape[1:],
-                                                 layout='NHWC')
+    mod, params = testing.mobilenet.get_workload(
+        batch_size=1, image_shape=image_shape[1:], layout="NHWC"
+    )
     benchmark_execution(mod, params, measure=False, data_shape=image_shape)
 
+
 def test_densenet():
     mod, params = testing.densenet.get_workload(batch_size=1)
     benchmark_execution(mod, params, model="densenet")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_resnet()
     test_vgg()
     test_squeezenet()
index e3c8c9e..622e291 100644 (file)
@@ -30,21 +30,21 @@ import tvm.testing
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_broadcast_to():
-    dtype = 'uint8'
+    dtype = "uint8"
     rank = 3
-    shape_type = 'int64'
-    dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type))
-    x_shape = (1, )
+    shape_type = "int64"
+    dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
+    x_shape = (1,)
     x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
     z = relay.broadcast_to(x, dyn_shape)
     zz = run_infer_type(z)
 
-    assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype)
+    assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)
 
     func = relay.Function([x, dyn_shape], z)
 
     x = np.random.uniform(size=x_shape).astype(dtype)
-    dyn_shape = (1, ) * rank
+    dyn_shape = (1,) * rank
     ref_res = np.broadcast_to(x, dyn_shape)
     for target, ctx in tvm.testing.enabled_targets():
         for kind in ["vm", "debug"]:
@@ -53,6 +53,7 @@ def test_dyn_broadcast_to():
             op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type))
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
+
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_one_hot():
@@ -86,8 +87,8 @@ def test_dyn_one_hot():
                 out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32"))
                 tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
 
-    _verify((3, ), 3, 1, 0, -1, "int32")
-    _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+    _verify((3,), 3, 1, 0, -1, "int32")
+    _verify((3,), 3, 1.0, 0.0, -1, "float32")
     _verify((2, 2), 5, 2, -2, 0, "int32")
     _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
     _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
index 63dfd10..37cc124 100644 (file)
@@ -43,29 +43,35 @@ def test_dyn_upsampling_run():
         if method == "nearest_neighbor":
             ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout)
         else:
-            ref_res = tvm.topi.testing.bilinear_resize_python(x_data, (int(round(h*scale_h)),
-                                                  int(round(w*scale_w))), layout)
+            ref_res = tvm.topi.testing.bilinear_resize_python(
+                x_data, (int(round(h * scale_h)), int(round(w * scale_w))), layout
+            )
         x = relay.Var("x", relay.TensorType(dshape, "float32"))
         scale_h_var = relay.var("scale_h", relay.TensorType((), "float32"))
         scale_w_var = relay.var("scale_h", relay.TensorType((), "float32"))
 
-        z = relay.nn.upsampling(x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners)
+        z = relay.nn.upsampling(
+            x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners
+        )
         zz = run_infer_type(z)
         func = relay.Function([x, scale_h_var, scale_w_var], z)
 
         for target, ctx in tvm.testing.enabled_targets():
-             for kind in ["vm", "debug"]:
-                 mod = tvm.ir.IRModule.from_expr(func)
-                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
-                 op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32"))
-                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
+            for kind in ["vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(
+                    x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")
+                )
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
 
     verify_upsampling((1, 16, 32, 32), 3, 2.0, "NCHW", "nearest_neighbor")
     verify_upsampling((1, 16, 32, 32), 5, 2.0, "NCHW", "bilinear", True)
     verify_upsampling((1, 16, 32, 32), 2.0, 6, "NHWC", "nearest_neighbor")
-    verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True)
+    verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "bilinear", True)
 
-#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
+
+# tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_upsampling_infer_type_const():
@@ -78,10 +84,13 @@ def test_dyn_upsampling_infer_type_const():
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8")
 
+
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_upsampling3d_run():
-    def verify_upsampling3d(dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel"):
+    def verify_upsampling3d(
+        dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel"
+    ):
 
         if layout == "NCDHW":
             (n, c, d, h, w) = dshape
@@ -92,18 +101,29 @@ def test_dyn_upsampling3d_run():
             x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32")
 
         if method == "nearest_neighbor":
-            ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale_d, scale_h, scale_w), layout)
+            ref_res = tvm.topi.testing.upsampling3d_python(
+                x_data, (scale_d, scale_h, scale_w), layout
+            )
         else:
-            ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, (int(round(d*scale_d)),
-                                                                 int(round(h*scale_h)),
-                                                                 int(round(w*scale_w))), layout)
+            ref_res = tvm.topi.testing.trilinear_resize3d_python(
+                x_data,
+                (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))),
+                layout,
+            )
         x = relay.Var("x", relay.TensorType(dshape, "float32"))
         scale_d_var = relay.var("scale_d", relay.TensorType((), "float32"))
         scale_h_var = relay.var("scale_h", relay.TensorType((), "float32"))
         scale_w_var = relay.var("scale_h", relay.TensorType((), "float32"))
 
-        z = relay.nn.upsampling3d(x, scale_d_var, scale_h_var, scale_w_var, method=method, layout=layout,
-                                coordinate_transformation_mode=coord_trans)
+        z = relay.nn.upsampling3d(
+            x,
+            scale_d_var,
+            scale_h_var,
+            scale_w_var,
+            method=method,
+            layout=layout,
+            coordinate_transformation_mode=coord_trans,
+        )
         zz = run_infer_type(z)
         func = relay.Function([x, scale_d_var, scale_h_var, scale_w_var], z)
 
@@ -111,18 +131,30 @@ def test_dyn_upsampling3d_run():
             for kind in ["vm", "debug"]:
                 mod = tvm.ir.IRModule.from_expr(func)
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
-                op_res = intrp.evaluate()(x_data, np.array(scale_d).astype("float32"), np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32"))
+                op_res = intrp.evaluate()(
+                    x_data,
+                    np.array(scale_d).astype("float32"),
+                    np.array(scale_h).astype("float32"),
+                    np.array(scale_w).astype("float32"),
+                )
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
 
     verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, "NCDHW", "nearest_neighbor")
     verify_upsampling3d((1, 8, 16, 16, 16), 2.0, 3.0, 4.0, "NCDHW", "nearest_neighbor")
     verify_upsampling3d((1, 8, 16, 16, 16), 2.0, 5.0, 1.0, "NCDHW", "trilinear", "align_corners")
     verify_upsampling3d((1, 20, 3, 4, 16), 2.0, 2.0, 2.0, "NDHWC", "nearest_neighbor")
-    verify_upsampling3d((1, 8, 4, 16, 15), 2.0, 2.0, 2.0,"NDHWC", "trilinear", "align_corners")
+    verify_upsampling3d((1, 8, 4, 16, 15), 2.0, 2.0, 2.0, "NDHWC", "trilinear", "align_corners")
 
-#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
+
+# tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
 def test_dyn_upsampling3d_infer_type_const():
-    n, c, d, h, w = te.size_var("n"), te.size_var("c"), te.size_var("d"), te.size_var("h"), te.size_var("w")
+    n, c, d, h, w = (
+        te.size_var("n"),
+        te.size_var("c"),
+        te.size_var("d"),
+        te.size_var("h"),
+        te.size_var("w"),
+    )
 
     data = relay.var("data", relay.TensorType((n, c, d, h, w), "int8"))
     scale_d = relay.Var("scale_h", relay.TensorType((), "float32"))
@@ -130,7 +162,9 @@ def test_dyn_upsampling3d_infer_type_const():
 
     z = relay.nn.upsampling3d(data, scale_d, 2.0, scale_w, layout="NCDHW", method="trilinear")
     zz = run_infer_type(z)
-    assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any(), relay.Any()), "int8")
+    assert zz.checked_type == relay.TensorType(
+        (n, c, relay.Any(), relay.Any(), relay.Any()), "int8"
+    )
 
 
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
@@ -139,7 +173,7 @@ def test_dyn_pad():
     def verify_pad(dshape, pad_width, pad_val, dtype):
         x = relay.var("x", relay.TensorType(dshape, dtype))
         ndim = len(dshape)
-        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), "int64"))
         pad_val_var = relay.var("pad_val_var", relay.TensorType((), dtype))
         y = relay.nn.pad(x, pad_width_var, pad_val_var)
         yy = run_infer_type(y)
@@ -147,15 +181,15 @@ def test_dyn_pad():
         assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype)
         func = relay.Function([x, pad_width_var, pad_val_var], y)
         data = np.random.uniform(size=dshape).astype(dtype)
-        ref_res = np.pad(data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * ndim)
-        pad_width = np.array(pad_width).astype('int64')
+        ref_res = np.pad(data, pad_width, "constant", constant_values=(((pad_val,) * 2),) * ndim)
+        pad_width = np.array(pad_width).astype("int64")
 
         verify_func(func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res)
 
     def verify_pad_default_fill(dshape, pad_width, dtype):
         x = relay.var("x", relay.TensorType(dshape, dtype))
         ndim = len(dshape)
-        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), "int64"))
         y = relay.nn.pad(x, pad_width_var)
         yy = run_infer_type(y)
 
@@ -163,7 +197,7 @@ def test_dyn_pad():
         func = relay.Function([x, pad_width_var], y)
         data = np.random.uniform(size=dshape).astype(dtype)
         ref_res = np.pad(data, pad_width)
-        pad_width = np.array(pad_width).astype('int64')
+        pad_width = np.array(pad_width).astype("int64")
 
         verify_func(func, [data, pad_width], ref_res)
 
@@ -172,6 +206,7 @@ def test_dyn_pad():
     verify_pad_default_fill((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), "float64")
     verify_pad_default_fill((2, 7), ((1, 4), (2, 2)), "int32")
 
+
 if __name__ == "__main__":
     test_dyn_pad()
     test_dyn_upsampling_infer_type_const()
index 74b4e10..301e722 100644 (file)
@@ -25,6 +25,7 @@ from tvm.relay import create_executor, transform
 from tvm.relay.testing import check_grad, run_infer_type
 import tvm.testing
 
+
 def verify_func(func, data, ref_res):
     assert isinstance(data, list)
     for target, ctx in tvm.testing.enabled_targets():
@@ -35,22 +36,27 @@ def verify_func(func, data, ref_res):
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
             relay.backend.compile_engine.get().clear()
 
+
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_reshape():
     def verify_reshape(shape, newshape, oshape):
         x = relay.var("x", relay.TensorType(shape, "float32"))
-        y = relay.var("y", relay.TensorType((len(newshape), ), "int64"))
+        y = relay.var("y", relay.TensorType((len(newshape),), "int64"))
         z = relay.reshape(x, y)
 
         func = relay.Function([x, y], z)
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
         x_data = np.ones(shape).astype("float32")
         ref_res = np.reshape(x_data, oshape)
-        check_grad(run_infer_type(func),
-                   inputs=[x_data, np.array(newshape).astype("int64")],
-                   test_inputs=[x_data], eps=1e-3)
+        check_grad(
+            run_infer_type(func),
+            inputs=[x_data, np.array(newshape).astype("int64")],
+            test_inputs=[x_data],
+            eps=1e-3,
+        )
         verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
+
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
     verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
@@ -61,6 +67,7 @@ def test_dyn_reshape():
     verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20))
     verify_reshape((2, 3, 4), (0, -3), (2, 12))
 
+
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_shape_reshape():
@@ -73,18 +80,19 @@ def test_dyn_shape_reshape():
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
         y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
         ref_res = np.reshape(x_data, oshape)
-        check_grad(run_infer_type(func),
-                   inputs=[x_data, y_data], eps=1e-3)
+        check_grad(run_infer_type(func), inputs=[x_data, y_data], eps=1e-3)
         verify_func(func, [x_data, y_data], ref_res)
+
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
 
+
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_tile():
     def verify_tile(dshape, reps):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
-        r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
+        r = relay.var("reps", relay.TensorType((len(reps),), "float32"))
         z = relay.tile(x, r)
 
         func = relay.Function([x, r], z)
@@ -92,6 +100,7 @@ def test_dyn_tile():
         ref_res = np.tile(x_data, reps=reps)
         reps_data = np.array(reps).astype("float32")
         verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
+
     verify_tile((2, 3, 4), (3, 2, 1))
     verify_tile((2, 3, 4), (1, 2))
     verify_tile((2, 3), (3, 2, 1))
@@ -103,16 +112,18 @@ def test_dyn_zeros_ones():
     def verify_zeros_ones(shape, dtype):
         for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
             rank = len(shape)
-            dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), 'int64'))
+            dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), "int64"))
             y = op(dyn_shape, dtype)
             yy = run_infer_type(y)
             assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)
 
             func = relay.Function([dyn_shape], y)
             ref_res = ref(shape, dtype)
-            verify_func(func, [np.array(shape).astype('int64')], ref_res.astype('int64'))
-    verify_zeros_ones((1, 3), 'int64')
-    verify_zeros_ones((8, 9, 1, 2), 'float32')
+            verify_func(func, [np.array(shape).astype("int64")], ref_res.astype("int64"))
+
+    verify_zeros_ones((1, 3), "int64")
+    verify_zeros_ones((8, 9, 1, 2), "float32")
+
 
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
@@ -120,15 +131,19 @@ def test_dyn_full():
     def verify_full(fill_value, src_shape, dtype):
         x = relay.var("x", relay.scalar_type(dtype))
         rank = len(src_shape)
-        dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), 'int64'))
+        dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), "int64"))
         z = relay.full(x, dyn_src_shape, dtype)
         func = relay.Function([x, dyn_src_shape], z)
         ref_res = np.full(src_shape, fill_value).astype(dtype)
 
-        verify_func(func, [np.array(fill_value).astype(dtype), np.array(src_shape).astype('int64')], ref_res)
-    verify_full(4, (1, 3, 4, 4), 'int32')
-    verify_full(4, (1, 3, 4, 4), 'int64')
-    verify_full(4.0, (2, 50), 'float32')
+        verify_func(
+            func, [np.array(fill_value).astype(dtype), np.array(src_shape).astype("int64")], ref_res
+        )
+
+    verify_full(4, (1, 3, 4, 4), "int32")
+    verify_full(4, (1, 3, 4, 4), "int64")
+    verify_full(4.0, (2, 50), "float32")
+
 
 if __name__ == "__main__":
     test_dyn_reshape()
index b739a0e..b8b2486 100644 (file)
@@ -26,8 +26,7 @@ import tvm.topi.testing
 # TODO(mbrookhart): Enable when VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dynamic_strided_slice():
-    def verify(dshape, begin, end, strides, output, slice_mode="end",
-               test_ref=True, dtype="int32"):
+    def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         ndim = len(dshape)
         begin = begin if begin else [0] * ndim
@@ -40,27 +39,18 @@ def test_dynamic_strided_slice():
 
         # target numpy result
         x_data = np.random.uniform(size=dshape).astype("float32")
-        ref_res = tvm.topi.testing.strided_slice_python(
-            x_data, begin, end, strides, slice_mode)
+        ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
         data = [x_data, np.array(begin), np.array(end)]
-        
+
         begin = relay.const(begin, dtype=dtype)
         end = relay.const(end, dtype=dtype)
 
-        
         if strides:
             data.append(np.array(strides))
             strides = relay.const(strides, dtype=dtype)
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    strides=strides,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
         else:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
         func = relay.Function([x], z)
 
         func = run_infer_type(func)
@@ -75,8 +65,14 @@ def test_dynamic_strided_slice():
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
     verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64")
-    verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3],
-           [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64")
+    verify(
+        (1, 224, 224, 3),
+        [0, 20, 20, 0],
+        [1, 140, 140, 3],
+        [1, 1, 1, 1],
+        (1, 120, 120, 3),
+        dtype="int64",
+    )
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
     verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
@@ -85,10 +81,10 @@ def test_dynamic_strided_slice():
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
     verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
-    verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1],
-           (2, 4, 3), slice_mode="size", test_ref=False)
-    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1],
-           (2, 2, 3), slice_mode="size", test_ref=True)
+    verify(
+        (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
+    )
+    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
 
 
 if __name__ == "__main__":
index a6e5b61..de199dd 100644 (file)
@@ -52,10 +52,11 @@ def test_resize():
             ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout)
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         size_var = relay.var("size", relay.TensorType((2,), "int64"))
-        
+
         coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners"
-        z = relay.image.resize(x, size_var, layout, method,
-                              coordinate_transformation_mode=coord_trans)
+        z = relay.image.resize(
+            x, size_var, layout, method, coordinate_transformation_mode=coord_trans
+        )
 
         zz = run_infer_type(z)
         func = relay.Function([x, size_var], z)
@@ -72,6 +73,7 @@ def test_resize():
             verify_resize((1, 4, 4, 4), 2, method, layout)
             verify_resize((2, 8, 17, 20), 7, method, layout)
 
+
 if __name__ == "__main__":
     test_resize_infer_type()
     test_resize()
index 58bf53c..bab8b9c 100644 (file)
@@ -1,4 +1,3 @@
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -65,6 +64,7 @@ def test_dynamic_topk():
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
                 else:
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+
     np.random.seed(0)
     for k in [0, 1, 5]:
         for axis in [0, -1, 1]:
index da72429..784abcb 100644 (file)
@@ -28,9 +28,11 @@ mod = tvm.IRModule()
 p = Prelude(mod)
 add_nat_definitions(p)
 
+
 def count(e):
     return count_(p, e)
 
+
 ctx = tvm.context("llvm", 0)
 intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
 
@@ -82,6 +84,7 @@ def make_nat(n):
     else:
         return ConstructorValue(z, [])
 
+
 def make_nat_expr(n):
     assert n >= 0
     ret = z()
@@ -90,6 +93,7 @@ def make_nat_expr(n):
         n = n - 1
     return ret
 
+
 def to_list(l):
     assert isinstance(l, ConstructorValue)
     val = l
@@ -103,15 +107,16 @@ def to_list(l):
             break
     return ret
 
+
 def tree_to_dict(t):
     assert isinstance(t, ConstructorValue)
     ret = {}
     assert t.tag == p.rose.tag
-    ret['member'] = t.fields[0]
-    ret['children'] = []
+    ret["member"] = t.fields[0]
+    ret["children"] = []
     for subtree in to_list(t.fields[1]):
         l = tree_to_dict(subtree)
-        ret['children'].append(l)
+        ret["children"].append(l)
     return ret
 
 
@@ -130,16 +135,16 @@ def vmobj_to_list(o, dtype="float32"):
             result.extend(vmobj_to_list(f, dtype))
         return result
     elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
-        if o.constructor.name_hint == 'Cons':
+        if o.constructor.name_hint == "Cons":
             tl = vmobj_to_list(o.fields[1], dtype)
             hd = vmobj_to_list(o.fields[0], dtype)
             hd.extend(tl)
             return hd
-        elif o.constructor.name_hint == 'Nil':
+        elif o.constructor.name_hint == "Nil":
             return []
-        elif 'tensor_nil' in o.constructor.name_hint:
+        elif "tensor_nil" in o.constructor.name_hint:
             return [0]
-        elif 'tensor' in o.constructor.name_hint:
+        elif "tensor" in o.constructor.name_hint:
             return [o.fields[0].asnumpy()]
         else:
             raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
@@ -241,7 +246,7 @@ def test_update():
 @tvm.testing.uses_gpu
 def test_length():
     a = relay.TypeVar("a")
-    assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
+    assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type("int32"), [a])
     res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
     assert get_scalar(res) == 3
 
@@ -273,10 +278,13 @@ def test_foldl():
     x = relay.Var("x")
     y = relay.Var("y")
     rev_dup = relay.Function([y, x], cons(x, cons(x, y)))
-    res = intrp.evaluate(foldl(rev_dup, nil(),
-                               cons(make_nat_expr(1),
-                                    cons(make_nat_expr(2),
-                                         cons(make_nat_expr(3), nil())))))
+    res = intrp.evaluate(
+        foldl(
+            rev_dup,
+            nil(),
+            cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))),
+        )
+    )
     reversed = to_list(res)
     assert len(reversed) == 6
     assert count(reversed[0]) == 3 and count(reversed[1]) == 3
@@ -295,10 +303,13 @@ def test_foldr():
     x = relay.Var("x")
     y = relay.Var("y")
     identity = relay.Function([x, y], cons(x, y))
-    res = intrp.evaluate(foldr(identity, nil(),
-                               cons(make_nat_expr(1),
-                                    cons(make_nat_expr(2),
-                                         cons(make_nat_expr(3), nil())))))
+    res = intrp.evaluate(
+        foldr(
+            identity,
+            nil(),
+            cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))),
+        )
+    )
     same = to_list(res)
     assert len(same) == 3
     assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
@@ -314,17 +325,18 @@ def test_foldr1():
     x = relay.Var("x")
     y = relay.Var("y")
     f = relay.Function([x, y], add(x, y))
-    res = intrp.evaluate(foldr1(f,
-                                cons(make_nat_expr(1),
-                                    cons(make_nat_expr(2),
-                                         cons(make_nat_expr(3), nil())))))
+    res = intrp.evaluate(
+        foldr1(f, cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))))
+    )
 
     assert count(res) == 6
 
 
 @tvm.testing.uses_gpu
 def test_sum():
-    assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
+    assert mod[sum].checked_type == relay.FuncType(
+        [l(relay.scalar_type("int32"))], relay.scalar_type("int32")
+    )
     res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
     assert get_scalar(res) == 3
 
@@ -349,32 +361,44 @@ def test_concat():
 @tvm.testing.uses_gpu
 def test_filter():
     a = relay.TypeVar("a")
-    expected_type = relay.FuncType([
-        relay.FuncType([a], relay.scalar_type("bool")), l(a)
-    ], l(a), [a])
+    expected_type = relay.FuncType(
+        [relay.FuncType([a], relay.scalar_type("bool")), l(a)], l(a), [a]
+    )
     assert mod[filter].checked_type == expected_type
 
     x = relay.Var("x", nat())
     greater_than_one = relay.Function(
         [x],
-        relay.Match(x, [
-            relay.Clause(
-                relay.PatternConstructor(s, [
+        relay.Match(
+            x,
+            [
+                relay.Clause(
                     relay.PatternConstructor(
-                        s, [relay.PatternWildcard()])
-                ]),
-                relay.const(True)),
-            relay.Clause(relay.PatternWildcard(), relay.const(False))
-        ]))
+                        s, [relay.PatternConstructor(s, [relay.PatternWildcard()])]
+                    ),
+                    relay.const(True),
+                ),
+                relay.Clause(relay.PatternWildcard(), relay.const(False)),
+            ],
+        ),
+    )
     res = intrp.evaluate(
-        filter(greater_than_one,
-               cons(make_nat_expr(1),
-                    cons(make_nat_expr(1),
-                         cons(make_nat_expr(3),
-                              cons(make_nat_expr(1),
-                                   cons(make_nat_expr(5),
-                                        cons(make_nat_expr(1),
-                                             nil()))))))))
+        filter(
+            greater_than_one,
+            cons(
+                make_nat_expr(1),
+                cons(
+                    make_nat_expr(1),
+                    cons(
+                        make_nat_expr(3),
+                        cons(
+                            make_nat_expr(1), cons(make_nat_expr(5), cons(make_nat_expr(1), nil()))
+                        ),
+                    ),
+                ),
+            ),
+        )
+    )
     filtered = to_list(res)
     assert len(filtered) == 2
     assert count(filtered[0]) == 3
@@ -385,15 +409,11 @@ def test_filter():
 def test_zip():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
-    expected_type = relay.FuncType([l(a), l(b)],
-                                   l(relay.TupleType([a, b])), [a, b])
+    expected_type = relay.FuncType([l(a), l(b)], l(relay.TupleType([a, b])), [a, b])
     assert mod[zip].checked_type == expected_type
 
     l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
-    l2 = cons(nil(),
-              cons(cons(nil(), nil()),
-                   cons(cons(nil(), cons(nil(), nil())),
-                        nil())))
+    l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), nil())))
 
     res = intrp.evaluate(zip(l1, l2))
     zipped = to_list(res)
@@ -428,9 +448,9 @@ def test_rev():
     a = relay.TypeVar("a")
     assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])
 
-    res = intrp.evaluate(rev(cons(make_nat_expr(1),
-                                  cons(make_nat_expr(2),
-                                       cons(make_nat_expr(3), nil())))))
+    res = intrp.evaluate(
+        rev(cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))))
+    )
     reversed = to_list(res)
 
     assert len(reversed) == 3
@@ -443,20 +463,24 @@ def test_rev():
 def test_unfoldr():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
-    expected_type = relay.FuncType([
-        relay.FuncType([a], optional(relay.TupleType([a, b]))), a],
-                                   l(b), [a, b])
+    expected_type = relay.FuncType(
+        [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b]
+    )
 
     x = relay.Var("x", nat())
     n = relay.Var("n", nat())
     count_down = relay.Function(
         [x],
-        relay.Match(x, [
-            relay.Clause(relay.PatternConstructor(
-                s, [relay.PatternVar(n)]),
-                         some(relay.Tuple([n, x]))),
-            relay.Clause(relay.PatternConstructor(z, []), none())
-        ]))
+        relay.Match(
+            x,
+            [
+                relay.Clause(
+                    relay.PatternConstructor(s, [relay.PatternVar(n)]), some(relay.Tuple([n, x]))
+                ),
+                relay.Clause(relay.PatternConstructor(z, []), none()),
+            ],
+        ),
+    )
 
     res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3)))
     unfolded = to_list(res)
@@ -471,20 +495,24 @@ def test_unfoldr():
 def test_unfoldl():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
-    expected_type = relay.FuncType([
-        relay.FuncType([a], optional(relay.TupleType([a, b]))), a],
-                                   l(b), [a, b])
+    expected_type = relay.FuncType(
+        [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b]
+    )
 
     x = relay.Var("x", nat())
     n = relay.Var("n", nat())
     count_down = relay.Function(
         [x],
-        relay.Match(x, [
-            relay.Clause(relay.PatternConstructor(
-                s, [relay.PatternVar(n)]),
-                         some(relay.Tuple([n, x]))),
-            relay.Clause(relay.PatternConstructor(z, []), none())
-        ]))
+        relay.Match(
+            x,
+            [
+                relay.Clause(
+                    relay.PatternConstructor(s, [relay.PatternVar(n)]), some(relay.Tuple([n, x]))
+                ),
+                relay.Clause(relay.PatternConstructor(z, []), none()),
+            ],
+        ),
+    )
 
     res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
     unfolded = to_list(res)
@@ -500,17 +528,16 @@ def test_map_accumr():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
     c = relay.TypeVar("c")
-    expected_type = relay.FuncType([
-        relay.FuncType([a, b], relay.TupleType([a, c])),
-        a, l(b)
-    ], relay.TupleType([a, l(c)]), [a, b, c])
+    expected_type = relay.FuncType(
+        [relay.FuncType([a, b], relay.TupleType([a, c])), a, l(b)],
+        relay.TupleType([a, l(c)]),
+        [a, b, c],
+    )
     assert mod[map_accumr].checked_type == expected_type
 
     acc = relay.Var("acc", nat())
     x = relay.Var("x", nat())
-    add_acc_to_each = relay.Function([acc, x],
-                                     relay.Tuple([add(x, acc),
-                                                  add(x, acc)]))
+    add_acc_to_each = relay.Function([acc, x], relay.Tuple([add(x, acc), add(x, acc)]))
 
     vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
     res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals))
@@ -530,16 +557,16 @@ def test_map_accuml():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
     c = relay.TypeVar("c")
-    expected_type = relay.FuncType([
-        relay.FuncType([a, b], relay.TupleType([a, c])),
-        a, l(b)
-    ], relay.TupleType([a, l(c)]), [a, b, c])
+    expected_type = relay.FuncType(
+        [relay.FuncType([a, b], relay.TupleType([a, c])), a, l(b)],
+        relay.TupleType([a, l(c)]),
+        [a, b, c],
+    )
     assert mod[map_accuml].checked_type == expected_type
 
     acc = relay.Var("acc", nat())
     x = relay.Var("x", nat())
-    add_to_acc = relay.Function([acc, x],
-                                relay.Tuple([add(x, acc), x]))
+    add_to_acc = relay.Function([acc, x], relay.Tuple([add(x, acc), x]))
 
     vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
     res = intrp.evaluate(map_accuml(add_to_acc, z(), vals))
@@ -556,19 +583,27 @@ def test_map_accuml():
 
 @tvm.testing.uses_gpu
 def test_optional_matching():
-    x = relay.Var('x')
-    y = relay.Var('y')
-    v = relay.Var('v')
+    x = relay.Var("x")
+    y = relay.Var("y")
+    v = relay.Var("v")
     condense = relay.Function(
         [x, y],
-        relay.Match(x, [
-            relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
-            relay.Clause(relay.PatternConstructor(none), y)
-        ]))
+        relay.Match(
+            x,
+            [
+                relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
+                relay.Clause(relay.PatternConstructor(none), y),
+            ],
+        ),
+    )
 
-    res = intrp.evaluate(foldr(condense, nil(), cons(
-        some(make_nat_expr(3)),
-        cons(none(), cons(some(make_nat_expr(1)), nil())))))
+    res = intrp.evaluate(
+        foldr(
+            condense,
+            nil(),
+            cons(some(make_nat_expr(3)), cons(none(), cons(some(make_nat_expr(1)), nil()))),
+        )
+    )
 
     reduced = to_list(res)
     assert len(reduced) == 2
@@ -586,30 +621,26 @@ def test_tmap():
 
     x = relay.Var("x")
     add_one = relay.Function([x], s(x))
-    res = intrp.evaluate(tmap(add_one,
-                              rose(z(),
-                                   cons(rose(z(), nil()),
-                                        cons(rose(z(), nil()),
-                                             nil())))))
+    res = intrp.evaluate(
+        tmap(add_one, rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil()))))
+    )
 
     tree_dict = tree_to_dict(res)
-    assert count(tree_dict['member']) == 1
-    assert len(tree_dict['children']) == 2
-    for subtree in tree_dict['children']:
-        assert count(subtree['member']) == 1
-        assert len(subtree['children']) == 0
+    assert count(tree_dict["member"]) == 1
+    assert len(tree_dict["children"]) == 2
+    for subtree in tree_dict["children"]:
+        assert count(subtree["member"]) == 1
+        assert len(subtree["children"]) == 0
 
 
 @tvm.testing.uses_gpu
 def test_size():
     a = relay.TypeVar("a")
     lhs = mod[size].checked_type
-    rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a])
+    rhs = relay.FuncType([tree(a)], relay.scalar_type("int32"), [a])
     assert lhs == rhs
 
-    root = rose(z(), cons(rose(z(), nil()),
-                                  cons(rose(z(), nil()),
-                                       nil())))
+    root = rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil())))
     t = rose(z(), cons(root, cons(root, cons(root, nil()))))
     res = intrp.evaluate(size(t))
     assert get_scalar(res) == 10
@@ -617,10 +648,8 @@ def test_size():
 
 @tvm.testing.uses_gpu
 def test_wildcard_match_solo():
-    x = relay.Var('x', nat())
-    copy = relay.Function([x],
-                          relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]),
-                          nat())
+    x = relay.Var("x", nat())
+    copy = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]), nat())
 
     res = intrp.evaluate(copy(s(s(s(z())))))
     assert count(res) == 3
@@ -628,20 +657,23 @@ def test_wildcard_match_solo():
 
 @tvm.testing.uses_gpu
 def test_wildcard_match_order():
-    x = relay.Var('x', l(nat()))
-    y = relay.Var('y')
-    a = relay.Var('a')
+    x = relay.Var("x", l(nat()))
+    y = relay.Var("y")
+    a = relay.Var("a")
     return_zero = relay.Function(
         [x],
-        relay.Match(x, [
-            relay.Clause(relay.PatternWildcard(), z()),
-            relay.Clause(
-                relay.PatternConstructor(
-                    cons, [relay.PatternVar(y), relay.PatternVar(a)]),
-                y),
-            relay.Clause(relay.PatternConstructor(nil), s(z()))
-        ]),
-        nat())
+        relay.Match(
+            x,
+            [
+                relay.Clause(relay.PatternWildcard(), z()),
+                relay.Clause(
+                    relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(a)]), y
+                ),
+                relay.Clause(relay.PatternConstructor(nil), s(z())),
+            ],
+        ),
+        nat(),
+    )
 
     res = intrp.evaluate(return_zero(cons(s(z()), nil())))
     # wildcard pattern is evaluated first
@@ -650,36 +682,44 @@ def test_wildcard_match_order():
 
 @tvm.testing.uses_gpu
 def test_nested_matches():
-    a = relay.TypeVar('a')
-    x = relay.Var('x')
-    y = relay.Var('y')
-    w = relay.Var('w')
-    h = relay.Var('h')
-    t = relay.Var('t')
-    flatten = relay.GlobalVar('flatten')
+    a = relay.TypeVar("a")
+    x = relay.Var("x")
+    y = relay.Var("y")
+    w = relay.Var("w")
+    h = relay.Var("h")
+    t = relay.Var("t")
+    flatten = relay.GlobalVar("flatten")
 
     # flatten could be written using a fold, but this way has nested matches
     inner_match = relay.Match(
-        y, [
+        y,
+        [
             relay.Clause(relay.PatternConstructor(nil), flatten(w)),
-            relay.Clause(relay.PatternConstructor(
-                cons, [relay.PatternVar(h), relay.PatternVar(t)]),
-                cons(h, flatten(cons(t, w))))
-        ])
+            relay.Clause(
+                relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]),
+                cons(h, flatten(cons(t, w))),
+            ),
+        ],
+    )
 
     mod[flatten] = relay.Function(
         [x],
-        relay.Match(x, [
-            relay.Clause(relay.PatternConstructor(nil), nil()),
-            relay.Clause(relay.PatternConstructor(
-                cons, [relay.PatternVar(y), relay.PatternVar(w)]),
-                         inner_match)
-        ]), l(a), [a])
-
-    first_list = cons(make_nat_expr(1), cons(make_nat_expr(2),
-                                         cons(make_nat_expr(3), nil())))
-    second_list = cons(make_nat_expr(4), cons(make_nat_expr(5),
-                                          cons(make_nat_expr(6), nil())))
+        relay.Match(
+            x,
+            [
+                relay.Clause(relay.PatternConstructor(nil), nil()),
+                relay.Clause(
+                    relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(w)]),
+                    inner_match,
+                ),
+            ],
+        ),
+        l(a),
+        [a],
+    )
+
+    first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
+    second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), cons(make_nat_expr(6), nil())))
     final_list = cons(first_list, cons(second_list, nil()))
 
     res = intrp.evaluate(flatten(final_list))
@@ -692,12 +732,9 @@ def test_nested_matches():
 
 @tvm.testing.uses_gpu
 def test_match_full_var():
-    x = relay.Var('x')
-    v = relay.Var('v')
-    id_func = relay.Function([x],
-                             relay.Match(x,
-                                         [relay.Clause(relay.PatternVar(v),
-                                                       v)]))
+    x = relay.Var("x")
+    v = relay.Var("v")
+    id_func = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternVar(v), v)]))
 
     res1 = intrp.evaluate(id_func(nil()))
     res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil()))))
@@ -713,36 +750,38 @@ def test_match_full_var():
 
 @tvm.testing.uses_gpu
 def test_nested_pattern_match():
-    x = relay.Var('x', l(nat()))
-    h1 = relay.Var('h1')
-    h2 = relay.Var('h2')
-    t = relay.Var('t')
+    x = relay.Var("x", l(nat()))
+    h1 = relay.Var("h1")
+    h2 = relay.Var("h2")
+    t = relay.Var("t")
     match = relay.Match(
         x,
-        [relay.Clause(
-            relay.PatternConstructor(
-                cons,
-                [relay.PatternVar(h1),
-                 relay.PatternConstructor(
+        [
+            relay.Clause(
+                relay.PatternConstructor(
                     cons,
-                     [relay.PatternVar(h2), relay.PatternVar(t)])]),
-            h2),
-         relay.Clause(relay.PatternWildcard(), z())
-        ])
+                    [
+                        relay.PatternVar(h1),
+                        relay.PatternConstructor(cons, [relay.PatternVar(h2), relay.PatternVar(t)]),
+                    ],
+                ),
+                h2,
+            ),
+            relay.Clause(relay.PatternWildcard(), z()),
+        ],
+    )
     get_second = relay.Function([x], match)
 
-    res = intrp.evaluate(get_second(cons(s(z()),
-                                         cons(s(s(z())),
-                                              nil()))))
+    res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil()))))
 
     assert count(res) == 2
 
 
 @tvm.testing.uses_gpu
 def test_compose():
-    n = relay.Var('n')
+    n = relay.Var("n")
     inc = relay.Function([n], s(n))
-    x = relay.Var('x')
+    x = relay.Var("x")
     res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
     assert count(res) == 5
 
@@ -768,31 +807,33 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5):
 @tvm.testing.uses_gpu
 def test_tensor_expand_dims():
     def run(dtype):
-        x = relay.var('x')
+        x = relay.var("x")
         mod = tvm.IRModule()
         p = Prelude(mod)
-        expand_dims_func = p.get_var('tensor_expand_dims', dtype)
-        tensor1 = p.get_var('tensor1', dtype)
+        expand_dims_func = p.get_var("tensor_expand_dims", dtype)
+        tensor1 = p.get_var("tensor1", dtype)
         mod["main"] = relay.Function([x], expand_dims_func(tensor1(x)))
         x_np = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype)
         expected = [np.expand_dims(x_np, axis=0)]
         check_tensor_array(mod, expected, x_np)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
 def test_tensor_array_constructor():
     def run(dtype):
-        x = relay.var('x')
+        x = relay.var("x")
         mod = tvm.IRModule()
         p = Prelude(mod)
-        tensor_array = p.get_var('tensor_array', dtype)
+        tensor_array = p.get_var("tensor_array", dtype)
         mod["main"] = relay.Function([x], tensor_array(x))
         expected = np.array([0, 0, 0, 0, 0])
         check_tensor_array(mod, expected, 5, dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -800,16 +841,17 @@ def test_tensor_array_read():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        l = relay.var('l')
-        i = relay.var('i')
-        read_func = p.get_var('tensor_array_read', dtype)
-        tensor_array = p.get_var('tensor_array', dtype)
+        l = relay.var("l")
+        i = relay.var("i")
+        read_func = p.get_var("tensor_array_read", dtype)
+        tensor_array = p.get_var("tensor_array", dtype)
         mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
         expected = [0]
         check_tensor_array(mod, expected, *(1, 0), dtype=dtype)
         check_tensor_array(mod, expected, *(5, 1), dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -817,20 +859,20 @@ def test_tensor_array_write():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        tensor_array = p.get_var('tensor_array', dtype)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        tensor_array = p.get_var("tensor_array", dtype)
         init_tensor_array = tensor_array(relay.const(2))
-        write_func = p.get_var('tensor_array_write', dtype)
-        tensor1 = p.get_var('tensor1', dtype)
-        tensor_array1 = write_func(init_tensor_array, relay.const(0),
-                                   tensor1(v1))
+        write_func = p.get_var("tensor_array_write", dtype)
+        tensor1 = p.get_var("tensor1", dtype)
+        tensor_array1 = write_func(init_tensor_array, relay.const(0), tensor1(v1))
         tensor_array2 = write_func(tensor_array1, relay.const(1), tensor1(v2))
         mod["main"] = relay.Function([v1, v2], tensor_array2)
         expected = [3, 7]
         check_tensor_array(mod, expected, *(3, 7), dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -838,11 +880,11 @@ def test_tensor_array_stack():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        tensor_array = p.get_var('tensor_array', dtype)
-        tensor1 = p.get_var('tensor1', dtype)
-        write = p.get_var('tensor_array_write', dtype)
-        stack = p.get_var('tensor_array_stack', dtype)
-        v = relay.var('v')
+        tensor_array = p.get_var("tensor_array", dtype)
+        tensor1 = p.get_var("tensor1", dtype)
+        write = p.get_var("tensor_array_write", dtype)
+        stack = p.get_var("tensor_array_stack", dtype)
+        v = relay.var("v")
         init_tensor_array = tensor_array(relay.const(3))
         tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
         tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
@@ -852,8 +894,9 @@ def test_tensor_array_stack():
         t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype)
         expected = [np.stack([t, t, t])]
         check_tensor_array(mod, expected, t, dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -861,13 +904,14 @@ def test_tensor_array_unstack():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
-        v = relay.var('v')
+        unstack_tensor1 = p.get_var("tensor_array_unstack_tensor1", dtype)
+        v = relay.var("v")
         mod["main"] = relay.Function([v], unstack_tensor1(v))
         t = np.random.uniform(low=0.0, high=8.0, size=(1,)).astype(dtype)
         check_tensor_array(mod, t, t, dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -875,19 +919,20 @@ def test_tensor_take():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        take = p.get_var('tensor_take', dtype)
-        tensor2 = p.get_var('tensor2', dtype)
-        v = relay.var('v')
-        lower = relay.var('lower')
-        upper = relay.var('upper')
+        take = p.get_var("tensor_take", dtype)
+        tensor2 = p.get_var("tensor2", dtype)
+        v = relay.var("v")
+        lower = relay.var("lower")
+        upper = relay.var("upper")
         mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper))
         v_data = np.random.uniform(low=0.0, high=8.0, size=(10, 10)).astype(dtype)
         expected = [np.take(v_data, range(2, 5), axis=0)]
         check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype)
         expected = [np.take(v_data, range(0, 9), axis=0)]
         check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -895,18 +940,18 @@ def test_tensor_concatenate():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        concat = p.get_var('tensor_concatenate', dtype)
-        tensor1 = p.get_var('tensor1', dtype)
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        mod["main"] = relay.Function([v1, v2], concat(tensor1(v1),
-                                                      tensor1(v2)))
+        concat = p.get_var("tensor_concatenate", dtype)
+        tensor1 = p.get_var("tensor1", dtype)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        mod["main"] = relay.Function([v1, v2], concat(tensor1(v1), tensor1(v2)))
         v1_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype)
         v2_data = np.random.uniform(low=0.0, high=8.0, size=(5,)).astype(dtype)
         expected = [np.concatenate((v1_data, v2_data))]
         check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -914,13 +959,13 @@ def test_tensor_array_concat():
     def run(dtype):
         mod = tvm.IRModule()
         p = Prelude(mod)
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        tensor_array = p.get_var('tensor_array', dtype)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        tensor_array = p.get_var("tensor_array", dtype)
         tensor_array1 = tensor_array(relay.const(2))
-        write_func = p.get_var('tensor_array_write', dtype)
-        concat_func = p.get_var('tensor_array_concat', dtype)
-        tensor1 = p.get_var('tensor2', dtype)
+        write_func = p.get_var("tensor_array_write", dtype)
+        concat_func = p.get_var("tensor_array_concat", dtype)
+        tensor1 = p.get_var("tensor2", dtype)
         tensor_array1 = write_func(tensor_array1, relay.const(0), tensor1(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor1(v2))
         tensor_array_concat = concat_func(tensor_array1)
@@ -929,8 +974,9 @@ def test_tensor_array_concat():
         v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype)
         expected = [np.concatenate((v1_data, v2_data), axis=0)]
         check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
-    run('float32')
-    run('int32')
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -940,34 +986,31 @@ def test_tensor_array_scatter():
         p = Prelude(mod)
 
         # tensor array
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        v3 = relay.var('v2')
-        tensor_array = p.get_var('tensor_array', dtype)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        v3 = relay.var("v2")
+        tensor_array = p.get_var("tensor_array", dtype)
         tensor_array1 = tensor_array(relay.const(3))
-        write_func = p.get_var('tensor_array_write', dtype)
-        scatter_func = p.get_var('tensor_array_scatter', dtype)
-        tensor2 = p.get_var('tensor2', dtype)
+        write_func = p.get_var("tensor_array_write", dtype)
+        scatter_func = p.get_var("tensor_array_scatter", dtype)
+        tensor2 = p.get_var("tensor2", dtype)
         tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2))
         tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3))
 
         # indices array
-        index = relay.var('index')
+        index = relay.var("index")
 
         # values array
-        value_0 = relay.var('value_0')
-        value_1 = relay.var('value_1')
+        value_0 = relay.var("value_0")
+        value_1 = relay.var("value_1")
         values_array = tensor_array(relay.const(2))
-        values_array = write_func(values_array, relay.const(0),
-                                  tensor2(value_0))
-        values_array = write_func(values_array, relay.const(1),
-                                  tensor2(value_1))
+        values_array = write_func(values_array, relay.const(0), tensor2(value_0))
+        values_array = write_func(values_array, relay.const(1), tensor2(value_1))
 
         # create the scatter function
         tensor_array_scatter = scatter_func(tensor_array1, index, values_array)
-        mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1],
-                                     tensor_array_scatter)
+        mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter)
 
         # initialize and check
         v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
@@ -977,11 +1020,15 @@ def test_tensor_array_scatter():
         val1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
         val2_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
         expected = [val1_data, val2_data, v3_data]
-        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
-                                            index_data, val1_data,
-                                            val2_data), dtype=dtype)
-    run('float32')
-    run('int32')
+        check_tensor_array(
+            mod,
+            expected,
+            *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data),
+            dtype=dtype,
+        )
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -991,28 +1038,27 @@ def test_tensor_array_split():
         p = Prelude(mod)
 
         # tensor array
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        v3 = relay.var('v2')
-        tensor_array = p.get_var('tensor_array', dtype)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        v3 = relay.var("v2")
+        tensor_array = p.get_var("tensor_array", dtype)
         tensor_array1 = tensor_array(relay.const(3))
-        write_func = p.get_var('tensor_array_write', dtype)
-        split_func = p.get_var('tensor_array_split', dtype)
-        tensor2 = p.get_var('tensor2', dtype)
+        write_func = p.get_var("tensor_array_write", dtype)
+        split_func = p.get_var("tensor_array_split", dtype)
+        tensor2 = p.get_var("tensor2", dtype)
         tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2))
         tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3))
 
         # value tensor
-        value = relay.var('value')
+        value = relay.var("value")
 
         # lengths tensor
-        ta_len = relay.var('length')
+        ta_len = relay.var("length")
 
         # create the scatter function
         tensor_array_split = split_func(tensor_array1, tensor2(value), ta_len)
-        mod["main"] = relay.Function([v1, v2, v3, value, ta_len],
-                                     tensor_array_split)
+        mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split)
 
         # initialize and check
         v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype)
@@ -1022,11 +1068,12 @@ def test_tensor_array_split():
         length_data = np.array([2, 2], dtype="int32")
         expected = np.concatenate([value_data, v3_data])
         expected = np.split(expected, indices_or_sections=[2, 4])
-        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
-                                            value_data, length_data),
-                           dtype=dtype)
-    run('float32')
-    run('int32')
+        check_tensor_array(
+            mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype
+        )
+
+    run("float32")
+    run("int32")
 
 
 @tvm.testing.uses_gpu
@@ -1037,19 +1084,20 @@ def test_static_tensor_take():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        take = p.get_var_static('tensor_take', dtype, shape)
-        tensor_constructor = p.get_var_static('tensor_constructor', dtype, shape)
-        v = relay.var('v')
-        lower = relay.var('lower')
-        upper = relay.var('upper')
+        take = p.get_var_static("tensor_take", dtype, shape)
+        tensor_constructor = p.get_var_static("tensor_constructor", dtype, shape)
+        v = relay.var("v")
+        lower = relay.var("lower")
+        upper = relay.var("upper")
         mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper))
         v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         expected = [np.take(v_data, range(2, 5), axis=0)]
         check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype)
         expected = [np.take(v_data, range(0, 9), axis=0)]
         check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype)
-    run('float32', [10, 10])
-    run('int32', [15, 11])
+
+    run("float32", [10, 10])
+    run("int32", [15, 11])
 
 
 @tvm.testing.uses_gpu
@@ -1060,37 +1108,48 @@ def test_static_tensor_concatenate():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        concat = p.get_var_static('tensor_concatenate', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        mod["main"] = relay.Function([v1, v2], concat(tensor(v1),
-                                                      tensor(v2)))
+        concat = p.get_var_static("tensor_concatenate", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        mod["main"] = relay.Function([v1, v2], concat(tensor(v1), tensor(v2)))
         v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         expected = [np.concatenate((v1_data, v2_data))]
         check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
-    run('float32', [5,])
-    run('int32', [2, 3])
+
+    run(
+        "float32",
+        [
+            5,
+        ],
+    )
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
 def test_static_tensor_expand_dims():
     def run(dtype, shape):
-        x = relay.var('x')
+        x = relay.var("x")
         mod = tvm.IRModule()
         p = Prelude(mod)
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        expand_dims_func = p.get_var_static('tensor_expand_dims', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        expand_dims_func = p.get_var_static("tensor_expand_dims", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
         mod["main"] = relay.Function([x], expand_dims_func(tensor(x)))
         x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         expected = [np.expand_dims(x_np, axis=0)]
         check_tensor_array(mod, expected, x_np)
-    run('float32', [])
-    run('int32', [2,])
+
+    run("float32", [])
+    run(
+        "int32",
+        [
+            2,
+        ],
+    )
 
 
 @tvm.testing.uses_gpu
@@ -1100,9 +1159,10 @@ def test_static_tensor_array_constructor():
         p = Prelude(mod)
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
-        tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape)
+        tensor_constructor = p.get_name_static("tensor_constructor", dtype, shape)
         assert tensor_constructor != None
-    run('float32', [1, 1])
+
+    run("float32", [1, 1])
 
 
 @tvm.testing.uses_gpu
@@ -1118,21 +1178,18 @@ def test_static_tensor_array_read():
         for _ in range(ta_length):
             np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
 
-        v0 = relay.var('v0')
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        n = relay.var('n')
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        v0 = relay.var("v0")
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        n = relay.var("n")
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
         init_tensor_array = tensor_array(relay.const(ta_length))
-        read_func = p.get_var_static('tensor_array_read', dtype, shape)
-        write_func = p.get_var_static('tensor_array_write', dtype, shape)
-        tensor_array0 = write_func(init_tensor_array, relay.const(0),
-                                   tensor(v0))
-        tensor_array1 = write_func(tensor_array0, relay.const(1),
-                                   tensor(v1))
-        tensor_array2 = write_func(tensor_array1, relay.const(2),
-                                   tensor(v2))
+        read_func = p.get_var_static("tensor_array_read", dtype, shape)
+        write_func = p.get_var_static("tensor_array_write", dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0))
+        tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1))
+        tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2))
 
         mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n))
         expected = [np_data_list[0]]
@@ -1141,8 +1198,9 @@ def test_static_tensor_array_read():
         check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
         expected = [np_data_list[2]]
         check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
-    run('float32', [])
-    run('int32', [2, 3])
+
+    run("float32", [])
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
@@ -1154,22 +1212,24 @@ def test_static_tensor_array_write():
         static_tensor_array_ops.register()
 
         ta_length = 2
-        np_data_list = [np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length)]
+        np_data_list = [
+            np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length)
+        ]
 
-        v0 = relay.var('v0')
-        v1 = relay.var('v1')
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        v0 = relay.var("v0")
+        v1 = relay.var("v1")
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
         init_tensor_array = tensor_array(relay.const(ta_length))
-        write_func = p.get_var_static('tensor_array_write', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        tensor_array0 = write_func(init_tensor_array, relay.const(0),
-                                   tensor(v0))
+        write_func = p.get_var_static("tensor_array_write", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0))
         tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1))
         mod["main"] = relay.Function([v0, v1], tensor_array1)
         expected = np_data_list
         check_tensor_array(mod, expected, *np_data_list, dtype=dtype)
-    run('float32', [])
-    run('int32', [2, 3])
+
+    run("float32", [])
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
@@ -1180,14 +1240,15 @@ def test_static_tensor_array_unstack():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        unstack_tensor = p.get_var_static('tensor_array_unstack', dtype, shape)
-        v = relay.var('v')
+        unstack_tensor = p.get_var_static("tensor_array_unstack", dtype, shape)
+        v = relay.var("v")
         mod["main"] = relay.Function([v], unstack_tensor(v))
         t = np.random.uniform(low=0, high=10, size=shape).astype(dtype)
-        *expected, = t
+        (*expected,) = t
         check_tensor_array(mod, expected, t, dtype=dtype)
-    run('float32', [4])
-    run('int32', [2, 3])
+
+    run("float32", [4])
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
@@ -1201,34 +1262,31 @@ def test_static_tensor_array_scatter():
             static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
 
         # tensor array
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        v3 = relay.var('v2')
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        v3 = relay.var("v2")
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
         tensor_array0 = tensor_array(relay.const(3))
-        write_func = p.get_var_static('tensor_array_write', dtype, shape)
-        scatter_func = p.get_var_static('tensor_array_scatter', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        write_func = p.get_var_static("tensor_array_write", dtype, shape)
+        scatter_func = p.get_var_static("tensor_array_scatter", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
         tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
         tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
 
         # indices array
-        index = relay.var('index')
+        index = relay.var("index")
 
         # values array
-        value_0 = relay.var('value_0')
-        value_1 = relay.var('value_1')
+        value_0 = relay.var("value_0")
+        value_1 = relay.var("value_1")
         values_array = tensor_array(relay.const(2))
-        values_array = write_func(values_array, relay.const(0),
-                                  tensor(value_0))
-        values_array = write_func(values_array, relay.const(1),
-                                  tensor(value_1))
+        values_array = write_func(values_array, relay.const(0), tensor(value_0))
+        values_array = write_func(values_array, relay.const(1), tensor(value_1))
 
         # create the scatter function
         tensor_array_scatter = scatter_func(tensor_array1, index, values_array)
-        mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1],
-                                     tensor_array_scatter)
+        mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], tensor_array_scatter)
 
         # initialize and check
         v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
@@ -1238,12 +1296,22 @@ def test_static_tensor_array_scatter():
         val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         expected = [val1_data, val2_data, v3_data]
-        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
-                                            index_data, val1_data,
-                                            val2_data), dtype=dtype)
-    run('float32', [2, 3])
-    run('int32', [2, 3])
-    run('float32', [2, 3], [2,])
+        check_tensor_array(
+            mod,
+            expected,
+            *(v1_data, v2_data, v3_data, index_data, val1_data, val2_data),
+            dtype=dtype,
+        )
+
+    run("float32", [2, 3])
+    run("int32", [2, 3])
+    run(
+        "float32",
+        [2, 3],
+        [
+            2,
+        ],
+    )
 
 
 @tvm.testing.uses_gpu
@@ -1257,57 +1325,64 @@ def test_static_tensor_array_split():
             static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True)
 
         # tensor array
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        v3 = relay.var('v2')
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        v3 = relay.var("v2")
 
-        adt_shape = [relay.Any(),] + shape[1:]
+        adt_shape = [
+            relay.Any(),
+        ] + shape[1:]
         origin_shape = static_tensor_array_ops.shape
         static_tensor_array_ops.shape = adt_shape
         static_tensor_array_ops.define_tensor_array()
-        tensor_array = p.get_var_static('tensor_array', dtype, adt_shape)
+        tensor_array = p.get_var_static("tensor_array", dtype, adt_shape)
         static_tensor_array_ops.shape = origin_shape
         tensor_array1 = tensor_array(relay.const(3))
-        write_func = p.get_var_static('tensor_array_write', dtype, adt_shape)
-        split_func = p.get_var_static('tensor_array_split', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, adt_shape)
+        write_func = p.get_var_static("tensor_array_write", dtype, adt_shape)
+        split_func = p.get_var_static("tensor_array_split", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, adt_shape)
         tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
         tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3))
 
         # value tensor
-        value = relay.var('value')
+        value = relay.var("value")
 
         # lengths tensor
-        ta_len = relay.var('length')
+        ta_len = relay.var("length")
 
         # create the split function
         if value_shape is None:
-            tensor1 = p.get_var_static('tensor_constructor', dtype, shape)
+            tensor1 = p.get_var_static("tensor_constructor", dtype, shape)
         else:
             static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape)
             static_tensor_array_ops.register()
-            tensor1 = p.get_var_static('tensor_constructor', dtype, value_shape)
+            tensor1 = p.get_var_static("tensor_constructor", dtype, value_shape)
         tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len)
-        mod["main"] = relay.Function([v1, v2, v3, value, ta_len],
-                                     tensor_array_split)
+        mod["main"] = relay.Function([v1, v2, v3, value, ta_len], tensor_array_split)
 
         # initialize and check
         v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
         v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
         v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype)
-        value_data = np.random.uniform(low=0.0, high=8.0,
-                                       size=value_shape or shape).astype(dtype)
+        value_data = np.random.uniform(low=0.0, high=8.0, size=value_shape or shape).astype(dtype)
         length_data = np.array([2, 2], dtype="int32")
         expected = np.concatenate([value_data, v3_data])
         expected = np.split(expected, indices_or_sections=[2, 4])
-        check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
-                                            value_data, length_data),
-                           dtype=dtype)
-
-    run('float32', [4, 3])
-    run('int32', [4, 3])
-    run('int32', [relay.Any(), 3], [4, 3], [2,])
+        check_tensor_array(
+            mod, expected, *(v1_data, v2_data, v3_data, value_data, length_data), dtype=dtype
+        )
+
+    run("float32", [4, 3])
+    run("int32", [4, 3])
+    run(
+        "int32",
+        [relay.Any(), 3],
+        [4, 3],
+        [
+            2,
+        ],
+    )
 
 
 @tvm.testing.uses_gpu
@@ -1318,13 +1393,13 @@ def test_static_tensor_array_concat():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
         tensor_array1 = tensor_array(relay.const(2))
-        write_func = p.get_var_static('tensor_array_write', dtype, shape)
-        concat_func = p.get_var_static('tensor_array_concat', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
+        write_func = p.get_var_static("tensor_array_write", dtype, shape)
+        concat_func = p.get_var_static("tensor_array_concat", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
         tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1))
         tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2))
         tensor_array_concat = concat_func(tensor_array1)
@@ -1333,8 +1408,9 @@ def test_static_tensor_array_concat():
         v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype)
         expected = [np.concatenate((v1_data, v2_data), axis=0)]
         check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
-    run('float32', [relay.Any(), 3])
-    run('int32', [relay.Any(), 3])
+
+    run("float32", [relay.Any(), 3])
+    run("int32", [relay.Any(), 3])
 
 
 @tvm.testing.uses_gpu
@@ -1345,12 +1421,12 @@ def test_static_tensor_array_gather():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        write = p.get_var_static('tensor_array_write', dtype, shape)
-        gather = p.get_var_static('tensor_array_gather', dtype, shape)
-        v = relay.var('v')
-        indice = relay.var('indice')
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        write = p.get_var_static("tensor_array_write", dtype, shape)
+        gather = p.get_var_static("tensor_array_gather", dtype, shape)
+        v = relay.var("v")
+        indice = relay.var("indice")
         init_tensor_array = tensor_array(relay.const(3))
         tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
         tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
@@ -1361,8 +1437,9 @@ def test_static_tensor_array_gather():
         indice_data = np.array([0, 2], dtype="int32")
         expected = [np.stack([t, t])]
         check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype)
-    run('float32', [])
-    run('int32', [2, 3])
+
+    run("float32", [])
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
@@ -1373,11 +1450,11 @@ def test_static_tensor_array_stack():
         static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
         static_tensor_array_ops.register()
 
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        write = p.get_var_static('tensor_array_write', dtype, shape)
-        stack = p.get_var_static('tensor_array_stack', dtype, shape)
-        v = relay.var('v')
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        write = p.get_var_static("tensor_array_write", dtype, shape)
+        stack = p.get_var_static("tensor_array_stack", dtype, shape)
+        v = relay.var("v")
         init_tensor_array = tensor_array(relay.const(3))
         tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v))
         tensor_array2 = write(tensor_array1, relay.const(1), tensor(v))
@@ -1387,8 +1464,9 @@ def test_static_tensor_array_stack():
         t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype)
         expected = [np.stack([t, t, t])]
         check_tensor_array(mod, expected, t, dtype=dtype)
-    run('float32', [])
-    run('int32', [2, 3])
+
+    run("float32", [])
+    run("int32", [2, 3])
 
 
 @tvm.testing.uses_gpu
@@ -1404,22 +1482,19 @@ def test_static_tensor_get_data():
         for _ in range(ta_length):
             np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype))
 
-        v0 = relay.var('v0')
-        v1 = relay.var('v1')
-        v2 = relay.var('v2')
-        n = relay.var('n')
-        tensor = p.get_var_static('tensor_constructor', dtype, shape)
-        tensor_array = p.get_var_static('tensor_array', dtype, shape)
+        v0 = relay.var("v0")
+        v1 = relay.var("v1")
+        v2 = relay.var("v2")
+        n = relay.var("n")
+        tensor = p.get_var_static("tensor_constructor", dtype, shape)
+        tensor_array = p.get_var_static("tensor_array", dtype, shape)
         init_tensor_array = tensor_array(relay.const(ta_length))
-        read_func = p.get_var_static('tensor_array_read', dtype, shape)
-        write_func = p.get_var_static('tensor_array_write', dtype, shape)
-        get_data_func = p.get_var_static('tensor_get_data', dtype, shape)
-        tensor_array0 = write_func(init_tensor_array, relay.const(0),
-                                   tensor(v0))
-        tensor_array1 = write_func(tensor_array0, relay.const(1),
-                                   tensor(v1))
-        tensor_array2 = write_func(tensor_array1, relay.const(2),
-                                   tensor(v2))
+        read_func = p.get_var_static("tensor_array_read", dtype, shape)
+        write_func = p.get_var_static("tensor_array_write", dtype, shape)
+        get_data_func = p.get_var_static("tensor_get_data", dtype, shape)
+        tensor_array0 = write_func(init_tensor_array, relay.const(0), tensor(v0))
+        tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1))
+        tensor_array2 = write_func(tensor_array1, relay.const(2), tensor(v2))
 
         mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n)))
         expected = [np_data_list[0]]
@@ -1428,8 +1503,10 @@ def test_static_tensor_get_data():
         check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype)
         expected = [np_data_list[2]]
         check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype)
-    run('float32', [])
-    run('int32', [2, 3])
+
+    run("float32", [])
+    run("int32", [2, 3])
+
 
 if __name__ == "__main__":
     pytest.main([__file__])
index dfd7dd1..5395be3 100644 (file)
@@ -20,22 +20,25 @@ import tvm
 from tvm import relay
 from tvm.relay.analysis import check_basic_block_normal_form
 
+
 def test_one_block():
-    x = relay.var('x')
+    x = relay.var("x")
     y = relay.add(x, x)
     z = relay.add(x, y)
     check_basic_block_normal_form(z)
 
+
 def test_let():
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     body = relay.Let(y, x, y)
     check_basic_block_normal_form(body)
 
+
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_if():
-    cond = relay.var('cond', dtype='bool', shape=())
-    shared = relay.var('shared')
+    cond = relay.var("cond", dtype="bool", shape=())
+    shared = relay.var("shared")
     true_branch = shared
     false_branch = relay.add(shared, shared)
     body = relay.If(cond, true_branch, false_branch)
@@ -53,13 +56,14 @@ def test_invalid_if():
     """
     check_basic_block_normal_form(body)
 
+
 def test_valid_if():
-    cond = relay.var('cond', dtype='bool', shape=())
-    shared = relay.var('shared')
+    cond = relay.var("cond", dtype="bool", shape=())
+    shared = relay.var("shared")
     true_branch = shared
     false_branch = relay.add(shared, shared)
     body = relay.If(cond, true_branch, false_branch)
-    shared_bound = relay.var('shared_bound', shape=(1,), dtype='float32')
+    shared_bound = relay.var("shared_bound", shape=(1,), dtype="float32")
     body = relay.Let(shared, shared_bound, body)
     """
     The program below uses let binding to control the scope of %shared, which
@@ -76,6 +80,7 @@ def test_valid_if():
     """
     check_basic_block_normal_form(body)
 
+
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_if2():
     """
@@ -89,9 +94,9 @@ def test_invalid_if2():
       }
     }
     """
-    x = relay.var('x', shape=(), dtype='float32')
-    one = relay.const(1, dtype='float32')
-    two = relay.const(2, dtype='float32')
+    x = relay.var("x", shape=(), dtype="float32")
+    one = relay.const(1, dtype="float32")
+    two = relay.const(2, dtype="float32")
     v1 = relay.add(x, one)
     v2 = relay.equal(x, two)
     true_branch = relay.multiply(v1, two)
@@ -100,6 +105,7 @@ def test_invalid_if2():
     func = relay.Function([x], body)
     check_basic_block_normal_form(func)
 
+
 def test_valid_if2():
     """
     fn (%x: float32) {
@@ -112,10 +118,10 @@ def test_valid_if2():
       }
     }
     """
-    x = relay.var('x', shape=(), dtype='float32')
-    one = relay.const(1, dtype='float32')
-    two = relay.const(2, dtype='float32')
-    v1 = relay.var('v1')
+    x = relay.var("x", shape=(), dtype="float32")
+    one = relay.const(1, dtype="float32")
+    two = relay.const(2, dtype="float32")
+    v1 = relay.var("v1")
     v2 = relay.equal(x, two)
     true_branch = relay.multiply(v1, two)
     false_branch = relay.multiply(v1, one)
@@ -124,14 +130,15 @@ def test_valid_if2():
     func = relay.Function([x], body)
     check_basic_block_normal_form(func)
 
+
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_func():
-    x = relay.var('x', shape=(1,), dtype='float32')#, a)
-    y = relay.var('y', shape=(1,), dtype='float32')#, a)
-    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x = relay.var("x", shape=(1,), dtype="float32")  # , a)
+    y = relay.var("y", shape=(1,), dtype="float32")  # , a)
+    z = relay.var("z", shape=(1,), dtype="float32")  # , a)
     x2 = relay.add(x, x)
-    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
-    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    func_a = relay.Function([y], relay.add(x2, y))  # , a, [a])
+    func_b = relay.Function([z], relay.add(x2, z))  # , a, [a])
     body = relay.Tuple([func_a, func_b])
     body = relay.Function([x], body)
     """
@@ -148,14 +155,15 @@ def test_func():
     """
     check_basic_block_normal_form(body)
 
+
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_higher_order_return():
-    x = relay.var('x', shape=(1,), dtype='float32')#, a)
-    y = relay.var('y', shape=(1,), dtype='float32')#, a)
-    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x = relay.var("x", shape=(1,), dtype="float32")  # , a)
+    y = relay.var("y", shape=(1,), dtype="float32")  # , a)
+    z = relay.var("z", shape=(1,), dtype="float32")  # , a)
     x2 = relay.add(x, x)
-    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
-    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    func_a = relay.Function([y], relay.add(x2, y))  # , a, [a])
+    func_b = relay.Function([z], relay.add(x2, z))  # , a, [a])
     body = relay.Tuple([func_a, func_b])
     body = relay.Function([x], body)
     """
@@ -175,13 +183,13 @@ def test_higher_order_return():
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_higher_order_nested():
-    x = relay.var('x', dtype='float32', shape=(1,))
-    s = relay.var('s', dtype='float32', shape=(1,))
+    x = relay.var("x", dtype="float32", shape=(1,))
+    s = relay.var("s", dtype="float32", shape=(1,))
     shared = relay.add(s, s)
     func_true = relay.Function([x], relay.add(x, shared))
-    choice_t = relay.FuncType([], relay.scalar_type('bool'))
-    f = relay.Var('f', choice_t)
-    z = relay.Var('z')
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
+    z = relay.Var("z")
     body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
     top = relay.Function([f, s], body)
     """
@@ -202,5 +210,5 @@ def test_higher_order_nested():
     check_basic_block_normal_form(top)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     pytest.main([__file__])
index 41f6ca0..5734f4e 100644 (file)
@@ -34,23 +34,11 @@ def get_conv_net():
     """
     dshape = (1, 1, 5, 1)
     x = relay.var("x", shape=dshape)
-    y = relay.nn.conv2d(x, relay.var("w1"),
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        channels=1)
-
-    x1 = relay.nn.conv2d(y, relay.var("w2"),
-                         kernel_size=(3, 3),
-                         padding=(1, 1),
-                         channels=1)
-    x2 = relay.nn.conv2d(y, relay.var("w3"),
-                         kernel_size=(3, 3),
-                         padding=(1, 1),
-                         channels=1)
-    x3 = relay.nn.conv2d(y, relay.var("w4"),
-                         kernel_size=(3, 3),
-                         padding=(1, 1),
-                         channels=1)
+    y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+
+    x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+    x2 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+    x3 = relay.nn.conv2d(y, relay.var("w4"), kernel_size=(3, 3), padding=(1, 1), channels=1)
 
     z = relay.add(x1, x2)
     z = relay.add(x3, z)
@@ -60,13 +48,16 @@ def get_conv_net():
 
 def get_conv2d():
     x = relay.var("x", shape=(1, 56, 56, 64))
-    weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-    y = relay.nn.conv2d(x, weight1,
-                        channels=32,
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        data_layout='NHWC',
-                        kernel_layout='HWIO')
+    weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+    y = relay.nn.conv2d(
+        x,
+        weight1,
+        channels=32,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+    )
     return tvm.IRModule.from_expr(y)
 
 
@@ -75,8 +66,7 @@ def test_extract_identity():
     items = relay.analysis.extract_fused_functions(mod)
     assert len(items) == 1
 
-    mod["main"] = mod["main"].with_attr(
-        "Primitive", tvm.tir.IntImm("int32", 1))
+    mod["main"] = mod["main"].with_attr("Primitive", tvm.tir.IntImm("int32", 1))
     tvm.ir.structural_equal(list(items.values())[0], mod["main"])
 
 
@@ -109,7 +99,7 @@ def test_extract_resnet():
     assert len(items) == 6
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_extract_identity()
     test_extract_conv_net()
     test_extract_resnet()
index 2b32376..6ac7085 100644 (file)
@@ -23,28 +23,31 @@ from tvm.relay.transform import gradient
 from tvm.relay.prelude import Prelude
 from tvm.relay.testing import run_infer_type
 
+
 def test_prelude():
     p = Prelude()
     feats = detect_feature(p.mod)
-    assert feats == set([
-        Feature.fVar,
-        Feature.fGlobalVar,
-        Feature.fConstant,
-        Feature.fTuple,
-        Feature.fTupleGetItem,
-        Feature.fFunction,
-        Feature.fOp,
-        Feature.fCall,
-        Feature.fLet,
-        Feature.fIf,
-        Feature.fConstructor,
-        Feature.fMatch,
-    ])
+    assert feats == set(
+        [
+            Feature.fVar,
+            Feature.fGlobalVar,
+            Feature.fConstant,
+            Feature.fTuple,
+            Feature.fTupleGetItem,
+            Feature.fFunction,
+            Feature.fOp,
+            Feature.fCall,
+            Feature.fLet,
+            Feature.fIf,
+            Feature.fConstructor,
+            Feature.fMatch,
+        ]
+    )
 
 
 def test_ad():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x + x)
@@ -53,20 +56,22 @@ def test_ad():
     mod = relay.transform.InferType()(mod)
     back_func = mod["main"]
     feats = detect_feature(back_func)
-    assert feats == set([
-        Feature.fVar,
-        Feature.fTuple,
-        Feature.fTupleGetItem,
-        Feature.fFunction,
-        Feature.fOp,
-        Feature.fCall,
-        Feature.fLet,
-        Feature.fRefCreate,
-        Feature.fRefRead,
-        Feature.fRefWrite,
-    ])
+    assert feats == set(
+        [
+            Feature.fVar,
+            Feature.fTuple,
+            Feature.fTupleGetItem,
+            Feature.fFunction,
+            Feature.fOp,
+            Feature.fCall,
+            Feature.fLet,
+            Feature.fRefCreate,
+            Feature.fRefRead,
+            Feature.fRefWrite,
+        ]
+    )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_prelude()
     test_ad()
index 9a29f2e..72c1c81 100644 (file)
@@ -34,12 +34,13 @@ def check_data_size(mod, data):
             else:
                 assert len(data[key]["outputs"]) == 1
 
+
 def test_simple_graph():
     # A module with two subgraphs
     mod = tvm.IRModule()
 
-    x0 = relay.var('x0', shape=(8, 8))
-    y0 = relay.var('y0', shape=(8, 8))
+    x0 = relay.var("x0", shape=(8, 8))
+    y0 = relay.var("y0", shape=(8, 8))
     z0 = x0 + y0
     z1 = x0 - y0
     z2 = relay.Tuple((z0, z1))
@@ -48,26 +49,25 @@ def test_simple_graph():
     g0 = relay.GlobalVar("g0")
     mod[g0] = f0
 
-    x1 = relay.var('x1', shape=(8, 8))
-    y1 = relay.var('y1', shape=(8, 8))
+    x1 = relay.var("x1", shape=(8, 8))
+    y1 = relay.var("y1", shape=(8, 8))
     z1 = x1 - y1
     f1 = relay.Function([x1, y1], z1)
     f1 = f1.with_attr("Compiler", "test_graph")
     g1 = relay.GlobalVar("g1")
     mod[g1] = f1
 
-
-    x = relay.var('x', shape=(8, 8))
-    y = relay.var('y', shape=(8, 8))
-    z = relay.var('z', shape=(8, 8))
+    x = relay.var("x", shape=(8, 8))
+    y = relay.var("y", shape=(8, 8))
+    z = relay.var("z", shape=(8, 8))
     c0 = relay.Call(g0, [x, y])
     c1 = relay.Call(g1, [relay.TupleGetItem(c0, 0), z])
     fm = relay.Function([x, y, z], c1)
     mod["main"] = fm
 
-    x_data = np.random.rand(8, 8).astype('float32')
-    y_data = np.random.rand(8, 8).astype('float32')
-    z_data = np.random.rand(8, 8).astype('float32')
+    x_data = np.random.rand(8, 8).astype("float32")
+    y_data = np.random.rand(8, 8).astype("float32")
+    z_data = np.random.rand(8, 8).astype("float32")
     data = get_calibration_data(mod, {"x": x_data, "y": y_data, "z": z_data})
 
     # Check the number and orders
@@ -80,15 +80,15 @@ def test_simple_graph():
     tvm.testing.assert_allclose(data[g1]["inputs"][1].asnumpy(), z_data)
     tvm.testing.assert_allclose(data[g1]["outputs"][0].asnumpy(), x_data + y_data - z_data)
 
+
 def test_mobilenet_dnnl():
     if not tvm.get_global_func("relay.ext.dnnl", True):
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 3, 224, 224)
-    mod, params = relay.testing.mobilenet.get_workload(
-        batch_size=1, dtype='float32')
+    mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32")
 
     mod = transform.AnnotateTarget(["dnnl"])(mod)
     mod = transform.MergeCompilerRegions()(mod)
@@ -100,6 +100,7 @@ def test_mobilenet_dnnl():
     # Check the number and orders
     check_data_size(mod, data)
 
+
 if __name__ == "__main__":
     test_simple_graph()
     test_mobilenet_dnnl()
index f3c157d..17c4ff2 100644 (file)
@@ -30,51 +30,51 @@ def check_region(region_set, target, args, nodes, rets):
 
 
 def test_region_set_creator_diamond():
-    data = relay.var('data', shape=(10, 10))
-    cb_1 = compiler_begin(data, 'test_target')
+    data = relay.var("data", shape=(10, 10))
+    cb_1 = compiler_begin(data, "test_target")
     O_1 = relay.abs(cb_1)
-    ce_1 = compiler_end(O_1, 'test_target')
-    ce_2 = compiler_end(O_1, 'test_target')
-    cb_2 = compiler_begin(ce_1, 'test_target')
+    ce_1 = compiler_end(O_1, "test_target")
+    ce_2 = compiler_end(O_1, "test_target")
+    cb_2 = compiler_begin(ce_1, "test_target")
     O_2 = relay.nn.relu(cb_2)
-    ce_3 = compiler_end(O_2, 'test_target')
+    ce_3 = compiler_end(O_2, "test_target")
     cb_d = compiler_begin(ce_2, "default")
     X = relay.tanh(cb_d)
-    ce_d = compiler_end(X, 'default')
-    cb_3 = compiler_begin(ce_3, 'test_target')
-    cb_4 = compiler_begin(ce_d, 'test_target')
+    ce_d = compiler_end(X, "default")
+    cb_3 = compiler_begin(ce_3, "test_target")
+    cb_4 = compiler_begin(ce_d, "test_target")
     O_3 = relay.add(cb_3, cb_4)
-    ce_4 = compiler_end(O_3, 'test_target')
+    ce_4 = compiler_end(O_3, "test_target")
     diamond = relay.Function([data], ce_4)
 
-    region_set = relay.analysis.AnnotatedRegionSet(diamond,
-                                                   relay.op.get("annotation.compiler_begin"),
-                                                   relay.op.get("annotation.compiler_end"))
+    region_set = relay.analysis.AnnotatedRegionSet(
+        diamond, relay.op.get("annotation.compiler_begin"), relay.op.get("annotation.compiler_end")
+    )
     assert len(region_set) == 4
     check_region(
         region_set,
-        'test_target',
+        "test_target",
         [cb_1],
         [cb_1, O_1, ce_1, ce_2],
         [ce_1, ce_2],
     )
     check_region(
         region_set,
-        'test_target',
+        "test_target",
         [cb_2],
         [cb_2, O_2, ce_3],
         [ce_3],
     )
     check_region(
         region_set,
-        'default',
+        "default",
         [cb_d],
         [cb_d, X, ce_d],
         [ce_d],
     )
     check_region(
         region_set,
-        'test_target',
+        "test_target",
         [cb_3, cb_4],
         [cb_3, cb_4, O_3, ce_4],
         [ce_4],
@@ -82,44 +82,44 @@ def test_region_set_creator_diamond():
 
 
 def test_region_set_creator_merged():
-    data = relay.var('data', shape=(10, 10))
-    cb_1 = compiler_begin(data, 'test_target')
+    data = relay.var("data", shape=(10, 10))
+    cb_1 = compiler_begin(data, "test_target")
     O_1 = relay.abs(cb_1)
-    ce_2 = compiler_end(O_1, 'test_target')
+    ce_2 = compiler_end(O_1, "test_target")
     O_2 = relay.nn.relu(O_1)
-    ce_3 = compiler_end(O_2, 'test_target')
+    ce_3 = compiler_end(O_2, "test_target")
     cb_d = compiler_begin(ce_2, "default")
     X = relay.tanh(cb_d)
-    ce_d = compiler_end(X, 'default')
-    cb_3 = compiler_begin(ce_3, 'test_target')
-    cb_4 = compiler_begin(ce_d, 'test_target')
+    ce_d = compiler_end(X, "default")
+    cb_3 = compiler_begin(ce_3, "test_target")
+    cb_4 = compiler_begin(ce_d, "test_target")
     O_3 = relay.add(cb_3, cb_4)
     O_4 = relay.add(cb_3, cb_4)
     O_5 = relay.Tuple([O_3, O_4])
-    ce_4 = compiler_end(O_5, 'test_target')
+    ce_4 = compiler_end(O_5, "test_target")
     merged = relay.Function([data], ce_4)
 
-    region_set = relay.analysis.AnnotatedRegionSet(merged,
-                                                   relay.op.get("annotation.compiler_begin"),
-                                                   relay.op.get("annotation.compiler_end"))
+    region_set = relay.analysis.AnnotatedRegionSet(
+        merged, relay.op.get("annotation.compiler_begin"), relay.op.get("annotation.compiler_end")
+    )
     assert len(region_set) == 3
     check_region(
         region_set,
-        'test_target',
+        "test_target",
         [cb_1],
         [cb_1, O_1, O_2, ce_2, ce_3],
         [ce_2, ce_3],
     )
     check_region(
         region_set,
-        'default',
+        "default",
         [cb_d],
         [cb_d, X, ce_d],
         [ce_d],
     )
     check_region(
         region_set,
-        'test_target',
+        "test_target",
         [cb_3, cb_4],
         [cb_3, cb_4, O_3, O_4, O_5, ce_4],
         [ce_4],
index 3a46fdd..b2b0c19 100644 (file)
@@ -24,8 +24,10 @@ from tvm.relay.loops import while_loop
 from tvm.relay.testing import run_infer_type as infer_type
 import tvm.topi.testing
 
+
 def int32(val):
-    return relay.const(val, 'int32')
+    return relay.const(val, "int32")
+
 
 def any_dims(ndim):
     shape = []
@@ -33,20 +35,20 @@ def any_dims(ndim):
         shape.append(relay.Any())
     return tuple(shape)
 
-def check_result(args, mod, expected, flatten=False, assert_shape=False,
-                 only_vm=False):
+
+def check_result(args, mod, expected, flatten=False, assert_shape=False, only_vm=False):
     for kind in ["debug", "vm"]:
         for tgt, ctx in tvm.testing.enabled_targets():
-            if kind == "debug" and (only_vm or ctx.device_type !=
-                                    tvm.cpu().device_type):
+            if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
                 continue
             ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt)
             result = ex.evaluate()(*args)
             result = result.asnumpy()
             if assert_shape:
-                assert result.shape == expected, \
-                        "Shape mismatch: expect %s but got %s." \
-                        % (str(expected), str(result.shape))
+                assert result.shape == expected, "Shape mismatch: expect %s but got %s." % (
+                    str(expected),
+                    str(result.shape),
+                )
                 return
 
             if flatten:
@@ -54,10 +56,11 @@ def check_result(args, mod, expected, flatten=False, assert_shape=False,
                 expected = expected.flatten()
             tvm.testing.assert_allclose(result, expected)
 
+
 def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
-    dtype = 'float32'
-    x = relay.var('x', shape=x_shape, dtype=dtype)
-    y = relay.var('y', shape=y_shape, dtype=dtype)
+    dtype = "float32"
+    x = relay.var("x", shape=x_shape, dtype=dtype)
+    y = relay.var("y", shape=y_shape, dtype=dtype)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x, y], op(x, y))
     x_np = np.random.uniform(size=x_np_shape).astype(dtype)
@@ -65,6 +68,7 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
     res_np = np_op(x_np, y_np)
     check_result([x_np, y_np], mod, res_np)
 
+
 @tvm.testing.uses_gpu
 def test_any_broadcast():
     # Test broadcast with 1s
@@ -78,28 +82,30 @@ def test_any_broadcast():
     verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
     verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
 
+
 def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
-    dtype = 'float32'
-    x = relay.var('x', shape=x_shape, dtype=dtype)
+    dtype = "float32"
+    x = relay.var("x", shape=x_shape, dtype=dtype)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], op(x))
     x_np = np.random.uniform(size=x_np_shape).astype(dtype)
     res_np = np_op(x_np)
     check_result([x_np], mod, res_np)
 
+
 @tvm.testing.uses_gpu
 def test_any_elemwise():
     verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
     verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
     verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)
 
+
 @tvm.testing.uses_gpu
 def test_any_broadcast_fail():
     # Test broadcast with incompatible values at runtime
     def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
         try:
-            verify_any_broadcast(
-                x_shape, y_shape, x_np_shape, y_np_shape, op, np_op)
+            verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op)
         except tvm._ffi.base.TVMError:
             pass
         else:
@@ -112,33 +118,38 @@ def test_any_broadcast_fail():
     check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
 
 
-def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
-    x = relay.var('x', shape=x_shape, dtype=dtype)
+def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype="float32"):
+    x = relay.var("x", shape=x_shape, dtype=dtype)
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([x], relay_op(x))
+    mod["main"] = relay.Function([x], relay_op(x))
     x_np = np.random.uniform(size=x_np_shape).astype(dtype)
     res_np = np_op(x_np)
     check_result([x_np], mod, res_np)
 
+
 @tvm.testing.uses_gpu
 def test_any_full_like():
     # zeros_like, ones_like
     verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
     verify_any_full_like(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
-    verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
+    verify_any_full_like(
+        any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32"
+    )
     verify_any_full_like(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
     verify_any_full_like(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
     verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")
 
-def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None):
-    x = relay.var('x', shape=(len(x_np_shape),), dtype="int32")
+
+def verify_any_full(x_np_shape, relay_op, np_op, dtype="float32", value=None):
+    x = relay.var("x", shape=(len(x_np_shape),), dtype="int32")
     mod = tvm.IRModule()
     out = relay_op(x, dtype) if value is None else relay_op(relay.expr.const(value), x, dtype)
-    mod['main'] = relay.Function([x], out)
+    mod["main"] = relay.Function([x], out)
     res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value)
     x_np = np.array(x_np_shape).astype("int32")
     check_result([x_np], mod, res_np)
 
+
 @tvm.testing.uses_gpu
 def test_any_full():
     # zeros, ones, full
@@ -151,31 +162,33 @@ def test_any_full():
     verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0)
     verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2)
 
+
 @tvm.testing.uses_gpu
 def test_any_concat():
-    x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
-    y = relay.var('y', shape=(1, 2), dtype="float32")
+    x = relay.var("x", shape=(relay.Any(), 2), dtype="float32")
+    y = relay.var("y", shape=(1, 2), dtype="float32")
     xx = x - relay.expr.const(3.0)
     yy = y * relay.expr.const(5.0)
     z = relay.op.concatenate([xx, yy], axis=0)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x, y], z)
-    x_np = np.random.uniform(size=(3, 2)).astype('float32')
-    y_np = np.random.uniform(size=(1, 2)).astype('float32')
+    x_np = np.random.uniform(size=(3, 2)).astype("float32")
+    y_np = np.random.uniform(size=(1, 2)).astype("float32")
     ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
     check_result([x_np, y_np], mod, ref)
 
+
 def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
-    x = relay.var('x', shape=x_shape, dtype="float32")
+    x = relay.var("x", shape=x_shape, dtype="float32")
     relu_x = relay.nn.relu(x)
-    data = np.random.uniform(size=x_np_shape).astype('float32')
+    data = np.random.uniform(size=x_np_shape).astype("float32")
     params = [x]
     args = [data]
 
     if variable_newshape:
-        newshape_var = relay.var('newshape', shape=(len(newshape),), dtype='int64')
+        newshape_var = relay.var("newshape", shape=(len(newshape),), dtype="int64")
         params.append(newshape_var)
-        args.append(np.array(newshape, dtype='int64'))
+        args.append(np.array(newshape, dtype="int64"))
         newshape = newshape_var
 
     y = relay.reshape(relu_x, newshape=newshape)
@@ -183,6 +196,7 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha
     mod["main"] = relay.Function(params, y)
     check_result(args, mod, data, flatten=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_reshape():
     for variable_newshape in [False, True]:
@@ -193,8 +207,9 @@ def test_any_reshape():
     verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
     verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
 
+
 def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
-    x = relay.var('x', shape=x_shape, dtype=dtype)
+    x = relay.var("x", shape=x_shape, dtype=dtype)
     y = relay.argwhere(x)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y)
@@ -209,6 +224,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
     # TODO(@zhiics) argwhere gpu schedule is currently not avaiable
     # check_result([data], mod, expected, flatten=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_argwhere():
     verify_any_argwhere(any_dims(1), (5,))
@@ -227,21 +243,23 @@ def test_any_argwhere():
     verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
     verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
 
+
 def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
     mod = tvm.IRModule()
-    data = relay.var('data', shape=data_shape, dtype='float32')
-    indices = relay.var('indices', shape=indices_shape, dtype='int32')
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    indices = relay.var("indices", shape=indices_shape, dtype="int32")
     y = relay.take(data, indices, axis=axis)
     mod["main"] = relay.Function([data, indices], y)
-    data_np = np.random.uniform(size=data_np_shape).astype('float32')
+    data_np = np.random.uniform(size=data_np_shape).astype("float32")
     if axis is None:
         max_index = data_np.size
     else:
         max_index = data_np.shape[axis]
-    indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32')
+    indices_np = np.random.randint(max_index, size=indices_np_shape).astype("int32")
     ref = np.take(data_np, indices_np, axis=axis)
     check_result([data_np, indices_np], mod, ref)
 
+
 @tvm.testing.uses_gpu
 def test_any_take():
     verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
@@ -251,6 +269,7 @@ def test_any_take():
     verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
     verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
 
+
 def verify_any_tile(dshape, reps, np_dshape, np_reps):
     mod = tvm.IRModule()
     x = relay.var("x", shape=dshape, dtype="float32")
@@ -260,6 +279,7 @@ def verify_any_tile(dshape, reps, np_dshape, np_reps):
     ref_res = np.tile(x_data, reps=np_reps)
     check_result([x_data], mod, ref_res)
 
+
 @tvm.testing.uses_gpu
 def test_any_tile():
     verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1))
@@ -267,33 +287,37 @@ def test_any_tile():
     verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1))
     verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,))
 
+
 @tvm.testing.uses_gpu
 def test_any_shape_of():
-    x = relay.var('x', shape=any_dims(2), dtype='float32')
+    x = relay.var("x", shape=any_dims(2), dtype="float32")
     y = relay.shape_of(x)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y)
-    data = np.random.uniform(size=(3, 4)).astype('float32')
-    check_result([data], mod, np.array([3,4]).astype("int64"))
+    data = np.random.uniform(size=(3, 4)).astype("float32")
+    check_result([data], mod, np.array([3, 4]).astype("int64"))
 
-    x = relay.var('x', shape=any_dims(3), dtype='float32')
+    x = relay.var("x", shape=any_dims(3), dtype="float32")
     y0 = relay.shape_of(x)
-    y1 = relay.take(y0, relay.const(1, 'int32'))
+    y1 = relay.take(y0, relay.const(1, "int32"))
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y1)
-    data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+    data = np.random.uniform(size=(2, 3, 4)).astype("float32")
     check_result([data], mod, np.array(3).astype("int64"))
 
-def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims,
-                      static_data_shape, ref_out_shape):
+
+def verify_any_reduce(
+    reduce_op, data_shape, axis, exclude, keepdims, static_data_shape, ref_out_shape
+):
     mod = tvm.IRModule()
     dtype = "bool" if reduce_op == relay.all else "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = reduce_op(data, axis, keepdims, exclude)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_reduce():
     verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ())
@@ -305,47 +329,57 @@ def test_any_reduce():
     verify_any_reduce(relay.mean, any_dims(2), 0, False, False, (1, 2), (2,))
     verify_any_reduce(relay.variance, any_dims(5), (2, 4), False, False, (3, 4, 5, 6, 7), (3, 4, 6))
 
-def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape):
+
+def verify_any_layout_transform(
+    data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape
+):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.layout_transform(data, src_layout, dst_layout)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_layout_transform():
     verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4))
-    verify_any_layout_transform(any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2))
+    verify_any_layout_transform(
+        any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2)
+    )
     verify_any_layout_transform(any_dims(5), "NCHW6n", "NHWC", (3, 4, 5, 6, 6), (18, 5, 6, 4))
     verify_any_layout_transform(any_dims(4), "NCHW", "NCHW4c", (3, 4, 5, 6), (3, 1, 5, 6, 4))
     verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1))
 
+
 def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_expand_dims():
     verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3))
     verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1))
 
+
 def verify_any_transpose(data_shape, axes, static_data_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.transpose(data, axes=axes)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     ref_out = np.transpose(data_np, axes)
     check_result([data_np], mod, ref_out)
 
+
 @tvm.testing.uses_gpu
 def test_any_transpose():
     verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
@@ -353,101 +387,181 @@ def test_any_transpose():
     verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
     verify_any_transpose(any_dims(2), (-1, 0), (3, 2))
 
+
 def verify_any_squeeze(data_shape, axis, static_data_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.squeeze(data, axis=axis)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     ref_out = np.squeeze(data_np, axis)
     check_result([data_np], mod, ref_out)
 
+
 @tvm.testing.uses_gpu
 def test_any_squeeze():
     verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8))
-    verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17))
+    verify_any_squeeze(
+        (1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17)
+    )
+
 
 @tvm.testing.uses_gpu
 def test_any_reshape_like():
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=(relay.Any(), 3, 10), dtype=dtype)
-    shape_like = relay.var('data', shape=(relay.Any(), 5, 6), dtype=dtype)
+    data = relay.var("data", shape=(relay.Any(), 3, 10), dtype=dtype)
+    shape_like = relay.var("data", shape=(relay.Any(), 5, 6), dtype=dtype)
     y = relay.reshape_like(data, shape_like)
     mod["main"] = relay.Function([data, shape_like], y)
     data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
     shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype)
     check_result([data_np, shape_like_np], mod, shape_like_np.shape, assert_shape=True)
 
-def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation,
-                            data_layout, kernel_layout, out_layout,
-                            static_data_shape, ref_out_shape):
+
+def verify_any_conv2d_NCHWc(
+    data_shape,
+    kernel_shape,
+    strides,
+    padding,
+    dilation,
+    data_layout,
+    kernel_layout,
+    out_layout,
+    static_data_shape,
+    ref_out_shape,
+):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
-    kernel = relay.var('kernel', shape=kernel_shape, dtype=dtype)
-    y = relay.nn.contrib_conv2d_nchwc(data, kernel, strides, padding, dilation,
-                                      kernel_size=kernel_shape[2:4],
-                                      channels=kernel_shape[0]*kernel_shape[-1],
-                                      data_layout=data_layout, kernel_layout=kernel_layout,
-                                      out_layout=out_layout)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype)
+    y = relay.nn.contrib_conv2d_nchwc(
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        kernel_size=kernel_shape[2:4],
+        channels=kernel_shape[0] * kernel_shape[-1],
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+        out_layout=out_layout,
+    )
     mod["main"] = relay.Function([data, kernel], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
     check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)
 
+
 # TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any
 @pytest.mark.skip
 def test_any_conv2d_NCHWc():
-    verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
-                            "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
-    verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (2, 2),
-                            "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 222, 222, 8))
-
-def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding,
-                      layout, static_data_shape, ref_out_shape):
+    verify_any_conv2d_NCHWc(
+        (relay.Any(), 8, relay.Any(), relay.Any(), 8),
+        (8, 8, 3, 3, 8, 8),
+        (1, 1),
+        (1, 1),
+        (1, 1),
+        "NCHW8c",
+        "OIHW8i8o",
+        "NCHW8c",
+        (1, 8, 224, 224, 8),
+        (1, 8, 224, 224, 8),
+    )
+    verify_any_conv2d_NCHWc(
+        (relay.Any(), 8, relay.Any(), relay.Any(), 8),
+        (8, 8, 3, 3, 8, 8),
+        (1, 1),
+        (1, 1),
+        (2, 2),
+        "NCHW8c",
+        "OIHW8i8o",
+        "NCHW8c",
+        (1, 8, 224, 224, 8),
+        (1, 8, 222, 222, 8),
+    )
+
+
+def verify_any_pool2d(
+    pool_type, data_shape, pool_size, strides, padding, layout, static_data_shape, ref_out_shape
+):
     mod = tvm.IRModule()
     dtype = "float32"
     pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = pool_func(data, pool_size, strides, padding, layout)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_pool2d():
-    verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
-                      (3, 3), (1, 1), (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220))
-    verify_any_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
-                      (1, 1), (2, 2), (0, 0), "NHWC", (3, 220, 220, 4), (3, 110, 110, 4))
-    verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
-                      (3, 3), (2, 2), (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4))
+    verify_any_pool2d(
+        "max",
+        (relay.Any(), 3, relay.Any(), relay.Any()),
+        (3, 3),
+        (1, 1),
+        (1, 1),
+        "NCHW",
+        (2, 3, 220, 220),
+        (2, 3, 220, 220),
+    )
+    verify_any_pool2d(
+        "avg",
+        (relay.Any(), relay.Any(), relay.Any(), 4),
+        (1, 1),
+        (2, 2),
+        (0, 0),
+        "NHWC",
+        (3, 220, 220, 4),
+        (3, 110, 110, 4),
+    )
+    verify_any_pool2d(
+        "max",
+        (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+        (3, 3),
+        (2, 2),
+        (1, 1),
+        "NCHW4c",
+        (2, 3, 220, 220, 4),
+        (2, 3, 110, 110, 4),
+    )
+
 
 def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
     pool_func = relay.nn.global_max_pool2d if pool_type == "max" else relay.nn.global_avg_pool2d
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = pool_func(data, layout)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_global_pool2d():
-    verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
-                      "NCHW", (2, 3, 220, 220), (2, 3, 1, 1))
-    verify_any_global_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
-                      "NHWC", (3, 220, 220, 4), (3, 1, 1, 4))
-    verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
-                      "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
+    verify_any_global_pool2d(
+        "max", (relay.Any(), 3, relay.Any(), relay.Any()), "NCHW", (2, 3, 220, 220), (2, 3, 1, 1)
+    )
+    verify_any_global_pool2d(
+        "avg", (relay.Any(), relay.Any(), relay.Any(), 4), "NHWC", (3, 220, 220, 4), (3, 1, 1, 4)
+    )
+    verify_any_global_pool2d(
+        "max",
+        (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+        "NCHW4c",
+        (2, 3, 220, 220, 4),
+        (2, 3, 1, 1, 4),
+    )
+
 
 def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.split(data, indices_or_sections, axis)
     mod["main"] = relay.Function([data], y.astuple())
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
@@ -455,8 +569,11 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
         result = ex.evaluate()(data_np)
         for ret, ref_ret in zip(result, ref_out_shape):
-            assert ret.asnumpy().shape == ref_ret, \
-                "Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))
+            assert ret.asnumpy().shape == ref_ret, "Shape mismatch: expect %s but got %s." % (
+                str(ref_ret),
+                str(ret.asnumpy().shape),
+            )
+
 
 @tvm.testing.uses_gpu
 def test_any_split():
@@ -465,65 +582,74 @@ def test_any_split():
     verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
     verify_any_split((relay.Any(), relay.Any()), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
 
+
 @tvm.testing.uses_gpu
 def test_any_batch_flatten():
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=any_dims(3), dtype=dtype)
+    data = relay.var("data", shape=any_dims(3), dtype=dtype)
     y = relay.nn.batch_flatten(data)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
     ref_out_shape = (3, 30)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
-def verify_any_dense(data_shape, weight_shape, units, static_data_shape,
-                     static_weight_shape, ref_out_shape):
+
+def verify_any_dense(
+    data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape
+):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
-    weight = relay.var('weight', shape=weight_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype=dtype)
     y = relay.nn.dense(data, weight, units)
     mod["main"] = relay.Function([data, weight], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
     check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True)
 
+
 # TODO(tvm-team) Fix dense schedule
 # @tvm.testing.uses_gpu
 def test_any_dense():
     verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8))
     verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))
 
+
 @tvm.testing.uses_gpu
 def verify_any_pad(data_shape, pad_width, static_data_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.nn.pad(data, pad_width)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     ref_out = np.pad(data_np, pad_width)
     check_result([data_np], mod, ref_out)
 
+
 @tvm.testing.uses_gpu
 def test_any_pad():
     verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
     verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
 
+
 def verify_any_dilate(data_shape, strides, static_data_shape):
     assert len(data_shape) == len(strides)
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.nn.dilate(data, strides)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
-    ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1
-                      for i in range(len(static_data_shape)))
+    ref_shape = tuple(
+        (static_data_shape[i] - 1) * strides[i] + 1 for i in range(len(static_data_shape))
+    )
     ref_out = np.zeros(shape=ref_shape, dtype=dtype)
     ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
     check_result([data_np], mod, ref_out)
 
+
 @tvm.testing.uses_gpu
 def test_any_dilate():
     verify_any_dilate(any_dims(1), (1,), (1,))
@@ -535,30 +661,33 @@ def test_any_dilate():
     verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
     verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))
 
+
 def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.nn.softmax(data, axis)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_softmax():
     verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
     verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
 
+
 def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
     mod = tvm.IRModule()
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     np_data = np.random.uniform(size=np_dshape).astype(dtype)
     if const_k:
         k = relay.const(kval)
         args = [data]
         in_vals = [np_data]
     else:
-        k = relay.var('k', shape=(), dtype="int32")
+        k = relay.var("k", shape=(), dtype="int32")
         args = [data, k]
         in_vals = [np_data, kval]
     out = relay.topk(data, k, ret_type="indices")
@@ -578,68 +707,83 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
     # TODO(@zhiics) Fix topk cuda schedule for dynamic inputs
     # check_result(in_vals, mod, ref_out)
 
+
 def test_any_topk():
     verify_any_topk(any_dims(1), 5, (10,), "float32")
     verify_any_topk(any_dims(2), 2, (6, 3), "int32")
     verify_any_topk(any_dims(2), 3, (6, 3), "float32", True)
 
+
 @tvm.testing.uses_gpu
 def test_fused_ops():
-    x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
-    y0 = x + relay.const(1.0, 'float32')
-    y1 = y0 * relay.const(2.0, 'float32')
+    x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")
+    y0 = x + relay.const(1.0, "float32")
+    y1 = y0 * relay.const(2.0, "float32")
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y1)
-    data = np.random.uniform(size=(5, 4)).astype('float32')
+    data = np.random.uniform(size=(5, 4)).astype("float32")
     check_result([data], mod, (data + 1) * 2)
 
+
 @tvm.testing.uses_gpu
 def test_arange_with_dynamic_shape():
     # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
     m, n, k = relay.Any(), relay.Any(), relay.Any()
-    x = relay.var('x', shape=(m, n, k), dtype='float32')
+    x = relay.var("x", shape=(m, n, k), dtype="float32")
     y0 = relay.shape_of(x)
-    y1 = relay.take(y0, relay.const(0, 'int32'))
+    y1 = relay.take(y0, relay.const(0, "int32"))
     y2 = relay.op.arange(y1, dtype="int32")
     y3 = y2 + relay.const(1, dtype="int32")
-    data = np.random.rand(10, 5, 3).astype('float32')
+    data = np.random.rand(10, 5, 3).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y3)
-    check_result([data], mod, np.array(range(10)).astype("int32")+1)
-
-def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape,
-                             data_np_shape, slice_mode="end", const_attrs=False):
+    check_result([data], mod, np.array(range(10)).astype("int32") + 1)
+
+
+def verify_any_strided_slice(
+    data_shape,
+    begin_shape,
+    end_shape,
+    strides_shape,
+    data_np_shape,
+    slice_mode="end",
+    const_attrs=False,
+):
     # Generate random numpy input data
-    np_data = np.random.uniform(size=data_np_shape).astype('float32')
+    np_data = np.random.uniform(size=data_np_shape).astype("float32")
     np_begin = np.random.randint(2, size=begin_shape, dtype="int32")
     np_end = np.random.randint(5, 10, size=end_shape, dtype="int32")
-    np_strides = np.random.randint(1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32")
+    np_strides = np.random.randint(
+        1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32"
+    )
     # target numpy result
-    ref_res = tvm.topi.testing.strided_slice_python(np_data, np_begin, np_end, np_strides, slice_mode)
+    ref_res = tvm.topi.testing.strided_slice_python(
+        np_data, np_begin, np_end, np_strides, slice_mode
+    )
 
     # Relay Module
     mod = tvm.IRModule()
-    data = relay.var('data', shape=data_shape, dtype='float32')
+    data = relay.var("data", shape=data_shape, dtype="float32")
     if const_attrs:
-        data = relay.var('data', shape=data_np_shape, dtype='float32')
+        data = relay.var("data", shape=data_np_shape, dtype="float32")
         begin = relay.const(np_begin)
         end = relay.const(np_end)
         strides = relay.const(np_strides)
         args = [data]
         np_inputs = [np_data]
     else:
-        begin = relay.var('begin', shape=begin_shape, dtype="int32")
-        end = relay.var('end', shape=end_shape, dtype="int32")
-        strides = relay.var('strides', shape=strides_shape, dtype="int32")
+        begin = relay.var("begin", shape=begin_shape, dtype="int32")
+        end = relay.var("end", shape=end_shape, dtype="int32")
+        strides = relay.var("strides", shape=strides_shape, dtype="int32")
         args = [data, begin, end, strides]
         np_inputs = [np_data, np_begin, np_end, np_strides]
 
-    y = relay.strided_slice(data, begin=begin, end=end,
-                            strides=strides, slice_mode=slice_mode)
+    y = relay.strided_slice(data, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
     mod["main"] = relay.Function(args, y)
 
     check_result(np_inputs, mod, ref_res)
 
+
 @tvm.testing.uses_gpu
 def test_any_strided_slice():
     verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21))
@@ -649,6 +793,7 @@ def test_any_strided_slice():
     verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size")
     verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True)
 
+
 @tvm.testing.uses_gpu
 def test_recursive_concat():
     """
@@ -663,27 +808,28 @@ def test_recursive_concat():
     }
     """
     # Initial Values.
-    i = relay.var('i', shape=(), dtype='int32')
-    st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
+    i = relay.var("i", shape=(), dtype="int32")
+    st = relay.var("st", shape=(relay.Any(), 1), dtype="int32")
 
     def _cond(i, st):
         return relay.op.min(relay.op.less(i, int32(10)))
 
     def _body(i, st):
-        i_vec = relay.op.reshape(i, (1,1))
+        i_vec = relay.op.reshape(i, (1, 1))
         ret = relay.op.concatenate([st, i_vec], axis=0)
         return i + int32(1), ret
 
     loop = while_loop(_cond, [i, st], _body)
-    start = relay.var('start', shape=(), dtype='int32')
+    start = relay.var("start", shape=(), dtype="int32")
     body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
     func = relay.Function([start], relay.TupleGetItem(body, 1))
     mod = tvm.IRModule()
     mod["main"] = func
-    data = np.array(0.0, dtype='int32')
+    data = np.array(0.0, dtype="int32")
     ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
     check_result([data], mod, ref)
 
+
 @tvm.testing.uses_gpu
 def test_recursive_concat_with_wrong_annotation():
     """
@@ -711,19 +857,19 @@ def test_recursive_concat_with_wrong_annotation():
     }
     """
     # Initial Values.
-    i = relay.var('i', shape=(), dtype='int32')
-    st = relay.var('st', shape=(1, 1), dtype='int32')
+    i = relay.var("i", shape=(), dtype="int32")
+    st = relay.var("st", shape=(1, 1), dtype="int32")
 
     def _cond(i, st):
         return relay.op.min(relay.op.less(i, int32(10)))
 
     def _body(i, st):
-        i_vec = relay.op.reshape(i, (1,1))
+        i_vec = relay.op.reshape(i, (1, 1))
         ret = relay.op.concatenate([st, i_vec], axis=0)
         return i + int32(1), ret
 
     loop = while_loop(_cond, [i, st], _body)
-    start = relay.var('start', shape=(), dtype='int32')
+    start = relay.var("start", shape=(), dtype="int32")
     body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
     func = relay.Function([start], relay.TupleGetItem(body, 1))
     try:
@@ -732,6 +878,7 @@ def test_recursive_concat_with_wrong_annotation():
     except Exception as e:
         assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
 
+
 @tvm.testing.uses_gpu
 def test_tuple_get_item():
     mod = tvm.IRModule()
@@ -740,7 +887,7 @@ def test_tuple_get_item():
     data_shape = (relay.Any(), 4)
     indices_or_sections = 2
     axis = 1
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.split(data, indices_or_sections, axis)
     y = relay.expr.TupleGetItem(y.astuple(), 0)
     mod["main"] = relay.Function([data], y)
@@ -748,6 +895,7 @@ def test_tuple_get_item():
     ref_out_shape = (9, 2)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_mixed_input_type():
     mod = tvm.IRModule()
@@ -765,24 +913,39 @@ def test_mixed_input_type():
     data_np0 = np.random.uniform(size=static_data_shape).astype(dtype)
     data_np1 = np.random.uniform(size=static_data_shape).astype(dtype)
     ref_out_shape = (9, 4)
-    check_result([[[data_np0, data_np0], data_np0], data_np1], mod,
-                 ref_out_shape, assert_shape=True, only_vm=True)
-
-def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
-                               layout, static_boxes, static_box_indices_shape, ref_out_shape):
+    check_result(
+        [[[data_np0, data_np0], data_np0], data_np1],
+        mod,
+        ref_out_shape,
+        assert_shape=True,
+        only_vm=True,
+    )
+
+
+def verify_any_crop_and_resize(
+    data_shape,
+    boxes_shape,
+    box_indices_shape,
+    crop_size,
+    layout,
+    static_boxes,
+    static_box_indices_shape,
+    ref_out_shape,
+):
     mod = tvm.IRModule()
     dtype = "float32"
     indices_dtype = "int32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
-    boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype)
-    box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    boxes = relay.var("boxes", shape=boxes_shape, dtype=dtype)
+    box_indices = relay.var("box_indices", shape=box_indices_shape, dtype=indices_dtype)
     y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, layout)
     mod["main"] = relay.Function([data, boxes, box_indices], y)
     data_np = np.random.uniform(size=data_shape).astype(dtype)
     boxes_np = np.random.uniform(size=static_boxes).astype(dtype)
-    box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)    
+    box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)
     check_result([data_np, boxes_np, box_indices_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_crop_and_resize():
     verify_any_crop_and_resize(
@@ -790,55 +953,62 @@ def test_any_crop_and_resize():
         boxes_shape=(relay.Any(), 4),
         box_indices_shape=(relay.Any(),),
         crop_size=(14, 14),
-        layout='NHWC',
+        layout="NHWC",
         static_boxes=(128, 4),
         static_box_indices_shape=(128,),
-        ref_out_shape=(128, 14, 14, 256))
+        ref_out_shape=(128, 14, 14, 256),
+    )
     verify_any_crop_and_resize(
         data_shape=(1, 256, 234, 234),
         boxes_shape=(relay.Any(), 4),
         box_indices_shape=(relay.Any(),),
         crop_size=(14, 14),
-        layout='NCHW',
+        layout="NCHW",
         static_boxes=(128, 4),
         static_box_indices_shape=(128,),
-        ref_out_shape=(128, 256, 14, 14)
-        )
+        ref_out_shape=(128, 256, 14, 14),
+    )
+
 
 def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"
-    data = relay.var('data', shape=data_shape, dtype=dtype)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
     y = relay.nn.mirror_pad(data, pad_width)
     mod["main"] = relay.Function([data], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+
 @tvm.testing.uses_gpu
 def test_any_mirror_pad():
     verify_any_mirror_pad(
         data_shape=(1, 256, 232, 232),
         pad_width=((0, 0), (0, 0), (1, 1), (1, 1)),
         static_data_shape=(1, 256, 232, 232),
-        ref_out_shape=(1, 256, 234, 234))
+        ref_out_shape=(1, 256, 234, 234),
+    )
+
 
 def verify_any_ndarray_size(data_np_shape):
-    v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
-    n = relay.ndarray_size(v, dtype='int32')
+    v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype="float32")
+    n = relay.ndarray_size(v, dtype="int32")
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([v], n)
-    np_data = np.zeros(data_np_shape, dtype='float32')
+    mod["main"] = relay.Function([v], n)
+    np_data = np.zeros(data_np_shape, dtype="float32")
     ref_res = np.size(np_data)
     check_result([np_data], mod, ref_res)
 
+
 @tvm.testing.uses_gpu
 def test_any_ndarray_size():
     verify_any_ndarray_size((2,))
     verify_any_ndarray_size((2, 2))
     verify_any_ndarray_size((1, 2, 3, 4))
 
+
 def test_any_consecutive_broadcast():
-    dtype = 'float32'
+    dtype = "float32"
     data0 = relay.var("data0", shape=any_dims(2), dtype=dtype)
     data1 = relay.var("data1", shape=any_dims(2), dtype=dtype)
     data2 = relay.var("data2", shape=any_dims(2), dtype=dtype)
@@ -855,23 +1025,25 @@ def test_any_consecutive_broadcast():
     out6 = out2 * out5
 
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([data0, data1, data2, data3], out6)
+    mod["main"] = relay.Function([data0, data1, data2, data3], out6)
 
     np_data0 = np.random.uniform(size=(1, 4)).astype(dtype)
     np_data1 = np.random.uniform(size=(2, 4)).astype(dtype)
     np_data2 = np.random.uniform(size=(1, 4)).astype(dtype)
     np_data3 = np.random.uniform(size=(2, 4)).astype(dtype)
-    ref_res = ((np_data0 + np_data1) - (np_data0 * np_data1)) * \
-              ((np_data2 + np_data3) - (np_data2 * np_data3))
+    ref_res = ((np_data0 + np_data1) - (np_data0 * np_data1)) * (
+        (np_data2 + np_data3) - (np_data2 * np_data3)
+    )
     check_result([np_data0, np_data1, np_data2, np_data3], mod, ref_res)
 
+
 def test_reshape_concat():
     dtype = "float32"
     d0 = relay.var("d0", shape=any_dims(2), dtype=dtype)
     d1 = relay.var("d1", shape=any_dims(3), dtype=dtype)
     out = relay.op.concatenate([relay.op.reshape(d0, [-1]), relay.op.reshape(d1, [-1])], axis=0)
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([d0, d1], out)
+    mod["main"] = relay.Function([d0, d1], out)
     np_data0 = np.random.uniform(size=(4, 5)).astype(dtype)
     np_data1 = np.random.uniform(size=(2, 5, 2)).astype(dtype)
     ref_res = np.concatenate([np.reshape(np_data0, [-1]), np.reshape(np_data1, [-1])], axis=0)
@@ -881,28 +1053,33 @@ def test_reshape_concat():
     d1 = relay.var("d1", shape=any_dims(2), dtype=dtype)
     s0 = relay.var("s0", shape=any_dims(3), dtype=dtype)
     s1 = relay.var("s1", shape=any_dims(3), dtype=dtype)
-    out = relay.op.concatenate([relay.op.reshape_like(d0, s0), relay.op.reshape_like(d1, s1)], axis=0)
+    out = relay.op.concatenate(
+        [relay.op.reshape_like(d0, s0), relay.op.reshape_like(d1, s1)], axis=0
+    )
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([d0, d1, s0, s1], out)
+    mod["main"] = relay.Function([d0, d1, s0, s1], out)
     np_data0 = np.random.uniform(size=(4, 5)).astype(dtype)
     np_data1 = np.random.uniform(size=(8, 5)).astype(dtype)
     np_shape_like0 = np.random.uniform(size=(2, 2, 5)).astype(dtype)
     np_shape_like1 = np.random.uniform(size=(4, 2, 5)).astype(dtype)
-    ref_res = np.concatenate([np.reshape(np_data0, np_shape_like0.shape),
-                              np.reshape(np_data1, np_shape_like1.shape)], axis=0)
+    ref_res = np.concatenate(
+        [np.reshape(np_data0, np_shape_like0.shape), np.reshape(np_data1, np_shape_like1.shape)],
+        axis=0,
+    )
     check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res)
 
+
 def test_any_adv_index():
-    data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype='float32')
-    index0 = relay.var("index0", shape=(1, relay.Any()), dtype='int64')
-    index1 = relay.var("index1", shape=(1, relay.Any()), dtype='int64')
+    data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype="float32")
+    index0 = relay.var("index0", shape=(1, relay.Any()), dtype="int64")
+    index1 = relay.var("index1", shape=(1, relay.Any()), dtype="int64")
     out = relay.adv_index([data, index0, index1])
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([data, index0, index1], out)
+    mod["main"] = relay.Function([data, index0, index1], out)
     np_data_shape = (5, 5, 10)
     np_index_shape = (1, 4)
-    np_data = np.random.uniform(size=np_data_shape).astype('float32')
-    np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype('int64')
+    np_data = np.random.uniform(size=np_data_shape).astype("float32")
+    np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype("int64")
     ref_res = np_data[tuple([np_index, np_index])]
     check_result([np_data, np_index, np_index], mod, ref_res)
 
index f0208c2..da71ac3 100644 (file)
@@ -19,17 +19,18 @@ import tvm.relay.testing
 from tvm import relay
 from tvm import autotvm
 
+
 def get_network(name, batch_size):
     """Get the symbol definition and random weight of a network"""
     input_shape = (batch_size, 3, 224, 224)
 
-    if name == 'resnet-18':
+    if name == "resnet-18":
         mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
-    elif name == 'resnet3d-18':
+    elif name == "resnet3d-18":
         mod, params = relay.testing.resnet_3d.get_workload(num_layers=18, batch_size=batch_size)
-    elif name == 'mobilenet':
+    elif name == "mobilenet":
         mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
-    elif name == 'dcgan':
+    elif name == "dcgan":
         mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
         input_shape = (batch_size, 100)
     else:
@@ -37,8 +38,9 @@ def get_network(name, batch_size):
 
     return mod, params, input_shape
 
+
 def test_task_extraction():
-    target = 'llvm'
+    target = "llvm"
     mod_list = []
     params_list = []
     conv2d = relay.op.get("nn.conv2d")
@@ -46,65 +48,59 @@ def test_task_extraction():
     conv2d_transpose = relay.op.get("nn.conv2d_transpose")
     dense = relay.op.get("nn.dense")
 
-    mod, params, _ = get_network('resnet-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(conv2d,))
+    mod, params, _ = get_network("resnet-18", batch_size=1)
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(conv2d,)
+    )
     assert len(tasks) == 12
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(conv2d,))
+    tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(conv2d,))
     assert len(tasks) == 12
 
-    mod, params, _ = get_network('resnet-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(dense,))
+    mod, params, _ = get_network("resnet-18", batch_size=1)
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(dense,)
+    )
     assert len(tasks) == 1
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(dense,))
+    tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,))
     assert len(tasks) == 1
 
-    mod, params, _ = get_network('resnet-18', batch_size=1)
+    mod, params, _ = get_network("resnet-18", batch_size=1)
     mod_list.append(mod)
     params_list.append(params)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(conv2d, dense))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(conv2d, dense)
+    )
     assert len(tasks) == 13
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(conv2d, dense))
+    tasks = autotvm.task.extract_from_program(
+        mod, target=target, params=params, ops=(conv2d, dense)
+    )
     assert len(tasks) == 13
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params)
+    tasks = autotvm.task.extract_from_program(mod, target=target, params=params)
     assert len(tasks) == 13
 
-    mod, params, _ = get_network('resnet3d-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(conv3d,))
+    mod, params, _ = get_network("resnet3d-18", batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(conv3d,))
     assert len(tasks) == 12
 
-    mod, params, _ = get_network('mobilenet', batch_size=1)
+    mod, params, _ = get_network("mobilenet", batch_size=1)
     mod_list.append(mod)
     params_list.append(params)
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(conv2d, dense))
+    tasks = autotvm.task.extract_from_program(
+        mod, target=target, params=params, ops=(conv2d, dense)
+    )
     assert len(tasks) == 20
 
-    mod, params, _ = get_network('dcgan', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod, target=target,
-                                              params=params,
-                                              ops=(conv2d_transpose,))
+    mod, params, _ = get_network("dcgan", batch_size=1)
+    tasks = autotvm.task.extract_from_program(
+        mod, target=target, params=params, ops=(conv2d_transpose,)
+    )
     assert len(tasks) == 4
 
-    tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
-                                                       target=target,
-                                                       ops=(conv2d,))
+    tasks = autotvm.task.extract_from_multiple_program(
+        mod_list, params_list, target=target, ops=(conv2d,)
+    )
     assert len(tasks) == 31
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_task_extraction()
index 0b0fd58..bf53dc5 100644 (file)
@@ -30,24 +30,30 @@ import tvm.testing
 def _compute_conv2d_1(cfg, input, filter, strides, padding, dilation, out_dtype):
     return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
 
+
 @autotvm.register_topi_schedule("test/conv2d_1")
 def _schedule_conv2d_1(cfg, outs):
     return topi.generic.schedule_conv2d_nchw(outs)
 
+
 @autotvm.register_topi_compute("test/conv2d_2")
 def _compute_conv2d_2(cfg, input, filter, strides, padding, dilation, out_dtype):
     return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
 
+
 @autotvm.register_topi_schedule("test/conv2d_2")
 def _schedule_conv2d_2(cfg, outs):
     return topi.generic.schedule_conv2d_nchw(outs)
 
+
 def _compute_conv2d_3(input, filter, strides, padding, dilation, out_dtype):
     return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
 
+
 def _schedule_conv2d_3(outs):
     return topi.generic.schedule_conv2d_nchw(outs)
 
+
 @tvm.target.override_native_generic_func("test_conv2d_strategy")
 def _tmp_strategy(attrs, inputs, out_type, target):
     strategy = relay.op.OpStrategy()
@@ -55,24 +61,27 @@ def _tmp_strategy(attrs, inputs, out_type, target):
         relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_1),
         relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_1),
         name="conv2d_1",
-        plevel=10)
+        plevel=10,
+    )
     strategy.add_implementation(
         relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_2),
         relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_2),
         name="conv2d_2",
-        plevel=15)
+        plevel=15,
+    )
     ic = inputs[0].shape[1]
     with tvm.te.SpecializedCondition(ic >= 16):
         strategy.add_implementation(
             relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_3),
             relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_3),
             name="conv2d_3",
-            plevel=20)
+            plevel=20,
+        )
     return strategy
 
+
 def _create_record(task_name, dshape, wshape, target, cost):
-    args = [te.placeholder(dshape), te.placeholder(wshape), (1, 1), (1, 1, 1, 1),
-            (1, 1), 'float32']
+    args = [te.placeholder(dshape), te.placeholder(wshape), (1, 1), (1, 1, 1, 1), (1, 1), "float32"]
     task = autotvm.task.create(task_name, args, target)
     cfg = autotvm.ConfigEntity(0, None, {}, [])
     cfg.cost = cost
@@ -80,6 +89,7 @@ def _create_record(task_name, dshape, wshape, target, cost):
     result = autotvm.MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
     return (inp, result)
 
+
 def test_get_valid_implementations():
     target = tvm.target.Target("llvm")
 
@@ -93,7 +103,8 @@ def test_get_valid_implementations():
             out.attrs,
             [te.placeholder(dshape), te.placeholder(wshape)],
             out.checked_type,
-            target)
+            target,
+        )
 
     with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
         impls = _get_impls((1, 8, 7, 7), (32, 8, 3, 3))
@@ -101,6 +112,7 @@ def test_get_valid_implementations():
         impls = _get_impls((1, 16, 7, 7), (32, 16, 3, 3))
         assert len(impls) == 3
 
+
 def test_select_implementation():
     target = tvm.target.Target("llvm")
 
@@ -115,7 +127,8 @@ def test_select_implementation():
             [te.placeholder(dshape), te.placeholder(wshape)],
             out.checked_type,
             target,
-            use_autotvm)
+            use_autotvm,
+        )
 
     with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
         impl, _ = _select_impl((1, 8, 7, 7), (32, 8, 3, 3))
@@ -147,8 +160,10 @@ def test_select_implementation():
                 impl, _ = _select_impl((1, 16, 7, 7), (32, 16, 3, 3), True)
                 assert impl.name == "conv2d_1"
 
+
 def test_compile_engine():
     engine = relay.backend.compile_engine.get()
+
     def get_func(shape):
         x = relay.var("x", shape=shape)
         y = relay.add(x, x)
@@ -157,6 +172,7 @@ def test_compile_engine():
         mod = tvm.IRModule.from_expr(f)
         mod = relay.transform.InferType()(mod)
         return mod["main"]
+
     z1 = engine.lower(get_func((10,)), "llvm")
     z2 = engine.lower(get_func((10,)), "llvm")
     z3 = engine.lower(get_func(()), "llvm")
@@ -174,10 +190,10 @@ def test_compile_engine():
             x = tvm.nd.array(np.ones(10).astype("float32"), ctx=ctx)
             y = tvm.nd.empty((10,), ctx=ctx)
             f(x, y)
-            tvm.testing.assert_allclose(
-                y.asnumpy(), x.asnumpy() * 3)
+            tvm.testing.assert_allclose(y.asnumpy(), x.asnumpy() * 3)
     engine.dump()
 
+
 def test_compile_placeholder_bypass():
     engine = relay.backend.compile_engine.get()
     x = relay.var("x", shape=(2, 3))
@@ -186,7 +202,7 @@ def test_compile_placeholder_bypass():
     result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
     func = relay.Function(relay.analysis.free_vars(result), result)
     with tvm.transform.PassContext(opt_level=0):
-       graph, lib, params = relay.build(tvm.IRModule.from_expr(func), 'llvm')
+        graph, lib, params = relay.build(tvm.IRModule.from_expr(func), "llvm")
 
 
 def test_compile_injective_with_tuple():
@@ -195,7 +211,7 @@ def test_compile_injective_with_tuple():
     x_transpose = relay.transpose(x)
     output = relay.Tuple([x_transpose, y])
     func = relay.Function([x, y], output)
-    relay.build(tvm.IRModule.from_expr(func), 'llvm')
+    relay.build(tvm.IRModule.from_expr(func), "llvm")
 
 
 def test_compile_tuple_dup():
@@ -203,30 +219,38 @@ def test_compile_tuple_dup():
     log = relay.log(x)
     output = relay.Tuple([log, log])
     f = relay.Function([x], output)
-    relay.build(tvm.IRModule.from_expr(f), 'llvm')
+    relay.build(tvm.IRModule.from_expr(f), "llvm")
 
 
 def test_compile_full():
     # Shape calculations can happen in int64. The test checks that full operator
     # can handle when shapes are not int32
-    shape = (tvm.tir.IntImm('int32', 1),
-             tvm.tir.IntImm('int64', 16),
-             tvm.tir.IntImm('int64', 16),
-             tvm.tir.IntImm('int32', 64))
-    output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
+    shape = (
+        tvm.tir.IntImm("int32", 1),
+        tvm.tir.IntImm("int64", 16),
+        tvm.tir.IntImm("int64", 16),
+        tvm.tir.IntImm("int32", 64),
+    )
+    output = relay.full(relay.const(0, "int32"), shape=shape, dtype="int32")
     f = relay.Function([], output)
     mod = tvm.IRModule.from_expr(f)
     mod = relay.qnn.transform.CanonicalizeOps()(mod)
-    relay.build(mod, 'llvm')
+    relay.build(mod, "llvm")
 
 
 def test_compile_nhwc_pack():
     data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8")
     weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8")
     p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32")
-    conv = relay.nn.conv2d(data, weight, kernel_size=(1, 1), data_layout="NHWC",
-                           kernel_layout="HWIO", out_dtype="int32")
-    multiply = relay.multiply(relay.const(-22, dtype='int32'), p2)
+    conv = relay.nn.conv2d(
+        data,
+        weight,
+        kernel_size=(1, 1),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+    )
+    multiply = relay.multiply(relay.const(-22, dtype="int32"), p2)
     tile = relay.tile(multiply, reps=(1, 1, 1, 1001))
     subtract = relay.subtract(conv, tile)
 
index 70a6fb1..5550caa 100644 (file)
@@ -39,13 +39,14 @@ def check_rts(expr, args, expected_result, mod=None):
     expected_result:
         The expected result of running the expression.
     """
-    intrp = relay.create_executor('debug', mod=mod)
-    graph = relay.create_executor('graph', mod=mod)
+    intrp = relay.create_executor("debug", mod=mod)
+    graph = relay.create_executor("graph", mod=mod)
     eval_result = intrp.evaluate(expr)(*args)
     rts_result = graph.evaluate(expr)(*args)
     tvm.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
     tvm.testing.assert_allclose(eval_result.asnumpy(), expected_result)
 
+
 def test_add_op_scalar():
     """
     Program:
@@ -53,13 +54,14 @@ def test_add_op_scalar():
             return x + y;
         }
     """
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     func = relay.Function([x, y], add(x, y))
-    x_data = np.array(10.0, dtype='float32')
-    y_data = np.array(1.0, dtype='float32')
+    x_data = np.array(10.0, dtype="float32")
+    y_data = np.array(1.0, dtype="float32")
     check_rts(func, [x_data, y_data], x_data + y_data)
 
+
 def test_add_op_tensor():
     """
     Program:
@@ -67,13 +69,14 @@ def test_add_op_tensor():
             return x + y;
         }
     """
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(10, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(10, 5))
     func = relay.Function([x, y], add(x, y))
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(10, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(10, 5).astype("float32")
     check_rts(func, [x_data, y_data], x_data + y_data)
 
+
 def test_add_op_broadcast():
     """
     Program:
@@ -81,22 +84,22 @@ def test_add_op_broadcast():
             return x + y;
         }
     """
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(1, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(1, 5))
     func = relay.Function([x, y], add(x, y))
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(1, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(1, 5).astype("float32")
     check_rts(func, [x_data, y_data], x_data + y_data)
 
 
 def test_with_params():
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(1, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(1, 5))
     z = relay.add(x, y)
     z = relay.exp(z)
     func = relay.Function([x, y], z)
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(1, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(1, 5).astype("float32")
     params = {"y": y_data}
     graph, lib, params = relay.build(tvm.IRModule.from_expr(func), "llvm", params=params)
     mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
@@ -162,7 +165,7 @@ def test_gru_like():
     dtype = "float32"
     rnn_dim = 1000
     x = np.random.rand(1, rnn_dim).astype(dtype)
-    y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
+    y = np.random.rand(3 * rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
     out_shape = (1, rnn_dim)
     z = unit(rnn_dim)
 
index 39786fd..08a8401 100644 (file)
@@ -35,13 +35,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
         intrp = create_executor(mod=mod, ctx=ctx, target=target)
         result = intrp.evaluate(expr)(*args)
         # use tvm.testing which also set atol
-        tvm.testing.assert_allclose(
-            result.asnumpy(), expected_result, rtol=rtol)
+        tvm.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
 
 
 def test_tuple_value():
-    tv = container.tuple_object([relay.const(1), relay.const(2),
-                                 relay.const(3)])
+    tv = container.tuple_object([relay.const(1), relay.const(2), relay.const(3)])
     np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
     np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
     np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
@@ -54,9 +52,9 @@ def test_tuple_getitem():
 
 
 def test_id():
-    x = relay.var('x', 'float32')
+    x = relay.var("x", "float32")
     ident = relay.Function([x], x)
-    one = np.array(1.0, 'float32')
+    one = np.array(1.0, "float32")
     check_eval(ident, [one], one)
 
 
@@ -67,75 +65,75 @@ def test_add_const():
 
 
 def test_mul_param():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(1, 10))
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(1, 10))
     func = relay.Function([x, y], relay.multiply(x, y))
-    x_data = np.random.rand(10, 10).astype('float32')
-    y_data = np.random.rand(1, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
+    y_data = np.random.rand(1, 10).astype("float32")
     check_eval(func, [x_data, y_data], x_data * y_data)
 
 
 def test_equal():
-    i = relay.var('i', shape=[], dtype='int32')
-    j = relay.var('i', shape=[], dtype='int32')
+    i = relay.var("i", shape=[], dtype="int32")
+    j = relay.var("i", shape=[], dtype="int32")
     z = relay.equal(i, j)
-    func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
-    i_data = relay.const(0, 'int32')
-    j_data = relay.const(0, 'int32')
+    func = relay.Function([i, j], z, ret_type=relay.TensorType([], "bool"))
+    i_data = relay.const(0, "int32")
+    j_data = relay.const(0, "int32")
     check_eval(func, [i_data, j_data], True)
 
 
 def test_subtract():
-    i = relay.var('i', shape=[], dtype='int32')
-    sub = relay.subtract(i, relay.const(1, dtype='int32'))
-    func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
-    i_data = np.array(1, dtype='int32')
+    i = relay.var("i", shape=[], dtype="int32")
+    sub = relay.subtract(i, relay.const(1, dtype="int32"))
+    func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32"))
+    i_data = np.array(1, dtype="int32")
     check_eval(func, [i_data], 0)
 
 
 def test_simple_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
     sb = ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
         sb.ret(i)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+        one_less = relay.subtract(i, relay.const(1, dtype="int32"))
         rec_call = relay.Call(sum_up, [one_less])
         sb.ret(relay.add(rec_call, i))
-    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
     mod[sum_up] = func
-    i_data = np.array(10, dtype='int32')
+    i_data = np.array(10, dtype="int32")
     check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
 
 
 def test_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
-    accum = relay.var('accum', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
+    accum = relay.var("accum", shape=[], dtype="int32")
     sb = ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, "int32"))):
         sb.ret(accum)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, 'int32'))
+        one_less = relay.subtract(i, relay.const(1, "int32"))
         new_accum = relay.add(accum, i)
         sb.ret(relay.Call(sum_up, [one_less, new_accum]))
     func = relay.Function([i, accum], sb.get())
     mod[sum_up] = func
-    i_data = np.array(10, dtype='int32')
-    accum_data = np.array(0, dtype='int32')
+    i_data = np.array(10, dtype="int32")
+    accum_data = np.array(0, dtype="int32")
     check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
 
 
 def test_ref():
     mod = tvm.IRModule()
-    three_with_ref = relay.GlobalVar('three_with_ref')
-    i = relay.Var('i')
-    iv = relay.Var('iv')
-    u = relay.Var('u')
-    uv = relay.Var('uv')
+    three_with_ref = relay.GlobalVar("three_with_ref")
+    i = relay.Var("i")
+    iv = relay.Var("iv")
+    u = relay.Var("u")
+    uv = relay.Var("uv")
     body = relay.add(iv, uv)
     body = relay.Let(uv, relay.RefRead(i), body)
     body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
@@ -159,10 +157,10 @@ def test_kwargs_params():
     y = relay.var("y", shape=(1, 10))
     z = relay.var("z", shape=(1, 10))
     f = relay.Function([x, y, z], x + y + z)
-    x_data = np.random.rand(1, 10).astype('float32')
-    y_data = np.random.rand(1, 10).astype('float32')
-    z_data = np.random.rand(1, 10).astype('float32')
-    params = { 'y': y_data, 'z': z_data }
+    x_data = np.random.rand(1, 10).astype("float32")
+    y_data = np.random.rand(1, 10).astype("float32")
+    z_data = np.random.rand(1, 10).astype("float32")
+    params = {"y": y_data, "z": z_data}
     intrp = create_executor("debug")
     res = intrp.evaluate(f)(x_data, **params)
     tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
@@ -174,15 +172,16 @@ def test_function_taking_adt_ref_tuple():
     intrp = create_executor("debug", mod)
 
     nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
-    cons_value = ConstructorValue(prelude.cons.tag, [
-        nd.array(np.random.rand(1, 10).astype('float32')),
-        nil_value
-    ], prelude.cons)
+    cons_value = ConstructorValue(
+        prelude.cons.tag,
+        [nd.array(np.random.rand(1, 10).astype("float32")), nil_value],
+        prelude.cons,
+    )
 
-    ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
-    tuple_value = container.tuple_object([
-        nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
-    ])
+    ref_value = RefValue(nd.array(np.random.rand(1, 10).astype("float32")))
+    tuple_value = container.tuple_object(
+        [nd.array(np.random.rand(1, 10).astype("float32")) for _ in range(10)]
+    )
 
     id_func = intrp.evaluate(prelude.id)
 
@@ -193,8 +192,7 @@ def test_function_taking_adt_ref_tuple():
     res_cons = id_func(cons_value)
     assert res_cons.tag == cons_value.tag
     assert len(res_cons.fields) == len(cons_value.fields)
-    tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
-                                cons_value.fields[0].asnumpy())
+    tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(), cons_value.fields[0].asnumpy())
     assert isinstance(res_cons.fields[1], ConstructorValue)
     assert res_cons.fields[1].tag == prelude.nil.tag
     assert len(res_cons.fields[1].fields) == 0
@@ -204,33 +202,36 @@ def test_function_taking_adt_ref_tuple():
 
     res_tuple = id_func(tuple_value)
     for i in range(10):
-        tvm.testing.assert_allclose(res_tuple[i].asnumpy(),
-                                    tuple_value[i].asnumpy())
+        tvm.testing.assert_allclose(res_tuple[i].asnumpy(), tuple_value[i].asnumpy())
+
 
 def test_tuple_passing():
-    x = relay.var('x', type_annotation=relay.ty.TupleType([
-        relay.ty.TensorType((), 'int64'),
-        relay.ty.TensorType((), 'int64')]))
+    x = relay.var(
+        "x",
+        type_annotation=relay.ty.TupleType(
+            [relay.ty.TensorType((), "int64"), relay.ty.TensorType((), "int64")]
+        ),
+    )
 
     fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
     mod = tvm.IRModule({})
-    gv = relay.GlobalVar('main')
+    gv = relay.GlobalVar("main")
     mod[gv] = fn
     mod = relay.transform.InferType()(mod)
 
     ctx = tvm.cpu()
-    target = tvm.target.Target('llvm')
+    target = tvm.target.Target("llvm")
     exec = relay.create_executor(mod=mod, ctx=ctx, target=target)
     f = exec.evaluate(gv)
     # First use a Python tuple.
     out = f((10, 8))
     tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
     # Second use a tuple value.
-    value_tuple = container.tuple_object([nd.array(np.array(11)),
-                                          nd.array(np.array(12))])
+    value_tuple = container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))])
     out = f(value_tuple)
     tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
 
+
 if __name__ == "__main__":
     test_id()
     test_add_const()
index bae077c..a597824 100644 (file)
@@ -116,27 +116,25 @@ def test_nested_ref():
 def test_recursive_func():
     mod = tvm.IRModule({})
 
-    x = relay.var('x', shape=[], dtype='int32')
+    x = relay.var("x", shape=[], dtype="int32")
     fn0 = relay.Function([x], x)
     gx = relay.GlobalVar("gx")
     mod[gx] = fn0
 
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
     sb = relay.ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
         sb.ret(i)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+        one_less = relay.subtract(i, relay.const(1, dtype="int32"))
         global_call = gx(i)
         rec_call = relay.Call(sum_up, [one_less]) + global_call
         sb.ret(relay.add(rec_call, i))
-    func = relay.Function([i],
-                          sb.get(),
-                          ret_type=relay.TensorType([], 'int32'))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
     func = func.with_attr("Compiler", "a")
     mod[sum_up] = func
-    iarg = relay.var('i', shape=[], dtype='int32')
+    iarg = relay.var("i", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     call_graph = relay.analysis.CallGraph(mod)
 
index 4237641..8b4c6ba 100644 (file)
@@ -20,10 +20,12 @@ from tvm import relay
 from tvm.relay.testing import synthetic
 from tvm.relay import transform
 
+
 def test_change_batch_synthetic():
     net, params = synthetic.get_workload()
     new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net)
     assert new_net["main"].checked_type.ret_type.shape[0] == 123
 
+
 if __name__ == "__main__":
     test_change_batch_synthetic()
index d096eec..b826466 100644 (file)
 # under the License.
 
 from tvm import relay
+
 a = relay.Var("a")
-b = relay.expr.const (1.0, dtype='float32')
+b = relay.expr.const(1.0, dtype="float32")
 
 c = a < b
-d = relay.less (a, b)
-assert (c.astext() == d.astext())
+d = relay.less(a, b)
+assert c.astext() == d.astext()
 
 c = a > b
-d = relay.greater (a, b)
-assert (c.astext() == d.astext())
+d = relay.greater(a, b)
+assert c.astext() == d.astext()
 
-c = (a >= b)
+c = a >= b
 d = relay.greater_equal(a, b)
-assert (c.astext() == d.astext())
+assert c.astext() == d.astext()
 
-c = (a <= b)
+c = a <= b
 d = relay.less_equal(a, b)
-assert (c.astext() == d.astext())
+assert c.astext() == d.astext()
index faf6867..fe44eb2 100644 (file)
@@ -37,14 +37,9 @@ def test_basic_build():
     A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
     B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
     C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
-    params = {
-        "b" : B,
-        "c" : C
-    }
+    params = {"b": B, "c": C}
     # build
-    targets = {
-        tvm.tir.IntImm("int32", ctx.device_type): tgt
-    }
+    targets = {tvm.tir.IntImm("int32", ctx.device_type): tgt}
     mod = tvm.IRModule.from_expr(func)
     func_in_mod = mod["main"]
     assert mod["main"] == func_in_mod, "cannot compare function to itself"
@@ -59,10 +54,12 @@ def test_basic_build():
     rt.run()
     out = rt.get_output(0)
 
-    np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
-                                                                B.asnumpy().T),
-                                                         0) + C.asnumpy(),
-                               atol=1e-5, rtol=1e-5)
+    np.testing.assert_allclose(
+        out.asnumpy(),
+        np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(),
+        atol=1e-5,
+        rtol=1e-5,
+    )
 
 
 @tvm.testing.requires_cuda
@@ -94,8 +91,7 @@ def test_fp16_build():
     rt.run()
     out = rt.get_output(0)
 
-    np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
-                               atol=1e-5, rtol=1e-5)
+    np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(), atol=1e-5, rtol=1e-5)
 
 
 @tvm.testing.parametrize_targets("llvm", "cuda")
@@ -106,7 +102,7 @@ def test_fp16_conversion(target, ctx):
 
     n = 10
 
-    for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
+    for (src, dst) in [("float32", "float16"), ("float16", "float32")]:
         x = relay.var("x", relay.TensorType((n,), src))
         y = x.astype(dst)
         func = relay.Function([x], y)
@@ -124,8 +120,7 @@ def test_fp16_conversion(target, ctx):
         rt.run()
         out = rt.get_output(0)
 
-        np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
-                                   atol=1e-5, rtol=1e-5)
+        np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst), atol=1e-5, rtol=1e-5)
 
 
 if __name__ == "__main__":
index 34a0987..23c0f93 100644 (file)
@@ -32,7 +32,7 @@ K_BROADCAST = 1
 
 ## NODE TESTS
 def test_expr_pattern():
-    ep = is_expr(relay.var('x', shape=(4, 1)))
+    ep = is_expr(relay.var("x", shape=(4, 1)))
     assert isinstance(ep, ExprPattern)
     assert isinstance(ep.expr, relay.Var)
 
@@ -83,7 +83,7 @@ def test_TupleGetItemPattern():
 
 
 def test_AltPattern():
-    is_add_or_sub = is_op('add') | is_op('subtract')
+    is_add_or_sub = is_op("add") | is_op("subtract")
     assert isinstance(is_add_or_sub, AltPattern)
 
 
@@ -109,7 +109,7 @@ def test_ShapePattern():
 
 
 def test_AttrPattern():
-    op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
+    op = is_op("add").has_attr({"TOpPattern": K_ELEMWISE})
     assert isinstance(op, AttrPattern)
     assert op.attrs["TOpPattern"] == K_ELEMWISE
 
@@ -118,61 +118,64 @@ def test_AttrPattern():
 
 
 def test_match_op():
-    assert is_op('add').match(relay.op.op.get("add"))
+    assert is_op("add").match(relay.op.op.get("add"))
 
 
 def test_no_match_op():
-    assert not is_op('add').match(relay.op.op.get("subtract"))
+    assert not is_op("add").match(relay.op.op.get("subtract"))
 
 
 def test_match_op_or():
-    is_add_or_sub = is_op('add') | is_op('subtract')
+    is_add_or_sub = is_op("add") | is_op("subtract")
     assert is_add_or_sub.match(relay.op.op.get("add"))
     assert is_add_or_sub.match(relay.op.op.get("subtract"))
 
 
 def test_match_call_commutive():
-    x = relay.var('x')
-    y = relay.var('y')
-    add_pattern = is_op('add')(is_var("x"), is_var("y"))
+    x = relay.var("x")
+    y = relay.var("y")
+    add_pattern = is_op("add")(is_var("x"), is_var("y"))
     assert add_pattern.match(x + y)
     assert add_pattern.match(y + x)
-    mul_pattern = is_op('multiply')(is_var("x"), is_var("y"))
+    mul_pattern = is_op("multiply")(is_var("x"), is_var("y"))
     assert mul_pattern.match(x * y)
     assert mul_pattern.match(y * x)
 
 
 def test_no_match_call_commutive():
-    x = relay.var('x')
-    y = relay.var('y')
-    add_pattern = is_op('subtract')(is_var("x"), is_var("y"))
+    x = relay.var("x")
+    y = relay.var("y")
+    add_pattern = is_op("subtract")(is_var("x"), is_var("y"))
     assert add_pattern.match(x - y)
     assert not add_pattern.match(y - x)
-    add_pattern = is_op('divide')(is_var("x"), is_var("y"))
+    add_pattern = is_op("divide")(is_var("x"), is_var("y"))
     assert add_pattern.match(x / y)
     assert not add_pattern.match(y / x)
 
 
 def test_match_call():
-    x = relay.var('x')
-    y = relay.var('y')
-    add_pattern = is_op('add')(wildcard(), wildcard())
+    x = relay.var("x")
+    y = relay.var("y")
+    add_pattern = is_op("add")(wildcard(), wildcard())
     assert add_pattern.match(x + y)
 
 
 def test_no_match_call():
-    x = relay.var('x')
-    y = relay.var('y')
-    add_pattern = is_op('add')(wildcard(), wildcard())
+    x = relay.var("x")
+    y = relay.var("y")
+    add_pattern = is_op("add")(wildcard(), wildcard())
     assert not add_pattern.match(x - y)
 
 
 def test_match_option():
-    x = relay.var('x')
-    w = relay.var('w')
-    b = relay.var('b')
-    pattern = is_op("nn.relu")(is_op("nn.conv2d")(
-        wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
+    x = relay.var("x")
+    w = relay.var("w")
+    b = relay.var("b")
+    pattern = is_op("nn.relu")(
+        is_op("nn.conv2d")(wildcard(), wildcard()).optional(
+            lambda x: is_op("nn.bias_add")(x, wildcard())
+        )
+    )
 
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
@@ -184,7 +187,7 @@ def test_match_option():
     assert pattern.match(relu)
 
     pattern = is_op("nn.conv2d")(wildcard(), wildcard())
-    pattern = pattern.optional(is_op('nn.relu')).optional(is_op("tanh"))
+    pattern = pattern.optional(is_op("nn.relu")).optional(is_op("tanh"))
 
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
@@ -199,11 +202,14 @@ def test_match_option():
 
 
 def test_no_match_option():
-    x = relay.var('x')
-    w = relay.var('w')
-    b = relay.var('b')
-    pattern = is_op("nn.relu")(is_op("nn.conv2d")(
-        wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
+    x = relay.var("x")
+    w = relay.var("w")
+    b = relay.var("b")
+    pattern = is_op("nn.relu")(
+        is_op("nn.conv2d")(wildcard(), wildcard()).optional(
+            lambda x: is_op("nn.bias_add")(x, wildcard())
+        )
+    )
 
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.tanh(conv2d)
@@ -225,26 +231,25 @@ def test_no_match_option():
 
 
 def test_match_const():
-    conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
-    pattern = is_op('nn.bias_add')(conv2d, wildcard())
+    conv2d = is_op("nn.conv2d")(wildcard(), is_constant())
+    pattern = is_op("nn.bias_add")(conv2d, wildcard())
 
-    x = relay.var('x', shape=(1, 3, 224, 224))
-    w = relay.var('w', shape=(3, 3, 3, 3))
-    b = relay.var('b', shape=(3, ))
+    x = relay.var("x", shape=(1, 3, 224, 224))
+    w = relay.var("w", shape=(3, 3, 3, 3))
+    b = relay.var("b", shape=(3,))
     conv2d = relay.op.nn.conv2d(x, w)
     out = relay.op.nn.bias_add(conv2d, b)
     func = relay.Function([x, w, b], out)
     mod = tvm.IRModule.from_expr(func)
 
-    assert not pattern.match(mod['main'].body)
-    mod["main"] = bind_params_by_name(mod["main"],
-                                      {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
-    assert pattern.match(mod['main'].body)
+    assert not pattern.match(mod["main"].body)
+    mod["main"] = bind_params_by_name(mod["main"], {"w": tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
+    assert pattern.match(mod["main"].body)
 
 
 def test_match_tuple():
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     z = relay.op.op.get("add")
     tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
     assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))
@@ -253,94 +258,93 @@ def test_match_tuple():
     tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
     assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
 
-    tuple_get_item_pattern = is_tuple_get_item(tuple_pattern) # Match any index
+    tuple_get_item_pattern = is_tuple_get_item(tuple_pattern)  # Match any index
     assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0))
     assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
     assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))
 
 
 def test_no_match_tuple():
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     z = relay.op.op.get("add")
-    tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard()))
+    tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"), wildcard()))
     assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))
 
-    tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add")))
+    tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
     tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
-    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
-        (x, y, z)), 2))
+    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))
 
 
 def test_match_type():
-    x = relay.var('x', shape=(10, 10), dtype="float32")
+    x = relay.var("x", shape=(10, 10), dtype="float32")
     ty_pat = has_type(relay.TensorType((10, 10), "float32"))
     assert ty_pat.match(x)
 
 
 def test_no_match_type():
-    x = relay.var('x', shape=(10, 10), dtype="int32")
+    x = relay.var("x", shape=(10, 10), dtype="int32")
     ty_pat = has_type(relay.TensorType((10, 10), "float32"))
     assert not ty_pat.match(x)
 
 
 def test_match_dtype():
-    x = relay.var('x', shape=(10, 10), dtype="float32")
+    x = relay.var("x", shape=(10, 10), dtype="float32")
     ty_pat = has_dtype("float32")
     assert ty_pat.match(x)
 
 
 def test_no_match_dtype():
-    x = relay.var('x', shape=(10, 10), dtype="int32")
+    x = relay.var("x", shape=(10, 10), dtype="int32")
     ty_pat = has_dtype("float32")
     assert not ty_pat.match(x)
 
 
 def test_match_shape():
-    x = relay.var('x', shape=(10, 10), dtype="float32")
+    x = relay.var("x", shape=(10, 10), dtype="float32")
     ty_pat = has_shape((10, 10))
     assert ty_pat.match(x)
 
 
 def test_no_match_shape():
-    x = relay.var('x', shape=(10, 10), dtype="int32")
+    x = relay.var("x", shape=(10, 10), dtype="int32")
     ty_pat = has_shape((10, 5))
     assert not ty_pat.match(x)
 
 
 def test_match_op_attr():
-    op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
+    op = is_op("add").has_attr({"TOpPattern": K_BROADCAST})
     op_pat = op(wildcard(), wildcard())
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     assert op_pat.match(x + y)
 
 
 def test_no_match_op_attr():
-    op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
+    op = is_op("nn.dense").has_attr({"TOpPattern": K_ELEMWISE})
     op_pat = op(wildcard(), wildcard())
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     assert not op_pat.match(relay.op.nn.dense(x, y))
-    op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
+    op = is_op("add").has_attr({"TOpPattern": K_BROADCAST})
     op_pat = op(wildcard(), wildcard())
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     assert not op_pat.match(x - y)
 
 
 def test_match_func_attr():
     pattern = wildcard().has_attr({"Composite": "add"})
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
     f = relay.Function([x, y], x + y).with_attr("Composite", "add")
     assert pattern.match(f)
 
 
 def test_no_match_func_attr():
     pattern = wildcard().has_attr({"Composite": "add"})
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
 
     f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
     assert not pattern.match(f)
@@ -349,33 +353,33 @@ def test_no_match_func_attr():
 
 
 def test_match_call_attr():
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
-    x = relay.var('x')
-    y = relay.var('y')
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
+    x = relay.var("x")
+    y = relay.var("y")
     assert is_conv2d.match(relay.op.nn.conv2d(x, y))
 
 
 def test_no_match_call_attr():
-    x = relay.var('x')
-    y = relay.var('y')
+    x = relay.var("x")
+    y = relay.var("y")
 
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
     assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
     assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
 
 def test_match_diamond():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    path1 = is_op('nn.relu')(is_conv2d)
-    path2 = is_op('nn.leaky_relu')(is_conv2d)
-    diamond = is_op('add')(path1, path2)
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    path1 = is_op("nn.relu")(is_conv2d)
+    path2 = is_op("nn.leaky_relu")(is_conv2d)
+    diamond = is_op("add")(path1, path2)
 
     # Expr
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
@@ -387,14 +391,14 @@ def test_match_diamond():
 
 def test_no_match_diamond():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    path1 = is_op('nn.relu')(is_conv2d)
-    path2 = is_op('nn.leaky_relu')(is_conv2d)
-    diamond = is_op('add')(path1, path2)
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    path1 = is_op("nn.relu")(is_conv2d)
+    path2 = is_op("nn.leaky_relu")(is_conv2d)
+    diamond = is_op("add")(path1, path2)
 
     # Expr
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
@@ -406,17 +410,17 @@ def test_no_match_diamond():
 
 def test_match_fake_diamond():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    path1 = is_op('nn.relu')(is_conv2d)
-    path2 = is_op('nn.leaky_relu')(is_conv2d)
-    diamond = is_op('add')(path1, path2)
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    path1 = is_op("nn.relu")(is_conv2d)
+    path2 = is_op("nn.leaky_relu")(is_conv2d)
+    diamond = is_op("add")(path1, path2)
 
     # Expr
-    input1 = relay.var('input1')
-    weight1 = relay.var('weight1')
+    input1 = relay.var("input1")
+    weight1 = relay.var("weight1")
     conv2d1 = relay.op.nn.conv2d(input1, weight1)
-    inp2 = relay.var('input2')
-    weight2 = relay.var('weight2')
+    inp2 = relay.var("input2")
+    weight2 = relay.var("weight2")
     conv2d2 = relay.op.nn.conv2d(inp2, weight2)
     relu = relay.op.nn.relu(conv2d1)
     leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
@@ -428,14 +432,14 @@ def test_match_fake_diamond():
 
 def test_match_dominator():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
     is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
-    reduction = is_op('add')(wildcard(), wildcard())
+    reduction = is_op("add")(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Classic Diamond
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relay.op.nn.relu(relu)
@@ -446,8 +450,8 @@ def test_match_dominator():
     assert diamond.match(out)
 
     # Deeper Branch
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relay.op.nn.relu(relu)
@@ -459,8 +463,8 @@ def test_match_dominator():
     assert diamond.match(out)
 
     # Single Branch
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relay.op.nn.relu(relu)
@@ -471,14 +475,15 @@ def test_match_dominator():
     assert diamond.match(out)
 
     # Fuzzy path/nested Diamond
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
-        wildcard()) | is_op('add')(wildcard(), wildcard())
-    reduction = is_op('add')(wildcard(), wildcard())
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op(
+        "add"
+    )(wildcard(), wildcard())
+    reduction = is_op("add")(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relu + relu
@@ -490,17 +495,17 @@ def test_match_dominator():
 
 
 def test_not_match_dominator():
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
     is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
-    reduction = is_op('add')(wildcard(), wildcard())
+    reduction = is_op("add")(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Fake Diamond
-    input1 = relay.var('input1')
-    weight1 = relay.var('weight1')
+    input1 = relay.var("input1")
+    weight1 = relay.var("weight1")
     conv2d1 = relay.op.nn.conv2d(input1, weight1)
-    inp2 = relay.var('input2')
-    weight2 = relay.var('weight2')
+    inp2 = relay.var("input2")
+    weight2 = relay.var("weight2")
     conv2d2 = relay.op.nn.conv2d(inp2, weight2)
     relu = relay.op.nn.relu(conv2d1)
     leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
@@ -510,8 +515,8 @@ def test_not_match_dominator():
     assert not diamond.match(out)
 
     # Add op that doesn't match K_ELEMWISE
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relu + relu
@@ -522,8 +527,8 @@ def test_not_match_dominator():
     assert not diamond.match(out)
 
     # Relu on the input instead of the conv
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(inp)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
@@ -533,7 +538,7 @@ def test_not_match_dominator():
     assert not diamond.match(out)
 
     # No conv
-    inp = relay.var('input')
+    inp = relay.var("input")
     relu = relay.op.nn.relu(inp)
     relu = relay.op.nn.relu(relu)
     tanh = relay.op.tanh(relu)
@@ -545,14 +550,16 @@ def test_not_match_dominator():
 
 def test_match_typed_dominator():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
-    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
+        "float32"
+    )
+    reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Classic Diamond
-    inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
-    weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+    inp = relay.var("input", relay.TensorType((1, 3, 12, 12), "float32"))
+    weight = relay.var("weight", relay.TensorType((3, 3, 3, 3), "float32"))
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relay.op.nn.relu(relu)
@@ -562,10 +569,11 @@ def test_match_typed_dominator():
     # Check
     assert diamond.match(out)
 
+
 def test_no_match_typed_dominator():
     # Classic Diamond
-    inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
-    weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+    inp = relay.var("input", relay.TensorType((1, 3, 12, 12), "float32"))
+    weight = relay.var("weight", relay.TensorType((3, 3, 3, 3), "float32"))
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     relu = relay.op.nn.relu(relu)
@@ -573,18 +581,22 @@ def test_no_match_typed_dominator():
     out = relu + leaky_relu
 
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
-    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 1, 10, 10])
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
+        "float32"
+    )
+    reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 1, 10, 10])
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Check
     assert not diamond.match(out)
 
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float16")
-    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype(
+        "float16"
+    )
+    reduction = is_op("add")(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Check
@@ -592,10 +604,10 @@ def test_no_match_typed_dominator():
 
 
 def test_rewrite():
-    x = relay.var('x')
-    y = relay.var('y')
-    add_pattern = is_op('add')(wildcard(), wildcard())
-    sub_pattern = is_op('subtract')(wildcard(), wildcard())
+    x = relay.var("x")
+    y = relay.var("y")
+    add_pattern = is_op("add")(wildcard(), wildcard())
+    sub_pattern = is_op("subtract")(wildcard(), wildcard())
 
     class TestRewrite(DFPatternCallback):
         def __init__(self):
@@ -610,11 +622,11 @@ def test_rewrite():
 
 
 def test_rewrite_func():
-    x = relay.var('x')
-    w = relay.var('w')
-    y = relay.var('y')
-    add_pattern = is_op('add')(wildcard(), wildcard())
-    sub_pattern = is_op('subtract')(wildcard(), wildcard())
+    x = relay.var("x")
+    w = relay.var("w")
+    y = relay.var("y")
+    add_pattern = is_op("add")(wildcard(), wildcard())
+    sub_pattern = is_op("subtract")(wildcard(), wildcard())
 
     class TestRewrite(DFPatternCallback):
         def __init__(self):
@@ -626,9 +638,9 @@ def test_rewrite_func():
 
     inpf = relay.var("input")
     weightf = relay.var("weight")
-    func = relay.Function([inpf, weightf],
-                          relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)),
-                          attrs=None)
+    func = relay.Function(
+        [inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
+    )
     out = rewrite(TestRewrite(), func(x, w) + y)
     assert sub_pattern.match(out)
 
@@ -643,8 +655,8 @@ def test_nested_rewrite():
             return post
 
     def gen():
-        x = relay.var('x')
-        y = relay.var('y')
+        x = relay.var("x")
+        y = relay.var("y")
         y_add = relay.add(y, y)
         n0 = relay.add(x, y_add)
         n1 = relay.add(x, n0)
@@ -653,9 +665,9 @@ def test_nested_rewrite():
     def pattern():
         a = wildcard()
         b = wildcard()
-        n0 = is_op('add')(a, b)
-        n1 = is_op('add')(n0, a)
-        return is_op('add')(n0, n1)
+        n0 = is_op("add")(a, b)
+        n1 = is_op("add")(n0, a)
+        return is_op("add")(n0, n1)
 
     out = gen()
     pat = pattern()
@@ -666,14 +678,14 @@ def test_nested_rewrite():
 
 def test_not_fuse_multi_diamond():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    path1 = is_op('nn.relu')(is_conv2d)
-    path2 = is_op('nn.leaky_relu')(is_conv2d)
-    diamond = is_op('add')(path1, path2)
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    path1 = is_op("nn.relu")(is_conv2d)
+    path2 = is_op("nn.leaky_relu")(is_conv2d)
+    diamond = is_op("add")(path1, path2)
 
     # Expr
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
@@ -693,8 +705,9 @@ class BatchnormCallback(DFPatternCallback):
         self.gamma = wildcard()
         self.eps = is_constant()
 
-        self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
-                       self.beta
+        self.pattern = (
+            self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta
+        )
 
     def callback(self, pre, post, node_map):
         x = node_map[self.x][0]
@@ -703,31 +716,32 @@ class BatchnormCallback(DFPatternCallback):
         beta = node_map[self.beta][0]
         gamma = node_map[self.gamma][0]
         eps = node_map[self.eps][0]
-        return relay.op.nn.batch_norm(x, gamma, beta, mean, var,
-                                      epsilon=eps.data.asnumpy().item())[0]
+        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.asnumpy().item())[
+            0
+        ]
 
 
 def test_fuse_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
 
     BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
 
     out = rewrite(BatchnormCallback(), BN)
     assert tvm.ir.structural_equal(
-        out,
-        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+    )
 
 
 def test_no_fuse_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
 
     fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
 
@@ -736,11 +750,11 @@ def test_no_fuse_batchnorm():
 
 
 def test_fuse_double_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
 
     BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
     BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
@@ -754,11 +768,11 @@ def test_fuse_double_batchnorm():
 
 
 def test_partial_fuse_double_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
 
     BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
     BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
@@ -771,32 +785,32 @@ def test_partial_fuse_double_batchnorm():
 
 
 def test_fuse_batchnorm_commutation():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
 
-    #commute add
+    # commute add
     BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))
     out = rewrite(BatchnormCallback(), BN)
     assert tvm.ir.structural_equal(
-        out,
-        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+    )
 
     # associate divide/multiply
     BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta
     out = rewrite(BatchnormCallback(), BN)
     assert tvm.ir.structural_equal(
-        out,
-        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+    )
 
     # associate multiply/divide
     BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta
     out = rewrite(BatchnormCallback(), BN)
     assert tvm.ir.structural_equal(
-        out,
-        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+    )
 
 
 def test_quadruple_rewrite_dominator():
@@ -805,10 +819,11 @@ def test_quadruple_rewrite_dominator():
             super(DominatorRemovalCallback, self).__init__()
             self.inp = wildcard()
             self.weight = wildcard()
-            is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
+            is_conv2d = is_op("nn.conv2d")(self.inp, self.weight)
             is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
-                wildcard()) | is_op('add')(wildcard(), wildcard())
-            reduction = is_op('add')(wildcard(), wildcard())
+                wildcard()
+            ) | is_op("add")(wildcard(), wildcard())
+            reduction = is_op("add")(wildcard(), wildcard())
             self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
 
         def callback(self, pre, post, node_map):
@@ -816,8 +831,8 @@ def test_quadruple_rewrite_dominator():
             weight = node_map[self.weight][0]
             return relay.op.nn.conv2d(inp, weight)
 
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
     # Classic Diamond
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
@@ -856,8 +871,8 @@ def test_quadruple_rewrite_dominator():
 
 
 def algebraic_simplify(expr):
-    zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
-    one = (is_expr(relay.const(1)) | is_expr(relay.const(1.0)))
+    zero = is_expr(relay.const(0)) | is_expr(relay.const(0.0))
+    one = is_expr(relay.const(1)) | is_expr(relay.const(1.0))
 
     class ElwiseNullCallback(DFPatternCallback):
         def callback(self, pre, post, node_map):
@@ -899,19 +914,22 @@ def algebraic_simplify(expr):
             self.x = zero
             self.pattern = self.x / wildcard()
 
-    return rewrite([
-        AddCallback(),
-        SubCallback(),
-        MulCallback(),
-        DivCallback(),
-        MulZeroCallback(),
-        ZeroDivCallback()
-    ], expr)
+    return rewrite(
+        [
+            AddCallback(),
+            SubCallback(),
+            MulCallback(),
+            DivCallback(),
+            MulZeroCallback(),
+            ZeroDivCallback(),
+        ],
+        expr,
+    )
 
 
 def test_algebraic_simplify():
-    x = relay.Var('x')
-    y = relay.Var('y')
+    x = relay.Var("x")
+    y = relay.Var("y")
 
     one = relay.const(1)
     zero = relay.const(0)
@@ -938,22 +956,23 @@ def test_algebraic_simplify():
     assert algebraic_simplify(zero / x) == zero
     assert algebraic_simplify(zerof / x) == zerof
 
-    assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x),
-                                   x + y)
+    assert tvm.ir.structural_equal(
+        algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y
+    )
 
 
 def test_double_partition():
     # Pattern 1
-    conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard())
+    conv2d_p = is_op("nn.conv2d")(wildcard(), wildcard())
     bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard())
-    relu_p = is_op('nn.relu')(bias_add_p)
+    relu_p = is_op("nn.relu")(bias_add_p)
 
     # Graph
-    x = relay.var('input')
-    w = relay.var('weight')
-    b = relay.var('bias')
-    w2 = relay.var('weight')
-    b2 = relay.var('bias')
+    x = relay.var("input")
+    w = relay.var("weight")
+    b = relay.var("bias")
+    w2 = relay.var("weight")
+    b2 = relay.var("bias")
     conv2d = relay.op.nn.conv2d(x, w)
     bias_add = relay.op.nn.bias_add(conv2d, b)
     relu = relay.op.nn.relu(bias_add)
@@ -967,22 +986,24 @@ def test_double_partition():
     inpf = relay.var("input")
     weightf = relay.var("weight")
     biasf = relay.var("bias")
-    func0 = relay.Function(
-        [inpf, weightf, biasf],
-        relay.op.nn.relu(relay.op.nn.bias_add(
-            relay.op.nn.conv2d(inpf, weightf),
-            biasf))).with_attr("Composite",
-                               "conv_bias_relu").with_attr("PartitionedFromPattern",
-                                                           "nn.conv2d_nn.bias_add_nn.relu_")
+    func0 = (
+        relay.Function(
+            [inpf, weightf, biasf],
+            relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)),
+        )
+        .with_attr("Composite", "conv_bias_relu")
+        .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
+    )
     inpf = relay.var("input")
     weightf = relay.var("weight")
     biasf = relay.var("bias")
-    func1 = relay.Function([inpf, weightf, biasf],
-                           relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf),
-                                                biasf)).with_attr("Composite",
-                                                                  "conv_bias").with_attr(
-                                                                      "PartitionedFromPattern",
-                                                                      "nn.conv2d_nn.bias_add_")
+    func1 = (
+        relay.Function(
+            [inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)
+        )
+        .with_attr("Composite", "conv_bias")
+        .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_")
+    )
 
     expected = func1(func0(x, w, b), w2, b2)
     assert tvm.ir.structural_equal(partitioned, expected)
@@ -990,14 +1011,14 @@ def test_double_partition():
 
 def test_partition_dominator():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
     is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
-    reduction = is_op('add')(wildcard(), wildcard())
+    reduction = is_op("add")(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
     # Classic Diamond
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
 
     def generate_diamond(inp, weight):
         conv2d = relay.op.nn.conv2d(inp, weight)
@@ -1013,20 +1034,22 @@ def test_partition_dominator():
     i = relay.Var("input")
     w = relay.Var("weight")
     f = relay.Function([i, w], generate_diamond(i, w)).with_attr(
-        "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
+        "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_"
+    )
     assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight))
 
 
 def test_quadruple_partition_dominator():
     # Pattern
-    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
-        wildcard()) | is_op('add')(wildcard(), wildcard())
-    reduction = is_op('add')(wildcard(), wildcard())
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op(
+        "add"
+    )(wildcard(), wildcard())
+    reduction = is_op("add")(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
-    inp = relay.var('input')
-    weight = relay.var('weight')
+    inp = relay.var("input")
+    weight = relay.var("weight")
 
     # Classic Diamond
     def classic_diamond(inp, weight):
@@ -1063,25 +1086,30 @@ def test_quadruple_partition_dominator():
         return tanh + leaky_relu
 
     partitioned = diamond.partition(
-        nested_diamond(single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight),
-                       weight))
+        nested_diamond(
+            single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight), weight
+        )
+    )
 
     functions = []
     partition_names = [
         "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_",
-        "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", "nn.conv2d_nn.relu_nn.relu_tanh_add_",
-        "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_"
+        "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_",
+        "nn.conv2d_nn.relu_nn.relu_tanh_add_",
+        "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_",
     ]
     for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]):
         inpf = relay.var("input")
         weightf = relay.var("weight")
         functions.append(
-            relay.Function([inpf, weightf], f(inpf,
-                                              weightf)).with_attr("PartitionedFromPattern",
-                                                                  partition_names[i]))
-
-    reference = functions[3](functions[2](functions[1](functions[0](inp, weight), weight), weight),
-                             weight)
+            relay.Function([inpf, weightf], f(inpf, weightf)).with_attr(
+                "PartitionedFromPattern", partition_names[i]
+            )
+        )
+
+    reference = functions[3](
+        functions[2](functions[1](functions[0](inp, weight), weight), weight), weight
+    )
     assert tvm.ir.structural_equal(partitioned, reference)
 
 
@@ -1090,24 +1118,23 @@ def get_BN(x, var, mean, beta, gamma, eps):
 
 
 def test_partition_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
     eps = relay.const(1e-5)
     BN = get_BN(x, var, mean, beta, gamma, eps)
 
-    xf = relay.var('xf')
-    varf = relay.var('varf')
-    meanf = relay.var('meanf')
-    betaf = relay.var('betaf')
-    gammaf = relay.var('gammaf')
+    xf = relay.var("xf")
+    varf = relay.var("varf")
+    meanf = relay.var("meanf")
+    betaf = relay.var("betaf")
+    gammaf = relay.var("gammaf")
     # Put the arguments in toplogological order for the reference
-    f = relay.Function([gammaf, xf, meanf, varf, betaf],
-                       get_BN(xf, varf, meanf, betaf, gammaf,
-                              eps)).with_attr("PartitionedFromPattern",
-                                              "subtract_multiply_add_sqrt_divide_add_")
+    f = relay.Function(
+        [gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, eps)
+    ).with_attr("PartitionedFromPattern", "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN)
     reference = f(gamma, x, mean, var, beta)
@@ -1115,54 +1142,53 @@ def test_partition_batchnorm():
 
 
 def test_partition_double_batchnorm():
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
     eps = relay.const(1e-5)
 
     BN = gamma * (x - mean) / relay.op.sqrt(var + eps) + beta
     BN2 = gamma * (BN - mean) / relay.op.sqrt(var + eps) + beta
 
-    xf = relay.var('xf')
-    varf = relay.var('varf')
-    meanf = relay.var('meanf')
-    betaf = relay.var('betaf')
-    gammaf = relay.var('gammaf')
-    f1 = relay.Function([gammaf, xf, meanf, varf, betaf],
-                        get_BN(xf, varf, meanf, betaf, gammaf,
-                               eps)).with_attr("PartitionedFromPattern",
-                                               "subtract_multiply_add_sqrt_divide_add_")
+    xf = relay.var("xf")
+    varf = relay.var("varf")
+    meanf = relay.var("meanf")
+    betaf = relay.var("betaf")
+    gammaf = relay.var("gammaf")
+    f1 = relay.Function(
+        [gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, eps)
+    ).with_attr("PartitionedFromPattern", "subtract_multiply_add_sqrt_divide_add_")
     # The partitioner doesn't replace duplicates, so we use two copies of the function
-    xf2 = relay.var('xf2')
-    varf2 = relay.var('varf2')
-    meanf2 = relay.var('meanf2')
-    betaf2 = relay.var('betaf2')
-    gammaf2 = relay.var('gammaf2')
-    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2],
-                        get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
-                               eps)).with_attr("PartitionedFromPattern",
-                                               "subtract_multiply_add_sqrt_divide_add_")
+    xf2 = relay.var("xf2")
+    varf2 = relay.var("varf2")
+    meanf2 = relay.var("meanf2")
+    betaf2 = relay.var("betaf2")
+    gammaf2 = relay.var("gammaf2")
+    f2 = relay.Function(
+        [gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, eps)
+    ).with_attr("PartitionedFromPattern", "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN2)
     reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
     assert tvm.ir.structural_equal(partitioned, reference)
 
+
 def test_overlappting_partitions():
     x = wildcard()
     gamma = wildcard()
     beta = wildcard()
     moving_mean = wildcard()
     moving_var = wildcard()
-    bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
+    bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var)
     tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
 
-    x = relay.var('x')
-    var = relay.var('var')
-    mean = relay.var('mean')
-    beta = relay.var('beta')
-    gamma = relay.var('gamma')
+    x = relay.var("x")
+    var = relay.var("var")
+    mean = relay.var("mean")
+    beta = relay.var("beta")
+    gamma = relay.var("gamma")
     BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
     T1 = BN[0]
     T2 = BN[0]
@@ -1170,34 +1196,35 @@ def test_overlappting_partitions():
 
     assert tuple_get_item_node.partition(add) == add
 
+
 def test_partition_overused():
     pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
 
-    x = relay.var('input')
-    w = relay.var('weight')
+    x = relay.var("input")
+    w = relay.var("weight")
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
     out = relu + conv2d
-    
+
     assert pattern.partition(out) == out
 
+
 def test_partition_check():
     pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
 
     def check(pre):
         return pre.args[0].attrs.data_layout == "NCHW"
 
-    x = relay.var('input')
-    w = relay.var('weight')
+    x = relay.var("input")
+    w = relay.var("weight")
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
 
-    xf = relay.var('input')
-    wf = relay.var('weight')
+    xf = relay.var("input")
+    wf = relay.var("weight")
     conv2df = relay.op.nn.conv2d(xf, wf)
     reluf = relay.op.nn.relu(conv2df)
-    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern",
-                                                     "nn.conv2d_nn.relu_")
+    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
 
     reference = func(x, w)
     partitioned = pattern.partition(relu, check=check)
@@ -1215,8 +1242,8 @@ def test_partition_check_types():
         conv = pre.args[0]
         return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)
 
-    x = relay.var('input', shape=(1, 10, 10, 10))
-    w = relay.var('weight', shape=(10, 10, 3, 3))
+    x = relay.var("input", shape=(1, 10, 10, 10))
+    w = relay.var("weight", shape=(10, 10, 3, 3))
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
     relu = run_opt_pass(relu, relay.transform.InferType())
@@ -1229,8 +1256,8 @@ def test_partition_check_types():
     relu = run_opt_pass(relu, relay.transform.InferType())
     assert relu == pattern.partition(relu, check=check)
 
-    x = relay.var('input', shape=(2, 10, 10, 10))
-    w = relay.var('weight', shape=(10, 10, 3, 3))
+    x = relay.var("input", shape=(2, 10, 10, 10))
+    w = relay.var("weight", shape=(10, 10, 3, 3))
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
     relu = run_opt_pass(relu, relay.transform.InferType())
@@ -1245,26 +1272,26 @@ def conv_bias_relu(x, w, b):
 
 
 def test_partition_option():
-    x = relay.var('x')
-    w = relay.var('w')
-    b = relay.var('b')
+    x = relay.var("x")
+    w = relay.var("w")
+    b = relay.var("b")
 
-    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    bias = conv2d.optional(lambda x: is_op('nn.bias_add')(x, wildcard()))
-    pattern1 = is_op('nn.relu')(bias)
+    conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    bias = conv2d.optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
+    pattern1 = is_op("nn.relu")(bias)
 
-    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    bias = is_op('nn.bias_add')(conv2d, wildcard())
-    pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))
+    conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+    bias = is_op("nn.bias_add")(conv2d, wildcard())
+    pattern2 = bias.optional(lambda x: is_op("nn.relu")(x))
 
     relu = conv_bias_relu(x, w, b)
 
-    xf = relay.var('x')
-    wf = relay.var('w')
-    bf = relay.var('b')
-    func = relay.Function([xf, wf, bf],
-                          conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern",
-                                                                "nn.conv2d_nn.bias_add_nn.relu_")
+    xf = relay.var("x")
+    wf = relay.var("w")
+    bf = relay.var("b")
+    func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr(
+        "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
+    )
 
     assert pattern1.match(relu)
     assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))
@@ -1272,76 +1299,85 @@ def test_partition_option():
     assert pattern2.match(relu)
     assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
 
+
 def test_match_match():
-    add_pattern = is_op('add')(wildcard(), wildcard())
+    add_pattern = is_op("add")(wildcard(), wildcard())
+
     class TestRewrite(DFPatternCallback):
         def __init__(self):
             super(TestRewrite, self).__init__()
             self.pattern = add_pattern
+
         def callback(self, pre, post, node_map):
             return post.args[0] - post.args[1]
+
     mod = tvm.IRModule({})
     tvm.relay.prelude.Prelude(mod)
     # Apply rewrite on IR including relay.Match
-    out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
-    assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)
+    out = rewrite(TestRewrite(), mod["tensor_concatenate_int64"])
+    assert tvm.ir.structural_equal(mod["tensor_concatenate_int64"], out)
+
 
 def test_partition_constant_embedding():
-    x = relay.var('x')
-    w = relay.var('w')
+    x = relay.var("x")
+    w = relay.var("w")
     wc = relay.const(1)
-    b = relay.var('b')
-
-    xf = relay.var('x')
-    wf = relay.var('w')
-    bf = relay.var('b')
-    embeded_func = relay.Function([xf, bf],
-                                  conv_bias_relu(xf, wc,
-                                                 bf)).with_attr("PartitionedFromPattern",
-                                                                "nn.conv2d_nn.bias_add_nn.relu_")
-    xf = relay.var('x')
-    wf = relay.var('w')
-    bf = relay.var('b')
-    lifted_func = relay.Function([xf, wf, bf],
-                                 conv_bias_relu(xf, wf,
-                                                bf)).with_attr("PartitionedFromPattern",
-                                                               "nn.conv2d_nn.bias_add_nn.relu_")
+    b = relay.var("b")
+
+    xf = relay.var("x")
+    wf = relay.var("w")
+    bf = relay.var("b")
+    embeded_func = relay.Function([xf, bf], conv_bias_relu(xf, wc, bf)).with_attr(
+        "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
+    )
+    xf = relay.var("x")
+    wf = relay.var("w")
+    bf = relay.var("b")
+    lifted_func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr(
+        "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
+    )
     relu = conv_bias_relu(x, w, b)
     reluc = conv_bias_relu(x, wc, b)
 
     # Check lifting of wildcard matches
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()),
-                                                    wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard())
+    )
     assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
     assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))
 
     # Check lifting of input matches
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()),
-                                                    wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var()), wildcard())
+    )
     assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
-    assert tvm.ir.structural_equal(reluc, pattern.partition(reluc))  #Constants are not Inputs
+    assert tvm.ir.structural_equal(reluc, pattern.partition(reluc))  # Constants are not Inputs
 
     # Check embedding of constant matches
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_constant()),
-                                                    wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant()), wildcard())
+    )
     assert tvm.ir.structural_equal(relu, pattern.partition(relu))
     assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
 
     # Check embedding of constant ExprPatterns
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_expr(wc)),
-                                                    wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_expr(wc)), wildcard())
+    )
     assert tvm.ir.structural_equal(relu, pattern.partition(relu))
     assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
 
     # Check lifting/embedding of Alt matches
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
-        wildcard(), is_var() | is_constant()), wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var() | is_constant()), wildcard())
+    )
     assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
     assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
 
     # Check lifting/embedding of Alt matches with the other ordering
-    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
-        wildcard(), is_constant() | is_var()), wildcard()))
+    pattern = is_op("nn.relu")(
+        is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant() | is_var()), wildcard())
+    )
     assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
     assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
 
index fc21f68..60feaf1 100644 (file)
@@ -20,16 +20,19 @@ from tvm.relay.op import debug
 
 _test_debug_hit = False
 
+
 def test_debug():
     global _test_debug_hit
     ex = create_executor()
-    x = var('x', shape=(), dtype='int32')
+    x = var("x", shape=(), dtype="int32")
     _test_debug_hit = False
+
     def did_exec(x):
         global _test_debug_hit
         _test_debug_hit = True
+
     prog = debug(x, debug_func=did_exec)
-    result = ex.evaluate(prog, { x: const(1, 'int32') })
+    result = ex.evaluate(prog, {x: const(1, "int32")})
     assert _test_debug_hit
     assert result.asnumpy() == 1
 
@@ -38,12 +41,14 @@ def test_debug_with_expr():
     global _test_debug_hit
     _test_debug_hit = False
     ex = create_executor()
-    x = var('x', shape=(), dtype='int32')
+    x = var("x", shape=(), dtype="int32")
     _test_debug_hit = False
+
     def did_exec(x):
         global _test_debug_hit
         _test_debug_hit = True
+
     prog = debug(x + x * x, debug_func=did_exec)
-    result = ex.evaluate(prog, { x: const(2, 'int32') })
+    result = ex.evaluate(prog, {x: const(2, "int32")})
     assert _test_debug_hit
     assert result.asnumpy() == 6
index d697448..fc5c743 100644 (file)
@@ -18,6 +18,7 @@ import tvm
 from tvm import te
 from tvm import relay
 
+
 def check_type_err(expr, msg):
     try:
         mod = tvm.IRModule.from_expr(expr)
@@ -28,33 +29,37 @@ def check_type_err(expr, msg):
     except tvm.error.TVMError as err:
         assert msg in str(err)
 
+
 def test_wellformed():
-    x = relay.var('x', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
     f = relay.Function([x], x)
-    check_type_err(
-        f(x),
-        "Check failed: WellFormed")
+    check_type_err(f(x), "Check failed: WellFormed")
+
 
 def test_too_many_args():
-    x = relay.var('x', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
     f = relay.Function([x], x)
-    y = relay.var('y', shape=(10, 10))
-    check_type_err(
-        f(y, y),
-        "the function is provided too many arguments expected 1, found 2;")
+    y = relay.var("y", shape=(10, 10))
+    check_type_err(f(y, y), "the function is provided too many arguments expected 1, found 2;")
+
 
 def test_too_few_args():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(10, 10))
-    z = relay.var('z', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(10, 10))
+    z = relay.var("z", shape=(10, 10))
     f = relay.Function([x, y], x)
     check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;")
 
+
 def test_rel_fail():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(11, 10))
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(11, 10))
     f = relay.Function([x, y], x + y)
-    check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
+    check_type_err(
+        f,
+        "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);",
+    )
+
 
 if __name__ == "__main__":
     test_wellformed()
index ea7f8f6..f8ae7a9 100644 (file)
@@ -19,6 +19,7 @@ from tvm import te
 from tvm import relay
 from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
 
+
 def check_visit(expr):
     try:
         ef = ExprFunctor()
@@ -39,47 +40,41 @@ def test_constant():
 
 
 def test_tuple():
-    t = relay.Tuple([relay.var('x', shape=())])
+    t = relay.Tuple([relay.var("x", shape=())])
     check_visit(t)
 
 
 def test_var():
-    v = relay.var('x', shape=())
+    v = relay.var("x", shape=())
     check_visit(v)
 
 
 def test_global():
-    v = relay.GlobalVar('f')
+    v = relay.GlobalVar("f")
     check_visit(v)
 
 
 def test_function():
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     params = [x, y]
     body = x + y
     ret_type = relay.TensorType(())
     type_params = []
-    attrs = None # How to build?
-    f = relay.Function(
-        params,
-        body,
-        ret_type,
-        type_params,
-        attrs
-    )
+    attrs = None  # How to build?
+    f = relay.Function(params, body, ret_type, type_params, attrs)
     check_visit(f)
 
 
 def test_call():
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     call = relay.op.add(x, y)
     check_visit(call)
 
 
 def test_let():
-    x = relay.var('x', shape=())
+    x = relay.var("x", shape=())
     value = relay.const(2.0)
     body = x + x
     l = relay.Let(x, value, body)
@@ -87,13 +82,13 @@ def test_let():
 
 
 def test_ite():
-    cond = relay.var('x', shape=(), dtype='bool')
+    cond = relay.var("x", shape=(), dtype="bool")
     ite = relay.If(cond, cond, cond)
     check_visit(ite)
 
 
 def test_get_item():
-    t = relay.Tuple([relay.var('x', shape=())])
+    t = relay.Tuple([relay.var("x", shape=())])
     t = relay.TupleGetItem(t, 0)
     check_visit(t)
 
index 216d23e..c919e7c 100644 (file)
@@ -27,8 +27,8 @@ from tvm import relay
 from tvm import runtime
 from tvm.contrib import util
 
-def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
-                 ctx=tvm.cpu()):
+
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu()):
     if sys.platform == "win32":
         print("Skip test on Windows for now")
         return
@@ -41,7 +41,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         kwargs = {}
         kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
         tmp_path = util.tempdir()
-        lib_name = 'lib.so'
+        lib_name = "lib.so"
         lib_path = tmp_path.relpath(lib_name)
         lib.export_library(lib_path, fcompile=False, **kwargs)
         lib = tvm.runtime.load_module(lib_path)
@@ -49,8 +49,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         return lib
 
     def check_vm_result():
-        with tvm.transform.PassContext(opt_level=3,
-                                       disabled_pass=["AlterOpLayout"]):
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             exe = relay.vm.compile(mod, target=target)
         code, lib = exe.save()
         lib = update_lib(lib)
@@ -60,8 +59,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
 
     def check_graph_runtime_result():
-        with tvm.transform.PassContext(opt_level=3,
-                                       disabled_pass=["AlterOpLayout"]):
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             json, lib, _ = relay.build(mod, target=target)
         lib = update_lib(lib)
         rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
@@ -86,21 +84,21 @@ def set_external_func_attr(func, compiler, ext_symbol):
 
 
 def test_multi_node_subgraph():
-    x = relay.var('x', shape=(10, 10))
-    w0 = relay.var('w0', shape=(10, 10))
-    w1 = relay.var('w1', shape=(10, 10))
-    w2 = relay.var('w2', shape=(10, 10))
-    w3 = relay.var('w3', shape=(10, 10))
-    w4 = relay.var('w4', shape=(10, 10))
-    w5 = relay.var('w5', shape=(10, 10))
-    w6 = relay.var('w6', shape=(10, 10))
-    w7 = relay.var('w7', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
+    w0 = relay.var("w0", shape=(10, 10))
+    w1 = relay.var("w1", shape=(10, 10))
+    w2 = relay.var("w2", shape=(10, 10))
+    w3 = relay.var("w3", shape=(10, 10))
+    w4 = relay.var("w4", shape=(10, 10))
+    w5 = relay.var("w5", shape=(10, 10))
+    w6 = relay.var("w6", shape=(10, 10))
+    w7 = relay.var("w7", shape=(10, 10))
 
     # subgraph0
-    x0 = relay.var('x0', shape=(10, 10))
-    w00 = relay.var('w00', shape=(10, 10))
-    w01 = relay.var('w01', shape=(10, 10))
-    w02 = relay.var('w02', shape=(10, 10))
+    x0 = relay.var("x0", shape=(10, 10))
+    w00 = relay.var("w00", shape=(10, 10))
+    w01 = relay.var("w01", shape=(10, 10))
+    w02 = relay.var("w02", shape=(10, 10))
     z00 = relay.add(x0, w00)
     p00 = relay.subtract(z00, w01)
     q00 = relay.multiply(p00, w02)
@@ -109,10 +107,10 @@ def test_multi_node_subgraph():
     call0 = relay.Call(subgraph0, [x, w0, w1, w2])
 
     # subgraph1
-    x1 = relay.var('x1', shape=(10, 10))
-    w10 = relay.var('w10', shape=(10, 10))
-    w11 = relay.var('w11', shape=(10, 10))
-    w12 = relay.var('w12', shape=(10, 10))
+    x1 = relay.var("x1", shape=(10, 10))
+    w10 = relay.var("w10", shape=(10, 10))
+    w11 = relay.var("w11", shape=(10, 10))
+    w12 = relay.var("w12", shape=(10, 10))
     z10 = relay.add(x1, w10)
     p10 = relay.subtract(z10, w11)
     q10 = relay.multiply(p10, w12)
@@ -120,7 +118,6 @@ def test_multi_node_subgraph():
     subgraph1 = set_external_func_attr(subgraph1, "ccompiler", "ccompiler_1")
     call1 = relay.Call(subgraph1, [x, w3, w4, w5])
 
-
     # Other parts on TVM
     z2 = relay.add(x, w6)
     q2 = relay.subtract(z2, w7)
@@ -131,86 +128,93 @@ def test_multi_node_subgraph():
     mod["main"] = f
     mod = relay.transform.InferType()(mod)
 
-    x_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
     w_data = []
     for _ in range(8):
-        w_data.append(np.random.rand(10, 10).astype('float32'))
+        w_data.append(np.random.rand(10, 10).astype("float32"))
 
     map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
     map_inputs["x"] = x_data
     check_result(
-        mod, map_inputs, (30, 10),
-        np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
-                        ((x_data + w_data[3]) - w_data[4]) * w_data[5],
-                        x_data + w_data[6] - w_data[7]),
-                       axis=0))
+        mod,
+        map_inputs,
+        (30, 10),
+        np.concatenate(
+            (
+                ((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                x_data + w_data[6] - w_data[7],
+            ),
+            axis=0,
+        ),
+    )
 
 
 def test_extern_gcc_single_op():
-    x = relay.var('x', shape=(8, 8))
-    y = relay.var('y', shape=(8, 8))
+    x = relay.var("x", shape=(8, 8))
+    y = relay.var("y", shape=(8, 8))
 
-    x0 = relay.var('x0', shape=(8, 8))
-    y0 = relay.var('y0', shape=(8, 8))
+    x0 = relay.var("x0", shape=(8, 8))
+    y0 = relay.var("y0", shape=(8, 8))
     z = x0 + y0
     f = relay.Function([x0, y0], z)
     f = set_external_func_attr(f, "ccompiler", "ccompiler_0")
     call = relay.Call(f, [x, y])
     mod = tvm.IRModule.from_expr(call)
-    x_data = np.random.rand(8, 8).astype('float32')
-    y_data = np.random.rand(8, 8).astype('float32')
+    x_data = np.random.rand(8, 8).astype("float32")
+    y_data = np.random.rand(8, 8).astype("float32")
 
     check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
 
 
 def test_extern_gcc_single_op_int():
-    x = relay.var('x', shape=(8, 8), dtype="int32")
-    y = relay.var('y', shape=(8, 8), dtype="int32")
+    x = relay.var("x", shape=(8, 8), dtype="int32")
+    y = relay.var("y", shape=(8, 8), dtype="int32")
 
-    x0 = relay.var('x0', shape=(8, 8), dtype="int32")
-    y0 = relay.var('y0', shape=(8, 8), dtype="int32")
+    x0 = relay.var("x0", shape=(8, 8), dtype="int32")
+    y0 = relay.var("y0", shape=(8, 8), dtype="int32")
     z = x0 + y0
     f = relay.Function([x0, y0], z)
     f = set_external_func_attr(f, "ccompiler", "ccompiler_0")
     call = relay.Call(f, [x, y])
     mod = tvm.IRModule.from_expr(call)
-    x_data = np.random.rand(8, 8).astype('int32')
-    y_data = np.random.rand(8, 8).astype('int32')
+    x_data = np.random.rand(8, 8).astype("int32")
+    y_data = np.random.rand(8, 8).astype("int32")
 
     check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
 
 
 def test_extern_gcc():
-    x = relay.var('x', shape=(2, 2))
-    y = relay.var('y', shape=(2, 2))
+    x = relay.var("x", shape=(2, 2))
+    y = relay.var("y", shape=(2, 2))
 
     # subgraph for mul
-    x0 = relay.var('x0', shape=(2, 2))
-    y0 = relay.var('y0', shape=(2, 2))
+    x0 = relay.var("x0", shape=(2, 2))
+    y0 = relay.var("y0", shape=(2, 2))
     mul = x0 * y0
     mul = relay.Function([x0, y0], mul)
     mul = set_external_func_attr(mul, "ccompiler", "ccompiler_2")
     call_mul = relay.Call(mul, [y, y])
 
     # subgraph for add
-    x1 = relay.var('x1', shape=(2, 2))
-    y1 = relay.var('y1', shape=(2, 2))
+    x1 = relay.var("x1", shape=(2, 2))
+    y1 = relay.var("y1", shape=(2, 2))
     add = x1 + y1
     add = relay.Function([x1, y1], add)
     add = set_external_func_attr(add, "ccompiler", "ccompiler_1")
     call_add = relay.Call(add, [x, x])
 
     # subgraph for sub
-    x2 = relay.var('x2', shape=(2, 2))
-    y2 = relay.var('y2', shape=(2, 2))
+    x2 = relay.var("x2", shape=(2, 2))
+    y2 = relay.var("y2", shape=(2, 2))
     sub = x2 - y2
     sub = relay.Function([x2, y2], sub)
     sub = set_external_func_attr(sub, "ccompiler", "ccompiler_0")
     call_sub = relay.Call(sub, [call_mul, call_add])
     mod = tvm.IRModule.from_expr(call_sub)
 
-    x_data = np.random.rand(2, 2).astype('float32')
-    y_data = np.random.rand(2, 2).astype('float32')
+    x_data = np.random.rand(2, 2).astype("float32")
+    y_data = np.random.rand(2, 2).astype("float32")
 
     check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
 
@@ -220,30 +224,26 @@ def test_extern_dnnl():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 32, 14, 14)
     w1shape = (32, 1, 3, 3)
-    data0 = relay.var('data0', shape=(ishape), dtype=dtype)
-    weight0 = relay.var('weight0', shape=(w1shape), dtype=dtype)
-
-    data1 = relay.var('data0', shape=(ishape), dtype=dtype)
-    weight1 = relay.var('weight0', shape=(w1shape), dtype=dtype)
-    weight2 = relay.var('weight1', shape=(w1shape), dtype=dtype)
-    depthwise_conv2d_1 = relay.nn.conv2d(data1,
-                                         weight1,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
-    depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                         weight2,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
+    data0 = relay.var("data0", shape=(ishape), dtype=dtype)
+    weight0 = relay.var("weight0", shape=(w1shape), dtype=dtype)
+
+    data1 = relay.var("data0", shape=(ishape), dtype=dtype)
+    weight1 = relay.var("weight0", shape=(w1shape), dtype=dtype)
+    weight2 = relay.var("weight1", shape=(w1shape), dtype=dtype)
+    depthwise_conv2d_1 = relay.nn.conv2d(
+        data1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+    )
+    depthwise_conv2d_2 = relay.nn.conv2d(
+        depthwise_conv2d_1, weight2, kernel_size=(3, 3), padding=(1, 1), groups=32
+    )
     out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
     f = relay.Function([data1, weight1, weight2], out)
     ref_mod = tvm.IRModule()
-    ref_mod['main'] = f
+    ref_mod["main"] = f
 
     f = set_external_func_attr(f, "dnnl", "dnnl_0")
     call = relay.Call(f, [data0, weight0, weight0])
@@ -254,8 +254,9 @@ def test_extern_dnnl():
 
     ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
     ref_res = ref_ex.evaluate()(i_data, w_data, w_data)
-    check_result(mod, {"data0": i_data, "weight0": w_data},
-                 (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+    check_result(
+        mod, {"data0": i_data, "weight0": w_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5
+    )
 
 
 def test_extern_dnnl_const():
@@ -263,30 +264,26 @@ def test_extern_dnnl_const():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 32, 14, 14)
     w1shape = (32, 1, 3, 3)
-    data0 = relay.var('data0', shape=(ishape), dtype=dtype)
+    data0 = relay.var("data0", shape=(ishape), dtype=dtype)
     w_data = np.random.uniform(0, 1, w1shape).astype(dtype)
 
-    data1 = relay.var('data0', shape=(ishape), dtype=dtype)
+    data1 = relay.var("data0", shape=(ishape), dtype=dtype)
     weight1 = relay.const(w_data, dtype=dtype)
     weight2 = relay.const(w_data, dtype=dtype)
-    depthwise_conv2d_1 = relay.nn.conv2d(data1,
-                                         weight1,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
-    depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                         weight2,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
+    depthwise_conv2d_1 = relay.nn.conv2d(
+        data1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+    )
+    depthwise_conv2d_2 = relay.nn.conv2d(
+        depthwise_conv2d_1, weight2, kernel_size=(3, 3), padding=(1, 1), groups=32
+    )
     out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
     f = relay.Function([data1], out)
     ref_mod = tvm.IRModule()
-    ref_mod['main'] = f
+    ref_mod["main"] = f
 
     f = set_external_func_attr(f, "dnnl", "dnnl_0")
     call = relay.Call(f, [data0])
@@ -296,8 +293,7 @@ def test_extern_dnnl_const():
 
     ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
     ref_res = ref_ex.evaluate()(i_data)
-    check_result(mod, {"data0": i_data},
-                 (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+    check_result(mod, {"data0": i_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
 
 
 if __name__ == "__main__":
index 8ba4644..b179096 100644 (file)
@@ -25,10 +25,8 @@ def test_bind_params():
     y = relay.var("y")
     z = relay.add(x, y)
     f = relay.Function([x, y], z)
-    fbinded = relay.bind(f, {x : relay.const(1, "float32")})
-    fexpected =relay.Function(
-        [y],
-        relay.add(relay.const(1, "float32"),  y))
+    fbinded = relay.bind(f, {x: relay.const(1, "float32")})
+    fexpected = relay.Function([y], relay.add(relay.const(1, "float32"), y))
     assert tvm.ir.structural_equal(fbinded, fexpected)
 
     zbinded = relay.bind(z, {y: x})
index bab8247..d3f8f2c 100644 (file)
@@ -21,6 +21,7 @@ from tvm import relay
 from tvm.relay.prelude import Prelude
 from tvm.relay.testing import add_nat_definitions
 
+
 def constructor_list(p):
     return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]
 
index b53423a..1ba39b0 100644 (file)
@@ -23,6 +23,7 @@ from tvm.tir.expr import *
 from tvm.relay import op
 import numpy as np
 
+
 def check_json_roundtrip(node):
     json_str = tvm.ir.save_json(node)
     back = tvm.ir.load_json(json_str)
@@ -71,7 +72,7 @@ def test_tuple():
 
 
 def test_local_var():
-    name_hint = 's'
+    name_hint = "s"
     lv = relay.Var(name_hint)
     assert lv.name_hint == name_hint
     assert lv.type_annotation is None
@@ -86,15 +87,16 @@ def test_local_var():
 
 
 def test_global_var():
-    name_hint = 'g'
+    name_hint = "g"
     gv = relay.GlobalVar(name_hint)
     gv.name_hint == name_hint
     # assert lv.span == None todo(@jroesch): what do we do about spans
     str(gv)
     check_json_roundtrip(gv)
 
+
 def test_function():
-    param_names = ['a', 'b', 'c', 'd']
+    param_names = ["a", "b", "c", "d"]
     params = tvm.runtime.convert([relay.Var(n) for n in param_names])
     ret_type = relay.TupleType(tvm.runtime.convert([]))
     body = relay.Tuple(tvm.runtime.convert([]))
@@ -113,7 +115,7 @@ def test_function():
 
 
 def test_function_attrs():
-    param_names = ['a', 'b', 'c', 'd']
+    param_names = ["a", "b", "c", "d"]
     params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names])
     ret_type = relay.TupleType(tvm.runtime.convert([]))
     body = relay.Tuple(tvm.runtime.convert([]))
@@ -143,9 +145,10 @@ def test_function_attrs():
         p2 = model_params_after[key2]
         np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
 
+
 def test_call():
-    op = relay.Var('f')
-    arg_names = ['a', 'b', 'c', 'd']
+    op = relay.Var("f")
+    arg_names = ["a", "b", "c", "d"]
     args = tvm.runtime.convert([relay.Var(n) for n in arg_names])
     call = relay.Call(op, args, None, None)
     assert call.op == op
@@ -156,7 +159,7 @@ def test_call():
 
 
 def test_let():
-    lv = relay.Var('x')
+    lv = relay.Var("x")
     ty = None
     arr = tvm.nd.array(10)
     value = relay.Constant(arr)
@@ -172,9 +175,9 @@ def test_let():
 
 
 def test_if():
-    cond = relay.Var('cond')
-    left = relay.Var('left')
-    right = relay.Var('right')
+    cond = relay.Var("cond")
+    left = relay.Var("left")
+    right = relay.Var("right")
     ife = relay.If(cond, left, right)
     assert ife.cond == cond
     assert ife.true_branch == left
@@ -199,15 +202,9 @@ def test_op():
 
 
 def test_conv2d_attrs():
-    data = relay.var('data', shape=(1, 3, 224, 224))
-    param = relay.var('param', shape=(64, 3, 7, 7))
-    out = op.nn.conv2d(
-        data,
-        param,
-        strides=(2, 2),
-        padding=(3, 3),
-        channels=64,
-        kernel_size=(7, 7))
+    data = relay.var("data", shape=(1, 3, 224, 224))
+    param = relay.var("param", shape=(64, 3, 7, 7))
+    out = op.nn.conv2d(data, param, strides=(2, 2), padding=(3, 3), channels=64, kernel_size=(7, 7))
     check_json_roundtrip(out)
 
 
index 46e4b02..34c0000 100644 (file)
@@ -18,6 +18,7 @@ import tvm
 from tvm import relay
 from tvm.relay.testing.temp_op_attr import TempOpAttr
 
+
 def test_op_attr():
     log_op = relay.op.get("log")
 
@@ -25,12 +26,14 @@ def test_op_attr():
     def test(x):
         return x + 1
 
-    assert log_op.num_inputs  == 1
+    assert log_op.num_inputs == 1
     assert log_op.get_attr("ftest") is None
     assert relay.op.get("exp").get_attr("ftest")(1) == 2
 
+
 def test_op_reset_attr():
     """ Tests reset_attr functionality. """
+
     def add1(x):
         return x + 1
 
@@ -55,8 +58,10 @@ def test_op_reset_attr():
     # Check that other attrs of the log op are intact.
     assert relay.op.get("log").get_attr("fadd2")(1) == 3
 
+
 def test_op_temp_attr():
     """ Tests reset_attr functionality. """
+
     def add1(x):
         return x + 1
 
@@ -73,15 +78,17 @@ def test_op_temp_attr():
     # Check that the attr value is recovered to add1.
     assert relay.op.get("sqrt").get_attr("ftest")(1) == 2
 
+
 def test_op_level1():
     x = relay.Var("x")
 
-    for op_name in ["log", "exp", "sqrt", "rsqrt","tanh"]:
+    for op_name in ["log", "exp", "sqrt", "rsqrt", "tanh"]:
         y = getattr(relay, op_name)(x)
         assert y.op.name == op_name
         assert y.op.support_level == 1
         assert y.args[0] == x
 
+
 def test_op_level3():
     x = relay.Var("x")
 
@@ -91,6 +98,7 @@ def test_op_level3():
         assert y.op.support_level == 3
         assert y.args[0] == x
 
+
 if __name__ == "__main__":
     test_op_attr()
     test_op_reset_attr()
index 6d581b6..a95ae7a 100644 (file)
@@ -24,8 +24,7 @@ from typing import Union
 from functools import wraps
 
 
-
-SEMVER = "#[version = \"0.0.5\"]\n"
+SEMVER = '#[version = "0.0.5"]\n'
 
 BINARY_OPS = {
     "*": relay.multiply,
@@ -45,18 +44,14 @@ TYPES = {
     "int16",
     "int32",
     "int64",
-
     "uint8",
     "uint16",
     "uint32",
     "uint64",
-
     "float16",
     "float32",
     "float64",
-
     "bool",
-
     "int8x4",
     "uint1x4",
     "float16x4",
@@ -69,51 +64,62 @@ type List[A] {
 }
 """
 
+
 def assert_graph_equal(lhs, rhs):
     tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
 
+
 def graph_equal(lhs, rhs):
     return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
 
+
 def roundtrip_expr(expr):
     text = tvm.relay.Expr.astext(expr, show_meta_data=False)
     x = tvm.parser.parse_expr(text)
     assert_graph_equal(x, expr)
 
+
 # Testing Utilities for expressions.
 def roundtrip(expr):
     x = tvm.parser.fromtext(expr.astext())
     assert_graph_equal(x, expr)
 
+
 def parse_text(code):
     expr = tvm.parser.parse_expr(code)
     roundtrip_expr(expr)
     return expr
 
+
 def parses_as(code, expr):
     # type: (str, relay.Expr) -> bool
     parsed = parse_text(code)
     result = graph_equal(parsed, expr)
     return result
 
+
 # Testing Utilities for full modules.
 def parse_module(code):
     mod = tvm.parser.parse(SEMVER + code)
     roundtrip(mod)
     return mod
 
+
 def assert_parses_as(code, expr):
     parsed = parse_text(code)
     assert_graph_equal(parsed, expr)
 
+
 def assert_parse_module_as(code, mod):
     parsed = parse_module(code)
     assert_graph_equal(parsed, mod)
 
+
 def get_scalar(x):
     # type: (relay.Constant) -> (Union[float, int, bool])
     return x.data.asnumpy().item()
 
+
 int32 = relay.scalar_type("int32")
 
 _ = relay.Var("_")
@@ -131,7 +137,7 @@ def test_comments():
         // This is a line comment!
         ()
         """,
-        UNIT
+        UNIT,
     )
 
     assert_parses_as(
@@ -141,7 +147,7 @@ def test_comments():
         */
         ()
         """,
-        UNIT
+        UNIT,
     )
 
     assert_parses_as(
@@ -151,7 +157,7 @@ def test_comments():
         */
         ()
         """,
-        UNIT
+        UNIT,
     )
 
 
@@ -174,13 +180,13 @@ def test_float_literal():
 
     # scientific notation
     assert isclose(get_scalar(parse_text("1e-1f")), 1e-1)
-    assert get_scalar(parse_text("1e+1f")) == 1e+1
-    assert isclose(get_scalar(parse_text("1E-1f")), 1E-1)
-    assert get_scalar(parse_text("1E+1f")) == 1E+1
+    assert get_scalar(parse_text("1e+1f")) == 1e1
+    assert isclose(get_scalar(parse_text("1E-1f")), 1e-1)
+    assert get_scalar(parse_text("1E+1f")) == 1e1
     assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1)
-    assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1
-    assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1)
-    assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1
+    assert get_scalar(parse_text("1.0e+1f")) == 1.0e1
+    assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0e-1)
+    assert get_scalar(parse_text("1.0E+1f")) == 1.0e1
 
 
 def test_bool_literal():
@@ -198,8 +204,7 @@ def test_negative():
 def test_bin_op():
     for bin_op in BINARY_OPS.keys():
         assert_parses_as(
-            "1 {} 1".format(bin_op),
-            BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
+            "1 {} 1".format(bin_op), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
         )
 
 
@@ -212,6 +217,7 @@ def test_op_assoc():
     assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
     assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
 
+
 def test_vars():
     # var
     var = parse_text("let %foo = (); %foo")
@@ -233,6 +239,7 @@ def test_vars():
     assert isinstance(op, tvm.ir.Op)
     assert op.name == "nn.global_avg_pool2d"
 
+
 def test_meta_ref():
     with pytest.raises(tvm.error.DiagnosticError):
         meta_op = parse_text("meta[type_key][1337]")
@@ -241,14 +248,7 @@ def test_meta_ref():
 
 
 def test_let():
-    assert_parses_as(
-        "let %x = 1; ()",
-        relay.Let(
-            X,
-            relay.const(1),
-            UNIT
-        )
-    )
+    assert_parses_as("let %x = 1; ()", relay.Let(X, relay.const(1), UNIT))
 
     assert_parses_as(
         """
@@ -256,48 +256,25 @@ def test_let():
         let %y = 2;
         ()
         """,
-        relay.Let(
-            X,
-            relay.const(1),
-            relay.Let(
-                Y,
-                relay.const(2),
-                UNIT
-            )
-        )
+        relay.Let(X, relay.const(1), relay.Let(Y, relay.const(2), UNIT)),
     )
 
 
 def test_seq():
-    assert_parses_as(
-        "(); ()",
-        relay.Let(
-            _,
-            UNIT,
-            UNIT)
-    )
+    assert_parses_as("(); ()", relay.Let(_, UNIT, UNIT))
 
-    assert_parses_as(
-        "let %_ = 1; ()",
-        relay.Let(
-            X,
-            relay.const(1),
-            UNIT
-        )
-    )
+    assert_parses_as("let %_ = 1; ()", relay.Let(X, relay.const(1), UNIT))
 
 
 def test_graph():
     code = "%0 = (); %1 = 1; (%0, %0, %1)"
-    assert_parses_as(
-        code,
-        relay.Tuple([UNIT, UNIT, relay.const(1)])
-    )
+    assert_parses_as(code, relay.Tuple([UNIT, UNIT, relay.const(1)]))
 
 
 def test_graph_single():
     assert_parses_as("%1 = (); %1", relay.Tuple([]))
 
+
 def test_let_global_var():
     with pytest.raises(tvm.error.DiagnosticError):
         parse_text("let @x = 1; ()")
@@ -320,48 +297,16 @@ def test_tuple():
 
 def test_func():
     # 0 args
-    assert_parses_as(
-        "fn () { 0 }",
-        relay.Function(
-            [],
-            relay.const(0),
-            None,
-            []
-        )
-    )
+    assert_parses_as("fn () { 0 }", relay.Function([], relay.const(0), None, []))
 
     # 1 arg
-    assert_parses_as(
-        "fn (%x) { %x }",
-        relay.Function(
-            [X],
-            X,
-            None,
-            []
-        )
-    )
+    assert_parses_as("fn (%x) { %x }", relay.Function([X], X, None, []))
 
     # 2 args
-    assert_parses_as(
-        "fn (%x, %y) { %x + %y }",
-        relay.Function(
-            [X, Y],
-            relay.add(X, Y),
-            None,
-            []
-        )
-    )
+    assert_parses_as("fn (%x, %y) { %x + %y }", relay.Function([X, Y], relay.add(X, Y), None, []))
 
     # annotations
-    assert_parses_as(
-        "fn (%x: int32) -> int32 { %x }",
-        relay.Function(
-            [X_ANNO],
-            X_ANNO,
-            int32,
-            []
-        )
-    )
+    assert_parses_as("fn (%x: int32) -> int32 { %x }", relay.Function([X_ANNO], X_ANNO, int32, []))
 
     # Refactor the attribute syntax and printing.
     #
@@ -379,7 +324,8 @@ def test_defn():
         def @id(%x: int32) -> int32 {
             %x
         }
-        """)
+        """
+    )
     assert isinstance(id_defn, tvm.IRModule)
 
 
@@ -389,7 +335,8 @@ def test_recursive_call():
         def @id(%x: int32) -> int32 {
             @id(%x)
         }
-        """)
+        """
+    )
     assert isinstance(id_defn, tvm.IRModule)
 
 
@@ -402,11 +349,7 @@ def test_ifelse():
             1
         }
         """,
-        relay.If(
-            relay.const(True),
-            relay.const(0),
-            relay.const(1)
-        )
+        relay.If(relay.const(True), relay.const(0), relay.const(1)),
     )
 
 
@@ -435,8 +378,8 @@ def test_call():
         relay.Let(
             id_func,
             relay.Function([X], X, None, []),
-            relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)]))
-        )
+            relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)])),
+        ),
     )
 
     # 0 args
@@ -449,8 +392,8 @@ def test_call():
         relay.Let(
             constant,
             relay.Function([], relay.const(0), None, []),
-            relay.Call(constant, [], None, None)
-        )
+            relay.Call(constant, [], None, None),
+        ),
     )
 
     # 1 arg
@@ -463,8 +406,8 @@ def test_call():
         relay.Let(
             id_var,
             relay.Function([X], X, None, []),
-            relay.Call(id_var, [relay.const(1)], None, None)
-        )
+            relay.Call(id_var, [relay.const(1)], None, None),
+        ),
     )
 
     # 2 args
@@ -476,14 +419,9 @@ def test_call():
         """,
         relay.Let(
             multiply,
-            relay.Function(
-                [X, Y],
-                relay.multiply(X, Y),
-                None,
-                []
-            ),
-            relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
-        )
+            relay.Function([X, Y], relay.multiply(X, Y), None, []),
+            relay.Call(multiply, [relay.const(0), relay.const(0)], None, None),
+        ),
     )
 
     # anonymous function
@@ -491,17 +429,7 @@ def test_call():
         """
         (fn (%x) { %x })(0)
         """,
-        relay.Call(
-            relay.Function(
-                [X],
-                X,
-                None,
-                []
-            ),
-            [relay.const(0)],
-            None,
-            None
-        )
+        relay.Call(relay.Function([X], X, None, []), [relay.const(0)], None, None),
     )
 
     # curried function
@@ -519,43 +447,29 @@ def test_call():
         """,
         relay.Let(
             curried_mult,
-            relay.Function(
-                [X],
-                relay.Function(
-                    [Y],
-                    relay.multiply(X, Y),
-                    None,
-                    []
-                ),
-                None,
-                []
-            ),
+            relay.Function([X], relay.Function([Y], relay.multiply(X, Y), None, []), None, []),
             relay.Let(
                 _,
                 relay.Call(curried_mult, [relay.const(0)], None, None),
-                relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
-            )
-        )
+                relay.Call(
+                    relay.Call(curried_mult, [relay.const(0)], None, None),
+                    [relay.const(0)],
+                    None,
+                    None,
+                ),
+            ),
+        ),
     )
 
     # op
-    assert_parses_as(
-        "abs(1)",
-        relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
-    )
+    assert_parses_as("abs(1)", relay.Call(relay.op.get("abs"), [relay.const(1)], None, None))
+
 
 # Types
 
 
 def test_incomplete_type():
-    assert_parses_as(
-        "let %_ : _ = (); ()",
-        relay.Let(
-            _,
-            UNIT,
-            UNIT
-        )
-    )
+    assert_parses_as("let %_ : _ = (); ()", relay.Let(_, UNIT, UNIT))
 
 
 def test_builtin_types():
@@ -566,42 +480,25 @@ def test_builtin_types():
 def test_tensor_type():
     assert_parses_as(
         "let %_ : Tensor[(), float32] = (); ()",
-        relay.Let(
-            relay.Var("_", relay.TensorType((), "float32")),
-            UNIT,
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TensorType((), "float32")), UNIT, UNIT),
     )
 
     assert_parses_as(
         "let %_ : Tensor[(1), float32] = (); ()",
-        relay.Let(
-            relay.Var("_", relay.TensorType((1,), "float32")),
-            UNIT,
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TensorType((1,), "float32")), UNIT, UNIT),
     )
 
     assert_parses_as(
         "let %_ : Tensor[(1, 1), float32] = (); ()",
-        relay.Let(
-            relay.Var("_", relay.TensorType((1, 1), "float32")),
-            UNIT,
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, UNIT),
     )
 
     assert_parses_as(
         "let %_ : Tensor[(?, 1), float32] = (); ()",
-        relay.Let(
-            relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")),
-            UNIT,
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")), UNIT, UNIT),
     )
 
 
-
 def test_function_type():
     assert_parses_as(
         """
@@ -610,8 +507,8 @@ def test_function_type():
         relay.Let(
             relay.Var("_", relay.FuncType([], int32, [], [])),
             relay.Function([], relay.const(0), int32, []),
-            UNIT
-        )
+            UNIT,
+        ),
     )
 
     assert_parses_as(
@@ -621,8 +518,8 @@ def test_function_type():
         relay.Let(
             relay.Var("_", relay.FuncType([int32], int32, [], [])),
             relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
-            UNIT
-        )
+            UNIT,
+        ),
     )
 
     assert_parses_as(
@@ -631,9 +528,11 @@ def test_function_type():
         """,
         relay.Let(
             relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
-            relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
-            UNIT
-        )
+            relay.Function(
+                [relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []
+            ),
+            UNIT,
+        ),
     )
 
 
@@ -642,22 +541,14 @@ def test_tuple_type():
         """
         let %_: () = (); ()
         """,
-        relay.Let(
-            relay.Var("_", relay.TupleType([])),
-            UNIT,
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TupleType([])), UNIT, UNIT),
     )
 
     assert_parses_as(
         """
         let %_: (int32,) = (0,); ()
         """,
-        relay.Let(
-            relay.Var("_", relay.TupleType([int32])),
-            relay.Tuple([relay.const(0)]),
-            UNIT
-        )
+        relay.Let(relay.Var("_", relay.TupleType([int32])), relay.Tuple([relay.const(0)]), UNIT),
     )
 
     assert_parses_as(
@@ -667,8 +558,8 @@ def test_tuple_type():
         relay.Let(
             relay.Var("_", relay.TupleType([int32, int32])),
             relay.Tuple([relay.const(0), relay.const(1)]),
-            UNIT
-        )
+            UNIT,
+        ),
     )
 
 
@@ -676,18 +567,16 @@ def test_adt_defn():
     mod = tvm.IRModule()
 
     glob_typ_var = relay.GlobalTypeVar("Ayy")
-    prog = relay.TypeData(
-            glob_typ_var,
-            [],
-            [relay.Constructor("Nil", [], glob_typ_var)])
+    prog = relay.TypeData(glob_typ_var, [], [relay.Constructor("Nil", [], glob_typ_var)])
     mod[glob_typ_var] = prog
     assert_parse_module_as(
         """
         type Ayy { Nil }
         """,
-        mod
+        mod,
     )
 
+
 def test_adt_any():
     code = """
     type my_dtype {
@@ -717,7 +606,7 @@ def test_empty_adt_defn():
         """
         type Ayy { }
         """,
-        mod
+        mod,
     )
 
 
@@ -727,12 +616,13 @@ def test_multiple_cons_defn():
     list_var = relay.GlobalTypeVar("List")
     typ_var = relay.TypeVar("A")
     prog = relay.TypeData(
-            list_var,
-            [typ_var],
-            [
-                relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
-                relay.Constructor("Nil", [], list_var),
-            ])
+        list_var,
+        [typ_var],
+        [
+            relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
+            relay.Constructor("Nil", [], list_var),
+        ],
+    )
     mod[list_var] = prog
     assert_parse_module_as(LIST_DEFN, mod)
 
@@ -742,12 +632,13 @@ def test_multiple_type_param_defn():
     typ_var_a = relay.TypeVar("A")
     typ_var_b = relay.TypeVar("B")
     prog = relay.TypeData(
-            glob_typ_var,
-            [typ_var_a, typ_var_b],
-            [
-                relay.Constructor("Left", [typ_var_a], glob_typ_var),
-                relay.Constructor("Right", [typ_var_b], glob_typ_var),
-            ])
+        glob_typ_var,
+        [typ_var_a, typ_var_b],
+        [
+            relay.Constructor("Left", [typ_var_a], glob_typ_var),
+            relay.Constructor("Right", [typ_var_b], glob_typ_var),
+        ],
+    )
     mod = tvm.IRModule()
     mod[glob_typ_var] = prog
     assert_parse_module_as(
@@ -757,7 +648,7 @@ def test_multiple_type_param_defn():
           Right(B),
         }
         """,
-        mod
+        mod,
     )
 
 
@@ -769,13 +660,9 @@ def test_match():
 
         list_var = relay.GlobalTypeVar("List")
         typ_var = relay.TypeVar("A")
-        cons_constructor = relay.Constructor(
-            "Cons", [typ_var, list_var(typ_var)], list_var)
+        cons_constructor = relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var)
         nil_constructor = relay.Constructor("Nil", [], list_var)
-        list_def = relay.TypeData(
-            list_var,
-            [typ_var],
-            [cons_constructor, nil_constructor])
+        list_def = relay.TypeData(list_var, [typ_var], [cons_constructor, nil_constructor])
         mod[list_var] = list_def
 
         length_var = relay.GlobalVar("length")
@@ -786,24 +673,22 @@ def test_match():
         cons_case = relay.Let(
             relay.var("", type_annotation=None),
             UNIT,
-            relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
-        body = relay.Match(input_var,
-            [relay.Clause(
-                relay.PatternConstructor(
-                    cons_constructor,
-                    [relay.PatternWildcard(), relay.PatternVar(rest_var)]),
-                cons_case),
-            relay.Clause(
-                relay.PatternConstructor(nil_constructor, []),
-                relay.const(0))],
-            complete=is_complete
+            relay.add(relay.const(1), relay.Call(length_var, [rest_var])),
         )
-        length_func = relay.Function(
-            [input_var],
-            body,
-            int32,
-            [typ_var]
+        body = relay.Match(
+            input_var,
+            [
+                relay.Clause(
+                    relay.PatternConstructor(
+                        cons_constructor, [relay.PatternWildcard(), relay.PatternVar(rest_var)]
+                    ),
+                    cons_case,
+                ),
+                relay.Clause(relay.PatternConstructor(nil_constructor, []), relay.const(0)),
+            ],
+            complete=is_complete,
         )
+        length_func = relay.Function([input_var], body, int32, [typ_var])
         mod[length_var] = length_func
 
         assert_parse_module_as(
@@ -819,8 +704,9 @@ def test_match():
                 Nil => 0,
               }
             }
-            """ % (LIST_DEFN, match_keyword),
-            mod
+            """
+            % (LIST_DEFN, match_keyword),
+            mod,
         )
 
 
@@ -829,21 +715,15 @@ def test_adt_cons_expr():
 
     list_var = relay.GlobalTypeVar("List")
     typ_var = relay.TypeVar("A")
-    cons_constructor = relay.Constructor(
-        "Cons", [typ_var, list_var(typ_var)], list_var)
+    cons_constructor = relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var)
     nil_constructor = relay.Constructor("Nil", [], list_var)
-    list_def = relay.TypeData(
-        list_var,
-        [typ_var],
-        [cons_constructor, nil_constructor])
+    list_def = relay.TypeData(list_var, [typ_var], [cons_constructor, nil_constructor])
     mod[list_var] = list_def
 
     make_singleton_var = relay.GlobalVar("make_singleton")
     input_var = relay.Var("x", int32)
     make_singleton_func = relay.Function(
-        [input_var],
-        cons_constructor(input_var, nil_constructor()),
-        list_var(int32)
+        [input_var], cons_constructor(input_var, nil_constructor()), list_var(int32)
     )
     mod[make_singleton_var] = make_singleton_func
 
@@ -854,8 +734,9 @@ def test_adt_cons_expr():
         def @make_singleton(%%x: int32) -> List[int32] {
           Cons(%%x, Nil)
         }
-        """ % LIST_DEFN,
-        mod
+        """
+        % LIST_DEFN,
+        mod,
     )
 
 
@@ -869,7 +750,8 @@ def test_duplicate_adt_defn():
             Cons(A, List[A]),
             Nil,
             }
-            """ % LIST_DEFN
+            """
+            % LIST_DEFN
         )
 
 
@@ -915,19 +797,22 @@ def test_extern_adt_defn():
         """
         extern type T[A]
         """,
-        mod
+        mod,
     )
 
+
 def test_import_grad():
     mod = tvm.IRModule()
     mod.import_from_std("gradient.rly")
 
+
 def test_resnet():
     mod, _ = relay.testing.resnet.get_workload()
     text = mod.astext()
     parsed_mod = tvm.parser.parse(text)
     tvm.ir.assert_structural_equal(mod, parsed_mod)
 
+
 def inline_params(mod, params):
     main_fn = mod["main"]
     str_to_var = {}
@@ -943,6 +828,7 @@ def inline_params(mod, params):
     mod["main_fn"] = main_fn
     return mod
 
+
 def test_resnet_inlined_params():
     mod, params = relay.testing.resnet.get_workload()
     mod = inline_params(mod, params)
@@ -950,6 +836,8 @@ def test_resnet_inlined_params():
     parsed_mod = tvm.parser.parse(text)
     tvm.ir.assert_structural_equal(mod, parsed_mod)
 
+
 if __name__ == "__main__":
     import sys
+
     pytest.main(sys.argv)
index ecdb293..65919e0 100644 (file)
@@ -31,14 +31,18 @@ def consistent_equal(x, y, map_free_vars=False):
     if struct_equal0 != struct_equal1:
         raise ValueError(
             "Non-communicative {} vs {}, sequal0={}, sequal1={}".format(
-                x, y, struct_equal0, struct_equal1))
+                x, y, struct_equal0, struct_equal1
+            )
+        )
 
     # NOTE: hash colision can happen but should be rare.
     # we can confirm that hash colison doesn't happen for our testcases
     if struct_equal0 != (xhash == yhash):
         raise ValueError(
             "Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
-                x, y, struct_equal0, xhash, yhash))
+                x, y, struct_equal0, xhash, yhash
+            )
+        )
     return struct_equal0
 
 
@@ -74,18 +78,24 @@ def test_type_param_sequal():
     # only pointer equality and eq_map allow equal params
     assert t1 == t1
     assert t2 == t2
-    assert t1 != t2 # different kind
-    assert t1 != t3 # not in eq_map
+    assert t1 != t2  # different kind
+    assert t1 != t3  # not in eq_map
 
     # function types are the only way to put type params
     # in eq map
-    ft1 = relay.FuncType(tvm.runtime.convert([]), t1, tvm.runtime.convert([t1]), tvm.runtime.convert([]))
-    ft2 = relay.FuncType(tvm.runtime.convert([]), t3, tvm.runtime.convert([t3]), tvm.runtime.convert([]))
+    ft1 = relay.FuncType(
+        tvm.runtime.convert([]), t1, tvm.runtime.convert([t1]), tvm.runtime.convert([])
+    )
+    ft2 = relay.FuncType(
+        tvm.runtime.convert([]), t3, tvm.runtime.convert([t3]), tvm.runtime.convert([])
+    )
     # actually an invalid type because t2 is wrong kind
-    ft3 = relay.FuncType(tvm.runtime.convert([]), t2, tvm.runtime.convert([t2]), tvm.runtime.convert([]))
+    ft3 = relay.FuncType(
+        tvm.runtime.convert([]), t2, tvm.runtime.convert([t2]), tvm.runtime.convert([])
+    )
 
     assert ft1 == ft2
-    assert ft1 != ft3 # kinds still do not match
+    assert ft1 != ft3  # kinds still do not match
 
 
 def test_func_type_sequal():
@@ -104,47 +114,68 @@ def test_func_type_sequal():
     tr2 = relay.TypeRelation(broadcast, tvm.runtime.convert([tp2, tp4]), 1, None)
     tr3 = relay.TypeRelation(identity, tvm.runtime.convert([tp1, tp3]), 1, None)
 
-    ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
-                         tvm.runtime.convert([tp1, tp3]),
-                         tvm.runtime.convert([tr1]))
-    translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
-                         tvm.runtime.convert([tp2, tp4]),
-                         tvm.runtime.convert([tr2]))
+    ft = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp1,
+        tvm.runtime.convert([tp1, tp3]),
+        tvm.runtime.convert([tr1]),
+    )
+    translate_vars = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp2,
+        tvm.runtime.convert([tp2, tp4]),
+        tvm.runtime.convert([tr2]),
+    )
     assert ft == translate_vars
 
-    different_args = relay.FuncType(tvm.runtime.convert([t1]), tp1,
-                         tvm.runtime.convert([tp1, tp3]),
-                         tvm.runtime.convert([tr1]))
+    different_args = relay.FuncType(
+        tvm.runtime.convert([t1]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])
+    )
     assert ft != different_args
 
-    different_order = relay.FuncType(tvm.runtime.convert([t2, t1]), tp1,
-                         tvm.runtime.convert([tp1, tp3]),
-                         tvm.runtime.convert([tr1]))
+    different_order = relay.FuncType(
+        tvm.runtime.convert([t2, t1]),
+        tp1,
+        tvm.runtime.convert([tp1, tp3]),
+        tvm.runtime.convert([tr1]),
+    )
     assert ft != different_order
 
-    no_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
-                         tvm.runtime.convert([tp1, tp3]),
-                         tvm.runtime.convert([]))
+    no_rel = relay.FuncType(
+        tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([])
+    )
     assert ft != no_rel
 
-    more_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
-                         tvm.runtime.convert([tp1, tp2, tp3]),
-                         tvm.runtime.convert([tr1]))
+    more_vars = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp2,
+        tvm.runtime.convert([tp1, tp2, tp3]),
+        tvm.runtime.convert([tr1]),
+    )
     assert ft != more_vars
 
-    all_the_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
-                         tvm.runtime.convert([tp1, tp2, tp3, tp4]),
-                         tvm.runtime.convert([tr1, tr2]))
+    all_the_vars = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp1,
+        tvm.runtime.convert([tp1, tp2, tp3, tp4]),
+        tvm.runtime.convert([tr1, tr2]),
+    )
     assert ft != all_the_vars
 
-    different_rel = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
-                                   tvm.runtime.convert([tp1, tp3]),
-                                   tvm.runtime.convert([tr3]))
+    different_rel = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp1,
+        tvm.runtime.convert([tp1, tp3]),
+        tvm.runtime.convert([tr3]),
+    )
     assert ft != different_rel
 
-    more_rels = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
-                                   tvm.runtime.convert([tp1, tp3]),
-                                   tvm.runtime.convert([tr1, tr3]))
+    more_rels = relay.FuncType(
+        tvm.runtime.convert([t1, t2]),
+        tp1,
+        tvm.runtime.convert([tp1, tp3]),
+        tvm.runtime.convert([tr1, tr3]),
+    )
     assert ft != more_rels
 
 
@@ -176,9 +207,9 @@ def test_type_relation_sequal():
     broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
     identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
 
-    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
+    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
+    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
+    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4, 4))
 
     tr = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr1)
     same = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr1)
@@ -202,6 +233,7 @@ def test_type_relation_sequal():
 
     assert bigger != diff_num_inputs
 
+
 def test_type_call_sequal():
     h1 = relay.GlobalTypeVar("h1")
     h2 = relay.GlobalTypeVar("h2")
@@ -233,25 +265,29 @@ def test_constant_sequal():
     assert not consistent_equal(x, y)
     assert consistent_equal(x, relay.const(1))
 
+
 def test_type_node_sequal():
-    v1 = relay.TypeVar('v1', 6)
-    v2 = relay.TypeVar('v2', 6)
+    v1 = relay.TypeVar("v1", 6)
+    v2 = relay.TypeVar("v2", 6)
     assert not consistent_equal(v1, v2)
 
-    v1 = relay.TypeVar('v1', 0)
-    v2 = relay.TypeVar('v2', 6)
+    v1 = relay.TypeVar("v1", 0)
+    v2 = relay.TypeVar("v2", 6)
     assert not consistent_equal(v1, v2)
 
+
 def test_type_node_incompatible_sequal():
-    v1 = relay.TypeVar('v1', 6)
+    v1 = relay.TypeVar("v1", 6)
     v2 = relay.Var("v2")
     assert not consistent_equal(v1, v2)
 
+
 def test_expr_node_incompatible_sequal():
     v1 = relay.Var("v1")
     v2 = relay.PatternVar(relay.Var("v2"))
     assert not consistent_equal(v1, v2)
 
+
 def test_var_sequal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
@@ -312,56 +348,58 @@ def test_tuple_sequal():
 
     # use the eq_map
 
-
     let_tup = relay.Let(v1, tup, v1)
-    let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
-                                            relay.Tuple([relay.const(4)])]),
-                           v2)
+    let_mapped = relay.Let(
+        v2, relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]), v2
+    )
 
     assert consistent_equal(let_tup, let_mapped)
 
-    more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
+    more_fields = relay.Tuple(
+        [v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]
+    )
     assert not consistent_equal(tup, more_fields)
 
     fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
     assert not consistent_equal(tup, fewer_fields)
 
-    different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
-                           relay.Tuple([relay.const(5)])])
+    different_end = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(5)])])
     assert not consistent_equal(tup, different_end)
 
-    different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
-                                 relay.Tuple([relay.const(4)])])
+    different_start = relay.Tuple(
+        [v2, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]
+    )
     assert not consistent_equal(tup, different_start)
 
-    longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
-                                 relay.Tuple([relay.const(4), relay.const(5)])])
+    longer_at_end = relay.Tuple(
+        [v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4), relay.const(5)])]
+    )
     assert not consistent_equal(tup, longer_at_end)
 
 
 def test_tuple_get_item_sequal():
-    x = relay.Var('x')
-    y = relay.Var('y')
+    x = relay.Var("x")
+    y = relay.Var("y")
     assert not consistent_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
     assert not consistent_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
     assert consistent_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
 
 def test_function_attr():
-    x0 = relay.var('x0', shape=(10, 10))
-    w00 = relay.var('w00', shape=(10, 10))
-    w01 = relay.var('w01', shape=(10, 10))
-    w02 = relay.var('w02', shape=(10, 10))
+    x0 = relay.var("x0", shape=(10, 10))
+    w00 = relay.var("w00", shape=(10, 10))
+    w01 = relay.var("w01", shape=(10, 10))
+    w02 = relay.var("w02", shape=(10, 10))
     z00 = relay.add(x0, w00)
     p00 = relay.subtract(z00, w01)
     q00 = relay.multiply(p00, w02)
     func0 = relay.Function([x0, w00, w01, w02], q00)
     func0 = func0.with_attr("FuncName", "a")
 
-    x1 = relay.var('x1', shape=(10, 10))
-    w10 = relay.var('w10', shape=(10, 10))
-    w11 = relay.var('w11', shape=(10, 10))
-    w12 = relay.var('w12', shape=(10, 10))
+    x1 = relay.var("x1", shape=(10, 10))
+    w10 = relay.var("w10", shape=(10, 10))
+    w11 = relay.var("w11", shape=(10, 10))
+    w12 = relay.var("w12", shape=(10, 10))
     z10 = relay.add(x1, w10)
     p10 = relay.subtract(z10, w11)
     q10 = relay.multiply(p10, w12)
@@ -389,25 +427,22 @@ def test_function_sequal():
     basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
     basic_tps = [tp1, tp2]
 
-    func = relay.Function([v1, v2], v1,
-                          tt2, basic_tps)
+    func = relay.Function([v1, v2], v1, tt2, basic_tps)
     mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
     assert consistent_equal(func, mapped)
 
     fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
     assert not consistent_equal(func, fewer_params)
 
-    more_params = relay.Function([relay.Var("v3", tt1),
-                                  relay.Var("v4", tt2),
-                                  relay.Var("v2", tt2)], v4, tt2, basic_tps)
+    more_params = relay.Function(
+        [relay.Var("v3", tt1), relay.Var("v4", tt2), relay.Var("v2", tt2)], v4, tt2, basic_tps
+    )
     assert not consistent_equal(func, more_params)
 
-    params_unordered = relay.Function([v2, v1], v1,
-                                      tt2, basic_tps)
+    params_unordered = relay.Function([v2, v1], v1, tt2, basic_tps)
     assert not consistent_equal(func, params_unordered)
 
-    params_mismatch = relay.Function([v1, v3], v1,
-                                     tt2, basic_tps)
+    params_mismatch = relay.Function([v1, v3], v1, tt2, basic_tps)
     assert not consistent_equal(func, params_mismatch)
 
     # also would not typecheck
@@ -447,9 +482,9 @@ def test_call_sequal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
-    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
+    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
+    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
+    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4, 4))
 
     tt1 = relay.TensorType((1, 2, 3), "float32")
     tt2 = relay.TensorType((), "int8")
@@ -458,8 +493,7 @@ def test_call_sequal():
 
     # manually writing out args to ensure that args does not rely on
     # pointer equality
-    call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])],
-                      attr1, [tt1])
+    call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])], attr1, [tt1])
     same = relay.Call(v1, basic_args, attr1, [tt1])
     assert consistent_equal(call, same)
 
@@ -469,16 +503,20 @@ def test_call_sequal():
     fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1])
     assert not consistent_equal(call, fewer_args)
 
-    reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
-                                     relay.Tuple([]), v2], attr1, [tt1])
+    reordered_args = relay.Call(
+        v1, [relay.const(2), relay.const(1), relay.Tuple([]), v2], attr1, [tt1]
+    )
     assert not consistent_equal(call, reordered_args)
 
-    different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)],
-                                attr1, [tt1])
+    different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)], attr1, [tt1])
     assert not consistent_equal(call, different_args)
 
-    more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]),
-                                relay.const(3), relay.const(4)], attr1, [tt1])
+    more_args = relay.Call(
+        v1,
+        [relay.const(1), relay.const(2), v2, relay.Tuple([]), relay.const(3), relay.const(4)],
+        attr1,
+        [tt1],
+    )
     assert not consistent_equal(call, more_args)
 
     different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
@@ -561,20 +599,18 @@ def test_match_sequal():
     mod = tvm.IRModule()
     p = relay.prelude.Prelude(mod)
 
-    x = relay.Var('x')
-    y = relay.Var('y')
+    x = relay.Var("x")
+    y = relay.Var("y")
     nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
-    cons_case = relay.Clause(relay.PatternConstructor(p.cons,
-                                                      [relay.PatternVar(x),
-                                                       relay.PatternVar(y)]),
-                       p.cons(x, y))
-
-    z = relay.Var('z')
-    a = relay.Var('a')
-    equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons,
-                                                            [relay.PatternVar(z),
-                                                             relay.PatternVar(a)]),
-                                   p.cons(z, a))
+    cons_case = relay.Clause(
+        relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]), p.cons(x, y)
+    )
+
+    z = relay.Var("z")
+    a = relay.Var("a")
+    equivalent_cons = relay.Clause(
+        relay.PatternConstructor(p.cons, [relay.PatternVar(z), relay.PatternVar(a)]), p.cons(z, a)
+    )
 
     data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))
 
@@ -585,27 +621,33 @@ def test_match_sequal():
     no_nil = relay.Match(data, [cons_case])
     different_data = relay.Match(p.nil(), [nil_case, cons_case])
     different_order = relay.Match(data, [cons_case, nil_case])
-    different_nil = relay.Match(data, [
-        relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())),
-        cons_case
-    ])
-    different_cons = relay.Match(data, [
-        nil_case,
-        relay.Clause(relay.PatternConstructor(p.cons,
-                                              [relay.PatternWildcard(),
-                                               relay.PatternWildcard()]),
-                     p.nil())
-    ])
-    another_case = relay.Match(data, [
-        nil_case,
-        cons_case,
-        relay.Clause(relay.PatternWildcard(), p.nil())
-    ])
-    wrong_constructors = relay.Match(data, [
-        relay.Clause(relay.PatternConstructor(p.none), p.nil()),
-        relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
-                     p.cons(x, p.nil()))
-    ])
+    different_nil = relay.Match(
+        data, [relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), cons_case]
+    )
+    different_cons = relay.Match(
+        data,
+        [
+            nil_case,
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                ),
+                p.nil(),
+            ),
+        ],
+    )
+    another_case = relay.Match(
+        data, [nil_case, cons_case, relay.Clause(relay.PatternWildcard(), p.nil())]
+    )
+    wrong_constructors = relay.Match(
+        data,
+        [
+            relay.Clause(relay.PatternConstructor(p.none), p.nil()),
+            relay.Clause(
+                relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil())
+            ),
+        ],
+    )
 
     tvm.ir.assert_structural_equal(match, match)
     assert consistent_equal(match, match)
@@ -651,6 +693,7 @@ def test_graph_equal():
     # Check the difference in the text format.
     assert not consistent_equal(z0, z3)
 
+
 def test_hash_unequal():
     x1 = relay.var("x1", shape=(10, 10), dtype="float32")
     y1 = relay.var("y1", shape=(10, 10), dtype="float32")
@@ -671,7 +714,6 @@ def test_hash_unequal():
     assert not consistent_equal(func1, func3)
 
 
-
 def test_tuple_match():
     a = relay.Var("a")
     b = relay.Var("b")
@@ -687,15 +729,15 @@ def test_tuple_match():
 
 def test_fn_attribute():
     # create function that performs add
-    a = relay.var('a', shape=(10, 10))
-    b = relay.var('b', shape=(10, 10))
+    a = relay.var("a", shape=(10, 10))
+    b = relay.var("b", shape=(10, 10))
     add = relay.add(a, b)
     add_fn = relay.Function([a, b], add)
     add_fn = run_opt_pass(add_fn, relay.transform.InferType())
 
     # create function that performs add with test attribute
-    c = relay.var('c', shape=(10, 10))
-    d = relay.var('d', shape=(10, 10))
+    c = relay.var("c", shape=(10, 10))
+    d = relay.var("d", shape=(10, 10))
     add_1 = relay.add(c, d)
     add_1_fn = relay.Function([c, d], add_1)
     add_1_fn = add_1_fn.with_attr("TestAttribute", "test")
@@ -708,8 +750,7 @@ def test_fn_attribute():
 def test_fn_vid_map():
     def get_fn(with_vid):
         x = relay.var("x", shape=(10,), dtype="float32")
-        f = relay.Function([x], x).with_attr(
-            "dict", {x.vid: 1} if with_vid else {x : 1})
+        f = relay.Function([x], x).with_attr("dict", {x.vid: 1} if with_vid else {x: 1})
         return f
 
     assert consistent_equal(get_fn(True), get_fn(True))
index 52551bf..fd0853a 100644 (file)
@@ -24,7 +24,8 @@ from tvm.relay.analysis import free_vars
 
 DEBUG_PRINT = False
 
-SEMVER = "#[version = \"0.0.5\"]\n"
+SEMVER = '#[version = "0.0.5"]\n'
+
 
 def astext(program, unify_free_vars=False):
     text = program.astext()
@@ -38,11 +39,13 @@ def astext(program, unify_free_vars=False):
 
     return text
 
+
 def show(text):
     if DEBUG_PRINT:
         print("---------------------------")
         print(text)
 
+
 def test_func():
     x = relay.var("x", shape=(3, 2))
     y = relay.var("y")
@@ -75,10 +78,7 @@ def test_meta_data():
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", shape=(n, c, h, w))
     w = relay.var("w")
-    z = relay.nn.conv2d(x, w,
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        channels=2)
+    z = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=2)
     f = relay.Function([x, w], z)
     text = astext(f, unify_free_vars=True)
     text_no_meta = str(f)
@@ -89,7 +89,7 @@ def test_meta_data():
     assert "type_key" in text
     assert "type_key" not in text_no_meta
 
-    text = astext(relay.const([1,2,3]))
+    text = astext(relay.const([1, 2, 3]))
     assert "meta[relay.Constant][0]" in text
 
 
@@ -167,13 +167,14 @@ def test_lstm():
     net, _ = tvm.relay.testing.lstm.get_workload(4, 4)
     astext(net)
 
+
 def test_inception_v3():
     net, _ = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
     astext(net)
 
 
 def test_squeezenet():
-    for version in ['1.0', '1.1']:
+    for version in ["1.0", "1.1"]:
         net, _ = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
         astext(net)
 
@@ -191,29 +192,27 @@ def test_densenet():
 def test_call_node_order():
     x = relay.var("x")
     y = relay.var("y")
-    prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])])
-    assert astext(prog) == SEMVER + \
-        ("%0 = fn (%y) {\n"
-         "  %y\n"
-         "};\n"
-         "%1 = %0(1);\n"
-         "%2 = fn (%x) {\n"
-         "  %x\n"
-         "};\n"
-         "%2(%1)")
+    prog = relay.Call(
+        relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]
+    )
+    assert astext(prog) == SEMVER + (
+        "%0 = fn (%y) {\n"
+        "  %y\n"
+        "};\n"
+        "%1 = %0(1);\n"
+        "%2 = fn (%x) {\n"
+        "  %x\n"
+        "};\n"
+        "%2(%1)"
+    )
 
 
 def test_let_inlining():
     tup = relay.Tuple([relay.const(0), relay.const(0)])
     x = relay.var("x")
-    assert astext(relay.Let(x, tup, tup)) == SEMVER + \
-        ("%0 = (0, 0);\n"
-         "let %x = %0;\n"
-         "%0")
+    assert astext(relay.Let(x, tup, tup)) == SEMVER + ("%0 = (0, 0);\n" "let %x = %0;\n" "%0")
 
-    assert astext(relay.Let(x, tup, x)) == SEMVER + \
-        ("let %x = (0, 0);\n"
-         "%x")
+    assert astext(relay.Let(x, tup, x)) == SEMVER + ("let %x = (0, 0);\n" "%x")
 
 
 def test_zeros():
@@ -252,4 +251,5 @@ def test_null_attribute():
 
 if __name__ == "__main__":
     import sys
+
     pytext.argv(sys.argv)
index db953d5..53333d1 100644 (file)
@@ -20,6 +20,7 @@ from tvm import relay
 from tvm.relay.analysis import well_formed
 from tvm.relay.prelude import Prelude
 
+
 def test_let():
     x = relay.Var("x")
     assert well_formed(x)
@@ -30,9 +31,7 @@ def test_let():
     assert not well_formed(relay.Let(x, v, let))
     f = relay.Function([x], x, ty)
     assert well_formed(f)
-    assert well_formed(
-        relay.Let(relay.Var("y"), f,
-                  relay.Let(relay.Var("z"), f, v)))
+    assert well_formed(relay.Let(relay.Var("y"), f, relay.Let(relay.Var("z"), f, v)))
 
 
 def test_tuple():
@@ -54,15 +53,14 @@ def test_adt():
     mod = tvm.IRModule()
     p = Prelude(mod)
     x = relay.Var("x")
-    some_case = relay.Clause(relay.PatternConstructor(p.some,
-                                                      [relay.PatternVar(x)]),
-                             x)
+    some_case = relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), x)
     default_case = relay.Clause(relay.PatternVar(x), x)
     m0 = relay.Match(p.none(), [default_case])
     m1 = relay.Match(p.none(), [some_case, default_case])
     assert well_formed(m0)
     assert not well_formed(m1)
 
+
 if __name__ == "__main__":
     test_let()
     test_tuple()
index be3e2a0..2b519a7 100644 (file)
@@ -20,17 +20,16 @@ from tvm import relay
 from tvm import te
 import json
 
+
 def test_type_var():
     # type var in 0.6
     nodes = [
         {"type_key": ""},
-        {"type_key": "relay.TypeVar",
-         "attrs": {"kind": "0", "span": "0", "var": "2"}},
-        {"type_key": "Variable",
-         "attrs": {"dtype": "int32", "name": "in0"}},
-        ]
+        {"type_key": "relay.TypeVar", "attrs": {"kind": "0", "span": "0", "var": "2"}},
+        {"type_key": "Variable", "attrs": {"dtype": "int32", "name": "in0"}},
+    ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -43,37 +42,23 @@ def test_type_var():
     assert isinstance(tvar, tvm.ir.GlobalTypeVar)
     assert tvar.name_hint == "in0"
 
+
 def test_var():
     # type var in 0.6
     nodes = [
         {"type_key": ""},
-        {"type_key": "relay.Var",
-         "attrs": {
-             "_checked_type_": "0",
-             "span": "0",
-             "type_annotation": "0",
-             "vid": "2"
-         }
-        },
-        {"type_key": "relay.Id",
-         "attrs": {"name_hint": "a3"}},
-        {"type_key": "relay.TensorType",
-         "attrs": {
-             "dtype": "float32",
-             "shape": "4",
-             "span": "0"
-         }
+        {
+            "type_key": "relay.Var",
+            "attrs": {"_checked_type_": "0", "span": "0", "type_annotation": "0", "vid": "2"},
         },
-        {"type_key": "Array",
-         "data": [5, 6]
-        },
-        {"type_key": "IntImm",
-         "attrs": {"dtype": "int32", "value": "16"}},
-        {"type_key": "IntImm",
-         "attrs": {"dtype": "int32", "value": "8"}}
-        ]
+        {"type_key": "relay.Id", "attrs": {"name_hint": "a3"}},
+        {"type_key": "relay.TensorType", "attrs": {"dtype": "float32", "shape": "4", "span": "0"}},
+        {"type_key": "Array", "data": [5, 6]},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "16"}},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "8"}},
+    ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -82,13 +67,14 @@ def test_var():
     assert isinstance(tvar, relay.Var)
     assert tvar.name_hint == "a3"
 
+
 def test_incomplete_type():
     nodes = [
         {"type_key": ""},
-        {"type_key": "relay.IncompleteType",
-         "attrs": {"kind": "0", "span": "0"}}]
+        {"type_key": "relay.IncompleteType", "attrs": {"kind": "0", "span": "0"}},
+    ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -100,24 +86,24 @@ def test_incomplete_type():
 def test_func_tuple_type():
     nodes = [
         {"type_key": ""},
-        {"type_key": "relay.FuncType",
-         "attrs": {
-             "arg_types": "2",
-             "ret_type": "3",
-             "span": "0",
-             "type_constraints": "6",
-             "type_params": "5"
-         }
+        {
+            "type_key": "relay.FuncType",
+            "attrs": {
+                "arg_types": "2",
+                "ret_type": "3",
+                "span": "0",
+                "type_constraints": "6",
+                "type_params": "5",
+            },
         },
         {"type_key": "Array"},
-        {"type_key": "relay.TupleType",
-         "attrs": { "fields": "4", "span": "0" }},
+        {"type_key": "relay.TupleType", "attrs": {"fields": "4", "span": "0"}},
+        {"type_key": "Array"},
         {"type_key": "Array"},
         {"type_key": "Array"},
-        {"type_key": "Array"}
     ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -129,16 +115,13 @@ def test_func_tuple_type():
 def test_global_var():
     nodes = [
         {"type_key": ""},
-        {"type_key": "relay.GlobalVar",
-         "attrs": {
-             "_checked_type_": "0",
-             "name_hint": "x",
-             "span": "0"
-         }
-        }
+        {
+            "type_key": "relay.GlobalVar",
+            "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"},
+        },
     ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -147,16 +130,10 @@ def test_global_var():
     assert isinstance(tvar, tvm.ir.GlobalVar)
     nodes = [
         {"type_key": ""},
-        {"type_key": "GlobalVar",
-         "attrs": {
-             "_checked_type_": "0",
-             "name_hint": "x",
-             "span": "0"
-         }
-        }
+        {"type_key": "GlobalVar", "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}},
     ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -166,13 +143,9 @@ def test_global_var():
 
 
 def test_op():
-    nodes = [
-        {"type_key": ""},
-        {"type_key": "relay.Op",
-         "global_key": "nn.conv2d"}
-    ]
+    nodes = [{"type_key": ""}, {"type_key": "relay.Op", "global_key": "nn.conv2d"}]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -184,13 +157,11 @@ def test_op():
 def test_tir_var():
     nodes = [
         {"type_key": ""},
-        {"type_key": "Variable",
-         "attrs": {"dtype": "int32", "name": "x"}},
-        {"type_key": "SizeVar",
-         "attrs": {"dtype": "int32", "name": "y"}},
+        {"type_key": "Variable", "attrs": {"dtype": "int32", "name": "x"}},
+        {"type_key": "SizeVar", "attrs": {"dtype": "int32", "name": "y"}},
     ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
@@ -206,30 +177,30 @@ def test_tir_var():
 
 def test_str_map():
     nodes = [
-        {'type_key': ''},
-        {'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]},
-        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
-        {'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}},
-        {'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}},
-        {'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}},
-        {'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}},
-        {'type_key': 'runtime.String', 'repr_str': 'x'},
-        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}},
-        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
-        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}}
+        {"type_key": ""},
+        {"type_key": "StrMap", "keys": ["z", "x"], "data": [2, 3]},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "2"}},
+        {"type_key": "Max", "attrs": {"a": "4", "b": "10", "dtype": "int32"}},
+        {"type_key": "Add", "attrs": {"a": "5", "b": "9", "dtype": "int32"}},
+        {"type_key": "Add", "attrs": {"a": "6", "b": "8", "dtype": "int32"}},
+        {"type_key": "tir.Var", "attrs": {"dtype": "int32", "name": "7", "type_annotation": "0"}},
+        {"type_key": "runtime.String", "repr_str": "x"},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "1"}},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "2"}},
+        {"type_key": "IntImm", "attrs": {"dtype": "int32", "value": "100"}},
     ]
     data = {
-        "root" : 1,
+        "root": 1,
         "nodes": nodes,
         "attrs": {"tvm_version": "0.6.0"},
         "b64ndarrays": [],
     }
     x = tvm.ir.load_json(json.dumps(data))
-    assert(isinstance(x, tvm.ir.container.Map))
-    assert(len(x) == 2)
-    assert('x' in x)
-    assert('z' in x)
-    assert(bool(x['z'] == 2))
+    assert isinstance(x, tvm.ir.container.Map)
+    assert len(x) == 2
+    assert "x" in x
+    assert "z" in x
+    assert bool(x["z"] == 2)
 
 
 if __name__ == "__main__":
index cf3b2b2..ef56716 100644 (file)
@@ -39,14 +39,9 @@ def set_func_attr(func, compile_name, symbol_name):
     return func
 
 
-def check_result(mod,
-                 ref_mod,
-                 map_inputs,
-                 out_shape,
-                 tol=1e-5,
-                 target="llvm",
-                 ctx=tvm.cpu(),
-                 params=None):
+def check_result(
+    mod, ref_mod, map_inputs, out_shape, tol=1e-5, target="llvm", ctx=tvm.cpu(), params=None
+):
     if sys.platform == "win32":
         print("Skip test on Windows for now")
         return
@@ -100,7 +95,7 @@ def test_conv2d():
         return
 
     def conv2d_direct():
-        dtype = 'float32'
+        dtype = "float32"
         ishape = (1, 32, 14, 14)
         w1shape = (32, 32, 3, 3)
 
@@ -124,7 +119,7 @@ def test_conv2d():
         out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1))
         main_f = relay.Function([data0, weight0], out)
         ref_mod = tvm.IRModule()
-        ref_mod['main'] = main_f
+        ref_mod["main"] = main_f
 
         i_data = np.random.uniform(0, 1, ishape).astype(dtype)
         w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
@@ -132,7 +127,7 @@ def test_conv2d():
         return mod, ref_mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14)
 
     def group_conv2d():
-        dtype = 'float32'
+        dtype = "float32"
         ishape = (1, 32, 14, 14)
         w2shape = (32, 1, 3, 3)
 
@@ -156,7 +151,7 @@ def test_conv2d():
         out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32)
         main_f = relay.Function([data0, weight0], out)
         ref_mod = tvm.IRModule()
-        ref_mod['main'] = main_f
+        ref_mod["main"] = main_f
 
         i_data = np.random.uniform(0, 1, ishape).astype(dtype)
         w_data = np.random.uniform(0, 1, w2shape).astype(dtype)
@@ -173,7 +168,7 @@ def test_add():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     shape = (10, 10)
 
     def gen_add():
@@ -214,7 +209,7 @@ def test_relu():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     shape = (1, 32, 14, 14)
 
     def gen_relu():
@@ -242,7 +237,15 @@ def test_relu():
     mod, ref_mod = gen_relu()
 
     data0 = np.random.uniform(-1, 1, shape).astype(dtype)
-    check_result(mod, ref_mod, {"data0": data0,}, (1, 32, 14, 14), tol=1e-5)
+    check_result(
+        mod,
+        ref_mod,
+        {
+            "data0": data0,
+        },
+        (1, 32, 14, 14),
+        tol=1e-5,
+    )
 
 
 def test_dense():
@@ -251,7 +254,7 @@ def test_dense():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     a_shape = (1, 512)
     b_shape = (1024, 512)
 
@@ -293,12 +296,12 @@ def test_bn():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     d_shape = (1, 8)
-    c_shape = (8, )
+    c_shape = (8,)
 
     def gen_bn():
-        data = relay.var('data', shape=d_shape)
+        data = relay.var("data", shape=d_shape)
         gamma = relay.var("gamma", shape=c_shape)
         beta = relay.var("beta", shape=c_shape)
         moving_mean = relay.var("moving_mean", shape=c_shape)
@@ -312,16 +315,18 @@ def test_bn():
         mod = tvm.IRModule()
         mod[glb_var] = func
 
-        data = relay.var('data', shape=d_shape)
+        data = relay.var("data", shape=d_shape)
         gamma = relay.var("gamma", shape=c_shape)
         beta = relay.var("beta", shape=c_shape)
         moving_mean = relay.var("moving_mean", shape=c_shape)
         moving_var = relay.var("moving_var", shape=c_shape)
-        main_f = relay.Function([data, gamma, beta, moving_mean, moving_var],
-                                glb_var(data, gamma, beta, moving_mean, moving_var))
+        main_f = relay.Function(
+            [data, gamma, beta, moving_mean, moving_var],
+            glb_var(data, gamma, beta, moving_mean, moving_var),
+        )
         mod["main"] = main_f
 
-        data = relay.var('data', shape=d_shape)
+        data = relay.var("data", shape=d_shape)
         gamma = relay.var("gamma", shape=c_shape)
         beta = relay.var("beta", shape=c_shape)
         moving_mean = relay.var("moving_mean", shape=c_shape)
@@ -341,16 +346,19 @@ def test_bn():
     beta = np.random.uniform(-1, 1, c_shape).astype(dtype)
     moving_mean = np.random.uniform(-1, 1, c_shape).astype(dtype)
     moving_var = np.random.uniform(-1, 1, c_shape).astype(dtype)
-    check_result(mod,
-                 ref_mod, {
-                     "data": data,
-                     "gamma": gamma,
-                     "beta": beta,
-                     "moving_mean": moving_mean,
-                     "moving_var": moving_var
-                 },
-                 d_shape,
-                 tol=1e-5)
+    check_result(
+        mod,
+        ref_mod,
+        {
+            "data": data,
+            "gamma": gamma,
+            "beta": beta,
+            "moving_mean": moving_mean,
+            "moving_var": moving_var,
+        },
+        d_shape,
+        tol=1e-5,
+    )
 
 
 def test_multiple_ops():
@@ -359,7 +367,7 @@ def test_multiple_ops():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 32, 14, 14)
     w1shape = (32, 32, 3, 3)
     w2shape = (64, 32, 5, 5)
@@ -380,18 +388,22 @@ def test_multiple_ops():
         return mod
 
     def get_partitoned_mod(mod):
-        remove_bn_pass = tvm.transform.Sequential([
-            transform.InferType(),
-            transform.SimplifyInference(),
-            transform.FoldConstant(),
-            transform.FoldScaleAxis(),
-        ])
-        byoc_pass = tvm.transform.Sequential([
-            remove_bn_pass,
-            transform.AnnotateTarget("dnnl"),
-            transform.MergeCompilerRegions(),
-            transform.PartitionGraph()
-        ])
+        remove_bn_pass = tvm.transform.Sequential(
+            [
+                transform.InferType(),
+                transform.SimplifyInference(),
+                transform.FoldConstant(),
+                transform.FoldScaleAxis(),
+            ]
+        )
+        byoc_pass = tvm.transform.Sequential(
+            [
+                remove_bn_pass,
+                transform.AnnotateTarget("dnnl"),
+                transform.MergeCompilerRegions(),
+                transform.PartitionGraph(),
+            ]
+        )
 
         with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             return byoc_pass(mod)
@@ -402,11 +414,17 @@ def test_multiple_ops():
     data = np.random.uniform(0, 1, ishape).astype(dtype)
     w1 = np.random.uniform(0, 1, w1shape).astype(dtype)
     w2 = np.random.uniform(0, 1, w2shape).astype(dtype)
-    check_result(mod, ref_mod, {
-        "data": data,
-        "w1": w1,
-        "w2": w2,
-    }, (1, 64, 14, 14), tol=1e-5)
+    check_result(
+        mod,
+        ref_mod,
+        {
+            "data": data,
+            "w1": w1,
+            "w2": w2,
+        },
+        (1, 64, 14, 14),
+        tol=1e-5,
+    )
 
 
 def test_composite():
@@ -415,7 +433,7 @@ def test_composite():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
 
     def conv2d_relu():
         ishape = (1, 32, 14, 14)
@@ -427,8 +445,8 @@ def test_composite():
         conv2d = relay.nn.conv2d(in_1, in_2, kernel_size=(3, 3), padding=(1, 1))
         relu = relay.nn.relu(conv2d)
         func = relay.Function([in_1, in_2], relu)
-        func = func.with_attr('Composite', 'dnnl.conv2d_relu')
-        func = func.with_attr('PartitionedFromPattern', 'nn.conv2d_nn.relu_')
+        func = func.with_attr("Composite", "dnnl.conv2d_relu")
+        func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
 
         # Partition function
         arg_1 = relay.var("arg_1", shape=ishape, dtype=dtype)
@@ -458,7 +476,7 @@ def test_composite():
         i_data = np.random.uniform(0, 1, ishape).astype(dtype)
         w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
 
-        return mod, ref_mod, {'data': i_data, 'weight': w1_data}, (1, 32, 14, 14)
+        return mod, ref_mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14)
 
     def conv2d_bias_relu():
         ishape = (1, 32, 14, 14)
@@ -473,8 +491,8 @@ def test_composite():
         add = relay.add(conv2d, in_3)
         relu = relay.nn.relu(add)
         func = relay.Function([in_1, in_2, in_3], relu)
-        func = func.with_attr('Composite', 'dnnl.conv2d_bias_relu')
-        func = func.with_attr('PartitionedFromPattern', 'nn.conv2d_add_nn.relu_')
+        func = func.with_attr("Composite", "dnnl.conv2d_bias_relu")
+        func = func.with_attr("PartitionedFromPattern", "nn.conv2d_add_nn.relu_")
 
         # Partition function
         arg_1 = relay.var("arg_1", shape=ishape, dtype=dtype)
@@ -490,14 +508,14 @@ def test_composite():
         # Main function
         data = relay.var("data", shape=ishape, dtype=dtype)
         weight = relay.var("weight", shape=w1shape, dtype=dtype)
-        bias = relay.var('bias', shape=bshape, dtype=dtype)
+        bias = relay.var("bias", shape=bshape, dtype=dtype)
         main_func = relay.Function([data, weight, bias], glb_var(data, weight, bias))
         mod["main"] = main_func
 
         # Reference module
         data = relay.var("data", shape=ishape, dtype=dtype)
         weight = relay.var("weight", shape=w1shape, dtype=dtype)
-        bias = relay.var('bias', shape=bshape, dtype=dtype)
+        bias = relay.var("bias", shape=bshape, dtype=dtype)
         conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1))
         add = relay.add(conv2d, bias)
         relu = relay.nn.relu(add)
@@ -509,7 +527,7 @@ def test_composite():
         w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
         b_data = np.random.uniform(0, 1, bshape).astype(dtype)
 
-        return mod, ref_mod, {'data': i_data, 'weight': w1_data, 'bias': b_data}, (1, 32, 14, 14)
+        return mod, ref_mod, {"data": i_data, "weight": w1_data, "bias": b_data}, (1, 32, 14, 14)
 
     for mod, ref_mod, input_maps, out_shape in [conv2d_relu(), conv2d_bias_relu()]:
         check_result(mod, ref_mod, input_maps, out_shape, tol=1e-5)
@@ -521,7 +539,7 @@ def test_constant():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 32, 14, 14)
     wshape = (32, 32, 3, 3)
 
@@ -541,26 +559,31 @@ def test_constant():
     ref_mod, params = tvm.relay.testing.create_workload(func)
     ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)
 
-    remove_bn_pass = tvm.transform.Sequential([
-        transform.InferType(),
-        transform.SimplifyInference(),
-        transform.FoldConstant(),
-        transform.FoldScaleAxis(),
-    ])
+    remove_bn_pass = tvm.transform.Sequential(
+        [
+            transform.InferType(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.FoldScaleAxis(),
+        ]
+    )
 
     dnnl_patterns = get_pattern_table("dnnl")
-    composite_partition = tvm.transform.Sequential([
-        transform.MergeComposite(dnnl_patterns),
-        transform.AnnotateTarget("dnnl"),
-        transform.PartitionGraph()
-    ])
+    composite_partition = tvm.transform.Sequential(
+        [
+            transform.MergeComposite(dnnl_patterns),
+            transform.AnnotateTarget("dnnl"),
+            transform.PartitionGraph(),
+        ]
+    )
 
     with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
         ref_mod = remove_bn_pass(ref_mod)
         mod = composite_partition(ref_mod)
 
     i_data = np.random.uniform(0, 1, ishape).astype(dtype)
-    check_result(mod, ref_mod, {'data': i_data}, (1, 32, 14, 14), tol=1e-5)
+    check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5)
+
 
 def test_partial_constant():
     """Test the subgraph with (const, var, const, var) arguments."""
@@ -568,7 +591,7 @@ def test_partial_constant():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (10, 10)
 
     in_1 = relay.var("in_1", shape=ishape, dtype=dtype)
@@ -589,27 +612,29 @@ def test_partial_constant():
     data3 = np.random.uniform(0, 1, ishape).astype(dtype)
 
     params = {
-        'in_1': tvm.nd.array(data1, ctx=tvm.cpu(0)),
-        'in_3': tvm.nd.array(data3, ctx=tvm.cpu(0))
+        "in_1": tvm.nd.array(data1, ctx=tvm.cpu(0)),
+        "in_3": tvm.nd.array(data3, ctx=tvm.cpu(0)),
     }
     ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)
 
-    opt_pass = tvm.transform.Sequential([
-        transform.InferType(),
-        transform.SimplifyInference(),
-        transform.FoldConstant(),
-        transform.FoldScaleAxis(),
-        transform.AnnotateTarget("dnnl"),
-        transform.MergeCompilerRegions(),
-        transform.PartitionGraph()
-    ])
+    opt_pass = tvm.transform.Sequential(
+        [
+            transform.InferType(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.FoldScaleAxis(),
+            transform.AnnotateTarget("dnnl"),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
 
     with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
         mod = opt_pass(ref_mod)
 
     data2 = np.random.uniform(0, 1, ishape).astype(dtype)
     data4 = np.random.uniform(0, 1, ishape).astype(dtype)
-    check_result(mod, ref_mod, {'in_2': data2, 'in_4': data4}, (10, 10), tol=1e-5)
+    check_result(mod, ref_mod, {"in_2": data2, "in_4": data4}, (10, 10), tol=1e-5)
 
 
 if __name__ == "__main__":
index dc16865..c960d1f 100644 (file)
@@ -20,6 +20,7 @@ import numpy as np
 from tvm import relay
 from tvm.relay import memory_alloc
 
+
 def check_memory_plan(func, check_fn):
     # Build Module
     mod = tvm.IRModule().from_expr(func)
@@ -33,48 +34,55 @@ def check_memory_plan(func, check_fn):
         args.append(tvm.nd.array(data))
 
     # Compute without memory planning.
-    ex = relay.create_executor('vm', mod)
-    no_plan_result = ex.evaluate(mod['main'])(*args)
+    ex = relay.create_executor("vm", mod)
+    no_plan_result = ex.evaluate(mod["main"])(*args)
 
     # Compute with memory planning.
     with tvm.transform.PassContext(opt_level=1, disabled_pass=["MemoryPlan"]):
-        plan_result = ex.evaluate(mod['main'])(*args)
+        plan_result = ex.evaluate(mod["main"])(*args)
 
     # Compute Python result.
     py_res = check_fn(*[arg.asnumpy() for arg in args])
 
     # First check that the two VM results agree.
-    np.testing.assert_allclose(
-        no_plan_result.asnumpy(),
-        plan_result.asnumpy())
+    np.testing.assert_allclose(no_plan_result.asnumpy(), plan_result.asnumpy())
 
     # Finally check that the results match the Python result.
     np.testing.assert_allclose(plan_result.asnumpy(), py_res)
 
+
 def storage_type(mod):
     return relay.TypeCall(mod.get_global_type_var("Storage"), [])
 
+
 def test_tyck_alloc_storage():
     mod = tvm.IRModule()
     mod.import_from_std("core.rly")
 
+
 def test_tyck_alloc_tensor():
     mod = tvm.IRModule()
     mod.import_from_std("core.rly")
     sto = relay.Var("x", storage_type(mod))
     sh = relay.const(np.array([1, 2]), dtype="int64")
     at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh)
-    mod['main'] = relay.Function([sto], at)
+    mod["main"] = relay.Function([sto], at)
     relay.transform.InferType()(mod)
 
 
 def check_add(x):
     return x + x
 
+
 def test_add():
-    x = relay.var('x', shape=(2,))
+    x = relay.var("x", shape=(2,))
     z = x + x
-    func = relay.Function([x,], z)
+    func = relay.Function(
+        [
+            x,
+        ],
+        z,
+    )
     check_memory_plan(func, check_add)
 
 
@@ -84,26 +92,29 @@ def check_add_sub(x, y):
 
 
 def test_add_sub():
-    x = relay.var('x', shape=(10,))
-    y = relay.var('y', shape=(10,))
+    x = relay.var("x", shape=(10,))
+    y = relay.var("y", shape=(10,))
     z = x + x
     z = z - y
     func = relay.Function([x, y], z)
     check_memory_plan(func, check_add_sub)
 
+
 def check_no_fuse(x, y, w):
     z = x + y
     return np.matmul(z, np.transpose(w))
 
+
 def test_no_fuse():
-    x = relay.var('x', shape=(5, 1))
-    y = relay.var('y', shape=(5, 1))
-    w = relay.var('w', shape=(5, 1))
+    x = relay.var("x", shape=(5, 1))
+    y = relay.var("y", shape=(5, 1))
+    w = relay.var("w", shape=(5, 1))
     z = x + y
     out = relay.op.nn.dense(z, w)
     func = relay.Function([x, y, w], out)
     check_memory_plan(func, check_no_fuse)
 
+
 if __name__ == "__main__":
     test_tyck_alloc_tensor()
     test_add()
index cb95955..d067ab5 100644 (file)
@@ -34,7 +34,7 @@ def test_fastmath():
         func = relay.Function([x], y)
         mod = tvm.IRModule.from_expr(func)
 
-        with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']):
+        with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
             graph, lib, params = relay.build(mod, target="llvm", params=None)
 
         # Check that the op related to fast math have been convered to function in lib
@@ -44,14 +44,13 @@ def test_fastmath():
         ctx = tvm.cpu(0)
         m = graph_runtime.create(graph, lib, ctx)
         # Set inputs
-        m.set_input('x', tvm.nd.array(a_np, ctx))
+        m.set_input("x", tvm.nd.array(a_np, ctx))
         m.set_input(**params)
         # Execute
         m.run()
         # Get outputs
         tvm_output = m.get_output(0)
-        tvm.testing.assert_allclose(tvm_output.asnumpy(), b_np,
-                                    rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(tvm_output.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
     test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
     test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
index 3847e18..c0270ea 100644 (file)
@@ -53,31 +53,33 @@ def test_unary_op():
 
             for target, ctx in tvm.testing.enabled_targets():
                 intrp = relay.create_executor(ctx=ctx, target=target)
-                op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+                op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
                 np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
-    for opfunc, ref in [(tvm.relay.log, lambda x: 1 / x),
-                        (tvm.relay.exp, np.exp),
-                        (tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
-                        (tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
-                        (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
-                        (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
-                        (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
-                        (tvm.relay.erf, lambda x: 2.0 / (np.pi**(0.5)) * np.exp(-x * x)),
-                        (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
-                        (tvm.relay.sin, lambda x: np.cos(x)),
-                        (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
-                        (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
-                        (tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
-                        (tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
-                        (tvm.relay.cosh, lambda x: np.sinh(x)),
-                        (tvm.relay.sinh, lambda x: np.cosh(x)),
-                        (tvm.relay.asin, lambda x: 1. / (1. - x**2) ** (1./2.)),
-                        (tvm.relay.acos, lambda x: -1. / (1. - x**2.) ** (1./2.)),
-                        (tvm.relay.acosh, lambda x: 1./ (x**2 - 1.)**(1./2.)),
-                        (tvm.relay.asinh, lambda x: 1./ (x**2 + 1.)**(1./2.)),
-                        (tvm.relay.atanh, lambda x: -1./ (x**2 - 1.))]:
-        for dtype in ('float32', 'float64'):
+    for opfunc, ref in [
+        (tvm.relay.log, lambda x: 1 / x),
+        (tvm.relay.exp, np.exp),
+        (tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
+        (tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
+        (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
+        (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
+        (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
+        (tvm.relay.erf, lambda x: 2.0 / (np.pi ** (0.5)) * np.exp(-x * x)),
+        (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
+        (tvm.relay.sin, lambda x: np.cos(x)),
+        (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
+        (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
+        (tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
+        (tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
+        (tvm.relay.cosh, lambda x: np.sinh(x)),
+        (tvm.relay.sinh, lambda x: np.cosh(x)),
+        (tvm.relay.asin, lambda x: 1.0 / (1.0 - x ** 2) ** (1.0 / 2.0)),
+        (tvm.relay.acos, lambda x: -1.0 / (1.0 - x ** 2.0) ** (1.0 / 2.0)),
+        (tvm.relay.acosh, lambda x: 1.0 / (x ** 2 - 1.0) ** (1.0 / 2.0)),
+        (tvm.relay.asinh, lambda x: 1.0 / (x ** 2 + 1.0) ** (1.0 / 2.0)),
+        (tvm.relay.atanh, lambda x: -1.0 / (x ** 2 - 1.0)),
+    ]:
+        for dtype in ("float32", "float64"):
             check_single_op(opfunc, ref, dtype)
 
 
@@ -106,11 +108,13 @@ def test_binary_op():
             np.testing.assert_allclose(op_grad0.asnumpy(), ref_grad0, rtol=0.01)
             np.testing.assert_allclose(op_grad1.asnumpy(), ref_grad1, rtol=0.01)
 
-    for opfunc, ref in [(relay.add, lambda x, y: [np.ones_like(x), np.ones_like(y)]),
-                        (relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]),
-                        (relay.multiply, lambda x, y: [y, x]),
-                        (relay.divide, lambda x, y: [1 / y, - x / (y**2)])]:
-        for dtype in ('float32', 'float64'):
+    for opfunc, ref in [
+        (relay.add, lambda x, y: [np.ones_like(x), np.ones_like(y)]),
+        (relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]),
+        (relay.multiply, lambda x, y: [y, x]),
+        (relay.divide, lambda x, y: [1 / y, -x / (y ** 2)]),
+    ]:
+        for dtype in ("float32", "float64"):
             check_binary_op(opfunc, ref, dtype)
 
 
index b8624b4..462a752 100644 (file)
@@ -21,31 +21,42 @@ from tvm.relay.testing import check_grad
 
 
 def test_cross_entropy_grad():
-    for dtype in ('float32', 'float64'):
+    for dtype in ("float32", "float64"):
         x = relay.var("x", shape=(2, 5), dtype=dtype)
         y = relay.var("y", shape=(2, 5), dtype=dtype)
-        check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
+        check_grad(
+            relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1
+        )
 
 
 def test_cross_entropy_with_logits_grad():
-    for dtype in ('float32', 'float64'):
+    for dtype in ("float32", "float64"):
         x = relay.var("x", shape=(2, 5), dtype=dtype)
         y = relay.var("y", shape=(2, 5), dtype=dtype)
-        check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
+        check_grad(
+            relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)),
+            eps=0.01,
+            scale=0.1,
+            mean=1,
+        )
 
 
 def test_checkpoint():
     inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
-    output = relay.multiply(relay.add(inputs[0], inputs[1]),
-                            relay.add(inputs[2], inputs[3]))
+    output = relay.multiply(relay.add(inputs[0], inputs[1]), relay.add(inputs[2], inputs[3]))
     check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
 
     scope = relay.ScopeBuilder()
-    out_tuple = scope.let("out_tuple",
-                          relay.Tuple([relay.add(inputs[0], inputs[1]),
-                                       relay.multiply(inputs[2], inputs[3])]))
-    scope.ret(relay.subtract(relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)),
-                                relay.TupleGetItem(out_tuple, 1)))
+    out_tuple = scope.let(
+        "out_tuple",
+        relay.Tuple([relay.add(inputs[0], inputs[1]), relay.multiply(inputs[2], inputs[3])]),
+    )
+    scope.ret(
+        relay.subtract(
+            relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)),
+            relay.TupleGetItem(out_tuple, 1),
+        )
+    )
     out_single = scope.get()
     check_grad(relay.Function(inputs, out_single))
 
index 396e43d..34bbf9e 100644 (file)
@@ -28,8 +28,9 @@ import tvm.testing
 
 def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
     x = relay.var("x", relay.TensorType(x_shape, "float32"))
-    y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
-                                ceil_mode=ceil_mode)
+    y = tvm.relay.nn.max_pool2d(
+        x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode
+    )
 
     fwd_func = relay.Function([x], y)
     fwd_func = run_infer_type(fwd_func)
@@ -40,26 +41,41 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
     y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
     out_grad = np.ones(shape=y_shape)
     ref_grad = tvm.topi.testing.pool_grad_nchw(
-        data, out_grad, pool_size=pool_size, strides=strides,
+        data,
+        out_grad,
+        pool_size=pool_size,
+        strides=strides,
         padding=[ph, pw, ph, pw],
-        pool_type='max', ceil_mode=ceil_mode)
+        pool_type="max",
+        ceil_mode=ceil_mode,
+    )
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp = relay.create_executor(ctx=ctx, target=target)
-        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+        op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
         np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
 
 @tvm.testing.uses_gpu
 def test_max_pool2d_grad():
-    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
-    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
+    verify_max_pool2d_grad(
+        (1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False
+    )
+    verify_max_pool2d_grad(
+        (1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False
+    )
 
 
 def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
     x = relay.var("x", relay.TensorType(x_shape, "float32"))
-    y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
-                                ceil_mode=ceil_mode, count_include_pad=count_include_pad)
+    y = tvm.relay.nn.avg_pool2d(
+        x,
+        pool_size=pool_size,
+        strides=strides,
+        padding=padding,
+        ceil_mode=ceil_mode,
+        count_include_pad=count_include_pad,
+    )
 
     fwd_func = relay.Function([x], y)
     fwd_func = run_infer_type(fwd_func)
@@ -70,21 +86,39 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
     y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
     out_grad = np.ones(shape=y_shape)
     ref_grad = tvm.topi.testing.pool_grad_nchw(
-        data, out_grad, pool_size=pool_size, strides=strides,
+        data,
+        out_grad,
+        pool_size=pool_size,
+        strides=strides,
         padding=[ph, pw, ph, pw],
-        pool_type='avg', ceil_mode=ceil_mode)
+        pool_type="avg",
+        ceil_mode=ceil_mode,
+    )
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp = relay.create_executor(ctx=ctx, target=target)
-        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+        op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
         np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
+
 @tvm.testing.uses_gpu
 def test_avg_pool2d_grad():
-    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
-                           ceil_mode=False, count_include_pad=True)
-    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
-                           ceil_mode=False, count_include_pad=False)
+    verify_avg_pool2d_grad(
+        (1, 4, 16, 16),
+        pool_size=(2, 2),
+        strides=(2, 2),
+        padding=(0, 0),
+        ceil_mode=False,
+        count_include_pad=True,
+    )
+    verify_avg_pool2d_grad(
+        (1, 4, 16, 16),
+        pool_size=(1, 1),
+        strides=(1, 1),
+        padding=(1, 1),
+        ceil_mode=False,
+        count_include_pad=False,
+    )
 
 
 def verify_global_avg_pool2d_grad(x_shape):
@@ -99,49 +133,77 @@ def verify_global_avg_pool2d_grad(x_shape):
     y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
     out_grad = np.ones(shape=y_shape)
     ref_grad = tvm.topi.testing.pool_grad_nchw(
-        data, out_grad, pool_size=(x_shape[2], x_shape[3]),
-        strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
-        ceil_mode=False)
+        data,
+        out_grad,
+        pool_size=(x_shape[2], x_shape[3]),
+        strides=(1, 1),
+        padding=[0, 0, 0, 0],
+        pool_type="avg",
+        ceil_mode=False,
+    )
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp = relay.create_executor(ctx=ctx, target=target)
-        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+        op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
         np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
+
 @tvm.testing.uses_gpu
 def test_global_avg_pool2d_grad():
     verify_global_avg_pool2d_grad((1, 4, 16, 16))
     verify_global_avg_pool2d_grad((1, 8, 8, 24))
 
-def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
+
+def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode="higher_order"):
     try:
         import torch
         import torch.nn.functional as F
     except ImportError:
-        print('Skip because pytorch is not installed')
+        print("Skip because pytorch is not installed")
         return
 
-    dtype = 'float32'
-    data = relay.var('data', shape=dshape, dtype=dtype)
-    weight = relay.var('weight', shape=wshape, dtype=dtype)
-    conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation,
-                           groups=groups)
+    dtype = "float32"
+    data = relay.var("data", shape=dshape, dtype=dtype)
+    weight = relay.var("weight", shape=wshape, dtype=dtype)
+    conv = relay.nn.conv2d(
+        data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups
+    )
     fwd_func = relay.Function([data, weight], conv)
     fwd_func = run_infer_type(fwd_func)
     bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
 
     data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True)
     weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True)
-    out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation,
-                      groups=groups)
+    out_pt = F.conv2d(
+        data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation, groups=groups
+    )
     grad_output_pt = torch.ones(out_pt.shape)
-    grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides,
-                                        padding=padding, dilation=dilation, groups=groups) \
-                          .detach().numpy()
-    grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides,
-                                          padding=padding, dilation=dilation, groups=groups) \
-                           .detach().numpy()
-
+    grad_input_pt = (
+        F.grad.conv2d_input(
+            dshape,
+            weight_pt,
+            grad_output_pt,
+            stride=strides,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+        .detach()
+        .numpy()
+    )
+    grad_weight_pt = (
+        F.grad.conv2d_weight(
+            data_pt,
+            wshape,
+            grad_output_pt,
+            stride=strides,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+        .detach()
+        .numpy()
+    )
 
     for target, ctx in tvm.testing.enabled_targets():
         data = tvm.nd.array(data_pt.detach().numpy(), ctx)
@@ -157,7 +219,7 @@ def test_conv2d_grad():
     verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1])
     verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1])
     verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1])
-    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')
+    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode="first_order")
 
 
 def verify_dense_grad(d_shape, w_shape):
index a63ec6e..0b4f892 100644 (file)
@@ -27,9 +27,10 @@ import tvm.testing
 
 @tvm.testing.uses_gpu
 def test_clip():
-    for dtype in ('float32', 'float64'):
-        ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
-                         np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
+    for dtype in ("float32", "float64"):
+        ref = lambda x: np.where(
+            x > 10.0, np.zeros_like(x), np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))
+        )
         x = relay.var("x", relay.TensorType((10, 4), dtype))
         y = tvm.relay.clip(x, 1.0, 10.0)
 
@@ -41,7 +42,7 @@ def test_clip():
 
         for target, ctx in tvm.testing.enabled_targets():
             intrp = relay.create_executor(ctx=ctx, target=target)
-            op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+            op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
             np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
 
index b35ffe9..d479221 100644 (file)
@@ -38,7 +38,9 @@ def test_reduction_grad():
 
 def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
     data = relay.var("data", relay.TensorType(d_shape, "float32"))
-    fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
+    fwd_func = relay.Function(
+        [data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude)
+    )
     check_grad(fwd_func, scale=1e-3)
 
 
index 086a880..8c724da 100644 (file)
@@ -31,15 +31,18 @@ def sigmoid(x):
     one = np.ones_like(x)
     return one / (one + np.exp(-x))
 
+
 def relu(x):
     x_copy = np.copy(x)
     np.maximum(x_copy, 0, x_copy)
     return x_copy
 
+
 def rsqrt(x):
     one = np.ones_like(x)
     return one / np.sqrt(x)
 
+
 @tvm.testing.uses_gpu
 def test_unary_op():
     def check_single_op(opfunc, ref, dtype):
@@ -61,26 +64,31 @@ def test_unary_op():
             for target, ctx in tvm.testing.enabled_targets():
                 # use graph by execuor default for testing, as we need
                 # create function explicitly to avoid constant-folding.
-                if dtype ==  'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
+                if (
+                    dtype == "float16"
+                    and target == "cuda"
+                    and not have_fp16(tvm.gpu(0).compute_version)
+                ):
                     continue
                 intrp = relay.create_executor("graph", ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data)
                 np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
-
-    for opfunc, ref in [(tvm.relay.log, np.log),
-                        (tvm.relay.exp, np.exp),
-                        (tvm.relay.erf, scipy.special.erf),
-                        (tvm.relay.sqrt, np.sqrt),
-                        (tvm.relay.rsqrt, rsqrt),
-                        (tvm.relay.sigmoid, sigmoid),
-                        (tvm.relay.tanh, np.tanh),
-                        (relay.nn.relu, relu),
-                        (tvm.relay.cos, np.cos),
-                        (tvm.relay.sin, np.sin),
-                        (tvm.relay.tan, np.tan),
-                        (tvm.relay.atan, np.arctan)]:
-        for dtype in ['float16', 'float32']:
+    for opfunc, ref in [
+        (tvm.relay.log, np.log),
+        (tvm.relay.exp, np.exp),
+        (tvm.relay.erf, scipy.special.erf),
+        (tvm.relay.sqrt, np.sqrt),
+        (tvm.relay.rsqrt, rsqrt),
+        (tvm.relay.sigmoid, sigmoid),
+        (tvm.relay.tanh, np.tanh),
+        (relay.nn.relu, relu),
+        (tvm.relay.cos, np.cos),
+        (tvm.relay.sin, np.sin),
+        (tvm.relay.tan, np.tan),
+        (tvm.relay.atan, np.arctan),
+    ]:
+        for dtype in ["float16", "float32"]:
             check_single_op(opfunc, ref, dtype)
 
 
@@ -118,19 +126,25 @@ def test_binary_op():
             for target, ctx in tvm.testing.enabled_targets():
                 # use graph by execuor default for testing, as we need
                 # create function explicitly to avoid constant-folding.
-                if dtype ==  'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
+                if (
+                    dtype == "float16"
+                    and target == "cuda"
+                    and not have_fp16(tvm.gpu(0).compute_version)
+                ):
                     continue
                 intrp = relay.create_executor("graph", ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data, y_data)
                 np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
-    for opfunc, ref in [(relay.add, np.add),
-                        (relay.subtract, np.subtract),
-                        (relay.multiply, np.multiply),
-                        (relay.divide, np.divide),
-                        (relay.floor_divide, np.floor_divide),
-                        (relay.floor_mod, np.fmod)]:
-        for dtype in ['float16', 'float32']:
+    for opfunc, ref in [
+        (relay.add, np.add),
+        (relay.subtract, np.subtract),
+        (relay.multiply, np.multiply),
+        (relay.divide, np.divide),
+        (relay.floor_divide, np.floor_divide),
+        (relay.floor_mod, np.fmod),
+    ]:
+        for dtype in ["float16", "float32"]:
             check_binary_op(opfunc, ref, dtype)
 
 
@@ -141,24 +155,29 @@ def test_expand_dims():
         x = relay.Var("x", relay.TensorType(dshape, dtype))
         func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis))
         for target, ctx in tvm.testing.enabled_targets():
-            if dtype ==  'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
+            if (
+                dtype == "float16"
+                and target == "cuda"
+                and not have_fp16(tvm.gpu(0).compute_version)
+            ):
                 continue
             data = np.random.uniform(size=dshape).astype(dtype)
             ref_res = data.reshape(oshape)
             intrp = relay.create_executor("graph", ctx=ctx, target=target)
             op_res = intrp.evaluate(func)(data)
             np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
-    for dtype in ['float16', 'float32']:
+
+    for dtype in ["float16", "float32"]:
         verify_expand_dims((3, 10), dtype, (3, 10, 1, 1), 2, 2)
         verify_expand_dims((3, 10), dtype, (1, 3, 10), -3, 1)
 
 
 @tvm.testing.uses_gpu
 def test_bias_add():
-    for dtype in ['float16', 'float32']:
-        xshape=(10, 2, 3, 4)
-        bshape=(2,)
-        rtol = 1e-2 if dtype == 'float16' else 1e-5
+    for dtype in ["float16", "float32"]:
+        xshape = (10, 2, 3, 4)
+        bshape = (2,)
+        rtol = 1e-2 if dtype == "float16" else 1e-5
         x = relay.var("x", shape=xshape, dtype=dtype)
         bias = relay.var("bias", dtype=dtype)
         z = relay.nn.bias_add(x, bias)
@@ -171,7 +190,11 @@ def test_bias_add():
         y_data = np.random.uniform(size=bshape).astype(dtype)
         ref_res = x_data + y_data.reshape((2, 1, 1))
         for target, ctx in tvm.testing.enabled_targets():
-            if dtype ==  'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
+            if (
+                dtype == "float16"
+                and target == "cuda"
+                and not have_fp16(tvm.gpu(0).compute_version)
+            ):
                 continue
             intrp = relay.create_executor("graph", ctx=ctx, target=target)
             op_res = intrp.evaluate(func)(x_data, y_data)
@@ -179,7 +202,7 @@ def test_bias_add():
 
 
 def test_expand_dims_infer_type():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         n, t, d = te.size_var("n"), te.size_var("t"), 100
         x = relay.var("x", shape=(n, t, d), dtype=dtype)
         y = relay.expand_dims(x, axis=2)
@@ -190,9 +213,9 @@ def test_expand_dims_infer_type():
 
 @tvm.testing.uses_gpu
 def test_softmax():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         # Softmax accuracy for float16 is poor
-        if dtype == 'float16':
+        if dtype == "float16":
             return
         shape = (10, 4)
         x = relay.var("x", shape=shape, dtype=dtype)
@@ -211,9 +234,9 @@ def test_softmax():
 
 @tvm.testing.uses_gpu
 def test_log_softmax():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         # Softmax accuracy for float16 is poor
-        if dtype == 'float16':
+        if dtype == "float16":
             return
         shape = (10, 4)
         x = relay.var("x", shape=shape, dtype=dtype)
@@ -232,7 +255,7 @@ def test_log_softmax():
 
 @tvm.testing.uses_gpu
 def test_concatenate():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         n, t, d = te.size_var("n"), te.size_var("t"), 100
         x = relay.var("x", shape=(n, t, d))
         y = relay.var("y", shape=(n, t, d))
@@ -252,8 +275,8 @@ def test_concatenate():
 
         # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError.
         try:
-            x = relay.var('p1', shape=(2, 5))
-            y = relay.var('p2', shape=(2, 3))
+            x = relay.var("p1", shape=(2, 5))
+            y = relay.var("p2", shape=(2, 3))
             c = relay.concatenate([x, y], axis=0)
             func = relay.Function([x, y], c)
             zz = run_infer_type(func)
@@ -275,7 +298,11 @@ def test_concatenate():
         ref_res = np.concatenate((x_data, y_data), axis=1) + t_data
 
         for target, ctx in tvm.testing.enabled_targets():
-            if dtype ==  'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
+            if (
+                dtype == "float16"
+                and target == "cuda"
+                and not have_fp16(tvm.gpu(0).compute_version)
+            ):
                 continue
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
@@ -284,8 +311,9 @@ def test_concatenate():
             op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data)
             tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
 
+
 def test_dropout():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         n, t, d = te.size_var("n"), te.size_var("t"), te.size_var("d")
         input_ty = relay.TensorType((n, t, d), dtype)
         x = relay.var("x", input_ty)
@@ -296,36 +324,46 @@ def test_dropout():
 
 
 def test_batch_norm():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         # beta and gamma ignored
         data = relay.var("data", relay.TensorType((3, 2, 1), dtype))
         beta = relay.var("beta", relay.TensorType((2,), dtype))
         gamma = relay.var("gamma", relay.TensorType((2,), dtype))
         moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype))
         moving_var = relay.var("moving_var", relay.TensorType((2,), dtype))
-        y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
-                                center=False, scale=False)
+        y = relay.nn.batch_norm(
+            data, gamma, beta, moving_mean, moving_var, center=False, scale=False
+        )
         yy = run_infer_type(y.astuple())
         assert "center=" in yy.astext()
-        assert yy.checked_type == relay.ty.TupleType(tvm.runtime.convert([
-            relay.TensorType((3, 2, 1), dtype),
-            relay.TensorType((2,), dtype),
-            relay.TensorType((2,), dtype)
-        ]))
+        assert yy.checked_type == relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.TensorType((3, 2, 1), dtype),
+                    relay.TensorType((2,), dtype),
+                    relay.TensorType((2,), dtype),
+                ]
+            )
+        )
 
         beta = relay.var("beta", relay.TensorType((3,), dtype))
         gamma = relay.var("gamma", relay.TensorType((3,), dtype))
         moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
         moving_var = relay.var("moving_var", relay.TensorType((3,), dtype))
 
-        y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
-                                axis=0, center=False, scale=False)
+        y = relay.nn.batch_norm(
+            data, gamma, beta, moving_mean, moving_var, axis=0, center=False, scale=False
+        )
         yy = run_infer_type(y.astuple())
-        assert yy.checked_type == relay.ty.TupleType(tvm.runtime.convert([
-            relay.ty.TensorType((3, 2, 1), dtype),
-            relay.ty.TensorType((3,), dtype),
-            relay.ty.TensorType((3,), dtype)
-        ]))
+        assert yy.checked_type == relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((3, 2, 1), dtype),
+                    relay.ty.TensorType((3,), dtype),
+                    relay.ty.TensorType((3,), dtype),
+                ]
+            )
+        )
 
         # axis=-1
         data = relay.var("data", relay.TensorType((1, 2, 3), dtype))
@@ -333,19 +371,25 @@ def test_batch_norm():
         gamma = relay.var("gamma", relay.TensorType((3,), dtype))
         moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
         moving_var = relay.var("moving_var", relay.TensorType((3,), dtype))
-        y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
-                                axis=-1, center=False, scale=False)
+        y = relay.nn.batch_norm(
+            data, gamma, beta, moving_mean, moving_var, axis=-1, center=False, scale=False
+        )
         yy = run_infer_type(y.astuple())
-        assert yy.checked_type == relay.ty.TupleType(tvm.runtime.convert([
-            relay.ty.TensorType((1, 2, 3), dtype),
-            relay.ty.TensorType((3,), dtype),
-            relay.ty.TensorType((3,), dtype)
-        ]))
+        assert yy.checked_type == relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((1, 2, 3), dtype),
+                    relay.ty.TensorType((3,), dtype),
+                    relay.ty.TensorType((3,), dtype),
+                ]
+            )
+        )
+
 
 @pytest.mark.xfail
 def test_dense_type_check():
-    dtype = 'float16'
-    n, c , h, w = 2, 2 , 2 ,2
+    dtype = "float16"
+    n, c, h, w = 2, 2, 2, 2
     x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
     # it should fail since it does not match with m(2)
     mismatch_w = 3
@@ -353,13 +397,14 @@ def test_dense_type_check():
     y = relay.nn.dense(x, w)
     yy = run_infer_type(y)
 
+
 @tvm.testing.uses_gpu
 def test_dense():
-    for dtype in ['float16', 'float32']:
+    for dtype in ["float16", "float32"]:
         # Dense accuracy for float16 is poor
-        if dtype == 'float16':
+        if dtype == "float16":
             return
-        n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
         x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
         w = relay.var("w", relay.TensorType((2, w), dtype))
         y = relay.nn.dense(x, w, units=2)
@@ -367,7 +412,7 @@ def test_dense():
         yy = run_infer_type(y)
         assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
 
-        n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
         x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
         wh, ww = te.size_var("wh"), te.size_var("ww")
         w = relay.var("w", relay.TensorType((ww, wh), dtype))
@@ -375,7 +420,7 @@ def test_dense():
         yy = run_infer_type(y)
         assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype)
 
-        n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
         x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
         w = relay.var("w", relay.IncompleteType())
         y = relay.nn.dense(x, w, units=2)
@@ -402,18 +447,18 @@ def test_dense():
 
 
 def test_dense_dtype():
-    data_dtype = 'uint8'
-    weight_dtype = 'int8'
-    out_dtype = 'uint8'
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    data_dtype = "uint8"
+    weight_dtype = "int8"
+    out_dtype = "uint8"
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
     x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype))
     w = relay.var("w", relay.TensorType((2, w), weight_dtype))
     y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype)
     assert "units=2" in y.astext()
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n, c, h, 2), out_dtype)
-    assert run_infer_type(yy.args[0]).checked_type.dtype == 'uint8'
-    assert run_infer_type(yy.args[1]).checked_type.dtype == 'int8'
+    assert run_infer_type(yy.args[0]).checked_type.dtype == "uint8"
+    assert run_infer_type(yy.args[1]).checked_type.dtype == "int8"
 
 
 def test_bitserial_dense():
index 3aaa76d..55fba49 100644 (file)
@@ -47,17 +47,20 @@ def test_checkpoint():
             f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs)
             tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0)
 
+
 def test_checkpoint_alpha_equal():
     xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
-    f = relay.Function(xs, relay.annotation.checkpoint(
-        relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
-    ))
+    f = relay.Function(
+        xs,
+        relay.annotation.checkpoint(
+            relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
+        ),
+    )
     df = transform.gradient(run_infer_type(f))
 
     # run PE and DCE
     with tvm.transform.PassContext(opt_level=3):
-        passes = [transform.PartialEvaluate(),
-                  transform.DeadCodeElimination(inline_once=True)]
+        passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]
         mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
         df = mod["main"]
 
@@ -103,17 +106,20 @@ def test_checkpoint_alpha_equal():
 
     tvm.ir.assert_structural_equal(df, df_parsed)
 
+
 def test_checkpoint_alpha_equal_tuple():
     xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
-    f = relay.Function(xs, relay.annotation.checkpoint(
-        relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
-    ))
+    f = relay.Function(
+        xs,
+        relay.annotation.checkpoint(
+            relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
+        ),
+    )
     df = transform.gradient(run_infer_type(f))
 
     # run PE and DCE
     with tvm.transform.PassContext(opt_level=3):
-        passes = [transform.PartialEvaluate(),
-                  transform.DeadCodeElimination(inline_once=True)]
+        passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]
         mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
         df = mod["main"]
 
@@ -150,12 +156,13 @@ def test_checkpoint_alpha_equal_tuple():
 
     tvm.ir.assert_structural_equal(df, df_parsed)
 
+
 @tvm.testing.uses_gpu
 def test_collapse_sum_like():
     shape = (3, 4, 5, 6)
     shape_like = (4, 5, 6)
     dtype = "float32"
-    x = relay.Var("x", relay.ty.TensorType(shape , dtype))
+    x = relay.Var("x", relay.ty.TensorType(shape, dtype))
     y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
     z = relay.collapse_sum_like(x, y)
     zz = run_infer_type(z)
@@ -177,7 +184,7 @@ def test_collapse_sum_to():
     shape = (3, 4, 5, 6)
     shape_to = (4, 5, 6)
     dtype = "float32"
-    x = relay.Var("x", relay.ty.TensorType(shape , dtype))
+    x = relay.Var("x", relay.ty.TensorType(shape, dtype))
     z = relay.collapse_sum_to(x, shape_to)
     zz = run_infer_type(z)
     assert zz.checked_type == relay.ty.TensorType(shape_to, dtype)
@@ -197,7 +204,7 @@ def test_broadcast_to():
     shape = (4, 1, 6)
     shape_like = (3, 4, 5, 6)
     dtype = "float32"
-    x = relay.Var("x", relay.ty.TensorType(shape , dtype))
+    x = relay.Var("x", relay.ty.TensorType(shape, dtype))
     z = relay.broadcast_to(x, shape=shape_like)
     zz = run_infer_type(z)
     assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
@@ -211,12 +218,13 @@ def test_broadcast_to():
             op_res = intrp.evaluate(func)(x)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_broadcast_to_like():
     shape = (4, 1, 6)
     shape_like = (3, 4, 5, 6)
     dtype = "float32"
-    x = relay.Var("x", relay.ty.TensorType(shape , dtype))
+    x = relay.Var("x", relay.ty.TensorType(shape, dtype))
     y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
     z = relay.broadcast_to_like(x, y)
 
@@ -263,8 +271,9 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"):
         assert "axes" in z.astext()
     assert zz.checked_type == relay.ty.TensorType(output, dtype)
 
-    if all(isinstance(v, int) == 0 for v in data) or \
-        all(isinstance(v, int) == 0 for v in slice_like):
+    if all(isinstance(v, int) == 0 for v in data) or all(
+        isinstance(v, int) == 0 for v in slice_like
+    ):
         return
 
     func = relay.Function([x, y], z)
@@ -278,20 +287,21 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"):
             op_res = intrp.evaluate(func)(x_data, y_data)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_slice_like():
     d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
     verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
     verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=None, output=(d1, d2, d3))
-    verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1,2), output=(d2, d2, d3))
+    verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1, 2), output=(d2, d2, d3))
     verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
     verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=None, output=(1, 2, 5))
     verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(1, 2), output=(3, 2, 3))
     verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(-1, -3), output=(1, 4, 3))
-    verify_slice_like(data=(1, 3, 224, 224),
-                      slice_like=(1, 3, 112, 112),
-                      axes=(2, 3),
-                      output=(1, 3, 112, 112))
+    verify_slice_like(
+        data=(1, 3, 224, 224), slice_like=(1, 3, 112, 112), axes=(2, 3), output=(1, 3, 112, 112)
+    )
+
 
 @tvm.testing.uses_gpu
 def test_reverse_reshape():
@@ -310,12 +320,14 @@ def test_reverse_reshape():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
     verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4))
     verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8))
     verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
     verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))
 
+
 def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
     x = relay.var("x", relay.TensorType(x_shape, dtype))
     y = relay.var("y", relay.TensorType(y_shape, dtype))
@@ -334,6 +346,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
             z = intrp.evaluate(func)(x_np, y_np)
             tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_batch_matmul():
     b, m, n, k = te.size_var("b"), te.size_var("m"), te.size_var("n"), te.size_var("k")
@@ -348,21 +361,22 @@ def test_batch_matmul():
     verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
     verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
 
+
 @tvm.testing.uses_gpu
 def test_shape_of():
     shape = (10, 5, 12)
     x = relay.var("x", shape=shape)
     func = relay.Function([x], relay.op.shape_of(x))
     func = run_infer_type(func)
-    x_data = np.random.rand(*shape).astype('float32')
+    x_data = np.random.rand(*shape).astype("float32")
     for target, ctx in tvm.testing.enabled_targets():
         # Because using graph executor, this op will be optimized after
         # constant folding pass, here we only test with interpreter
         for kind in ["debug"]:
             intrp = relay.create_executor(kind, ctx=ctx, target=target)
             op_res = intrp.evaluate(func)(x_data)
-            tvm.testing.assert_allclose(op_res.asnumpy(),
-                                        np.array(shape).astype('int32'))
+            tvm.testing.assert_allclose(op_res.asnumpy(), np.array(shape).astype("int32"))
+
 
 @tvm.testing.uses_gpu
 def test_ndarray_size():
@@ -377,8 +391,8 @@ def test_ndarray_size():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
-                tvm.testing.assert_allclose(op_res.asnumpy(),
-                                            ref_res)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
+
     verify_ndarray_size((2, 3, 5))
     verify_ndarray_size((2, 3, 5, 7))
 
@@ -441,9 +455,11 @@ def test_sequence_mask():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 out_relay = intrp.evaluate(func)(data_np, valid_length_np)
                 tvm.testing.assert_allclose(out_relay.asnumpy(), gt_out_np)
-    _verify((5, 10), 0.0, 1, 'float32', 'int32')
-    _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
-    _verify((5, 8, 3), 0.1, 1, 'float64', 'float32')
+
+    _verify((5, 10), 0.0, 1, "float32", "int32")
+    _verify((2, 3, 5, 3), 0.0, 0, "float32", "int64")
+    _verify((5, 8, 3), 0.1, 1, "float64", "float32")
+
 
 @tvm.testing.uses_gpu
 def test_one_hot():
@@ -467,7 +483,9 @@ def test_one_hot():
         off_value_const = relay.const(off_value)
         out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
         checked = run_infer_type(out)
-        assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
+        assert checked.checked_type == relay.ty.TensorType(
+            _get_oshape(indices_shape, depth, axis), dtype
+        )
         func = relay.Function([indices], out)
         indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
         out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
@@ -485,6 +503,7 @@ def test_one_hot():
     _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+
 @tvm.testing.uses_gpu
 def test_matrix_set_diag():
     def _verify(input_shape, dtype):
@@ -509,9 +528,10 @@ def test_matrix_set_diag():
                 out_relay = intrp.evaluate(func)(input_np, diagonal_np)
                 tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
 
-    _verify((2, 2), 'float32')
-    _verify((4, 3, 3), 'int32')
-    _verify((2, 3, 4), 'float32')
+    _verify((2, 2), "float32")
+    _verify((4, 3, 3), "int32")
+    _verify((2, 3, 4), "float32")
+
 
 if __name__ == "__main__":
     test_adaptive_pool()
index 93eecfc..c25c2bf 100644 (file)
@@ -35,73 +35,66 @@ def test_conv1d_infer_type():
     n, c, w = te.var("n"), 10, 224
     x = relay.var("x", relay.ty.TensorType((n, c, w), "float32"))
     w = relay.var("w")
-    y = relay.nn.conv1d(x, w,
-                        kernel_size=3,
-                        padding=(1, 1),
-                        channels=2)
+    y = relay.nn.conv1d(x, w, kernel_size=3, padding=(1, 1), channels=2)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 224), "float32")
-    assert yy.args[1].checked_type == relay.TensorType(
-        (2, 10, 3), "float32")
+    assert yy.checked_type == relay.TensorType((n, 2, 224), "float32")
+    assert yy.args[1].checked_type == relay.TensorType((2, 10, 3), "float32")
 
     # infer by shape of w, mixed precision
     n, c, w = te.var("n"), 10, 224
     x = relay.var("x", relay.TensorType((n, c, w), "int8"))
     w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
     y = relay.nn.conv1d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222), "int32")
 
     # infer shape in case of different dtypes for input and weight.
     n, c, w = te.var("n"), 10, 224
     x = relay.var("x", relay.TensorType((n, c, w), "uint8"))
     w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
     y = relay.nn.conv1d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222), "int32")
 
     # Infer with NWC
     n, c, w = 4, 32, 224
     x = relay.var("x", relay.TensorType((n, w, c), "int8"))
     wt = relay.var("w")
-    y = relay.nn.conv1d(x, wt,
-                        kernel_size=3,
-                        padding=(1, 1),
-                        channels=16,
-                        data_layout="NWC",
-                        out_dtype="int32")
+    y = relay.nn.conv1d(
+        x, wt, kernel_size=3, padding=(1, 1), channels=16, data_layout="NWC", out_dtype="int32"
+    )
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, w, 16), "int32")
+    assert yy.checked_type == relay.TensorType((n, w, 16), "int32")
 
 
 @tvm.testing.uses_gpu
 def test_conv1d_run():
-    def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape,
-                        padding=(1, 1),
-                        fref=None,
-                        dilation=1,
-                        except_targets=None,
-                        **attrs):
+    def run_test_conv1d(
+        dtype,
+        out_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        fref=None,
+        dilation=1,
+        except_targets=None,
+        **attrs,
+    ):
         if except_targets is None:
             except_targets = []
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", dtype=dtype)
-        y = relay.nn.conv1d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            **attrs)
+        y = relay.nn.conv1d(x, w, padding=padding, dilation=dilation, **attrs)
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         ref_res = tvm.topi.testing.conv1d_ncw_python(
-            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation)
+            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation
+        )
 
         for target, ctx in tvm.testing.enabled_targets():
             if target in except_targets:
@@ -114,16 +107,25 @@ def test_conv1d_run():
     # normal conv1d
     dshape = (1, 3, 224)
     kshape = (10, 3, 3)
-    run_test_conv1d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=3)
+    run_test_conv1d(
+        "float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=3
+    )
     # mixed precision
-    run_test_conv1d("int8", "int32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=3)
+    run_test_conv1d("int8", "int32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=3)
     # dilated conv2d
     dshape = (1, 3, 18)
     kshape = (10, 3, 3)
-    run_test_conv1d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=3, dilation=3)
+    run_test_conv1d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=10,
+        kernel_size=3,
+        dilation=3,
+    )
 
 
 @tvm.testing.uses_gpu
@@ -132,99 +134,96 @@ def test_conv2d_infer_type():
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
     w = relay.var("w")
-    y = relay.nn.conv2d(x, w,
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        channels=2)
+    y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=2)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 224, 224), "float32")
-    assert yy.args[1].checked_type == relay.TensorType(
-        (2, 10, 3, 3), "float32")
+    assert yy.checked_type == relay.TensorType((n, 2, 224, 224), "float32")
+    assert yy.args[1].checked_type == relay.TensorType((2, 10, 3, 3), "float32")
 
     # infer by shape of w, mixed precision
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
     w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
     y = relay.nn.conv2d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222, 222), "int32")
 
     # infer shape in case of different dtypes for input and weight.
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", relay.TensorType((n, c, h, w), "uint8"))
     w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
     y = relay.nn.conv2d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222, 222), "int32")
 
     # Infer with a different layout
     n, c, h, w = 4, 32, 224, 224
-    x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))
+    x = relay.var("x", relay.TensorType((n // 4, c // 4, h, w, 4, 4), "int8"))
     wt = relay.var("w")
-    y = relay.nn.conv2d(x, wt,
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        channels=16,
-                        data_layout="NCHW4n4c",
-                        kernel_layout="OIHW4o4i",
-                        out_dtype="int32")
+    y = relay.nn.conv2d(
+        x,
+        wt,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        channels=16,
+        data_layout="NCHW4n4c",
+        kernel_layout="OIHW4o4i",
+        out_dtype="int32",
+    )
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (1, 4, 224, 224, 4, 4), "int32")
-    assert yy.args[1].checked_type == relay.TensorType(
-        (4, 8, 3, 3, 4, 4), "int8")
+    assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 4, 4), "int32")
+    assert yy.args[1].checked_type == relay.TensorType((4, 8, 3, 3, 4, 4), "int8")
 
     # Infer with NHWC
     n, c, h, w = 4, 32, 224, 224
     x = relay.var("x", relay.TensorType((n, h, w, c), "int8"))
     wt = relay.var("w")
-    y = relay.nn.conv2d(x, wt,
-                        kernel_size=(3, 3),
-                        padding=(1, 1),
-                        channels=16,
-                        data_layout="NHWC",
-                        out_dtype="int32")
+    y = relay.nn.conv2d(
+        x,
+        wt,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        channels=16,
+        data_layout="NHWC",
+        out_dtype="int32",
+    )
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, h, w, 16), "int32")
+    assert yy.checked_type == relay.TensorType((n, h, w, 16), "int32")
 
 
 @tvm.testing.uses_gpu
 def test_conv2d_run():
-    def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape,
-                        padding=(1, 1),
-                        fref=None,
-                        groups=1,
-                        dilation=(1, 1),
-                        except_targets=None,
-                        **attrs):
+    def run_test_conv2d(
+        dtype,
+        out_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        fref=None,
+        groups=1,
+        dilation=(1, 1),
+        except_targets=None,
+        **attrs,
+    ):
         if except_targets is None:
             except_targets = []
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
-        y = relay.nn.conv2d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            **attrs)
+        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
         if fref is None:
             ref_res = tvm.topi.testing.conv2d_nchw_python(
-                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding,
-                groups=groups)
+                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, groups=groups
+            )
         else:
             ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
 
-
         for target, ctx in tvm.testing.enabled_targets():
             if target in except_targets:
                 continue
@@ -233,23 +232,17 @@ def test_conv2d_run():
             op_res1 = intrp1.evaluate(func)(data, kernel)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4, atol=1e-4)
 
-    def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape,
-                        padding=(1, 1),
-                        groups=1,
-                        dilation=(1, 1),
-                        **attrs):
+    def compile_test_conv2d_arm_cpu(
+        dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), groups=1, dilation=(1, 1), **attrs
+    ):
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
-        y = relay.nn.conv2d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            **attrs)
+        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         mod = tvm.IRModule()
         mod["main"] = func
 
-        test_schedule='{"i": ["llvm -device=arm_cpu", "depthwise_conv2d_nchw_spatial_pack.arm_cpu", \
+        test_schedule = '{"i": ["llvm -device=arm_cpu", "depthwise_conv2d_nchw_spatial_pack.arm_cpu", \
                         [["TENSOR", [1, 512, 32, 32], "float32"], \
                         ["TENSOR", [512, 1, 3, 3], "float32"], \
                         [1, 1], [1, 1], [1, 1], "float32"], {}, \
@@ -270,56 +263,103 @@ def test_conv2d_run():
             log_file.write(test_schedule)
         with autotvm.apply_history_best(temp.relpath("temp.log")):
             with tvm.transform.PassContext(opt_level=3):
-                print('Compiling...')
+                print("Compiling...")
                 graph_json, mod, params = tvm.relay.build(mod, target="llvm -device=arm_cpu")
 
     # depthwise conv2d
     dshape = (1, 32, 18, 18)
     kshape = (32, 1, 3, 3)
-    run_test_conv2d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=32, groups=32, kernel_size=(3 ,3),
-                    fref=lambda x, w: tvm.topi.testing.depthwise_conv2d_python_nchw(
-                        x, w, (1, 1), "SAME"))
+    run_test_conv2d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=32,
+        groups=32,
+        kernel_size=(3, 3),
+        fref=lambda x, w: tvm.topi.testing.depthwise_conv2d_python_nchw(x, w, (1, 1), "SAME"),
+    )
 
     # depthwise conv2d for arm_cpu
     dshape = (1, 512, 32, 32)
     kshape = (512, 1, 3, 3)
-    compile_test_conv2d_arm_cpu("float32", "float32", 1, dshape, kshape,
-                                padding=(1, 1), channels=512,
-                                groups=512, kernel_size=(3 ,3))
+    compile_test_conv2d_arm_cpu(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=512,
+        groups=512,
+        kernel_size=(3, 3),
+    )
 
     # CUDA is disabled for 'direct' schedule:
     # https://github.com/apache/incubator-tvm/pull/3070#issuecomment-486597553
     # group conv2d
     dshape = (1, 32, 18, 18)
     kshape = (32, 4, 3, 3)
-    run_test_conv2d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3),
-                    except_targets=['cuda'])
+    run_test_conv2d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=32,
+        groups=8,
+        kernel_size=(3, 3),
+        except_targets=["cuda"],
+    )
     # also group conv2d
     dshape = (1, 32, 18, 18)
     kshape = (64, 1, 3, 3)
-    run_test_conv2d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3),
-                    except_targets=['cuda'])
+    run_test_conv2d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=64,
+        groups=32,
+        kernel_size=(3, 3),
+        except_targets=["cuda"],
+    )
 
     # normal conv2d
     dshape = (1, 3, 224, 224)
     kshape = (10, 3, 3, 3)
-    run_test_conv2d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=(3 ,3))
+    run_test_conv2d(
+        "float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3)
+    )
     # mixed precision
-    run_test_conv2d("int8", "int32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=(3 ,3))
+    run_test_conv2d(
+        "int8", "int32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3)
+    )
     kshape = (10, 3, 1, 3)
     # mixed precision.
-    run_test_conv2d("int8", "int32", 1, dshape, kshape,
-                    padding=(0, 1), channels=10, kernel_size=(1 ,3))
+    run_test_conv2d(
+        "int8", "int32", 1, dshape, kshape, padding=(0, 1), channels=10, kernel_size=(1, 3)
+    )
     # dilated conv2d
     dshape = (1, 3, 18, 18)
     kshape = (10, 3, 3, 3)
-    run_test_conv2d("float32", "float32", 1, dshape, kshape,
-                    padding=(1, 1), channels=10, kernel_size=(3 ,3), dilation=(3, 3))
+    run_test_conv2d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=10,
+        kernel_size=(3, 3),
+        dilation=(3, 3),
+    )
+
 
 @tvm.testing.uses_gpu
 def test_conv2d_winograd():
@@ -330,49 +370,43 @@ def test_conv2d_winograd():
                 return self.memory[key]
             cfg = autotvm.task.space.FallbackConfigEntity()
             cfg.is_fallback = False
-            cfg.cost = 0.1 if 'winograd' in workload[0] else 1
-            cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
-            cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(1500)
-            cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
+            cfg.cost = 0.1 if "winograd" in workload[0] else 1
+            cfg["tile_b"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_rc"] = autotvm.task.space.SplitEntity([-1, 1])
+            cfg["auto_unroll_max_step"] = autotvm.task.space.OtherOptionEntity(1500)
+            cfg["unroll_explicit"] = autotvm.task.space.OtherOptionEntity(1)
             self.memory[key] = cfg
             return cfg
 
-    def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape,
-                             padding=(1, 1),
-                             groups=1,
-                             dilation=(1, 1),
-                             **attrs):
+    def run_test_conv2d_cuda(
+        dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), groups=1, dilation=(1, 1), **attrs
+    ):
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
-        y = relay.nn.conv2d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            **attrs)
+        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         mod = tvm.IRModule()
-        mod['main'] = func
+        mod["main"] = func
         mod = relay.transform.InferType()(mod)
 
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         ref_res = tvm.topi.testing.conv2d_nchw_python(
-            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding,
-            groups=groups)
+            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups
+        )
 
         with WinogradFallback(), tvm.transform.PassContext(opt_level=3):
             for target, ctx in tvm.testing.enabled_targets():
-                if target != 'cuda':
+                if target != "cuda":
                     continue
                 ctx = tvm.context(target, 0)
-                params = {'w': tvm.nd.array(kernel)}
+                params = {"w": tvm.nd.array(kernel)}
                 graph, lib, params = relay.build_module.build(mod, target=target, params=params)
                 module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
-                module.set_input('x', tvm.nd.array(data))
+                module.set_input("x", tvm.nd.array(data))
                 module.set_input(**params)
                 module.run()
                 op_res1 = module.get_output(0)
@@ -381,17 +415,21 @@ def test_conv2d_winograd():
     # normal winograd: stride 1, padding 1, kernel 3x3
     dshape = (1, 80, 73, 73)
     kshape = (192, 80, 3, 3)
-    run_test_conv2d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(1, 1), channels=192, kernel_size=(3, 3))
+    run_test_conv2d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=192, kernel_size=(3, 3)
+    )
     # extended winograd: stride 1, padding N, kernel 3x3
-    run_test_conv2d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(0, 0), channels=192, kernel_size=(3, 3))
-    run_test_conv2d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(2, 2), channels=192, kernel_size=(3, 3))
+    run_test_conv2d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(0, 0), channels=192, kernel_size=(3, 3)
+    )
+    run_test_conv2d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(3, 3)
+    )
     # extended winograd: stride 1, padding N, kernel NxN
     kshape = (192, 80, 7, 7)
-    run_test_conv2d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(2, 2), channels=192, kernel_size=(7, 7))
+    run_test_conv2d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(7, 7)
+    )
 
 
 @tvm.testing.uses_gpu
@@ -400,82 +438,78 @@ def test_conv3d_infer_type():
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
     w = relay.var("w")
-    y = relay.nn.conv3d(x, w,
-                        kernel_size=(3, 3, 3),
-                        padding=(1, 1, 1),
-                        channels=2)
+    y = relay.nn.conv3d(x, w, kernel_size=(3, 3, 3), padding=(1, 1, 1), channels=2)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 224, 224, 224), "float32")
-    assert yy.args[1].checked_type == relay.TensorType(
-        (2, 10, 3, 3, 3), "float32")
+    assert yy.checked_type == relay.TensorType((n, 2, 224, 224, 224), "float32")
+    assert yy.args[1].checked_type == relay.TensorType((2, 10, 3, 3, 3), "float32")
 
     # infer by shape of w, mixed precision
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
     w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
     y = relay.nn.conv3d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222, 222, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222, 222, 222), "int32")
 
     # infer shape in case of different dtypes for input and weight.
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
     w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
     y = relay.nn.conv3d(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 222, 222, 222), "int32")
+    assert yy.checked_type == relay.TensorType((n, 2, 222, 222, 222), "int32")
 
     # Infer with NDHWC
     n, c, d, h, w = 4, 32, 224, 224, 224
     x = relay.var("x", relay.TensorType((n, d, h, w, c), "int8"))
     wt = relay.var("w")
-    y = relay.nn.conv3d(x, wt,
-                        kernel_size=(3, 3, 3),
-                        padding=(1, 1, 1),
-                        channels=16,
-                        data_layout="NDHWC",
-                        out_dtype="int32")
+    y = relay.nn.conv3d(
+        x,
+        wt,
+        kernel_size=(3, 3, 3),
+        padding=(1, 1, 1),
+        channels=16,
+        data_layout="NDHWC",
+        out_dtype="int32",
+    )
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, d, h, w, 16), "int32")
+    assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32")
 
 
 @tvm.testing.uses_gpu
 def test_conv3d_run():
-    def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
-                        padding=(1, 1, 1),
-                        fref=None,
-                        groups=1,
-                        dilation=(1, 1, 1),
-                        except_targets=None,
-                        **attrs):
+    def run_test_conv3d(
+        dtype,
+        out_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1, 1),
+        fref=None,
+        groups=1,
+        dilation=(1, 1, 1),
+        except_targets=None,
+        **attrs,
+    ):
         if except_targets is None:
             except_targets = []
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", dtype=dtype)
-        y = relay.nn.conv3d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            **attrs)
+        y = relay.nn.conv3d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
         if fref is None:
             ref_res = tvm.topi.testing.conv3d_ncdhw_python(
-                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding,
-                groups=groups)
+                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, groups=groups
+            )
         else:
             ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
 
-
         for target, ctx in tvm.testing.enabled_targets():
             if target in except_targets:
                 continue
@@ -488,40 +522,59 @@ def test_conv3d_run():
     # normal conv3d
     dshape = (1, 3, 5, 224, 224)
     kshape = (10, 3, 3, 3, 3)
-    run_test_conv3d("float32", "float32", 1, dshape, kshape,
-            padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3))
+    run_test_conv3d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1, 1),
+        channels=10,
+        kernel_size=(3, 3, 3),
+    )
+
 
 @tvm.testing.uses_gpu
 def test_conv3d_ndhwc_run():
-    def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
-                        padding=(1, 1, 1),
-                        fref=None,
-                        groups=1,
-                        dilation=(1, 1, 1),
-                        except_targets=None,
-                        **attrs):
+    def run_test_conv3d(
+        dtype,
+        out_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1, 1),
+        fref=None,
+        groups=1,
+        dilation=(1, 1, 1),
+        except_targets=None,
+        **attrs,
+    ):
         if except_targets is None:
             except_targets = []
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", dtype=dtype)
-        y = relay.nn.conv3d(x, w,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=groups,
-                            data_layout="NDHWC", kernel_layout="DHWIO",
-                            **attrs)
+        y = relay.nn.conv3d(
+            x,
+            w,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            data_layout="NDHWC",
+            kernel_layout="DHWIO",
+            **attrs,
+        )
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
         if fref is None:
             ref_res = tvm.topi.testing.conv3d_ndhwc_python(
-                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding)
+                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding
+            )
         else:
             ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
 
-
         for target, ctx in tvm.testing.enabled_targets():
             if target in except_targets:
                 continue
@@ -534,8 +587,18 @@ def test_conv3d_ndhwc_run():
     # normal conv3d
     dshape = (1, 5, 224, 224, 6)
     kshape = (3, 3, 3, 6, 10)
-    run_test_conv3d("float32", "float32", 1, dshape, kshape,
-            padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
+    run_test_conv3d(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1, 1),
+        channels=10,
+        kernel_size=(3, 3, 3),
+        except_targets=["cuda"],
+    )
+
 
 @tvm.testing.uses_gpu
 def test_conv3d_winograd():
@@ -546,22 +609,28 @@ def test_conv3d_winograd():
                 return self.memory[key]
             cfg = autotvm.task.space.FallbackConfigEntity()
             cfg.is_fallback = False
-            cfg.cost = 0.1 if 'winograd' in workload[0] else 1
-            cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
-            cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
-            cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(0)
-            cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
+            cfg.cost = 0.1 if "winograd" in workload[0] else 1
+            cfg["tile_b"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg["tile_rc"] = autotvm.task.space.SplitEntity([-1, 1])
+            cfg["auto_unroll_max_step"] = autotvm.task.space.OtherOptionEntity(0)
+            cfg["unroll_explicit"] = autotvm.task.space.OtherOptionEntity(1)
             self.memory[key] = cfg
             return cfg
 
-    def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape,
-                             padding=(1, 1, 1),
-                             groups=1,
-                             dilation=(1, 1, 1),
-                             prepack=False,
-                             **attrs):
+    def run_test_conv3d_cuda(
+        dtype,
+        out_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1, 1),
+        groups=1,
+        dilation=(1, 1, 1),
+        prepack=False,
+        **attrs,
+    ):
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
@@ -570,38 +639,37 @@ def test_conv3d_winograd():
             w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size)
 
             y = relay.nn.contrib_conv3d_winograd_without_weight_transform(
-                x, w_packed, tile_size,
+                x,
+                w_packed,
+                tile_size,
                 padding=padding,
                 dilation=dilation,
                 groups=groups,
                 channels=kshape[0],
-                **attrs)
+                **attrs,
+            )
         else:
-            y = relay.nn.conv3d(x, w,
-                                padding=padding,
-                                dilation=dilation,
-                                groups=groups,
-                                **attrs)
+            y = relay.nn.conv3d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
         func = relay.Function([x, w], y)
         mod = tvm.IRModule()
-        mod['main'] = func
+        mod["main"] = func
         mod = relay.transform.InferType()(mod)
 
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
         ref_res = tvm.topi.testing.conv3d_ncdhw_python(
-            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding,
-            groups=groups)
+            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups
+        )
 
         with WinogradFallback(), tvm.transform.PassContext(opt_level=3):
             for target, ctx in tvm.testing.enabled_targets():
-                if target != 'cuda':
+                if target != "cuda":
                     continue
                 ctx = tvm.context(target, 0)
-                params = {'w': tvm.nd.array(kernel)}
+                params = {"w": tvm.nd.array(kernel)}
                 graph, lib, params = relay.build_module.build(mod, target=target, params=params)
                 module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
-                module.set_input('x', tvm.nd.array(data))
+                module.set_input("x", tvm.nd.array(data))
                 module.set_input(**params)
                 module.run()
                 op_res1 = module.get_output(0)
@@ -610,22 +678,40 @@ def test_conv3d_winograd():
     # normal winograd: stride 1, padding 1, kernel 3x3x3
     dshape = (1, 32, 16, 16, 16)
     kshape = (64, 32, 3, 3, 3)
-    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(1, 1, 1), kernel_size=(3, 3, 3))
+    run_test_conv3d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), kernel_size=(3, 3, 3)
+    )
     # Without depth transform using 1x3x3 kernel.
     kshape = (64, 32, 1, 3, 3)
-    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(0, 1, 1), kernel_size=(1, 3, 3))
+    run_test_conv3d_cuda(
+        "float32", "float32", 1, dshape, kshape, padding=(0, 1, 1), kernel_size=(1, 3, 3)
+    )
 
     # extended winograd: stride 1, padding N, kernel NxNxN
     dshape = (1, 61, 20, 20, 20)
     kshape = (120, 61, 5, 5, 5)
-    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(2, 2, 2), channels=120, kernel_size=(5, 5, 5))
+    run_test_conv3d_cuda(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(2, 2, 2),
+        channels=120,
+        kernel_size=(5, 5, 5),
+    )
     # Without depth transform
     kshape = (120, 61, 1, 5, 5)
-    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
-                         padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5))
+    run_test_conv3d_cuda(
+        "float32",
+        "float32",
+        1,
+        dshape,
+        kshape,
+        padding=(0, 2, 2),
+        channels=120,
+        kernel_size=(1, 5, 5),
+    )
 
 
 @tvm.testing.uses_gpu
@@ -634,36 +720,29 @@ def test_conv3d_transpose_infer_type():
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
     w = relay.var("w")
-    y = relay.nn.conv3d_transpose(x, w,
-                                   kernel_size=(3, 3, 3),
-                                   padding=(1, 1, 1),
-                                   channels=2)
+    y = relay.nn.conv3d_transpose(x, w, kernel_size=(3, 3, 3), padding=(1, 1, 1), channels=2)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 2, 224, 224, 224), "float32")
+    assert yy.checked_type == relay.TensorType((n, 2, 224, 224, 224), "float32")
 
-    assert yy.args[1].checked_type == relay.TensorType(
-        (10, 2, 3, 3, 3), "float32")
+    assert yy.args[1].checked_type == relay.TensorType((10, 2, 3, 3, 3), "float32")
 
     # infer by shape of w, mixed precision
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
     w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8"))
     y = relay.nn.conv3d_transpose(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 12, 226, 226, 226), "int32")
+    assert yy.checked_type == relay.TensorType((n, 12, 226, 226, 226), "int32")
 
     # infer shape in case of different dtypes for input and weight.
     n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
     w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8"))
     y = relay.nn.conv3d_transpose(x, w, out_dtype="int32")
-    assert "out_dtype=\"int32\"" in y.astext()
+    assert 'out_dtype="int32"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
-        (n, 12, 226, 226, 226), "int32")
+    assert yy.checked_type == relay.TensorType((n, 12, 226, 226, 226), "int32")
 
 
 @tvm.testing.uses_gpu
@@ -673,9 +752,9 @@ def test_conv3d_transpose_ncdhw_run():
 
     x = relay.var("x", shape=dshape)
     w = relay.var("w")
-    y = relay.nn.conv3d_transpose(x, w,
-                                  channels=4, kernel_size=(2, 2, 2), strides=(1, 1, 1),
-                                  padding=(1, 1, 1))
+    y = relay.nn.conv3d_transpose(
+        x, w, channels=4, kernel_size=(2, 2, 2), strides=(1, 1, 1), padding=(1, 1, 1)
+    )
     func = relay.Function([x, w], y)
     dtype = "float32"
 
@@ -695,28 +774,19 @@ def test_conv2d_transpose_infer_type():
     n, c, h, w = te.size_var("n"), 10, 10, 12
     x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
     w = relay.var("w", relay.IncompleteType())
-    y = relay.nn.conv2d_transpose(x, w,
-                                  kernel_size=(3, 3),
-                                  padding=(1, 1),
-                                  channels=15)
+    y = relay.nn.conv2d_transpose(x, w, kernel_size=(3, 3), padding=(1, 1), channels=15)
     assert "channels=15" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (n, 15, 10, 12), "float32")
-    assert yy.args[1].checked_type == relay.TensorType(
-        (10, 15, 3, 3), "float32")
+    assert yy.checked_type == relay.TensorType((n, 15, 10, 12), "float32")
+    assert yy.args[1].checked_type == relay.TensorType((10, 15, 3, 3), "float32")
 
     # infer by shape of w, mixed precision
     n, h, w, c = te.size_var("n"), 10, 10, 12
     x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
     w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
-    y = relay.nn.conv2d_transpose(x, w,
-                                  output_padding=(1, 1),
-                                  channels=11,
-                                  data_layout="NHWC")
+    y = relay.nn.conv2d_transpose(x, w, output_padding=(1, 1), channels=11, data_layout="NHWC")
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (n, 15, 15, 11), "float32")
+    assert yy.checked_type == relay.TensorType((n, 15, 15, 11), "float32")
 
 
 @tvm.testing.uses_gpu
@@ -726,15 +796,14 @@ def test_conv2d_transpose_nchw_run():
     oshape = (1, 10, 36, 36)
     x = relay.var("x", shape=dshape)
     w = relay.var("w")
-    y = relay.nn.conv2d_transpose(x, w,
-                                  channels=10, kernel_size=(3,3), strides=(2,2),
-                                  padding=(1,1), output_padding=(1, 1))
+    y = relay.nn.conv2d_transpose(
+        x, w, channels=10, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), output_padding=(1, 1)
+    )
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape).astype(dtype)
     kernel = np.random.uniform(size=kshape).astype(dtype)
-    ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(
-        data, kernel, 2, 1, (1, 1))
+    ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel, 2, 1, (1, 1))
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
@@ -751,18 +820,26 @@ def test_conv2d_transpose_nhwc_run():
     w = relay.var("w")
     # kshape and kernel_layout should have swapped IO.
     # kshape is HWOI and kernel_layout is HWIO
-    y = relay.nn.conv2d_transpose(x, w,
-                                  channels=10, kernel_size=(3, 3), strides=(2, 2),
-                                  padding=(1, 1), output_padding=(1, 1),
-                                  data_layout="NHWC", kernel_layout="HWIO")
+    y = relay.nn.conv2d_transpose(
+        x,
+        w,
+        channels=10,
+        kernel_size=(3, 3),
+        strides=(2, 2),
+        padding=(1, 1),
+        output_padding=(1, 1),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+    )
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape_nhwc).astype(dtype)
     kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
     # use true kshape layout here - HWOI
 
-    ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI',
-                                                        2, 1, output_padding=(1, 1))
+    ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(
+        data, kernel, "HWOI", 2, 1, output_padding=(1, 1)
+    )
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
@@ -777,15 +854,14 @@ def test_conv1d_transpose_ncw_run():
     oshape = (1, 10, 36)
     x = relay.var("x", shape=dshape)
     w = relay.var("w")
-    y = relay.nn.conv1d_transpose(x, w,
-                                  channels=10, kernel_size=(3,), strides=(2,),
-                                  padding=(1,), output_padding=(1,))
+    y = relay.nn.conv1d_transpose(
+        x, w, channels=10, kernel_size=(3,), strides=(2,), padding=(1,), output_padding=(1,)
+    )
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape).astype(dtype)
     kernel = np.random.uniform(size=kshape).astype(dtype)
-    ref_res = tvm.topi.testing.conv1d_transpose_ncw_python(
-        data, kernel, 2, 1, output_padding=(1,))
+    ref_res = tvm.topi.testing.conv1d_transpose_ncw_python(data, kernel, 2, 1, output_padding=(1,))
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
@@ -795,40 +871,63 @@ def test_conv1d_transpose_ncw_run():
 
 @tvm.testing.uses_gpu
 def test_upsampling_infer_type():
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
     scale = tvm.tir.const(2.0, "float64")
     x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
     y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
-    "method=\"BINLINEAR\"" in y.astext()
+    'method="BINLINEAR"' in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", te.round(h*scale)),
-                                                tvm.tir.Cast("int32", te.round(w*scale))),
-                                                "float32")
+    assert yy.checked_type == relay.TensorType(
+        (
+            n,
+            c,
+            tvm.tir.Cast("int32", te.round(h * scale)),
+            tvm.tir.Cast("int32", te.round(w * scale)),
+        ),
+        "float32",
+    )
     n, c = te.size_var("n"), te.size_var("c")
     x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
     y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
 
+
 @tvm.testing.uses_gpu
 def test_upsampling3d_infer_type():
-    n, c, d, h, w = te.size_var("n"), te.size_var("c"),\
-                    te.size_var("d"), te.size_var("h"), te.size_var("w")
+    n, c, d, h, w = (
+        te.size_var("n"),
+        te.size_var("c"),
+        te.size_var("d"),
+        te.size_var("h"),
+        te.size_var("w"),
+    )
     scale = tvm.tir.const(2.0, "float64")
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
-    y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
+    y = relay.nn.upsampling3d(
+        x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear"
+    )
 
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", te.round(d*scale)),
-                                                tvm.tir.Cast("int32", te.round(h*scale)),
-                                                tvm.tir.Cast("int32", te.round(w*scale))),
-                                                "float32")
+    assert yy.checked_type == relay.TensorType(
+        (
+            n,
+            c,
+            tvm.tir.Cast("int32", te.round(d * scale)),
+            tvm.tir.Cast("int32", te.round(h * scale)),
+            tvm.tir.Cast("int32", te.round(w * scale)),
+        ),
+        "float32",
+    )
     n, c = te.size_var("n"), te.size_var("c")
     x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
-    y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
+    y = relay.nn.upsampling3d(
+        x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear"
+    )
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32")
 
+
 def _test_pool2d(opfunc, reffunc, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)):
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
@@ -849,6 +948,7 @@ def _test_pool2d(opfunc, reffunc, pool_size=(2, 2), strides=(2, 2), padding=(0,
         op_res1 = intrp1.evaluate(func)(data)
         tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
+
 def _test_pool2d_int(opfunc, reffunc, dtype):
     n, c, h, w = te.size_var("n"), 10, 224, 224
     x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
@@ -863,12 +963,13 @@ def _test_pool2d_int(opfunc, reffunc, dtype):
     y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
     func = relay.Function([x], y)
     data = np.random.randint(low=-128, high=128, size=dshape)
-    ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)).astype(dtype)
+    ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
         op_res1 = intrp1.evaluate(func)(data)
         tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
+
 def _test_global_pool2d(opfunc, reffunc):
     n, c, h, w = te.size_var("n"), te.size_var("c"), 224, 224
     x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
@@ -888,7 +989,7 @@ def _test_global_pool2d(opfunc, reffunc):
     y = opfunc(x)
     func = relay.Function([x], y)
     data = np.random.uniform(size=dshape).astype(dtype)
-    ref_res = reffunc(data, axis=(2,3), keepdims=True)
+    ref_res = reffunc(data, axis=(2, 3), keepdims=True)
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
         op_res1 = intrp1.evaluate(func)(data)
@@ -901,15 +1002,14 @@ def test_pool2d():
     _test_pool2d(relay.nn.max_pool2d, np.max, pool_size=2, strides=2, padding=0)
     _test_pool2d(relay.nn.avg_pool2d, np.mean)
     _test_pool2d(relay.nn.avg_pool2d, np.mean, pool_size=2, strides=2, padding=0)
-    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'int32')
-    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'uint16')
+    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int32")
+    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "uint16")
     _test_global_pool2d(relay.nn.global_max_pool2d, np.max)
     _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean)
 
 
 @tvm.testing.uses_gpu
 def test_pool1d():
-
     def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)):
         n, c, w = te.var("n"), 10, 224
         x = relay.var("x", relay.TensorType((n, c, w), "float32"))
@@ -921,12 +1021,13 @@ def test_pool1d():
         dtype = "float32"
         dshape = (1, 3, 32)
         x = relay.var("x", shape=dshape)
-        pool_type = 'max' if 'max' in str(opfunc) else 'avg'
+        pool_type = "max" if "max" in str(opfunc) else "avg"
         y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
         func = relay.Function([x], y)
         data = np.random.uniform(size=dshape).astype(dtype)
-        ref_res = tvm.topi.testing.pool1d_ncw_python(data, (2,), (2,),
-                                                 (0, 0), (1, 3, 16), pool_type, False)
+        ref_res = tvm.topi.testing.pool1d_ncw_python(
+            data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
+        )
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data)
@@ -940,12 +1041,13 @@ def test_pool1d():
 
 @tvm.testing.uses_gpu
 def test_pool3d():
-
-    def _test_pool3d(opfunc,
-                     pool_size=(2, 2, 2),
-                     strides=(2, 2, 2),
-                     padding=(0, 0, 0, 0, 0, 0),
-                     out_shape=(1, 3, 16, 16, 16)):
+    def _test_pool3d(
+        opfunc,
+        pool_size=(2, 2, 2),
+        strides=(2, 2, 2),
+        padding=(0, 0, 0, 0, 0, 0),
+        out_shape=(1, 3, 16, 16, 16),
+    ):
         n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224
         x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
         y = opfunc(x, pool_size=(1, 1, 1))
@@ -956,16 +1058,18 @@ def test_pool3d():
         dtype = "float32"
         dshape = (1, 3, 32, 32, 32)
         x = relay.var("x", shape=dshape)
-        pool_type = 'max' if 'max' in str(opfunc) else 'avg'
+        pool_type = "max" if "max" in str(opfunc) else "avg"
         y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
         func = relay.Function([x], y)
         # check output shape
         f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
-        assert out_shape == f_out_shape, \
-            "Output shape mismatch. expected {}, actual {}".format(out_shape, f_out_shape)
+        assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
+            out_shape, f_out_shape
+        )
         data = np.random.uniform(size=dshape).astype(dtype)
-        ref_res = tvm.topi.testing.pool3d_ncdhw_python(data, pool_size, strides,
-                                                   padding, out_shape, pool_type, False)
+        ref_res = tvm.topi.testing.pool3d_ncdhw_python(
+            data, pool_size, strides, padding, out_shape, pool_type, False
+        )
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data)
@@ -993,23 +1097,24 @@ def test_avg_pool2d_no_count_pad():
     (oc, oh, ow) = (3, 15, 15)
     dshape = (n, ic, ih, iw)
     x = relay.var("x", shape=dshape)
-    y = relay.nn.avg_pool2d(x,
-                            pool_size=(kh, kw),
-                            strides=(sw, sw),
-                            padding=(ph, pw),
-                            count_include_pad=False)
+    y = relay.nn.avg_pool2d(
+        x, pool_size=(kh, kw), strides=(sw, sw), padding=(ph, pw), count_include_pad=False
+    )
     func = relay.Function([x], y)
     dtype = "float32"
     a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
-    pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
-    no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
+    pad_np = np.zeros(shape=(n, ic, ih + 2 * ph, iw + 2 * pw)).astype(dtype)
+    no_zero = (range(n), range(ic), (range(ph, ih + ph)), (range(pw, iw + pw)))
     pad_np[np.ix_(*no_zero)] = a_np
     b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
     for i in range(oh):
         for j in range(ow):
-            pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
-            b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw],
-                                   axis=(2,3)) / np.maximum(pad_count, 1)
+            pad_count = np.sum(
+                pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw] > 0, axis=(2, 3)
+            )
+            b_np[:, :, i, j] = np.sum(
+                pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw], axis=(2, 3)
+            ) / np.maximum(pad_count, 1)
     ref_res = np.maximum(b_np, 0.0)
     data = a_np
 
@@ -1018,13 +1123,14 @@ def test_avg_pool2d_no_count_pad():
         op_res1 = intrp1.evaluate(func)(data)
         tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_flatten_infer_type():
     d1, d2, d3, d4 = te.size_var("d1"), te.size_var("d2"), te.size_var("d3"), te.size_var("d4")
     x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))
     y = relay.nn.batch_flatten(x)
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((d1, ((d2*d3)*d4)), "float32")
+    assert yy.checked_type == relay.TensorType((d1, ((d2 * d3) * d4)), "float32")
 
     x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32"))
     y = relay.nn.batch_flatten(x)
@@ -1034,7 +1140,7 @@ def test_flatten_infer_type():
     x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32"))
     y = relay.nn.batch_flatten(x)
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32")
+    assert yy.checked_type == relay.TensorType((d1, ((2 * d3) * 3)), "float32")
 
     shape = (1, 5, 10, 10)
     o_shape = (1, 500)
@@ -1055,6 +1161,7 @@ def test_flatten_infer_type():
         op_res2 = intrp2.evaluate(func)(x_data)
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_pad_infer_type():
     # entirely concrete case
@@ -1072,6 +1179,7 @@ def test_pad_infer_type():
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")
 
+
 @tvm.testing.uses_gpu
 def test_pad_run():
     def _test_run(dtype):
@@ -1080,32 +1188,33 @@ def test_pad_run():
         y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4)))
         func = relay.Function([x], y)
         data = np.random.uniform(size=dshape).astype(dtype)
-        ref_res = np.pad(data, ((1, 1), (2, 2), (3, 3), (4, 4)), 'constant')
+        ref_res = np.pad(data, ((1, 1), (2, 2), (3, 3), (4, 4)), "constant")
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
-    _test_run('float32')
-    _test_run('int32')
+    _test_run("float32")
+    _test_run("int32")
+
 
 @tvm.testing.uses_gpu
 def test_lrn():
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
-    x = relay.var("x", shape=(n, c , h, w))
-    y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    x = relay.var("x", shape=(n, c, h, w))
+    y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=0.00001, beta=0.75)
     "alpha=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c , h, w))
+    assert yy.checked_type == relay.TensorType((n, c, h, w))
 
     shape = (1, 5, 10, 10)
     dtype = "float32"
     x = relay.var("x", relay.TensorType(shape, dtype))
-    size=5
-    axis=1
-    bias=0.5
-    alpha=.00001
-    beta=0.75
+    size = 5
+    axis = 1
+    bias = 0.5
+    alpha = 0.00001
+    beta = 0.75
     z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
     yy = run_infer_type(z)
     assert yy.checked_type == relay.TensorType(shape, dtype)
@@ -1121,20 +1230,21 @@ def test_lrn():
         op_res2 = intrp2.evaluate(func)(x_data)
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_l2_normalize():
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
-    x = relay.var("x", shape=(n, c , h, w))
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    x = relay.var("x", shape=(n, c, h, w))
     y = relay.nn.l2_normalize(x, eps=0.001, axis=[1])
     "axis=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c , h, w))
+    assert yy.checked_type == relay.TensorType((n, c, h, w))
 
     shape = (1, 5, 10, 10)
     dtype = "float32"
     x = relay.var("x", relay.TensorType(shape, dtype))
-    eps=0.001
-    axis=1
+    eps = 0.001
+    axis = 1
     z = relay.nn.l2_normalize(x, eps=0.001, axis=[axis])
     yy = run_infer_type(z)
     assert yy.checked_type == relay.TensorType(shape, dtype)
@@ -1178,28 +1288,43 @@ def _test_upsampling(layout, method, align_corners=False):
     scale_h = 2.0
     scale_w = 2.0
     dtype = "float32"
+
     def get_shape():
         if layout == "NCHW":
-            return (c, h, w), (c, int(round(h*scale_h)), int(round(w*scale_w)))
+            return (c, h, w), (c, int(round(h * scale_h)), int(round(w * scale_w)))
         else:
-            return (h, w, c), (int(round(h*scale_h)), int(round(w*scale_w)), c)
+            return (h, w, c), (int(round(h * scale_h)), int(round(w * scale_w)), c)
+
     ishape, oshape = get_shape()
     x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
-    y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
-                            method=method, align_corners=align_corners)
+    y = relay.nn.upsampling(
+        x,
+        scale_h=scale_h,
+        scale_w=scale_w,
+        layout=layout,
+        method=method,
+        align_corners=align_corners,
+    )
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
     dshape = (1,) + ishape
     x = relay.var("x", shape=dshape)
-    y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
-                            method=method, align_corners=align_corners)
+    y = relay.nn.upsampling(
+        x,
+        scale_h=scale_h,
+        scale_w=scale_w,
+        layout=layout,
+        method=method,
+        align_corners=align_corners,
+    )
     func = relay.Function([x], y)
     data = np.random.uniform(size=dshape).astype(dtype)
     if method == "nearest_neighbor":
         ref = tvm.topi.testing.upsampling_python(data, (scale_h, scale_w), layout)
     else:
-        ref = tvm.topi.testing.bilinear_resize_python(data, (int(round(h*scale_h)),
-                                                  int(round(w*scale_w))), layout)
+        ref = tvm.topi.testing.bilinear_resize_python(
+            data, (int(round(h * scale_h)), int(round(w * scale_w))), layout
+        )
     for target, ctx in tvm.testing.enabled_targets():
         executor = relay.create_executor("graph", ctx=ctx, target=target)
         out = executor.evaluate(func)(data)
@@ -1213,45 +1338,71 @@ def test_upsampling():
     _test_upsampling("NHWC", "nearest_neighbor")
     _test_upsampling("NHWC", "bilinear", True)
 
+
 def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"):
     n, c, d, h, w = te.size_var("n"), 8, 16, 16, 16
     scale_d = 2.0
     scale_h = 2.0
     scale_w = 2.0
     dtype = "float32"
+
     def get_shape():
         if layout == "NCDHW":
-            return (c, d, h, w), (c, int(round(d*scale_d)), int(round(h*scale_h)),\
-                                  int(round(w*scale_w)))
+            return (c, d, h, w), (
+                c,
+                int(round(d * scale_d)),
+                int(round(h * scale_h)),
+                int(round(w * scale_w)),
+            )
         else:
-            return (d, h, w, c), (int(round(d*scale_d)), int(round(h*scale_h)),\
-                                  int(round(w*scale_w)), c)
+            return (d, h, w, c), (
+                int(round(d * scale_d)),
+                int(round(h * scale_h)),
+                int(round(w * scale_w)),
+                c,
+            )
+
     ishape, oshape = get_shape()
     x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
-    y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
-                              layout=layout, method=method,\
-                              coordinate_transformation_mode=coordinate_transformation_mode)
+    y = relay.nn.upsampling3d(
+        x,
+        scale_d=scale_d,
+        scale_h=scale_h,
+        scale_w=scale_w,
+        layout=layout,
+        method=method,
+        coordinate_transformation_mode=coordinate_transformation_mode,
+    )
 
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
     dshape = (1,) + ishape
     x = relay.var("x", shape=dshape)
-    y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
-                            layout=layout, method=method,\
-                            coordinate_transformation_mode=coordinate_transformation_mode)
+    y = relay.nn.upsampling3d(
+        x,
+        scale_d=scale_d,
+        scale_h=scale_h,
+        scale_w=scale_w,
+        layout=layout,
+        method=method,
+        coordinate_transformation_mode=coordinate_transformation_mode,
+    )
     func = relay.Function([x], y)
     data = np.random.uniform(size=dshape).astype(dtype)
     if method == "nearest_neighbor":
         ref = tvm.topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout)
     else:
-        ref = tvm.topi.testing.trilinear_resize3d_python(data, (int(round(d*scale_d)),\
-                                                     int(round(h*scale_h)),\
-                                                     int(round(w*scale_w))), layout)
+        ref = tvm.topi.testing.trilinear_resize3d_python(
+            data,
+            (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))),
+            layout,
+        )
     for target, ctx in tvm.testing.enabled_targets():
         executor = relay.create_executor("graph", ctx=ctx, target=target)
         out = executor.evaluate(func)(data)
         tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_upsampling3d():
     _test_upsampling3d("NCDHW", "nearest_neighbor")
@@ -1259,37 +1410,41 @@ def test_upsampling3d():
     _test_upsampling3d("NDHWC", "nearest_neighbor")
     _test_upsampling3d("NDHWC", "trilinear", "align_corners")
 
+
 @tvm.testing.uses_gpu
 def test_conv2d_int8_intrinsics():
     def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
         input_dtype, weight_dtype, output_dtype = dtypes
 
         n, h, w, ch, cw = 1, 64, 64, 3, 3
-        if data_layout == 'NCHW':
+        if data_layout == "NCHW":
             data_shape = (n, ic, h, w)
             x = relay.var("x", relay.TensorType(data_shape, input_dtype))
-        elif data_layout == 'NHWC':
+        elif data_layout == "NHWC":
             data_shape = (n, h, w, ic)
             x = relay.var("x", relay.TensorType(data_shape, input_dtype))
         else:
-            raise ValueError('Not supported')
+            raise ValueError("Not supported")
 
-        if kernel_layout == 'OIHW':
+        if kernel_layout == "OIHW":
             kernel_shape = (oc, ic, ch, cw)
-        elif kernel_layout == 'HWIO':
+        elif kernel_layout == "HWIO":
             kernel_shape = (ch, cw, ic, oc)
         else:
-            raise ValueError('Not supported')
+            raise ValueError("Not supported")
 
         weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
-        y = relay.nn.conv2d(x, weight,
-                            kernel_size=(ch, cw),
-                            channels=oc,
-                            padding=(1, 1),
-                            dilation=(1, 1),
-                            data_layout=data_layout,
-                            kernel_layout=kernel_layout,
-                            out_dtype=output_dtype)
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            kernel_size=(ch, cw),
+            channels=oc,
+            padding=(1, 1),
+            dilation=(1, 1),
+            data_layout=data_layout,
+            kernel_layout=kernel_layout,
+            out_dtype=output_dtype,
+        )
         func = relay.Function([x, weight], y)
         wdata = np.random.rand(*kernel_shape) * 10
         parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
@@ -1301,9 +1456,9 @@ def test_conv2d_int8_intrinsics():
         return assembly
 
     def _has_fast_int8_instructions(asm, target):
-        if 'skylake-avx512' in target:
+        if "skylake-avx512" in target:
             return "pmaddubs" in asm
-        elif 'cascadelake' in target:
+        elif "cascadelake" in target:
             return "vpdpbusd" in asm
         else:
             assert False, "Target should be Skylake or Cascadelake"
@@ -1316,13 +1471,18 @@ def test_conv2d_int8_intrinsics():
     llvm_version = tvm.target.codegen.llvm_version_major()
     for target in targets:
         if llvm_version >= 8:
-            dtypes = ('uint8', 'int8', 'int32')
+            dtypes = ("uint8", "int8", "int32")
             # Sweep the input channels to check int8 robustness
             # Input channels should be a multiple of 4 internally.
             for ic in [1, 4, 6]:
-                asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
-                               kernel_layout='OIHW',
-                               dtypes=dtypes)
+                asm = _compile(
+                    ic=ic,
+                    oc=16,
+                    target=target,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    dtypes=dtypes,
+                )
                 assert _has_fast_int8_instructions(asm, target)
 
             # for ic in [1, 4, 6]:
@@ -1334,9 +1494,14 @@ def test_conv2d_int8_intrinsics():
             # Sweep the output channels to check int8 robustness
             # Output channels should be a multiple of 16 internally.
             for oc in [4, 16, 20]:
-                asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
-                               kernel_layout='OIHW',
-                               dtypes=dtypes)
+                asm = _compile(
+                    ic=8,
+                    oc=oc,
+                    target=target,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    dtypes=dtypes,
+                )
                 assert _has_fast_int8_instructions(asm, target)
 
             # for oc in [4, 16, 20]:
@@ -1346,8 +1511,9 @@ def test_conv2d_int8_intrinsics():
             #     assert _has_fast_int8_instructions(asm, target)
 
             # Check that both non-divisible oc and ic work
-            asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
-                           dtypes=dtypes)
+            asm = _compile(
+                ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
+            )
             assert _has_fast_int8_instructions(asm, target)
 
             # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
@@ -1357,10 +1523,11 @@ def test_conv2d_int8_intrinsics():
     # Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
     for target in targets:
         if llvm_version >= 8:
-            dtypes = (('int8', 'int8', 'int32'))
+            dtypes = ("int8", "int8", "int32")
             # Check that both non-divisible oc and ic work
-            asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
-                           dtypes=dtypes)
+            asm = _compile(
+                ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
+            )
             assert _has_fast_int8_instructions(asm, target)
 
             # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
@@ -1377,18 +1544,24 @@ def test_conv2d_int8_intrinsics():
     # Check that a vectorized instruction is generated for older Intel
     # generations, because we default to NCHWc layout.
     target = "llvm -mcpu=core-avx2"
-    fast_int8_dtypes = ('uint8', 'int8', 'int32')
-    asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
-                   dtypes=fast_int8_dtypes)
+    fast_int8_dtypes = ("uint8", "int8", "int32")
+    asm = _compile(
+        ic=16,
+        oc=32,
+        target=target,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        dtypes=fast_int8_dtypes,
+    )
     # Check that vector int mult and add instructions are generated.
     assert "vpmulld" in asm and "vpadd" in asm
 
 
 @tvm.testing.uses_gpu
 def test_depthwise_conv2d_int8():
-    input_dtype = 'uint8'
-    weight_dtype = 'int8'
-    output_dtype = 'int32'
+    input_dtype = "uint8"
+    weight_dtype = "int8"
+    output_dtype = "int32"
 
     data_shape = (1, 64, 56, 56)
     x = relay.var("x", relay.TensorType(data_shape, input_dtype))
@@ -1396,12 +1569,15 @@ def test_depthwise_conv2d_int8():
     kernel_shape = (64, 1, 3, 3)
     weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
 
-    y = relay.nn.conv2d(x, weight,
-                        kernel_size=(3, 3),
-                        groups=64,
-                        padding=(1, 1),
-                        dilation=(1, 1),
-                        out_dtype=output_dtype)
+    y = relay.nn.conv2d(
+        x,
+        weight,
+        kernel_size=(3, 3),
+        groups=64,
+        padding=(1, 1),
+        dilation=(1, 1),
+        out_dtype=output_dtype,
+    )
     func = relay.Function([x, weight], y)
     wdata = np.random.rand(*kernel_shape) * 10
     parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
@@ -1420,11 +1596,9 @@ def test_bitserial_conv2d_infer_type():
     n, c, h, w = te.size_var("n"), 32, 224, 224
     x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16"))
     w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16"))
-    y = relay.nn.bitserial_conv2d(
-        x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
+    y = relay.nn.bitserial_conv2d(x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (n, 32, 222, 222), "int16")
+    assert yy.checked_type == relay.TensorType((n, 32, 222, 222), "int16")
 
 
 @tvm.testing.uses_gpu
@@ -1432,21 +1606,39 @@ def test_bitpack_infer_type():
     # Test axis packing shape inference.
     o, i, h, w = 32, 32, 128, 128
     x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16"))
-    y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1)
+    y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type="uint16", bits=1)
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (32, 2, 128, 128, 1), "uint16")
+    assert yy.checked_type == relay.TensorType((32, 2, 128, 128, 1), "uint16")
+
 
 # TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases
 
 
 @tvm.testing.uses_gpu
 def test_correlation():
-    def _test_correlation(data_shape, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, dtype='float32'):
+    def _test_correlation(
+        data_shape,
+        kernel_size,
+        max_displacement,
+        stride1,
+        stride2,
+        padding,
+        is_multiply,
+        dtype="float32",
+    ):
         data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype))
         data2 = relay.var("data2", relay.ty.TensorType(data_shape, dtype))
-        y = relay.nn.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2,
-                                 padding, is_multiply, "NCHW")
+        y = relay.nn.correlation(
+            data1,
+            data2,
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2,
+            padding,
+            is_multiply,
+            "NCHW",
+        )
         yy = run_infer_type(y)
         padded_height = data_shape[2] + 2 * padding
         padded_width = data_shape[3] + 2 * padding
@@ -1461,23 +1653,67 @@ def test_correlation():
         func = relay.Function([data1, data2], y)
         data1_np = np.random.uniform(size=data_shape).astype(dtype)
         data2_np = np.random.uniform(size=data_shape).astype(dtype)
-        ref_res = tvm.topi.testing.correlation_nchw_python(data1_np, data2_np, kernel_size, max_displacement, stride1, stride2, padding, is_multiply)
+        ref_res = tvm.topi.testing.correlation_nchw_python(
+            data1_np,
+            data2_np,
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2,
+            padding,
+            is_multiply,
+        )
 
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data1_np, data2_np)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
-    _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=4,
-                      stride1=1, stride2=1, padding=4, is_multiply=True)
-    _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=5,
-                      stride1=1, stride2=1, padding=5, is_multiply=True)
-    _test_correlation((5, 1, 4, 4), kernel_size=3, max_displacement=1,
-                      stride1=2, stride2=1, padding=2, is_multiply=True)
-    _test_correlation((5, 1, 6, 4), kernel_size=3, max_displacement=1,
-                      stride1=2, stride2=2, padding=2, is_multiply=False)
-    _test_correlation((5, 1, 11, 11), kernel_size=5, max_displacement=1,
-                      stride1=1, stride2=1, padding=2, is_multiply=False)
+    _test_correlation(
+        (1, 3, 10, 10),
+        kernel_size=1,
+        max_displacement=4,
+        stride1=1,
+        stride2=1,
+        padding=4,
+        is_multiply=True,
+    )
+    _test_correlation(
+        (1, 3, 10, 10),
+        kernel_size=1,
+        max_displacement=5,
+        stride1=1,
+        stride2=1,
+        padding=5,
+        is_multiply=True,
+    )
+    _test_correlation(
+        (5, 1, 4, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=1,
+        padding=2,
+        is_multiply=True,
+    )
+    _test_correlation(
+        (5, 1, 6, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=2,
+        padding=2,
+        is_multiply=False,
+    )
+    _test_correlation(
+        (5, 1, 11, 11),
+        kernel_size=5,
+        max_displacement=1,
+        stride1=1,
+        stride2=1,
+        padding=2,
+        is_multiply=False,
+    )
 
 
 if __name__ == "__main__":
index 98ef38d..b019777 100644 (file)
@@ -34,19 +34,22 @@ def test_zeros_ones():
         assert yy.checked_type == relay.TensorType((124, 50), "float64")
         intrp = create_executor()
         intrp_res = intrp.evaluate(y).asnumpy()
-        np.testing.assert_allclose(intrp_res, ref((124, 50), 'float64'))
+        np.testing.assert_allclose(intrp_res, ref((124, 50), "float64"))
+
 
 def test_unary_identity():
-    for op, ref in [(relay.zeros_like, np.zeros_like),
-               (relay.ones_like, np.ones_like),
-               (relay.ceil, np.ceil),
-               (relay.floor, np.floor),
-               (relay.trunc, np.trunc),
-               (relay.round, np.round),
-               (relay.abs, np.abs),
-               (relay.copy, None), # np.copy
-               (relay.negative, np.negative),
-               (relay.sign, np.sign)]:
+    for op, ref in [
+        (relay.zeros_like, np.zeros_like),
+        (relay.ones_like, np.ones_like),
+        (relay.ceil, np.ceil),
+        (relay.floor, np.floor),
+        (relay.trunc, np.trunc),
+        (relay.round, np.round),
+        (relay.abs, np.abs),
+        (relay.copy, None),  # np.copy
+        (relay.negative, np.negative),
+        (relay.sign, np.sign),
+    ]:
         shape = (8, 9, 4)
         x = relay.var("x", relay.TensorType(shape, "float32"))
         y = op(x)
@@ -54,12 +57,13 @@ def test_unary_identity():
         assert yy.checked_type == relay.TensorType(shape, "float32")
 
         if ref is not None:
-            data = np.random.rand(*shape).astype('float32')
+            data = np.random.rand(*shape).astype("float32")
             intrp = create_executor()
-            op_res = intrp.evaluate(y, { x: relay.const(data) })
+            op_res = intrp.evaluate(y, {x: relay.const(data)})
             ref_res = ref(data)
             np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
+
 def test_cast():
     x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
     y = x.astype("int32")
@@ -76,16 +80,17 @@ def test_cast():
 
 def test_clip():
     a = relay.var("a", relay.TensorType((10, 4), "float32"))
-    y = relay.clip(a, 1., 4.)
+    y = relay.clip(a, 1.0, 4.0)
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((10, 4), "float32")
 
-    data = np.random.rand(10, 4).astype('float32')
+    data = np.random.rand(10, 4).astype("float32")
     intrp = create_executor()
-    op_res = intrp.evaluate(y, { a: relay.const(data) })
-    ref_res = np.clip(data, 1., 4.)
+    op_res = intrp.evaluate(y, {a: relay.const(data)})
+    ref_res = np.clip(data, 1.0, 4.0)
     np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
+
 def test_fixed_point_multiply():
     # Test 23 * 1/16
     # [m,s] = [0.5, -3] = frexp(1/16)
@@ -97,19 +102,20 @@ def test_fixed_point_multiply():
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((10, 4), "int32")
 
-    data = 23*np.ones((10, 4)).astype('int32')
+    data = 23 * np.ones((10, 4)).astype("int32")
     intrp = create_executor()
-    op_res = intrp.evaluate(y, { a: relay.const(data) })
-    ref_res = np.ones((10, 4)).astype('int32')
+    op_res = intrp.evaluate(y, {a: relay.const(data)})
+    ref_res = np.ones((10, 4)).astype("int32")
     np.testing.assert_allclose(op_res.asnumpy(), ref_res, atol=1)
 
+
 def test_reinterpret():
     a = relay.var("a", relay.TensorType((1000, 4), "float32"))
     y = relay.reinterpret(a, "int32")
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((1000, 4), "int32")
 
-    data = np.random.randn(1000, 4).astype('float32') * 1000
+    data = np.random.randn(1000, 4).astype("float32") * 1000
     intrp = create_executor()
     op_res = intrp.evaluate(y, {a: relay.const(data)})
     ref_res = data.view("int32")
@@ -152,6 +158,7 @@ def test_approximate_transcendental():
 
     def reference_sigmoid(x):
         return np.exp(-np.logaddexp(0, -x))
+
     np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9)
 
     y = approximate_tanh(a)
@@ -163,6 +170,7 @@ def test_approximate_transcendental():
 
     def reference_tanh(x):
         return np.tanh(x)
+
     np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9)
 
 
@@ -175,7 +183,7 @@ def test_squeeze():
 
         data = np.random.random_sample(shape).astype(dtype)
         intrp = create_executor()
-        op_res = intrp.evaluate(squeeze, { x : relay.const(data) })
+        op_res = intrp.evaluate(squeeze, {x: relay.const(data)})
         ref_res = np.squeeze(data, axis=np_axis)
         np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
@@ -190,14 +198,12 @@ def test_transpose_infer_type():
     y = relay.transpose(x, axes=(1, 0, 2))
     assert "axes=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (t, n, 100), "float32")
+    assert yy.checked_type == relay.TensorType((t, n, 100), "float32")
 
     y = relay.transpose(x)
     assert "axes=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (100, t, n), "float32")
+    assert yy.checked_type == relay.TensorType((100, t, n), "float32")
 
 
 @tvm.testing.uses_gpu
@@ -215,6 +221,7 @@ def test_transpose():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_transpose((2, 3, 4), (0, 2, 1))
 
 
@@ -224,16 +231,15 @@ def test_squeeze_infer_type():
     y = relay.squeeze(x, axis=(2,))
     assert "axis=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (1, 4), "float32")
+    assert yy.checked_type == relay.TensorType((1, 4), "float32")
 
     n, t, d = 1, 4, 1
     x = relay.var("x", relay.TensorType((n, t, d), "float32"))
     y = relay.squeeze(x)
     assert "axis=" not in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (4,), "float32")
+    assert yy.checked_type == relay.TensorType((4,), "float32")
+
 
 @pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
 def test_squeeze_bad_axes_infer_type():
@@ -249,8 +255,8 @@ def test_reshape_infer_type():
     y = relay.reshape(x, newshape=(n, t, 2000))
     assert "newshape=" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (n, t, 2000), "float32")
+    assert yy.checked_type == relay.TensorType((n, t, 2000), "float32")
+
 
 @tvm.testing.uses_gpu
 def test_reshape():
@@ -270,6 +276,7 @@ def test_reshape():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
     verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
@@ -288,7 +295,7 @@ def test_reshape():
 
 def test_reshape_fail():
     with pytest.raises(TVMError) as reshape_err:
-        x = relay.var("x", relay.TensorType([2,3], "float32"))
+        x = relay.var("x", relay.TensorType([2, 3], "float32"))
         z = relay.reshape(x, [7])
         zz = run_infer_type(z)
 
@@ -296,7 +303,7 @@ def test_reshape_fail():
 def test_reshape_like_infer_type():
     # concrete shape
     x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
-    y = relay.var("y", relay.TensorType((1,6), "float32"))
+    y = relay.var("y", relay.TensorType((1, 6), "float32"))
     z = relay.reshape_like(x, y)
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((1, 6), "float32")
@@ -334,6 +341,7 @@ def test_reshape_like():
     verify_reshape_like((2, 3, 4), (1, 8, 3))
     verify_reshape_like((4, 7), (2, 7, 2))
 
+
 def test_take_infer_type():
     def verify_take(dshape, indices_shape, oshape, axis=None):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
@@ -351,6 +359,7 @@ def test_take_infer_type():
     verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
     verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
 
+
 @tvm.testing.uses_gpu
 def test_take():
     def verify_take(src_shape, indices_src, axis=None, mode="clip"):
@@ -373,22 +382,22 @@ def test_take():
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
     verify_take((4,), [1])
-    verify_take((4,), [[0,1,2,3]])
-    verify_take((3,3,3), [[11,25]])
-    verify_take((4,), [[0,1],[2,3]])
+    verify_take((4,), [[0, 1, 2, 3]])
+    verify_take((3, 3, 3), [[11, 25]])
+    verify_take((4,), [[0, 1], [2, 3]])
     verify_take((4,), [1], 0)
-    verify_take((2,2), [[[1,0],[0,1]]], 0)
-    verify_take((2,2), [[[1,0],[0,1]]], 1)
-    verify_take((4,3,5,6), [[2,1,0,0]], -2)
-    verify_take((3,4), [-5, 20])
-    verify_take((3,4), [-5, 20], mode="wrap")
-    verify_take((3,4), [-1, 2], axis=0)
-    verify_take((3,4), [-1, 2], axis=0, mode="wrap")
-    verify_take((3,4), [-1, 2], axis=1)
-    verify_take((3,4), [-1, 2], axis=1, mode="wrap")
-    verify_take((3,3,3), [[11,25]], mode="fast")
-    verify_take((3,4), [0, 2], axis=0, mode="fast")
-    verify_take((3,4), [0, 2], axis=1, mode="fast")
+    verify_take((2, 2), [[[1, 0], [0, 1]]], 0)
+    verify_take((2, 2), [[[1, 0], [0, 1]]], 1)
+    verify_take((4, 3, 5, 6), [[2, 1, 0, 0]], -2)
+    verify_take((3, 4), [-5, 20])
+    verify_take((3, 4), [-5, 20], mode="wrap")
+    verify_take((3, 4), [-1, 2], axis=0)
+    verify_take((3, 4), [-1, 2], axis=0, mode="wrap")
+    verify_take((3, 4), [-1, 2], axis=1)
+    verify_take((3, 4), [-1, 2], axis=1, mode="wrap")
+    verify_take((3, 3, 3), [[11, 25]], mode="fast")
+    verify_take((3, 4), [0, 2], axis=0, mode="fast")
+    verify_take((3, 4), [0, 2], axis=1, mode="fast")
 
 
 def test_split_infer_type():
@@ -402,41 +411,82 @@ def test_split_infer_type():
 
     d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
     axis = te.var("axis")
-    verify_split((5, 5, 2, 2), 5,
-                 relay.ty.TupleType(tvm.runtime.convert([
-                     relay.ty.TensorType((5, 1, 2, 2), "float32"),
-                     relay.ty.TensorType((5, 1, 2, 2), "float32"),
-                     relay.ty.TensorType((5, 1, 2, 2), "float32"),
-                     relay.ty.TensorType((5, 1, 2, 2), "float32"),
-                     relay.ty.TensorType((5, 1, 2, 2), "float32")])),
-                  axis=1)
-    verify_split((5, 5, 2, 2), 5,
-                 relay.ty.TupleType(tvm.runtime.convert([
-                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
-                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
-                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
-                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
-                     relay.ty.TensorType((1, 5, 2, 2), "float32")])),
-                  axis=0)
-    verify_split((d1, d2, d3, d4), 4,
-                 relay.ty.TupleType(tvm.runtime.convert([
-                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
-                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
-                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
-                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
-                  axis=2)
-    verify_split((d1, d2, d3, d4), 2,
-                 relay.ty.TupleType(tvm.runtime.convert([
-                     relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
-                     relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
-                  axis=0)
-    verify_split((d1, d2, d3, d4), (2, 4, 7),
-                 relay.ty.TupleType(tvm.runtime.convert([
-                     relay.ty.TensorType((d1, 2, d3, d4), "float32"),
-                     relay.ty.TensorType((d1, 2, d3, d4), "float32"),
-                     relay.ty.TensorType((d1, 3, d3, d4), "float32"),
-                     relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
-                  axis=1)
+    verify_split(
+        (5, 5, 2, 2),
+        5,
+        relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((5, 1, 2, 2), "float32"),
+                    relay.ty.TensorType((5, 1, 2, 2), "float32"),
+                    relay.ty.TensorType((5, 1, 2, 2), "float32"),
+                    relay.ty.TensorType((5, 1, 2, 2), "float32"),
+                    relay.ty.TensorType((5, 1, 2, 2), "float32"),
+                ]
+            )
+        ),
+        axis=1,
+    )
+    verify_split(
+        (5, 5, 2, 2),
+        5,
+        relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                    relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                    relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                    relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                    relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                ]
+            )
+        ),
+        axis=0,
+    )
+    verify_split(
+        (d1, d2, d3, d4),
+        4,
+        relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                    relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                    relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                    relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                ]
+            )
+        ),
+        axis=2,
+    )
+    verify_split(
+        (d1, d2, d3, d4),
+        2,
+        relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
+                    relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
+                ]
+            )
+        ),
+        axis=0,
+    )
+    verify_split(
+        (d1, d2, d3, d4),
+        (2, 4, 7),
+        relay.ty.TupleType(
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((d1, 2, d3, d4), "float32"),
+                    relay.ty.TensorType((d1, 2, d3, d4), "float32"),
+                    relay.ty.TensorType((d1, 3, d3, d4), "float32"),
+                    relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32"),
+                ]
+            )
+        ),
+        axis=1,
+    )
+
 
 def test_full_infer_type():
     # default settings: match input dtype
@@ -465,8 +515,9 @@ def test_full():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(np.array(fill_value, dtype))
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_full(4, (1, 3, 4, 4), "int32")
-    #verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it.
+    # verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it.
     verify_full(4.0, (1, 4), "float32")
 
 
@@ -503,13 +554,14 @@ def test_full_like():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data, np.array(fill_value, dtype))
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_full_like((1, 3, 4, 4), 4, "int32")
     verify_full_like((1, 1), 44.0, "float32")
 
 
 @tvm.testing.uses_gpu
 def test_infer_type_leaky_relu():
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
     x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
     y = relay.nn.leaky_relu(x, alpha=0.1)
     "alpha=0.1" in y.astext()
@@ -535,6 +587,7 @@ def test_infer_type_leaky_relu():
         op_res2 = intrp2.evaluate(func)(x_data)
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
+
 def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
     x = relay.var("data", relay.TensorType(data, dtype))
     if alpha:
@@ -559,9 +612,9 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
     a_data = np.random.uniform(low=-1, high=1, size=alpha).astype(dtype)
 
     if axis == 1:
-        ref_res = (x_data < 0) * (x_data * a_data.reshape(3, 1, 1)) + (x_data>=0) * x_data
+        ref_res = (x_data < 0) * (x_data * a_data.reshape(3, 1, 1)) + (x_data >= 0) * x_data
     else:
-        ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data>=0) * x_data
+        ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data >= 0) * x_data
 
     for target, ctx in tvm.testing.enabled_targets():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
@@ -574,7 +627,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
 
 @tvm.testing.uses_gpu
 def test_infer_type_prelu():
-    n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
     verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
     verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c))
     verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w))
@@ -602,7 +655,8 @@ def test_arange():
             x = relay.arange(
                 relay.const(start, dtype=dtype),
                 relay.const(stop, dtype=dtype),
-                relay.const(step, dtype=dtype))
+                relay.const(step, dtype=dtype),
+            )
             ref_res = np.arange(start, stop, step).astype(dtype)
 
         func = relay.Function([], x)
@@ -611,6 +665,7 @@ def test_arange():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)()
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_arange(None, 20, None)
     verify_arange(None, 20, 2)
     verify_arange(1, 20, None)
@@ -623,6 +678,7 @@ def test_arange():
     # arange doesnt' support floating point right now, see type relation
     # verify_arange(20, 1, -1.5)
 
+
 @tvm.testing.uses_gpu
 def test_meshgrid():
     def verify_meshgrid(lengths, indexing="ij"):
@@ -650,6 +706,7 @@ def test_meshgrid():
                 assert len(op_res) == len(ref_res)
                 for i in range(len(op_res)):
                     tvm.testing.assert_allclose(op_res[i].asnumpy(), ref_res[i], rtol=1e-5)
+
     verify_meshgrid([3, 5])
     verify_meshgrid([4, 2], indexing="xy")
     verify_meshgrid([3, 5, 2])
@@ -657,6 +714,7 @@ def test_meshgrid():
     # Length 0 signifies scalar.
     verify_meshgrid([3, 5, 0])
 
+
 @tvm.testing.uses_gpu
 def test_tile():
     def verify_tile(dshape, reps):
@@ -672,10 +730,12 @@ def test_tile():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_tile((2, 3, 4), (3, 2, 1))
     verify_tile((2, 3, 4), (1, 2))
     verify_tile((2, 3), (3, 2, 1))
 
+
 @tvm.testing.uses_gpu
 def test_repeat():
     def verify_repeat(dshape, repeats, axis):
@@ -688,10 +748,12 @@ def test_repeat():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_repeat((3,), 2, 0)
     verify_repeat((3, 10), 2, -1)
     verify_repeat((3, 2, 4), 3, 1)
 
+
 @tvm.testing.uses_gpu
 def test_stack():
     def verify_stack(dshapes, axis):
@@ -710,6 +772,7 @@ def test_stack():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(*x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_stack([(2,), (2,), (2,)], -1)
     verify_stack([(2,), (2,), (2,)], 0)
     verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
@@ -732,6 +795,7 @@ def test_reverse():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_reverse((2, 3, 4), 1)
     verify_reverse((4, 7), 0)
     verify_reverse((2, 3, 4), -1)
@@ -754,46 +818,53 @@ def test_reverse_sequence():
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 5, 10, 15],
-              [4, 1, 6, 11],
-              [8, 9, 2, 7],
-              [12, 13, 14, 3]]
+    result = [[0, 5, 10, 15], [4, 1, 6, 11], [8, 9, 2, 7], [12, 13, 14, 3]]
     verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result))
     verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result))
-    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32"))
+    verify_reverse_sequence(
+        indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32")
+    )
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 1, 2, 3],
-              [5, 4, 6, 7],
-              [10, 9, 8, 11],
-              [15, 14, 13, 12]]
+    result = [[0, 1, 2, 3], [5, 4, 6, 7], [10, 9, 8, 11], [15, 14, 13, 12]]
     verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result))
     verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result))
-    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32"))
+    verify_reverse_sequence(
+        indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32")
+    )
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 1, 2, 3],
-              [4, 5, 6, 7],
-              [8, 9, 10, 11],
-              [15, 14, 13, 12]]
+    result = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [15, 14, 13, 12]]
     verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result))
 
     indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
-    result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
-               [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
-               [[0,  1,  2], [3,  4,  5], [6,  7,  8]]],
-              [[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
-               [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
-               [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]]
+    result = [
+        [
+            [[18, 19, 20], [21, 22, 23], [24, 25, 26]],
+            [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8]],
+        ],
+        [
+            [[45, 46, 47], [48, 49, 50], [51, 52, 53]],
+            [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
+            [[27, 28, 29], [30, 31, 32], [33, 34, 35]],
+        ],
+    ]
     verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result))
 
     indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
-    result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]],
-               [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
-               [[18, 19, 20], [3, 4, 5], [24, 25, 26]]],
-              [[[36, 37, 38], [48, 49, 50], [42, 43, 44]],
-               [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
-               [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]]
+    result = [
+        [
+            [[9, 10, 11], [21, 22, 23], [15, 16, 17]],
+            [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
+            [[18, 19, 20], [3, 4, 5], [24, 25, 26]],
+        ],
+        [
+            [[36, 37, 38], [48, 49, 50], [42, 43, 44]],
+            [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
+            [[45, 46, 47], [30, 31, 32], [51, 52, 53]],
+        ],
+    ]
     verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result))
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
@@ -801,12 +872,13 @@ def test_reverse_sequence():
     with pytest.raises(Exception) as execinfo:
         verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result))
 
-    assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \
-           " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+    assert (
+        "For reverse_sequnece seq_lengths size should match with dimension of batch axis,"
+        " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+    )
 
 
 def test_scatter():
-
     def ref_scatter(data, indices, updates, axis=0):
         idx = np.indices(indices.shape).reshape(indices.ndim, -1)
 
@@ -836,10 +908,9 @@ def test_scatter():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
-                tvm.testing.assert_allclose(
-                    op_res.asnumpy(), ref_res, rtol=1e-5)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
-    verify_scatter((10, ), (10, ), 0)
+    verify_scatter((10,), (10,), 0)
     verify_scatter((10, 5), (10, 5), -2)
     verify_scatter((10, 5), (10, 5), -1)
     verify_scatter((10, 5), (3, 5), 0)
@@ -854,7 +925,6 @@ def test_scatter():
 
 
 def test_scatter_add():
-
     def ref_scatter_add(data, indices, updates, axis=0):
         output = np.copy(data)
         for index in np.ndindex(*indices.shape):
@@ -881,10 +951,9 @@ def test_scatter_add():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
-                tvm.testing.assert_allclose(
-                    op_res.asnumpy(), ref_res, rtol=1e-5)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
-    verify_scatter_add((10, ), (10, ), 0)
+    verify_scatter_add((10,), (10,), 0)
     verify_scatter_add((10, 5), (10, 5), -2)
     verify_scatter_add((10, 5), (10, 5), -1)
     verify_scatter_add((10, 5), (3, 5), 0)
@@ -901,8 +970,8 @@ def test_scatter_add():
 @tvm.testing.uses_gpu
 def test_gather():
     def verify_gather(data, axis, indices, ref_res):
-        data = np.asarray(data, dtype='float32')
-        indices = np.asarray(indices, dtype='int32')
+        data = np.asarray(data, dtype="float32")
+        indices = np.asarray(indices, dtype="int32")
         ref_res = np.asarray(ref_res)
 
         d = relay.var("x", relay.TensorType(data.shape, "float32"))
@@ -915,40 +984,72 @@ def test_gather():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data, indices)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
-                                            rtol=1e-5)
-
-    verify_gather([[1, 2], [3, 4]],
-                  1,
-                  [[0, 0], [1, 0]],
-                  [[1, 1], [4, 3]])
-    verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
-                  0,
-                  [[[1, 0, 1], [1, 1, 0]]],
-                  [[[6, 1, 8], [9, 10, 5]]])
-    verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448],
-                    [0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]],
-                   [[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502],
-                    [0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]],
-                  1,
-                  [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
-                  [[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
-                   [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]])
-    verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818],
-                    [0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]],
-                   [[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084],
-                    [0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]],
-                  2,
-                  [[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
-                   [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]],
-                  [[[1.6986, 1.6986, 0.3050, 1.6986],
-                    [0.7020, 0.7020, -2.1818, -2.1818],
-                    [-0.5773, -0.9912, -0.5773, -0.9912],
-                    [-1.0720, -1.0720, -1.3915, 0.0835]],
-                   [[0.1694, 0.1694, -0.6091, -0.6539],
-                    [0.5084, 0.5084, -0.1218, -0.5234],
-                    [-1.9537, -2.0078, 0.2374, 0.2374],
-                    [-0.5700, 0.1558, -0.5700, 0.1558]]])
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]])
+    verify_gather(
+        [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
+        0,
+        [[[1, 0, 1], [1, 1, 0]]],
+        [[[6, 1, 8], [9, 10, 5]]],
+    )
+    verify_gather(
+        [
+            [
+                [-0.2321, -0.2024, -1.7624],
+                [-0.3829, -0.4246, 0.2448],
+                [0.1822, 0.2360, -0.8965],
+                [0.4497, -0.2224, 0.6103],
+            ],
+            [
+                [0.0408, -0.7667, -0.4303],
+                [-0.3216, 0.7489, -0.1502],
+                [0.0144, -0.4699, -0.0064],
+                [-0.0768, -1.6064, 1.3390],
+            ],
+        ],
+        1,
+        [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
+        [
+            [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
+            [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]],
+        ],
+    )
+    verify_gather(
+        [
+            [
+                [0.3050, 1.6986, 1.1034],
+                [0.7020, -0.6960, -2.1818],
+                [0.3116, -0.5773, -0.9912],
+                [0.0835, -1.3915, -1.0720],
+            ],
+            [
+                [0.1694, -0.6091, -0.6539],
+                [-0.5234, -0.1218, 0.5084],
+                [0.2374, -1.9537, -2.0078],
+                [-0.5700, -1.0302, 0.1558],
+            ],
+        ],
+        2,
+        [
+            [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
+            [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]],
+        ],
+        [
+            [
+                [1.6986, 1.6986, 0.3050, 1.6986],
+                [0.7020, 0.7020, -2.1818, -2.1818],
+                [-0.5773, -0.9912, -0.5773, -0.9912],
+                [-1.0720, -1.0720, -1.3915, 0.0835],
+            ],
+            [
+                [0.1694, 0.1694, -0.6091, -0.6539],
+                [0.5084, 0.5084, -0.1218, -0.5234],
+                [-1.9537, -2.0078, 0.2374, 0.2374],
+                [-0.5700, 0.1558, -0.5700, 0.1558],
+            ],
+        ],
+    )
 
 
 @tvm.testing.uses_gpu
@@ -967,6 +1068,7 @@ def test_gather_nd():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data, y_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
     verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
     verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
@@ -974,7 +1076,7 @@ def test_gather_nd():
 
 
 def _verify_infiniteness_ops(relay_op, ref_op):
-    for dtype in ['float32', 'float16', 'float16', 'int32', 'int16']:
+    for dtype in ["float32", "float16", "float16", "int32", "int16"]:
         shape = (2, 8, 8)
         x = relay.var("x", relay.TensorType(shape, dtype))
         y = relay_op(x)
@@ -982,8 +1084,10 @@ def _verify_infiniteness_ops(relay_op, ref_op):
         assert yy.checked_type == relay.TensorType(shape, "bool")
 
         data = np.random.uniform(size=shape).astype(dtype)
-        if dtype.startswith('float'):
-            data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.infty
+        if dtype.startswith("float"):
+            data.ravel()[
+                np.random.choice(data.size, int(data.size * 0.5), replace=False)
+            ] = np.infty
             data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan
 
         intrp = create_executor()
@@ -1037,6 +1141,7 @@ def test_unravel_index():
         # output which is inline with Tensorflow
         # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)
 
+
 @tvm.testing.uses_gpu
 def test_sparse_to_dense():
     def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):
@@ -1044,13 +1149,19 @@ def test_sparse_to_dense():
         sparse_values_data = np.array(sparse_values)
         default_value_data = np.array(default_value)
 
-        a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype)))
-        b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype)))
+        a = relay.var(
+            "a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype))
+        )
+        b = relay.var(
+            "b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype))
+        )
         if default_value is None:
             args = [a, b]
             d = relay.sparse_to_dense(a, output_shape, b)
         else:
-            c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype)))
+            c = relay.var(
+                "c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype))
+            )
             args = [a, b, c]
             d = relay.sparse_to_dense(a, output_shape, b, c)
 
@@ -1069,27 +1180,31 @@ def test_sparse_to_dense():
                     )
                 tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5)
 
-
     verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0])  # scalar
     verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3])  # vector
-    verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])  # nXd
+    verify_sparse_to_dense(
+        [[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
+    )  # nXd
     verify_sparse_to_dense(
         [[0, 0, 0], [1, 2, 3]],
         [1, 2],
         4,
         [2, 3, 4],
-        [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]]
+        [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]],
     )  # nXd
-    verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])  # floats
+    verify_sparse_to_dense(
+        [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]
+    )  # floats
     verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0])  # default value not specified
 
-    #negative test cases
-    #sparse indices should be ints
-    #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
-    #sparse_values should be 0d or 1d only
-    #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
-    #sparse_indices should not be > 2d tensor
-    #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # negative test cases
+    # sparse indices should be ints
+    # verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # sparse_values should be 0d or 1d only
+    # verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # sparse_indices should not be > 2d tensor
+    # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+
 
 def test_adv_index():
     def verify_adv_index(data_shape, index_shapes):
@@ -1113,9 +1228,15 @@ def test_adv_index():
                 tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5)
 
     verify_adv_index((10, 5), [(3, 4), (3, 1)])
-    verify_adv_index((10, 5), [(2,),])
+    verify_adv_index(
+        (10, 5),
+        [
+            (2,),
+        ],
+    )
     verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
 
+
 if __name__ == "__main__":
     test_cast()
     test_zeros_ones()
index 4f74d72..0df5a28 100644 (file)
@@ -60,12 +60,14 @@ def test_binary_op():
 
 @tvm.testing.uses_gpu
 def test_cmp_type():
-    for op, ref in ((relay.greater, np.greater),
-                    (relay.greater_equal, np.greater_equal),
-                    (relay.less, np.less),
-                    (relay.less_equal, np.less_equal),
-                    (relay.equal, np.equal),
-                    (relay.not_equal, np.not_equal)):
+    for op, ref in (
+        (relay.greater, np.greater),
+        (relay.greater_equal, np.greater_equal),
+        (relay.less, np.less),
+        (relay.less_equal, np.less_equal),
+        (relay.equal, np.equal),
+        (relay.not_equal, np.not_equal),
+    ):
         x = relay.var("x", relay.TensorType((10, 4), "float32"))
         y = relay.var("y", relay.TensorType((5, 10, 1), "float32"))
         z = op(x, y)
@@ -93,8 +95,7 @@ def test_cmp_type():
 
 @tvm.testing.uses_gpu
 def test_binary_int_broadcast_1():
-    for op, ref in [(relay.right_shift, np.right_shift),
-                    (relay.left_shift, np.left_shift)]:
+    for op, ref in [(relay.right_shift, np.right_shift), (relay.left_shift, np.left_shift)]:
         x = relay.var("x", relay.TensorType((10, 4), "int32"))
         y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
         z = op(x, y)
@@ -104,8 +105,8 @@ def test_binary_int_broadcast_1():
         if ref is not None:
             x_shape = (10, 4)
             y_shape = (5, 10, 1)
-            t1 = relay.TensorType(x_shape, 'int32')
-            t2 = relay.TensorType(y_shape, 'int32')
+            t1 = relay.TensorType(x_shape, "int32")
+            t2 = relay.TensorType(y_shape, "int32")
             x_data = np.random.randint(1, 10000, size=(x_shape)).astype(t1.dtype)
             y_data = np.random.randint(1, 31, size=(y_shape)).astype(t2.dtype)
             func = relay.Function([x, y], z)
@@ -116,11 +117,10 @@ def test_binary_int_broadcast_1():
                 op_res = intrp.evaluate(func)(x_data, y_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
+
 @tvm.testing.uses_gpu
 def test_binary_int_broadcast_2():
-    for op, ref in [(relay.maximum, np.maximum),
-                    (relay.minimum, np.minimum),
-                    (relay.mod, np.mod)]:
+    for op, ref in [(relay.maximum, np.maximum), (relay.minimum, np.minimum), (relay.mod, np.mod)]:
         x = relay.var("x", relay.TensorType((10, 4), "int32"))
         y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
         z = op(x, y)
@@ -130,8 +130,8 @@ def test_binary_int_broadcast_2():
         if ref is not None:
             x_shape = (10, 4)
             y_shape = (5, 10, 1)
-            t1 = relay.TensorType(x_shape, 'int32')
-            t2 = relay.TensorType(y_shape, 'int32')
+            t1 = relay.TensorType(x_shape, "int32")
+            t2 = relay.TensorType(y_shape, "int32")
             x_data = np.random.randint(1, 10000, size=(x_shape)).astype(t1.dtype)
             y_data = np.random.randint(1, 10000, size=(y_shape)).astype(t2.dtype)
             func = relay.Function([x, y], z)
@@ -142,6 +142,7 @@ def test_binary_int_broadcast_2():
                 op_res = intrp.evaluate(func)(x_data, y_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
+
 @tvm.testing.uses_gpu
 def test_where():
     def run(func, inputs, ref_res):
@@ -206,15 +207,18 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
         return
 
     func = relay.Function([x], z)
-    x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
+    x_data = (
+        np.random.choice([True, False], size=data)
+        if ref_func in [np.all]
         else np.random.uniform(size=data).astype(dtype)
+    )
 
     if ref_func in [np.sum]:
         ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
     elif ref_func in [np.max, np.min, np.mean, np.prod]:
         ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
-    else: #argmin/argmax
-        if axis and not isinstance(axis, int) and len(axis) > 1 :
+    else:  # argmin/argmax
+        if axis and not isinstance(axis, int) and len(axis) > 1:
             return
         ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
 
@@ -226,6 +230,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
         op_res2 = intrp2.evaluate(func)(x_data)
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_reduce_functions():
     def _with_keepdims(func):
@@ -240,6 +245,7 @@ def test_reduce_functions():
                 else:
                     out_shape = [1 for _ in range(len(data.shape))]
                 return func(data, axis=axis).reshape(out_shape)
+
         return _wrapper
 
     def _np_log_sum_exp(x, axis, keepdims=False):
@@ -253,28 +259,32 @@ def test_reduce_functions():
     def _unbiased_relay_wrapper(f):
         def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
             return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)
+
         return _unbiased_func
 
     def _unbiased_np_wrapper(f):
         def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
             return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)
+
         return _unbiased_func
 
     d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
-    for func in [[relay.sum, np.sum],
-                 [relay.max, np.max],
-                 [relay.min, np.min],
-                 [relay.mean, np.mean],
-                 [relay.variance, np.var],
-                 [_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)],
-                 [relay.std, np.std],
-                 [_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)],
-                 [relay.prod, np.prod],
-                 [relay.all, np.all],
-                 [relay.any, np.any],
-                 [relay.logsumexp, _np_log_sum_exp],
-                 [relay.argmin, _with_keepdims(np.argmin)],
-                 [relay.argmax, _with_keepdims(np.argmax)]]:
+    for func in [
+        [relay.sum, np.sum],
+        [relay.max, np.max],
+        [relay.min, np.min],
+        [relay.mean, np.mean],
+        [relay.variance, np.var],
+        [_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)],
+        [relay.std, np.std],
+        [_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)],
+        [relay.prod, np.prod],
+        [relay.all, np.all],
+        [relay.any, np.any],
+        [relay.logsumexp, _np_log_sum_exp],
+        [relay.argmin, _with_keepdims(np.argmin)],
+        [relay.argmax, _with_keepdims(np.argmax)],
+    ]:
         verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
         verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
         verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4))
@@ -316,10 +326,10 @@ def verify_mean_var_std(funcs, shape, axis, keepdims):
         tvm.testing.assert_allclose(op_res2[0].asnumpy(), ref_mean, rtol=1e-5)
         tvm.testing.assert_allclose(op_res2[1].asnumpy(), ref_res, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_mean_var_std():
-    for func in [[relay.mean_variance, np.var],
-                 [relay.mean_std, np.std]]:
+    for func in [[relay.mean_variance, np.var], [relay.mean_std, np.std]]:
         verify_mean_var_std(func, (2, 3, 4), 1, True)
         verify_mean_var_std(func, (2, 3, 4), (1,), True)
         verify_mean_var_std(func, (2, 3, 4), -1, True)
@@ -334,8 +344,7 @@ def test_mean_var_std():
 
 @tvm.testing.uses_gpu
 def test_strided_slice():
-    def verify(dshape, begin, end, strides, output, slice_mode="end",
-               test_ref=True, dtype="int32"):
+    def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         ndim = len(dshape)
         begin = begin if begin else [0] * ndim
@@ -343,20 +352,12 @@ def test_strided_slice():
 
         # target numpy result
         x_data = np.random.uniform(size=dshape).astype("float32")
-        ref_res = tvm.topi.testing.strided_slice_python(
-            x_data, begin, end, strides, slice_mode)
+        ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
 
         if strides:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    strides=strides,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
         else:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
         func = relay.Function([x], z)
 
         func = run_infer_type(func)
@@ -375,8 +376,14 @@ def test_strided_slice():
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
     verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64")
-    verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3],
-           [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64")
+    verify(
+        (1, 224, 224, 3),
+        [0, 20, 20, 0],
+        [1, 140, 140, 3],
+        [1, 1, 1, 1],
+        (1, 120, 120, 3),
+        dtype="int64",
+    )
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
     verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
@@ -385,37 +392,29 @@ def test_strided_slice():
     verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
     verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
-    verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1],
-           (2, 4, 3), slice_mode="size", test_ref=False)
-    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1],
-           (2, 2, 3), slice_mode="size", test_ref=True)
+    verify(
+        (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
+    )
+    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
+
 
-#TODO(mbrookhart): enable once vm supports heterogenous execution
-#@tvm.testing.uses_gpu
+# TODO(mbrookhart): enable once vm supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_dyn_strided_slice():
-    def verify(dshape, begin, end, strides, output, slice_mode="end",
-               test_ref=True, dtype="int32"):
+    def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
         ndim = len(dshape)
         begin = begin if begin else [0] * ndim
         end = end if end else list(dshape)
 
         # target numpy result
         x_data = np.random.uniform(size=dshape).astype("float32")
-        ref_res = tvm.topi.testing.strided_slice_python(
-            x_data, begin, end, strides, slice_mode)
+        ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
 
-        x = relay.var("x", relay.TensorType((relay.Any(), ) * ndim, "float32"))
+        x = relay.var("x", relay.TensorType((relay.Any(),) * ndim, "float32"))
         if strides:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    strides=strides,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
         else:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
         func = relay.Function([x], z)
 
         func = run_infer_type(func)
@@ -432,21 +431,27 @@ def test_dyn_strided_slice():
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
     verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64")
-    verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3],
-           [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64")
+    verify(
+        (1, 224, 224, 3),
+        [0, 20, 20, 0],
+        [1, 140, 140, 3],
+        [1, 1, 1, 1],
+        (1, 120, 120, 3),
+        dtype="int64",
+    )
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
     verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
-    #TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin
-    #verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
-    #verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
-    verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1],
-           (2, 4, 3), slice_mode="size", test_ref=False)
-    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1],
-           (2, 2, 3), slice_mode="size", test_ref=True)
+    # TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin
+    # verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
+    # verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
+    verify(
+        (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
+    )
+    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
 
 
 @tvm.testing.uses_gpu
@@ -471,8 +476,7 @@ def test_strided_set():
             return
         x_data = np.random.uniform(size=dshape).astype("float32")
         v_data = np.random.uniform(size=vshape).astype("float32")
-        ref_res = tvm.topi.testing.strided_set_python(
-            x_data, v_data, begin, end, strides)
+        ref_res = tvm.topi.testing.strided_set_python(x_data, v_data, begin, end, strides)
         for target, ctx in tvm.testing.enabled_targets():
             intrp = relay.create_executor("graph", ctx=ctx, target=target)
             op_res = intrp.evaluate(func)(x_data, v_data)
index 4d5b17f..cfb85b6 100644 (file)
@@ -36,11 +36,12 @@ def test_resize_infer_type():
     assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
 
     x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
-    z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners")
+    z = relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners")
     assert "size=" in z.astext()
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
 
+
 @tvm.testing.uses_gpu
 def test_resize():
     def verify_resize(dshape, scale, method, layout, coord_trans):
@@ -56,8 +57,7 @@ def test_resize():
         else:
             ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout)
         x = relay.var("x", relay.TensorType(dshape, "float32"))
-        z = relay.image.resize(x, size, layout, method,
-                              coordinate_transformation_mode=coord_trans)
+        z = relay.image.resize(x, size, layout, method, coordinate_transformation_mode=coord_trans)
         assert "size=" in z.astext()
         zz = run_infer_type(z)
         assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
@@ -75,8 +75,15 @@ def test_resize():
         verify_resize((2, 8, 17, 20), 3, "bilinear", layout, "asymmetric")
         verify_resize((3, 4, 5, 6), 5, "nearest_neighbor", layout, "asymmetric")
 
+
 def test_resize3d_infer_type():
-    n, c, d, h, w = te.size_var("n"), te.size_var("c"), te.size_var("d"), te.size_var("h"), te.size_var("w")
+    n, c, d, h, w = (
+        te.size_var("n"),
+        te.size_var("c"),
+        te.size_var("d"),
+        te.size_var("h"),
+        te.size_var("w"),
+    )
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
     td, th, tw = te.var("td"), te.var("th"), te.var("tw")
     z = relay.image.resize3d(x, (td, th, tw))
@@ -84,11 +91,12 @@ def test_resize3d_infer_type():
     assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8")
 
     x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
-    z= relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners")
+    z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners")
     assert "size=" in z.astext()
     zz = run_infer_type(z)
     assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8")
 
+
 @tvm.testing.parametrize_targets
 def test_resize3d(target, ctx):
     def verify_resize(dshape, scale, method, layout):
@@ -113,30 +121,31 @@ def test_resize3d(target, ctx):
             intrp = relay.create_executor(kind, ctx=ctx, target=target)
             op_res = intrp.evaluate(func)(x_data)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
+
     for method in ["trilinear", "nearest_neighbor"]:
         for layout in ["NDHWC", "NCDHW"]:
             verify_resize((1, 4, 4, 4, 4), 2, method, layout)
 
+
 @tvm.testing.uses_gpu
 def test_crop_and_resize():
-    def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size,
-                               layout, method, extrapolation_value=0.0):
+    def verify_crop_and_resize(
+        img_shape, boxes, box_indices, crop_size, layout, method, extrapolation_value=0.0
+    ):
 
         image_data = np.random.uniform(size=img_shape).astype("float32")
 
-        ref_res = tvm.topi.testing.crop_and_resize_python(image_data,
-                                                      boxes,
-                                                      box_indices,
-                                                      crop_size,
-                                                      layout, method,
-                                                      extrapolation_value)
+        ref_res = tvm.topi.testing.crop_and_resize_python(
+            image_data, boxes, box_indices, crop_size, layout, method, extrapolation_value
+        )
 
-        img = relay.var("img", relay.TensorType(img_shape, 'float32'))
-        bx = relay.var('bx', relay.TensorType(boxes.shape, 'float32'))
-        bx_idx = relay.var('bx_idx', relay.TensorType(box_indices.shape, 'int32'))
+        img = relay.var("img", relay.TensorType(img_shape, "float32"))
+        bx = relay.var("bx", relay.TensorType(boxes.shape, "float32"))
+        bx_idx = relay.var("bx_idx", relay.TensorType(box_indices.shape, "int32"))
 
-        z = relay.image.crop_and_resize(img, bx, bx_idx, list(crop_size),
-                                        layout, method, extrapolation_value)
+        z = relay.image.crop_and_resize(
+            img, bx, bx_idx, list(crop_size), layout, method, extrapolation_value
+        )
         zz = run_infer_type(z)
         assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
         func = relay.Function([img, bx, bx_idx], z)
@@ -147,24 +156,27 @@ def test_crop_and_resize():
                 op_res = intrp.evaluate(func)(image_data, boxes, box_indices)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-04)
 
-    boxes_nhwc = np.array([[.1, .2, .8, .7], [.2, 0, 1, .6]]).astype("float32")
+    boxes_nhwc = np.array([[0.1, 0.2, 0.8, 0.7], [0.2, 0, 1, 0.6]]).astype("float32")
     indices_nhwc = np.array([1, 0]).astype("int32")
     size_nhwc = np.array([20, 30]).astype("int32")
-    boxes_nchw = np.array([[0, 0, 1, 1], [.2, .1, 1, .9]]).astype("float32")
+    boxes_nchw = np.array([[0, 0, 1, 1], [0.2, 0.1, 1, 0.9]]).astype("float32")
     indices_nchw = np.array([0, 1]).astype("int32")
     size_nchw = np.array([30, 30]).astype("int32")
 
     for method in ["bilinear", "nearest_neighbor"]:
-        verify_crop_and_resize((10, 224, 224, 3), boxes_nhwc, indices_nhwc,
-                               size_nhwc, 'NHWC', method)
-        verify_crop_and_resize((5, 3, 255, 255), boxes_nchw, indices_nchw,
-                               size_nchw, 'NCHW', method, 0.1)
+        verify_crop_and_resize(
+            (10, 224, 224, 3), boxes_nhwc, indices_nhwc, size_nhwc, "NHWC", method
+        )
+        verify_crop_and_resize(
+            (5, 3, 255, 255), boxes_nchw, indices_nchw, size_nchw, "NCHW", method, 0.1
+        )
+
 
 @tvm.testing.uses_gpu
 def test_multibox_prior():
-    def get_ref_result(dshape, sizes=(1.0,),
-                       ratios=(1.0,), steps=(-1.0, -1.0),
-                       offsets=(0.5, 0.5), clip=True):
+    def get_ref_result(
+        dshape, sizes=(1.0,), ratios=(1.0,), steps=(-1.0, -1.0), offsets=(0.5, 0.5), clip=True
+    ):
         in_height = dshape[2]
         in_width = dshape[3]
         num_sizes = len(sizes)
@@ -184,11 +196,25 @@ def test_multibox_prior():
             for j in range(in_width):
                 center_w = (j + offset_w) * steps_w
                 for k in range(num_sizes + num_ratios - 1):
-                    w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \
-                        size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0
-                    h = size_ratio_concat[k] / 2.0 if k < num_sizes else \
-                        size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
-                    count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k
+                    w = (
+                        size_ratio_concat[k] * in_height / in_width / 2.0
+                        if k < num_sizes
+                        else size_ratio_concat[0]
+                        * in_height
+                        / in_width
+                        * math.sqrt(size_ratio_concat[k + 1])
+                        / 2.0
+                    )
+                    h = (
+                        size_ratio_concat[k] / 2.0
+                        if k < num_sizes
+                        else size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
+                    )
+                    count = (
+                        i * in_width * (num_sizes + num_ratios - 1)
+                        + j * (num_sizes + num_ratios - 1)
+                        + k
+                    )
                     np_out[0][count][0] = center_w - w
                     np_out[0][count][1] = center_h - h
                     np_out[0][count][2] = center_w + w
@@ -198,18 +224,26 @@ def test_multibox_prior():
 
         return np_out
 
-    def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,),
-                              ratios=(1.0,), steps=(-1.0, -1.0),
-                              offsets=(0.5, 0.5), clip=True, check_size=False,
-                              check_type_only=False):
+    def verify_multibox_prior(
+        x,
+        dshape,
+        ref_res,
+        sizes=(1.0,),
+        ratios=(1.0,),
+        steps=(-1.0, -1.0),
+        offsets=(0.5, 0.5),
+        clip=True,
+        check_size=False,
+        check_type_only=False,
+    ):
 
         z = relay.vision.multibox_prior(x, sizes, ratios, steps, offsets, clip)
         zz = run_infer_type(z)
         if check_size:
             assert "sizes=" in z.astext()
         assert zz.checked_type == relay.TensorType(
-            (1, dshape[2] * dshape[3] * (len(sizes) + len(ratios) - 1), 4),
-            "float32")
+            (1, dshape[2] * dshape[3] * (len(sizes) + len(ratios) - 1), 4), "float32"
+        )
 
         if check_type_only:
             return
@@ -232,11 +266,11 @@ def test_multibox_prior():
     dshape = (1, 3, 56, 56)
     ref_res = get_ref_result(dshape, sizes, ratios, steps, offsets)
     x = relay.var("x", relay.TensorType(dshape, "float32"))
-    verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets,
-                          check_size=True)
+    verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets, check_size=True)
     y = relay.var("y", relay.TensorType((te.size_var("n"), 3, 56, 56), "float32"))
-    verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets,
-                          check_size=True, check_type_only=True)
+    verify_multibox_prior(
+        x, dshape, ref_res, sizes, ratios, steps, offsets, check_size=True, check_type_only=True
+    )
 
     dshape = (1, 24, 32, 32)
     ref_res = get_ref_result(dshape, clip=False)
@@ -281,7 +315,7 @@ def test_get_valid_counts():
             out = intrp.evaluate(func)(np_data)
             tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
             # get_valid_count for cuda, opencl doesn't do data rearrangement
-            if target in ['cuda', 'opencl']:
+            if target in ["cuda", "opencl"]:
                 return
             tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
             tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04)
@@ -294,19 +328,43 @@ def test_get_valid_counts():
 
 @tvm.testing.uses_gpu
 def test_non_max_suppression():
-    def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res,
-                   ref_indices_res, iou_threshold=0.5, force_suppress=False,
-                   top_k=-1, check_type_only=False):
+    def verify_nms(
+        x0_data,
+        x1_data,
+        x2_data,
+        x3_data,
+        dshape,
+        ref_res,
+        ref_indices_res,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        check_type_only=False,
+    ):
         x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
         x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
         x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32"))
         x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
-        z = relay.vision.non_max_suppression(x0, x1, x2, x3, \
-            iou_threshold=iou_threshold, force_suppress=force_suppress, \
-            top_k=top_k, return_indices=False)
-        z_indices = relay.vision.non_max_suppression(x0, x1, x2, x3, \
-                    iou_threshold=iou_threshold, force_suppress=force_suppress, \
-                    top_k=top_k, return_indices=True)
+        z = relay.vision.non_max_suppression(
+            x0,
+            x1,
+            x2,
+            x3,
+            iou_threshold=iou_threshold,
+            force_suppress=force_suppress,
+            top_k=top_k,
+            return_indices=False,
+        )
+        z_indices = relay.vision.non_max_suppression(
+            x0,
+            x1,
+            x2,
+            x3,
+            iou_threshold=iou_threshold,
+            force_suppress=force_suppress,
+            top_k=top_k,
+            return_indices=True,
+        )
         if isinstance(z_indices, relay.expr.TupleWrapper):
             z_indices = z_indices.astuple()
         assert "iou_threshold" in z.astext()
@@ -315,8 +373,11 @@ def test_non_max_suppression():
         zz_indices = run_infer_type(z_indices)
         assert zz.checked_type == relay.ty.TensorType(dshape, "float32")
         assert zz_indices.checked_type == relay.ty.TupleType(
-            [relay.ty.TensorType((dshape[0], dshape[1]), "int32"),
-             relay.ty.TensorType((dshape[0], 1), "int32")])
+            [
+                relay.ty.TensorType((dshape[0], dshape[1]), "int32"),
+                relay.ty.TensorType((dshape[0], 1), "int32"),
+            ]
+        )
 
         if check_type_only:
             return
@@ -332,44 +393,104 @@ def test_non_max_suppression():
             intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
             op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
-            if target == 'cuda':
+            if target == "cuda":
                 return
             op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
             op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5)
 
-    np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
-                         [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
-                         [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
+    np_data = np.array(
+        [
+            [
+                [0, 0.8, 1, 20, 25, 45],
+                [1, 0.7, 30, 60, 50, 80],
+                [0, 0.4, 4, 21, 19, 40],
+                [2, 0.9, 35, 61, 52, 79],
+                [1, 0.5, 100, 60, 70, 110],
+            ]
+        ]
+    ).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
     np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
     np_max_output_size = -1
 
-    np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
-                           [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
-                           [-1, -1, -1, -1, -1, -1]]])
+    np_result = np.array(
+        [
+            [
+                [2, 0.9, 35, 61, 52, 79],
+                [0, 0.8, 1, 20, 25, 45],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+            ]
+        ]
+    )
     np_indices_result = np.array([[4, 0, -1, -1, -1]])
     num_anchors = 5
 
     dshape = (te.size_var("n"), num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result,
-               force_suppress=True, top_k=2, check_type_only=True)
+    verify_nms(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_max_output_size,
+        dshape,
+        np_result,
+        np_indices_result,
+        force_suppress=True,
+        top_k=2,
+        check_type_only=True,
+    )
     dshape = (1, num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result,
-               force_suppress=True, top_k=2, check_type_only=False)
-
-    np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
-                           [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
-                           [-1, -1, -1, -1, -1, -1]]])
+    verify_nms(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_max_output_size,
+        dshape,
+        np_result,
+        np_indices_result,
+        force_suppress=True,
+        top_k=2,
+        check_type_only=False,
+    )
+
+    np_result = np.array(
+        [
+            [
+                [2, 0.9, 35, 61, 52, 79],
+                [0, 0.8, 1, 20, 25, 45],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+            ]
+        ]
+    )
     np_indices_result = np.array([[4, 0, -1, -1, -1]])
     np_max_output_size = 2
     dshape = (te.size_var("n"), num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
-               np_indices_result, check_type_only=True)
+    verify_nms(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_max_output_size,
+        dshape,
+        np_result,
+        np_indices_result,
+        check_type_only=True,
+    )
     dshape = (1, num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
-               np_indices_result, top_k=2)
+    verify_nms(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_max_output_size,
+        dshape,
+        np_result,
+        np_indices_result,
+        top_k=2,
+    )
 
 
 @tvm.testing.uses_gpu
@@ -378,37 +499,44 @@ def test_multibox_transform_loc():
         num_anchors = 3
         num_classes = 3
 
-        np_cls_prob = np.array(
-            [[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45],
-              [0.7, 0.1, 0.2]]]).astype("float32")
+        np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]).astype(
+            "float32"
+        )
         np_loc_preds = np.array(
-            [[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4,
-              -0.8]]).astype("float32")
+            [[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]]
+        ).astype("float32")
         np_anchors = np.array(
-            [[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2],
-              [1.2, 1.2, 1.5, 1.5]]]).astype("float32")
-
-        expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
-                                     [0, 0.44999999, 1, 1, 1, 1],
-                                     [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
-
+            [[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]]
+        ).astype("float32")
+
+        expected_np_out = np.array(
+            [
+                [
+                    [1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
+                    [0, 0.44999999, 1, 1, 1, 1],
+                    [0, 0.30000001, 0, 0, 0.22903419, 0.20435292],
+                ]
+            ]
+        )
 
         cls_prob = relay.var(
-            "cls_prob",
-            relay.ty.TensorType((1, num_anchors, num_classes), "float32"))
-        loc_pred = relay.var(
-            "loc_pred", relay.ty.TensorType((1, num_anchors * 4), "float32"))
-        anchors = relay.var(
-            "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
+            "cls_prob", relay.ty.TensorType((1, num_anchors, num_classes), "float32")
+        )
+        loc_pred = relay.var("loc_pred", relay.ty.TensorType((1, num_anchors * 4), "float32"))
+        anchors = relay.var("anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
 
         mtl = relay.vision.multibox_transform_loc(
-            cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors)
+            cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors
+        )
         ret = run_infer_type(mtl.astuple())
         ref_type = relay.ty.TupleType(
-            tvm.runtime.convert([
-                relay.ty.TensorType((1, num_anchors, 6), "float32"),
-                relay.ty.TensorType((1, ), "int")
-            ]))
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((1, num_anchors, 6), "float32"),
+                    relay.ty.TensorType((1,), "int"),
+                ]
+            )
+        )
 
         assert ret.checked_type == ref_type
 
@@ -417,12 +545,10 @@ def test_multibox_transform_loc():
         func = run_infer_type(func)
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
-            op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds,
-                                            np_anchors)
+            op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors)
             tvm.testing.assert_allclose(op_res1.asnumpy(), expected_np_out, rtol=1e-5)
             intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
-            op_res2 = intrp2.evaluate(func)(np_cls_prob, np_loc_preds,
-                                            np_anchors)
+            op_res2 = intrp2.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors)
             tvm.testing.assert_allclose(op_res2.asnumpy(), expected_np_out, rtol=1e-5)
 
     def test_threshold():
@@ -430,12 +556,10 @@ def test_multibox_transform_loc():
         num_classes = 5
         n = te.size_var("n")
         cls_prob = relay.var(
-            "cls_prob",
-            relay.ty.TensorType((n, num_anchors, num_classes), "float32"))
-        loc_pred = relay.var(
-            "loc_pred", relay.ty.TensorType((n, num_anchors * 4), "float32"))
-        anchors = relay.var(
-            "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
+            "cls_prob", relay.ty.TensorType((n, num_anchors, num_classes), "float32")
+        )
+        loc_pred = relay.var("loc_pred", relay.ty.TensorType((n, num_anchors * 4), "float32"))
+        anchors = relay.var("anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
         threshold = 0.02
         variances = (0.2, 0.2, 0.3, 0.3)
 
@@ -444,13 +568,17 @@ def test_multibox_transform_loc():
             loc_pred=loc_pred,
             anchor=anchors,
             threshold=threshold,
-            variances=variances)
+            variances=variances,
+        )
         ret = run_infer_type(ret.astuple())
         ref_type = relay.ty.TupleType(
-            tvm.runtime.convert([
-                relay.ty.TensorType((n, num_anchors, 6), "float32"),
-                relay.ty.TensorType((n, ), "int")
-            ]))
+            tvm.runtime.convert(
+                [
+                    relay.ty.TensorType((n, num_anchors, 6), "float32"),
+                    relay.ty.TensorType((n,), "int"),
+                ]
+            )
+        )
         assert ret.checked_type == ref_type
 
     test_default_value()
@@ -462,23 +590,33 @@ def test_roi_align():
     def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio):
         data = relay.var("data", relay.ty.TensorType(data_shape, "float32"))
         rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
-        z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size),
-                                   spatial_scale=spatial_scale, sample_ratio=sample_ratio,
-                                   layout="NCHW")
+        z = relay.vision.roi_align(
+            data,
+            rois,
+            pooled_size=(pooled_size, pooled_size),
+            spatial_scale=spatial_scale,
+            sample_ratio=sample_ratio,
+            layout="NCHW",
+        )
         zz = run_infer_type(z)
         batch, channel, in_size, _ = data_shape
         num_roi = rois_shape[0]
         assert zz.checked_type == relay.ty.TensorType(
-                (num_roi, channel, pooled_size, pooled_size), "float32")
+            (num_roi, channel, pooled_size, pooled_size), "float32"
+        )
 
         func = relay.Function([data, rois], z)
         func = run_infer_type(func)
         np_data = np.random.uniform(size=data_shape).astype("float32")
-        np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
-        np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
-        ref_res = tvm.topi.testing.roi_align_nchw_python(np_data, np_rois, pooled_size=pooled_size,
-                                                     spatial_scale=spatial_scale,
-                                                     sample_ratio=sample_ratio)
+        np_rois = np.random.uniform(size=rois_shape).astype("float32") * in_size
+        np_rois[:, 0] = np.random.randint(low=0, high=batch, size=num_roi)
+        ref_res = tvm.topi.testing.roi_align_nchw_python(
+            np_data,
+            np_rois,
+            pooled_size=pooled_size,
+            spatial_scale=spatial_scale,
+            sample_ratio=sample_ratio,
+        )
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(np_data, np_rois)
@@ -496,21 +634,28 @@ def test_roi_pool():
     def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale):
         data = relay.var("data", relay.ty.TensorType(data_shape, "float32"))
         rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
-        z = relay.vision.roi_pool(data, rois, pooled_size=(pooled_size, pooled_size),
-                                   spatial_scale=spatial_scale, layout="NCHW")
+        z = relay.vision.roi_pool(
+            data,
+            rois,
+            pooled_size=(pooled_size, pooled_size),
+            spatial_scale=spatial_scale,
+            layout="NCHW",
+        )
         zz = run_infer_type(z)
         batch, channel, in_size, _ = data_shape
         num_roi = rois_shape[0]
         assert zz.checked_type == relay.ty.TensorType(
-                (num_roi, channel, pooled_size, pooled_size), "float32")
+            (num_roi, channel, pooled_size, pooled_size), "float32"
+        )
 
         func = relay.Function([data, rois], z)
         func = run_infer_type(func)
         np_data = np.random.uniform(size=data_shape).astype("float32")
-        np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
-        np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32')
-        ref_res = tvm.topi.testing.roi_pool_nchw_python(np_data, np_rois, pooled_size=pooled_size,
-                                                     spatial_scale=spatial_scale)
+        np_rois = np.random.uniform(size=rois_shape).astype("float32") * in_size
+        np_rois[:, 0] = np.random.randint(low=0, high=batch, size=num_roi).astype("float32")
+        ref_res = tvm.topi.testing.roi_pool_nchw_python(
+            np_data, np_rois, pooled_size=pooled_size, spatial_scale=spatial_scale
+        )
         for target, ctx in tvm.testing.enabled_targets():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(np_data, np_rois)
@@ -535,7 +680,7 @@ def test_proposal():
 
         func = relay.Function([cls_prob, bbox_pred, im_info], z)
         func = run_infer_type(func)
-        for target in ['llvm', 'cuda']:
+        for target in ["llvm", "cuda"]:
             if not tvm.testing.device_enabled(target):
                 print("Skip test because %s is not enabled." % target)
                 continue
@@ -548,44 +693,59 @@ def test_proposal():
             tvm.testing.assert_allclose(op_res2.asnumpy(), np_out, rtol=1e-4)
 
     attrs = {
-        'scales': (0.5,),
-        'ratios': (0.5,),
-        'feature_stride': 16,
-        'iou_loss': False,
-        'rpn_min_size': 16,
-        'threshold': 0.7,
-        'rpn_pre_nms_top_n': 200,
-        'rpn_post_nms_top_n': 4,
+        "scales": (0.5,),
+        "ratios": (0.5,),
+        "feature_stride": 16,
+        "iou_loss": False,
+        "rpn_min_size": 16,
+        "threshold": 0.7,
+        "rpn_pre_nms_top_n": 200,
+        "rpn_post_nms_top_n": 4,
     }
 
-    np_cls_prob = np.array([[
-        [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
-        [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
-    ]], dtype='float32')
-    np_bbox_pred = np.array([[
-        [[0.5, 1.0, 0.6], [0.8,  1.2, 2.0], [0.9, 1.0, 0.8]],
-        [[0.5, 1.0, 0.7], [0.8,  1.2, 1.6], [2.1, 1.5, 0.7]],
-        [[1.0, 0.5, 0.7], [1.5,  0.9, 1.6], [1.4, 1.5, 0.8]],
-        [[1.0, 0.5, 0.6], [1.5,  0.9, 2.0], [1.8, 1.0, 0.9]],
-    ]], dtype='float32')
-    np_im_info = np.array([[48., 48., 1.]], dtype='float32')
-    np_out = np.array([
-        [0., 0., 2.8451548,28.38012, 18.154846],
-        [0., 0., 15.354933, 41.96971, 41.245064],
-        [0., 18.019852, 1.0538368, 51.98015, 25.946163],
-        [0., 27.320923, -1.266357, 55., 24.666357]
-    ], dtype='float32')
-
+    np_cls_prob = np.array(
+        [
+            [
+                [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
+                [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]],
+            ]
+        ],
+        dtype="float32",
+    )
+    np_bbox_pred = np.array(
+        [
+            [
+                [[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
+                [[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
+                [[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
+                [[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
+            ]
+        ],
+        dtype="float32",
+    )
+    np_im_info = np.array([[48.0, 48.0, 1.0]], dtype="float32")
+    np_out = np.array(
+        [
+            [0.0, 0.0, 2.8451548, 28.38012, 18.154846],
+            [0.0, 0.0, 15.354933, 41.96971, 41.245064],
+            [0.0, 18.019852, 1.0538368, 51.98015, 25.946163],
+            [0.0, 27.320923, -1.266357, 55.0, 24.666357],
+        ],
+        dtype="float32",
+    )
 
     verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
 
-    np_out = np.array([
-        [ 0., -5.25, -2.5, 21.75, 19.],
-        [ 0., 11.25, -2., 37.25, 18.5],
-        [ 0., 26.849998, -2.3000002, 53.45, 18.6],
-        [ 0., -4.95, 13.799999, 22.25, 35.5]
-    ], dtype='float32')
-    attrs['iou_loss'] = True
+    np_out = np.array(
+        [
+            [0.0, -5.25, -2.5, 21.75, 19.0],
+            [0.0, 11.25, -2.0, 37.25, 18.5],
+            [0.0, 26.849998, -2.3000002, 53.45, 18.6],
+            [0.0, -4.95, 13.799999, 22.25, 35.5],
+        ],
+        dtype="float32",
+    )
+    attrs["iou_loss"] = True
     verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
 
 
@@ -599,8 +759,9 @@ def test_yolo_reorg_infer_shape():
 
     n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
     idxd = tvm.tir.indexdiv
-    verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
-    verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
+    verify_yolo_reorg((n, c, 20, 20), 10, (n, c * 10 * 10, 2, 2))
+    verify_yolo_reorg((n, c, h, w), 2, (n, c * 2 * 2, idxd(h, 2), idxd(w, 2)))
+
 
 @tvm.testing.uses_gpu
 def test_yolo_reorg():
@@ -634,17 +795,26 @@ def test_deformable_conv2d():
         offset = relay.var("offset")
         kernel = relay.var("kernel")
         kernel_size = (3, 3)
-        y = relay.nn.deformable_conv2d(data, offset, kernel,
+        y = relay.nn.deformable_conv2d(
+            data,
+            offset,
+            kernel,
             strides=(1, 1),
             padding=(1, 1),
             dilation=(1, 1),
             kernel_size=kernel_size,
             deformable_groups=deformable_groups,
             groups=groups,
-            channels=out_channel)
+            channels=out_channel,
+        )
         weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
         out_shape = (batch, out_channel, size, size)
-        offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, out_shape[2], out_shape[3])
+        offset_shape = (
+            batch,
+            2 * kernel_size[0] * kernel_size[1] * deformable_groups,
+            out_shape[2],
+            out_shape[3],
+        )
         yy = run_infer_type(y)
         assert yy.checked_type == relay.TensorType(out_shape)
         assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type
@@ -653,35 +823,48 @@ def test_deformable_conv2d():
     test_infer_type(1, 4, 16, 4, 4, 1)
     test_infer_type(2, 4, 16, 4, 1, 2)
 
-
     def test_run(batch, in_channel, size, out_channel, deformable_groups, groups):
         kernel_size = (3, 3)
         data_shape = (batch, in_channel, size, size)
         offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, size, size)
         kernel_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
-        dtype = 'float32'
+        dtype = "float32"
         data = relay.var("data", shape=data_shape, dtype=dtype)
         offset = relay.var("offset")
         kernel = relay.var("kernel")
-        y = relay.nn.deformable_conv2d(data, offset, kernel,
+        y = relay.nn.deformable_conv2d(
+            data,
+            offset,
+            kernel,
             strides=(1, 1),
             padding=(1, 1),
             dilation=(1, 1),
             kernel_size=kernel_size,
             deformable_groups=deformable_groups,
             groups=groups,
-            channels=out_channel)
+            channels=out_channel,
+        )
         func = relay.Function([data, offset, kernel], y)
         data = np.random.uniform(size=data_shape).astype(dtype)
         offset = np.random.uniform(size=offset_shape).astype(dtype)
         kernel = np.random.uniform(size=kernel_shape).astype(dtype)
-        ref_res = tvm.topi.testing.deformable_conv2d_nchw_python(data, offset, kernel, stride=(1, 1), padding=(1, 1), dilation=(1, 1), deformable_groups=deformable_groups, groups=groups)
+        ref_res = tvm.topi.testing.deformable_conv2d_nchw_python(
+            data,
+            offset,
+            kernel,
+            stride=(1, 1),
+            padding=(1, 1),
+            dilation=(1, 1),
+            deformable_groups=deformable_groups,
+            groups=groups,
+        )
 
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res1 = intrp1.evaluate(func)(data, offset, kernel)
                 tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
     test_run(1, 4, 16, 4, 1, 1)
     test_run(2, 4, 16, 4, 4, 1)
 
@@ -690,9 +873,19 @@ def test_deformable_conv2d():
 def test_depth_to_space():
     def verify_depth_to_space(dshape, block_size, layout, mode):
         if layout == "NHWC":
-            out_shape = [dshape[0], dshape[1] * block_size, dshape[2] * block_size, dshape[3] / (block_size * block_size)]
+            out_shape = [
+                dshape[0],
+                dshape[1] * block_size,
+                dshape[2] * block_size,
+                dshape[3] / (block_size * block_size),
+            ]
         else:
-            out_shape = [dshape[0], dshape[1] / (block_size * block_size), dshape[2] * block_size, dshape[3] * block_size]
+            out_shape = [
+                dshape[0],
+                dshape[1] / (block_size * block_size),
+                dshape[2] * block_size,
+                dshape[3] * block_size,
+            ]
 
         x_data = np.random.uniform(size=dshape).astype("float32")
         if layout == "NHWC":
@@ -714,6 +907,7 @@ def test_depth_to_space():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
+
     for layout in ["NHWC", "NCHW"]:
         for mode in ["DCR", "CDR"]:
             verify_depth_to_space((1, 4, 4, 4), 2, layout, mode)
@@ -723,9 +917,19 @@ def test_depth_to_space():
 def test_space_to_depth():
     def verify_space_to_depth(dshape, block_size, layout):
         if layout == "NHWC":
-            out_shape = [dshape[0], dshape[1] / block_size, dshape[2] / block_size, dshape[3] * (block_size * block_size)]
+            out_shape = [
+                dshape[0],
+                dshape[1] / block_size,
+                dshape[2] / block_size,
+                dshape[3] * (block_size * block_size),
+            ]
         else:
-            out_shape = [dshape[0], dshape[1] * (block_size * block_size), dshape[2] / block_size, dshape[3] / block_size]
+            out_shape = [
+                dshape[0],
+                dshape[1] * (block_size * block_size),
+                dshape[2] / block_size,
+                dshape[3] / block_size,
+            ]
 
         x_data = np.random.uniform(size=dshape).astype("float32")
         if layout == "NHWC":
@@ -747,6 +951,7 @@ def test_space_to_depth():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
+
     for layout in ["NHWC", "NCHW"]:
         verify_space_to_depth((1, 4, 4, 4), 2, layout)
 
@@ -757,25 +962,31 @@ def test_dilation2d_infer_type():
     x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
     kc, kh, kw = 10, 8, 8
     w = relay.var("w", relay.ty.TensorType((kc, kw, kh), "float32"))
-    y = relay.image.dilation2d(x, w,
-                               # kernel_size=(3, 3),
-                               strides=[1, 1, 1, 1],
-                               dilations=[1, 1, 1, 1],
-                               padding=[0, 0, 0, 0])
+    y = relay.image.dilation2d(
+        x,
+        w,
+        # kernel_size=(3, 3),
+        strides=[1, 1, 1, 1],
+        dilations=[1, 1, 1, 1],
+        padding=[0, 0, 0, 0],
+    )
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType(
-        (n, 10, 217, 217), "float32")
+    assert yy.checked_type == relay.TensorType((n, 10, 217, 217), "float32")
 
 
 @tvm.testing.uses_gpu
 def test_dilation2d_run():
-    def run_test_dilation2d(indata, kernel, out,
-                            dtype='float32',
-                            strides=[1, 1],
-                            padding=[0, 0],
-                            dilations=[1, 1],
-                            except_targets=['cuda'],
-                            **attrs):
+    def run_test_dilation2d(
+        indata,
+        kernel,
+        out,
+        dtype="float32",
+        strides=[1, 1],
+        padding=[0, 0],
+        dilations=[1, 1],
+        except_targets=["cuda"],
+        **attrs,
+    ):
 
         dshape = indata.shape
         kshape = kernel.shape
@@ -785,11 +996,9 @@ def test_dilation2d_run():
 
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", shape=kshape, dtype=dtype)
-        y = relay.image.dilation2d(x, w,
-                                   strides=strides,
-                                   dilations=dilations,
-                                   padding=padding,
-                                   **attrs)
+        y = relay.image.dilation2d(
+            x, w, strides=strides, dilations=dilations, padding=padding, **attrs
+        )
         func = relay.Function([x, w], y)
 
         for target, ctx in tvm.testing.enabled_targets():
@@ -803,71 +1012,93 @@ def test_dilation2d_run():
         indata = np.asarray(indata)
         kernel = np.asarray(kernel)
         out = np.asarray(out)
-        if layout == 'NCHW':
+        if layout == "NCHW":
             indata = indata.transpose([0, 3, 1, 2])
             kernel = kernel.transpose([2, 0, 1])
             out = out.transpose([0, 3, 1, 2])
         return indata, kernel, out
 
-    image = [[[[.1], [.2]], [[.3], [.4]]]]
-    kernel = [[[.4], [.3]], [[.1], [.0]]]
-    out = [[[[.5]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
-    run_test_dilation2d(*_convert_data(image, kernel, out), data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1], [.2]], [[.3], [.4]]]]
-    kernel = [[[.4], [.3]], [[.1], [.0]]]
-    out = [[[[.5], [.6]], [[.7], [.8]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
-    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
-                        data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]]
-    kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]]
-    out = [[[[.5, .7, .3], [.6, .8, .4]], [[.7, .9, .5], [.8, 1., .6]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
-    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
-                        data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]]
-    kernel = [[[.4], [.3]], [[.1], [.0]]]
-    out = [[[[.5], [.6]], [[.7], [.8]]], [[[.6], [.7]], [[.8], [.9]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
-    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
-                        data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1], [.2]], [[.3], [.4]]]]
-    kernel = [[[.4], [.3]]]
-    out = [[[[.5]], [[.7]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
-    run_test_dilation2d(*_convert_data(image, kernel, out),
-                        data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
-    kernel = [[[.4], [.3]], [[.1], [.2]]]
-    out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[1, 1], dilations=[2, 2])
-    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[1, 1], dilations=[2, 2],
-                        data_layout='NHWC', kernel_layout='HWI')
-
-    image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
-              [[.9], [1.0], [1.1], [1.2]]]]
-    kernel = [[[.4], [.3]], [[.1], [.2]]]
-    out = [[[[.8], [1.0]], [[1.2], [1.4]]]]
-    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), strides=[1, 2])
-    run_test_dilation2d(*_convert_data(image, kernel, out), strides=[1, 2],
-                        data_layout='NHWC', kernel_layout='HWI')
+    image = [[[[0.1], [0.2]], [[0.3], [0.4]]]]
+    kernel = [[[0.4], [0.3]], [[0.1], [0.0]]]
+    out = [[[[0.5]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"))
+    run_test_dilation2d(*_convert_data(image, kernel, out), data_layout="NHWC", kernel_layout="HWI")
+
+    image = [[[[0.1], [0.2]], [[0.3], [0.4]]]]
+    kernel = [[[0.4], [0.3]], [[0.1], [0.0]]]
+    out = [[[[0.5], [0.6]], [[0.7], [0.8]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"), padding=[0, 0, 1, 1])
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out),
+        padding=[0, 0, 1, 1],
+        data_layout="NHWC",
+        kernel_layout="HWI",
+    )
+
+    image = [[[[0.1, 0.2, 0.0], [0.2, 0.3, 0.1]], [[0.3, 0.4, 0.2], [0.4, 0.5, 0.3]]]]
+    kernel = [[[0.4, 0.5, 0.3], [0.3, 0.4, 0.2]], [[0.1, 0.2, 0.0], [0.0, 0.1, -0.1]]]
+    out = [[[[0.5, 0.7, 0.3], [0.6, 0.8, 0.4]], [[0.7, 0.9, 0.5], [0.8, 1.0, 0.6]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"), padding=[0, 0, 1, 1])
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out),
+        padding=[0, 0, 1, 1],
+        data_layout="NHWC",
+        kernel_layout="HWI",
+    )
+
+    image = [[[[0.1], [0.2]], [[0.3], [0.4]]], [[[0.2], [0.3]], [[0.4], [0.5]]]]
+    kernel = [[[0.4], [0.3]], [[0.1], [0.0]]]
+    out = [[[[0.5], [0.6]], [[0.7], [0.8]]], [[[0.6], [0.7]], [[0.8], [0.9]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"), padding=[0, 0, 1, 1])
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out),
+        padding=[0, 0, 1, 1],
+        data_layout="NHWC",
+        kernel_layout="HWI",
+    )
+
+    image = [[[[0.1], [0.2]], [[0.3], [0.4]]]]
+    kernel = [[[0.4], [0.3]]]
+    out = [[[[0.5]], [[0.7]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"))
+    run_test_dilation2d(*_convert_data(image, kernel, out), data_layout="NHWC", kernel_layout="HWI")
+
+    image = [[[[0.1], [0.2], [0.3]], [[0.4], [0.5], [0.6]], [[0.7], [0.8], [0.9]]]]
+    kernel = [[[0.4], [0.3]], [[0.1], [0.2]]]
+    out = [[[[0.7], [0.8], [0.6]], [[1.0], [1.1], [0.9]], [[0.8], [0.9], [0.9]]]]
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out, layout="NCHW"), padding=[1, 1], dilations=[2, 2]
+    )
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out),
+        padding=[1, 1],
+        dilations=[2, 2],
+        data_layout="NHWC",
+        kernel_layout="HWI",
+    )
+
+    image = [
+        [[[0.1], [0.2], [0.3], [0.4]], [[0.5], [0.6], [0.7], [0.8]], [[0.9], [1.0], [1.1], [1.2]]]
+    ]
+    kernel = [[[0.4], [0.3]], [[0.1], [0.2]]]
+    out = [[[[0.8], [1.0]], [[1.2], [1.4]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout="NCHW"), strides=[1, 2])
+    run_test_dilation2d(
+        *_convert_data(image, kernel, out), strides=[1, 2], data_layout="NHWC", kernel_layout="HWI"
+    )
 
 
 @tvm.testing.uses_gpu
 def test_affine_grid():
     def verify_affine_grid(num_batch, target_shape):
-        dtype = 'float32'
+        dtype = "float32"
         data_shape = (num_batch, 2, 3)
         data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
         y = relay.image.affine_grid(data, target_shape)
         yy = run_infer_type(y)
-        assert yy.checked_type == relay.ty.TensorType((num_batch, len(target_shape), *target_shape), dtype)
+        assert yy.checked_type == relay.ty.TensorType(
+            (num_batch, len(target_shape), *target_shape), dtype
+        )
 
         func = relay.Function([data], y)
         data_np = np.random.uniform(size=data_shape).astype(dtype)
@@ -886,26 +1117,25 @@ def test_affine_grid():
 @tvm.testing.uses_gpu
 def test_grid_sample():
     def verify_grid_sample(data_shape, grid_shape):
-        dtype = 'float32'
+        dtype = "float32"
         batch, channel, _, _ = data_shape
         _, _, out_height, out_width = grid_shape
         data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
         grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype))
-        y = relay.image.grid_sample(data, grid, method='bilinear', layout='NCHW')
+        y = relay.image.grid_sample(data, grid, method="bilinear", layout="NCHW")
         yy = run_infer_type(y)
         assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype)
         func = relay.Function([data, grid], y)
 
         data_np = np.random.uniform(size=data_shape).astype(dtype)
         grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
-        ref_res = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, method='bilinear')
+        ref_res = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, method="bilinear")
 
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
                 intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res1 = intrp1.evaluate(func)(data_np, grid_np)
-                tvm.testing.assert_allclose(
-                    op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
     verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
     verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
index e683224..de51c15 100644 (file)
@@ -22,6 +22,7 @@ from tvm import te
 from tvm import relay
 import tvm.testing
 
+
 @tvm.testing.uses_gpu
 def test_argsort():
     def verify_argsort(shape, axis, is_ascend, dtype):
@@ -39,6 +40,7 @@ def test_argsort():
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
+
     for dtype in ["int32", "int64", "float32", "float64"]:
         verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
         verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
@@ -83,6 +85,7 @@ def test_topk():
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
                 else:
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+
     np.random.seed(0)
     for k in [0, 1, 5]:
         for axis in [0, -1, 1]:
index aa9de68..41ebfea 100644 (file)
@@ -21,32 +21,41 @@ from tvm import relay
 
 
 def test_tflite_same_io_qnn_params():
-    data_dtype = 'uint8'
+    data_dtype = "uint8"
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.00784314, 'float32'),
-                         lhs_zero_point=relay.const(127, 'int32'),
-                         rhs_scale=relay.const(0.00784314, 'float32'),
-                         rhs_zero_point=relay.const(127, 'int32'),
-                         output_scale=relay.const(0.00784314, 'float32'),
-                         output_zero_point=relay.const(127, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.00784314, "float32"),
+        lhs_zero_point=relay.const(127, "int32"),
+        rhs_scale=relay.const(0.00784314, "float32"),
+        rhs_zero_point=relay.const(127, "int32"),
+        output_scale=relay.const(0.00784314, "float32"),
+        output_zero_point=relay.const(127, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
-    x_datas = [np.array((140, 153, 165, 178)).reshape((1, 4)),
-               np.array((25, 153, 178, 216)).reshape((1, 4)),
-               np.array((25, 153, 216, 165)).reshape((1, 4))]
-    y_datas = [np.array((204, 178, 165, 140)).reshape((1, 4)),
-               np.array((204, 178, 191, 25)).reshape((1, 4)),
-               np.array((204, 178, 25, 191)).reshape((1, 4))]
-    golden_outputs = [np.array((217, 204, 203, 191)).reshape((1, 4)),
-                      np.array((102, 204, 242, 114)).reshape((1, 4)),
-                      np.array((102, 204, 114, 229)).reshape((1, 4))]
+    x_datas = [
+        np.array((140, 153, 165, 178)).reshape((1, 4)),
+        np.array((25, 153, 178, 216)).reshape((1, 4)),
+        np.array((25, 153, 216, 165)).reshape((1, 4)),
+    ]
+    y_datas = [
+        np.array((204, 178, 165, 140)).reshape((1, 4)),
+        np.array((204, 178, 191, 25)).reshape((1, 4)),
+        np.array((204, 178, 25, 191)).reshape((1, 4)),
+    ]
+    golden_outputs = [
+        np.array((217, 204, 203, 191)).reshape((1, 4)),
+        np.array((102, 204, 242, 114)).reshape((1, 4)),
+        np.array((102, 204, 114, 229)).reshape((1, 4)),
+    ]
 
     for i in range(0, 3):
         x_data = x_datas[i]
@@ -59,32 +68,41 @@ def test_tflite_same_io_qnn_params():
 
 
 def test_tflite_different_io_qnn_params():
-    data_dtype = 'uint8'
+    data_dtype = "uint8"
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.0156863, 'float32'),
-                         lhs_zero_point=relay.const(127, 'int32'),
-                         rhs_scale=relay.const(0.0117647, 'float32'),
-                         rhs_zero_point=relay.const(85, 'int32'),
-                         output_scale=relay.const(0.0235294, 'float32'),
-                         output_zero_point=relay.const(128, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.0156863, "float32"),
+        lhs_zero_point=relay.const(127, "int32"),
+        rhs_scale=relay.const(0.0117647, "float32"),
+        rhs_zero_point=relay.const(85, "int32"),
+        output_scale=relay.const(0.0235294, "float32"),
+        output_zero_point=relay.const(128, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
-    x_datas = [np.array((76, 140, 153, 172)).reshape((1, 4)),
-               np.array((133, 140, 146, 153)).reshape((1, 4)),
-               np.array((76, 140, 172, 146)).reshape((1, 4))]
-    y_datas = [np.array((136, 119, 128, 17)).reshape((1, 4)),
-               np.array((136, 119, 111, 94)).reshape((1, 4)),
-               np.array((136, 119, 17, 128)).reshape((1, 4))]
-    golden_outputs = [np.array((120, 154, 167, 124)).reshape((1, 4)),
-                      np.array((158, 154, 154, 150)).reshape((1, 4)),
-                      np.array((120, 154, 124, 163)).reshape((1, 4))]
+    x_datas = [
+        np.array((76, 140, 153, 172)).reshape((1, 4)),
+        np.array((133, 140, 146, 153)).reshape((1, 4)),
+        np.array((76, 140, 172, 146)).reshape((1, 4)),
+    ]
+    y_datas = [
+        np.array((136, 119, 128, 17)).reshape((1, 4)),
+        np.array((136, 119, 111, 94)).reshape((1, 4)),
+        np.array((136, 119, 17, 128)).reshape((1, 4)),
+    ]
+    golden_outputs = [
+        np.array((120, 154, 167, 124)).reshape((1, 4)),
+        np.array((158, 154, 154, 150)).reshape((1, 4)),
+        np.array((120, 154, 124, 163)).reshape((1, 4)),
+    ]
 
     for i in range(0, 3):
         x_data = x_datas[i]
@@ -98,16 +116,19 @@ def test_tflite_different_io_qnn_params():
 
 def test_saturation():
     # Same params
-    data_dtype = 'uint8'
+    data_dtype = "uint8"
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.125, 'float32'),
-                         lhs_zero_point=relay.const(0, 'int32'),
-                         rhs_scale=relay.const(0.125, 'float32'),
-                         rhs_zero_point=relay.const(0, 'int32'),
-                         output_scale=relay.const(0.125, 'float32'),
-                         output_zero_point=relay.const(0, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.125, "float32"),
+        lhs_zero_point=relay.const(0, "int32"),
+        rhs_scale=relay.const(0.125, "float32"),
+        rhs_zero_point=relay.const(0, "int32"),
+        output_scale=relay.const(0.125, "float32"),
+        output_zero_point=relay.const(0, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -123,13 +144,16 @@ def test_saturation():
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
     # Same params, different scale
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.125, 'float32'),
-                         lhs_zero_point=relay.const(0, 'int32'),
-                         rhs_scale=relay.const(0.125, 'float32'),
-                         rhs_zero_point=relay.const(0, 'int32'),
-                         output_scale=relay.const(0.25, 'float32'),
-                         output_zero_point=relay.const(0, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.125, "float32"),
+        lhs_zero_point=relay.const(0, "int32"),
+        rhs_scale=relay.const(0.125, "float32"),
+        rhs_zero_point=relay.const(0, "int32"),
+        output_scale=relay.const(0.25, "float32"),
+        output_zero_point=relay.const(0, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -145,13 +169,16 @@ def test_saturation():
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
     # Same io params, different output scale
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.125, 'float32'),
-                         lhs_zero_point=relay.const(0, 'int32'),
-                         rhs_scale=relay.const(0.125, 'float32'),
-                         rhs_zero_point=relay.const(0, 'int32'),
-                         output_scale=relay.const(0.25, 'float32'),
-                         output_zero_point=relay.const(0, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.125, "float32"),
+        lhs_zero_point=relay.const(0, "int32"),
+        rhs_scale=relay.const(0.125, "float32"),
+        rhs_zero_point=relay.const(0, "int32"),
+        output_scale=relay.const(0.25, "float32"),
+        output_zero_point=relay.const(0, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -167,13 +194,16 @@ def test_saturation():
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
     # All params different
-    z = relay.qnn.op.add(lhs=x, rhs=y,
-                         lhs_scale=relay.const(0.5, 'float32'),
-                         lhs_zero_point=relay.const(0, 'int32'),
-                         rhs_scale=relay.const(0.25, 'float32'),
-                         rhs_zero_point=relay.const(0, 'int32'),
-                         output_scale=relay.const(0.125, 'float32'),
-                         output_zero_point=relay.const(0, 'int32'))
+    z = relay.qnn.op.add(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(0.5, "float32"),
+        lhs_zero_point=relay.const(0, "int32"),
+        rhs_scale=relay.const(0.25, "float32"),
+        rhs_zero_point=relay.const(0, "int32"),
+        output_scale=relay.const(0.125, "float32"),
+        output_zero_point=relay.const(0, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -189,7 +219,7 @@ def test_saturation():
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_tflite_same_io_qnn_params()
     test_tflite_different_io_qnn_params()
     test_saturation()
index 19025c7..230e8a8 100644 (file)
@@ -22,23 +22,26 @@ from tvm import relay
 from tvm.contrib import graph_runtime
 import tvm.topi.testing
 
+
 def test_same_io_qnn_params():
-    data_dtype = 'int32'
+    data_dtype = "int32"
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
-    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    zero = relay.const(0, 'int32')
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    zero = relay.const(0, "int32")
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
-    z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=(x_scale, y_scale),
-                                 input_zero_points=(zero, zero),
-                                 output_scale=y_scale,
-                                 output_zero_point=zero,
-                                 axis=axis)
+    z = relay.qnn.op.concatenate(
+        (x, y),
+        input_scales=(x_scale, y_scale),
+        input_zero_points=(zero, zero),
+        output_scale=y_scale,
+        output_zero_point=zero,
+        axis=axis,
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -51,25 +54,28 @@ def test_same_io_qnn_params():
     op_res = intrp.evaluate(func)(x_data, y_data)
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
+
 def test_different_io_qnn_params():
-    data_dtype = 'int32'
+    data_dtype = "int32"
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
 
-    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    x_zero_point = relay.const(3, 'int32')
-    y_zero_point = relay.const(4, 'int32')
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    x_zero_point = relay.const(3, "int32")
+    y_zero_point = relay.const(4, "int32")
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
-    z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=(x_scale, y_scale),
-                                 input_zero_points=(x_zero_point, y_zero_point),
-                                 output_scale=y_scale,
-                                 output_zero_point=relay.const(1, 'int32'),
-                                 axis=axis)
+    z = relay.qnn.op.concatenate(
+        (x, y),
+        input_scales=(x_scale, y_scale),
+        input_zero_points=(x_zero_point, y_zero_point),
+        output_scale=y_scale,
+        output_zero_point=relay.const(1, "int32"),
+        axis=axis,
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -82,25 +88,28 @@ def test_different_io_qnn_params():
     op_res = intrp.evaluate(func)(x_data, y_data)
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
+
 def test_few_same_io_qnn_params():
-    data_dtype = 'int32'
+    data_dtype = "int32"
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
 
-    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    x_zero_point = relay.const(0, 'int32')
-    y_zero_point = relay.const(1, 'int32')
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    x_zero_point = relay.const(0, "int32")
+    y_zero_point = relay.const(1, "int32")
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
-    z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=(x_scale, y_scale),
-                                 input_zero_points=(x_zero_point, y_zero_point),
-                                 output_scale=y_scale,
-                                 output_zero_point=relay.const(1, 'int32'),
-                                 axis=axis)
+    z = relay.qnn.op.concatenate(
+        (x, y),
+        input_scales=(x_scale, y_scale),
+        input_zero_points=(x_zero_point, y_zero_point),
+        output_scale=y_scale,
+        output_zero_point=relay.const(1, "int32"),
+        axis=axis,
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -113,25 +122,28 @@ def test_few_same_io_qnn_params():
     op_res = intrp.evaluate(func)(x_data, y_data)
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
+
 def test_same_i_qnn_params():
-    data_dtype = 'int32'
+    data_dtype = "int32"
     axis = 0
     x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
     y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
 
-    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
-    x_zero_point = relay.const(0, 'int32')
-    y_zero_point = relay.const(0, 'int32')
+    x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), "float32")
+    x_zero_point = relay.const(0, "int32")
+    y_zero_point = relay.const(0, "int32")
 
     x = relay.var("x", shape=(1, 64), dtype=data_dtype)
     y = relay.var("y", shape=(1, 64), dtype=data_dtype)
-    z = relay.qnn.op.concatenate((x, y),
-                                 input_scales=(x_scale, y_scale),
-                                 input_zero_points=(x_zero_point, y_zero_point),
-                                 output_scale=y_scale,
-                                 output_zero_point=relay.const(1, 'int32'),
-                                 axis=axis)
+    z = relay.qnn.op.concatenate(
+        (x, y),
+        input_scales=(x_scale, y_scale),
+        input_zero_points=(x_zero_point, y_zero_point),
+        output_scale=y_scale,
+        output_zero_point=relay.const(1, "int32"),
+        axis=axis,
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -144,31 +156,35 @@ def test_same_i_qnn_params():
     op_res = intrp.evaluate(func)(x_data, y_data)
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
+
 def test_call_input():
     # This tests the case where the input to concatenate is not explicitly a
     # tuple node but is instead a call node.
-    x_data = np.ones(shape=(64,)).astype('uint8')
+    x_data = np.ones(shape=(64,)).astype("uint8")
 
-    x = relay.var("x", shape=(64,), dtype='uint8')
-    x_scale = relay.const(1, 'float32')
-    y_scale = relay.const(1, 'float32')
-    x_zero_point = relay.const(0, 'int32')
-    y_zero_point = relay.const(0, 'int32')
+    x = relay.var("x", shape=(64,), dtype="uint8")
+    x_scale = relay.const(1, "float32")
+    y_scale = relay.const(1, "float32")
+    x_zero_point = relay.const(0, "int32")
+    y_zero_point = relay.const(0, "int32")
 
     tup = relay.split(x, 2, axis=0)
-    z = relay.qnn.op.concatenate(tup,
-                                 input_scales=(x_scale, y_scale),
-                                 input_zero_points=(x_zero_point, y_zero_point),
-                                 output_scale=y_scale,
-                                 output_zero_point=relay.const(0, 'int32'),
-                                 axis=0)
+    z = relay.qnn.op.concatenate(
+        tup,
+        input_scales=(x_scale, y_scale),
+        input_zero_points=(x_zero_point, y_zero_point),
+        output_scale=y_scale,
+        output_zero_point=relay.const(0, "int32"),
+        axis=0,
+    )
     func = relay.Function([x], z)
 
     intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
     op_res = intrp.evaluate(func)(x_data)
     np.testing.assert_equal(op_res.asnumpy(), x_data)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_call_input()
     test_same_io_qnn_params()
     test_different_io_qnn_params()
index bb848e9..e14bba2 100644 (file)
@@ -31,136 +31,149 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
 def legalize_qnn_conv2d(attrs, inputs, types):
     return None
 
-def get_ref_func(data,
-                 kernel,
-                 input_zero_point,
-                 kernel_zero_point,
-                 input_scale,
-                 kernel_scale,
-                 kernel_size,
-                 padding,
-                 strides,
-                 dilation,
-                 data_layout,
-                 kernel_layout,
-                 out_dtype,
-                 groups,
-                 channels=None):
+
+def get_ref_func(
+    data,
+    kernel,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    data_layout,
+    kernel_layout,
+    out_dtype,
+    groups,
+    channels=None,
+):
     casted_data = relay.op.cast(data, "int32")
     casted_kernel = relay.op.cast(kernel, "int32")
-    shifted_data = relay.op.subtract(casted_data,
-            relay.const(input_zero_point, "int32"))
-    shifted_kernel = relay.op.subtract(casted_kernel,
-            relay.const(kernel_zero_point, "int32"))
-    func = relay.op.nn.conv2d(shifted_data,
-                              shifted_kernel,
-                              padding=padding,
-                              strides=strides,
-                              dilation=dilation,
-                              groups=groups,
-                              channels=channels,
-                              kernel_size=kernel_size,
-                              out_dtype=out_dtype,
-                              data_layout=data_layout,
-                              kernel_layout=kernel_layout)
+    shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32"))
+    shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32"))
+    func = relay.op.nn.conv2d(
+        shifted_data,
+        shifted_kernel,
+        padding=padding,
+        strides=strides,
+        dilation=dilation,
+        groups=groups,
+        channels=channels,
+        kernel_size=kernel_size,
+        out_dtype=out_dtype,
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
 
     func = relay.Function(relay.analysis.free_vars(func), func)
     return func
 
-def get_qnn_func(data,
-                 kernel,
-                 input_zero_point,
-                 kernel_zero_point,
-                 input_scale,
-                 kernel_scale,
-                 kernel_size,
-                 padding,
-                 strides,
-                 dilation,
-                 data_layout,
-                 kernel_layout,
-                 out_dtype,
-                 channels,
-                 groups):
+
+def get_qnn_func(
+    data,
+    kernel,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    data_layout,
+    kernel_layout,
+    out_dtype,
+    channels,
+    groups,
+):
     func = relay.qnn.op.conv2d(
-            data, kernel,
-            input_zero_point=relay.const(input_zero_point, 'int32'),
-            kernel_zero_point=relay.const(kernel_zero_point, 'int32'),
-            input_scale=relay.const(input_scale, 'float32'),
-            kernel_scale=relay.const(kernel_scale, 'float32'),
-            kernel_size=kernel_size,
-            strides=strides,
-            dilation=dilation,
-            padding=padding,
-            out_dtype=out_dtype,
-            groups=groups,
-            channels=channels,
-            data_layout=data_layout,
-            kernel_layout=kernel_layout)
+        data,
+        kernel,
+        input_zero_point=relay.const(input_zero_point, "int32"),
+        kernel_zero_point=relay.const(kernel_zero_point, "int32"),
+        input_scale=relay.const(input_scale, "float32"),
+        kernel_scale=relay.const(kernel_scale, "float32"),
+        kernel_size=kernel_size,
+        strides=strides,
+        dilation=dilation,
+        padding=padding,
+        out_dtype=out_dtype,
+        groups=groups,
+        channels=channels,
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+    )
 
     mod = relay.Function(relay.analysis.free_vars(func), func)
     mod = tvm.IRModule.from_expr(mod)
     return mod
 
-def get_funcs(data_shape,
-              data_dtype,
-              kernel_shape,
-              kernel_dtype,
-              input_zero_point,
-              kernel_zero_point,
-              input_scale,
-              kernel_scale,
-              kernel_size,
-              padding,
-              strides,
-              dilation,
-              data_layout,
-              kernel_layout,
-              out_dtype,
-              groups=1,
-              channels=None):
-    data = relay.var("data", shape=data_shape,
-            dtype=data_dtype)
-    kernel = relay.var("kernel", shape=kernel_shape,
-            dtype=kernel_dtype)
-
-    ref_func = get_ref_func(data,
-                            kernel,
-                            input_zero_point,
-                            kernel_zero_point,
-                            input_scale,
-                            kernel_scale,
-                            kernel_size,
-                            padding,
-                            strides,
-                            dilation,
-                            data_layout,
-                            kernel_layout,
-                            out_dtype,
-                            groups,
-                            channels)
+
+def get_funcs(
+    data_shape,
+    data_dtype,
+    kernel_shape,
+    kernel_dtype,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    data_layout,
+    kernel_layout,
+    out_dtype,
+    groups=1,
+    channels=None,
+):
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
+
+    ref_func = get_ref_func(
+        data,
+        kernel,
+        input_zero_point,
+        kernel_zero_point,
+        input_scale,
+        kernel_scale,
+        kernel_size,
+        padding,
+        strides,
+        dilation,
+        data_layout,
+        kernel_layout,
+        out_dtype,
+        groups,
+        channels,
+    )
     ref_func = run_infer_type(ref_func)
     ref_func = tvm.IRModule.from_expr(ref_func)
-    qnn_func = get_qnn_func(data,
-                            kernel,
-                            input_zero_point,
-                            kernel_zero_point,
-                            input_scale,
-                            kernel_scale,
-                            kernel_size,
-                            padding,
-                            strides,
-                            dilation,
-                            data_layout,
-                            kernel_layout,
-                            out_dtype,
-                            channels,
-                            groups)
+    qnn_func = get_qnn_func(
+        data,
+        kernel,
+        input_zero_point,
+        kernel_zero_point,
+        input_scale,
+        kernel_scale,
+        kernel_size,
+        padding,
+        strides,
+        dilation,
+        data_layout,
+        kernel_layout,
+        out_dtype,
+        channels,
+        groups,
+    )
 
     return (ref_func, qnn_func)
 
-def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
-        kernel_dtype):
+
+def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype):
     def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype):
         # Keeping inputs multiple of 4 because of a bug in Average Pool2d
         # https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377
@@ -169,22 +182,21 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
         if data_dtype == "uint8":
             low = 0
             high = 255
-        golden_data = np.random.randint(low=low, high=high,
-                size=data_shape).astype(data_dtype)
+        golden_data = np.random.randint(low=low, high=high, size=data_shape).astype(data_dtype)
         low = -128
         high = 127
         if kernel_dtype == "uint8":
             low = 0
             high = 255
-        golden_weight = np.random.randint(low=low, high=high,
-                size=kernel_shape).astype(kernel_dtype)
+        golden_weight = np.random.randint(low=low, high=high, size=kernel_shape).astype(
+            kernel_dtype
+        )
         return (golden_data, golden_weight)
 
-
     def get_output(func, golden_inputs):
         with tvm.transform.PassContext(opt_level=2):
             golden_data, golden_weight = golden_inputs
-            params = {'kernel': golden_weight}
+            params = {"kernel": golden_weight}
             graph, lib, params = relay.build(func, "llvm", params=params)
             mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
             mod.set_input("data", golden_data)
@@ -192,109 +204,115 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
             mod.run()
             res = mod.get_output(0).asnumpy()
             return res
-    golden_inputs = get_inputs(data_shape, data_dtype,
-            kernel_shape, kernel_dtype)
+
+    golden_inputs = get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype)
     golden_output = get_output(ref_func, golden_inputs)
     qnn_output = get_output(qnn_func, golden_inputs)
     np.testing.assert_equal(qnn_output, golden_output)
 
+
 def test_no_zero_point():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 1, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 1, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=0,
-                                       kernel_zero_point=0,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=0,
+            kernel_zero_point=0,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # int8 input
         data_shape = (2, 1, 2, 4)
-        data_dtype = 'int8'
+        data_dtype = "int8"
         kernel_shape = (3, 1, 2, 2)
-        kernel_dtype = 'int8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=0,
-                                       kernel_zero_point=0,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "int8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=0,
+            kernel_zero_point=0,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_kernel_zero_point():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=0,
-                                       kernel_zero_point=1,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=0,
+            kernel_zero_point=1,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # int8 input
         data_shape = (2, 1, 2, 4)
-        data_dtype = 'int8'
+        data_dtype = "int8"
         kernel_shape = (3, 1, 2, 2)
-        kernel_dtype = 'int8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=0,
-                                       kernel_zero_point=5,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "int8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=0,
+            kernel_zero_point=5,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
 
 def test_input_zero_point():
@@ -302,149 +320,156 @@ def test_input_zero_point():
 
         # uint8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=0,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=0,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # int8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'int8'
+        data_dtype = "int8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'int8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=0,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "int8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=0,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_both_zero_point():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # int8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'int8'
+        data_dtype = "int8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'int8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "int8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_layout():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
-        data_shape = (2, 2, 4, 4) # NHWC
-        data_dtype = 'uint8'
-        kernel_shape = (2, 2, 4, 3) # HWIO
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NHWC",
-                                       kernel_layout="HWIO",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        data_shape = (2, 2, 4, 4)  # NHWC
+        data_dtype = "uint8"
+        kernel_shape = (2, 2, 4, 3)  # HWIO
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # NHWC and HWOI layout. Used in depthwise conv.
-        data_shape = (2, 2, 4, 3) # NHWC
-        data_dtype = 'uint8'
-        kernel_shape = (2, 2, 3, 1) # HWOI
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       groups=3,
-                                       data_layout="NHWC",
-                                       kernel_layout="HWOI",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
-
+        data_shape = (2, 2, 4, 3)  # NHWC
+        data_dtype = "uint8"
+        kernel_shape = (2, 2, 3, 1)  # HWOI
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            groups=3,
+            data_layout="NHWC",
+            kernel_layout="HWOI",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
 
 def test_padding():
@@ -452,72 +477,75 @@ def test_padding():
 
         # uint8 input
         data_shape = (1, 4, 2, 2)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=8,
-                                       kernel_zero_point=5,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(1, 1),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=8,
+            kernel_zero_point=5,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(1, 1),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # Try different layout
-        data_shape = (2, 2, 4, 4) # NHWC
-        data_dtype = 'uint8'
-        kernel_shape = (2, 2, 4, 3) # HWIO
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=8,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(1, 1),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NHWC",
-                                       kernel_layout="HWIO",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        data_shape = (2, 2, 4, 4)  # NHWC
+        data_dtype = "uint8"
+        kernel_shape = (2, 2, 4, 3)  # HWIO
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=8,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(1, 1),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # Try asymmetric padding
-        data_shape = (2, 2, 4, 4) # NHWC
-        data_dtype = 'uint8'
-        kernel_shape = (2, 2, 4, 3) # HWIO
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=8,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(1, 1, 2, 2),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NHWC",
-                                       kernel_layout="HWIO",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        data_shape = (2, 2, 4, 4)  # NHWC
+        data_dtype = "uint8"
+        kernel_shape = (2, 2, 4, 3)  # HWIO
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=8,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(1, 1, 2, 2),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
 
 def test_dilation():
@@ -525,254 +553,282 @@ def test_dilation():
 
         # Non-zero kernel point - fall back to simpler lowering.
         data_shape = (2, 4, 4, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(2, 2),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(2, 2),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # Zero kernel point
         data_shape = (2, 4, 4, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=0,
-                                       kernel_zero_point=0,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(2, 2),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=0,
+            kernel_zero_point=0,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(2, 2),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
 
 def test_const_folding():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 2, 2)
-        kernel_dtype = 'uint8'
+        kernel_dtype = "uint8"
 
-        golden_weight = np.random.randint(low=0, high=255,
-                size=kernel_shape).astype(kernel_dtype)
-        data = relay.var("data", shape=data_shape,
-                dtype=data_dtype)
+        golden_weight = np.random.randint(low=0, high=255, size=kernel_shape).astype(kernel_dtype)
+        data = relay.var("data", shape=data_shape, dtype=data_dtype)
         kernel = relay.const(golden_weight)
-        qnn_func = get_qnn_func(data,
-                                kernel,
-                                input_zero_point=8,
-                                kernel_zero_point=3,
-                                kernel_size=(2, 2),
-                                input_scale=1.0,
-                                kernel_scale=1.0,
-                                padding=(0, 0),
-                                strides=(1, 1),
-                                dilation=(1, 1),
-                                data_layout="NCHW",
-                                kernel_layout="OIHW",
-                                out_dtype="int32",
-                                channels=kernel_shape[0],
-                                groups=1)
+        qnn_func = get_qnn_func(
+            data,
+            kernel,
+            input_zero_point=8,
+            kernel_zero_point=3,
+            kernel_size=(2, 2),
+            input_scale=1.0,
+            kernel_scale=1.0,
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+            channels=kernel_shape[0],
+            groups=1,
+        )
         folded_mod = transform.FoldConstant()(qnn_func)
         folded_func = folded_mod["main"]
         assert "reshape" not in folded_func.astext()
 
+
 def test_kernel_size_1x1():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 1, 1)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(1, 1),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        assert 'avg_pool2d' not in qnn_func.astext()
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(1, 1),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        assert "avg_pool2d" not in qnn_func.astext()
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_kernel_size_1x1_strides_2():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 4, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 4, 1, 1)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(1, 1),
-                                       padding=(0, 0),
-                                       strides=(2, 2),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        assert 'avg_pool2d' not in qnn_func.astext()
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(1, 1),
+            padding=(0, 0),
+            strides=(2, 2),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        assert "avg_pool2d" not in qnn_func.astext()
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_tflite_large_irregular():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (1, 1024, 1, 1)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (1001, 1024, 1, 1)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=127,
-                                       kernel_zero_point=127,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(1, 1),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        golden_data = np.full(data_shape, 127).astype('uint8')
-        golden_weight = np.full(kernel_shape, 127).astype('uint8')
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=127,
+            kernel_zero_point=127,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(1, 1),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        golden_data = np.full(data_shape, 127).astype("uint8")
+        golden_weight = np.full(kernel_shape, 127).astype("uint8")
 
         with tvm.transform.PassContext(opt_level=2):
-            params = {'kernel': golden_weight}
+            params = {"kernel": golden_weight}
             graph, lib, params = relay.build(qnn_func, "llvm", params=params)
             mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
             mod.set_input("data", golden_data)
             mod.set_input(**params)
             mod.run()
             qnn_output = mod.get_output(0).asnumpy()
-        golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
+        golden_output = np.full((1, 1001, 1, 1), 0).astype("uint8")
         np.testing.assert_equal(qnn_output, golden_output)
 
+
 def test_tflite_output_multiplier_greater_than_one():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (2, 1, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 1, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       input_zero_point=128,
-                                       kernel_zero_point=128,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(2, 2),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        golden_data = 128 + np.array((1, 1, 1, 1,
-                                      2, 2, 2, 2,
-                                      1, 2, 3, 4,
-                                      1, 2, 3, 4)).reshape(data_shape).astype('uint8')
-        golden_weight = 128 + np.array((1, 2, 3, 4,
-                                        -1, 1, -1, 1,
-                                        -1, -1, 1, 1)).reshape(kernel_shape)
-        golden_weight = golden_weight.astype('uint8')
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            input_zero_point=128,
+            kernel_zero_point=128,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(2, 2),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        golden_data = 128 + np.array((1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 3, 4, 1, 2, 3, 4)).reshape(
+            data_shape
+        ).astype("uint8")
+        golden_weight = 128 + np.array((1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1)).reshape(
+            kernel_shape
+        )
+        golden_weight = golden_weight.astype("uint8")
 
         with tvm.transform.PassContext(opt_level=2):
-            params = {'kernel': golden_weight}
+            params = {"kernel": golden_weight}
             graph, lib, params = relay.build(qnn_func, "llvm", params=params)
             mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
             mod.set_input("data", golden_data)
             mod.set_input(**params)
             mod.run()
             qnn_output = mod.get_output(0).asnumpy()
-        golden_output = np.array((17, 17,
-                                  0, 0,
-                                  2, 2,
-                                  16, 36,
-                                  2, 2,
-                                  0, 0)).reshape(2, 3, 1, 2)
+        golden_output = np.array((17, 17, 0, 0, 2, 2, 16, 36, 2, 2, 0, 0)).reshape(2, 3, 1, 2)
         np.testing.assert_equal(qnn_output, golden_output)
 
+
 def test_tflite_anistropic_strides():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input
         data_shape = (1, 1, 3, 6)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (1, 1, 2, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=127,
-                                       kernel_zero_point=127,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(2, 2),
-                                       padding=(0, 0),
-                                       strides=(1, 3),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32")
-        golden_data = np.array((133, 131, 129, 125, 123, 121,
-                                135, 133, 131, 123, 121, 119,
-                                137, 135, 133, 121, 119, 117)).reshape(data_shape)
-        golden_data = golden_data.astype('uint8')
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=127,
+            kernel_zero_point=127,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+            strides=(1, 3),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
+        golden_data = np.array(
+            (
+                133,
+                131,
+                129,
+                125,
+                123,
+                121,
+                135,
+                133,
+                131,
+                123,
+                121,
+                119,
+                137,
+                135,
+                133,
+                121,
+                119,
+                117,
+            )
+        ).reshape(data_shape)
+        golden_data = golden_data.astype("uint8")
         golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape)
-        golden_weight = golden_weight.astype('uint8')
+        golden_weight = golden_weight.astype("uint8")
 
         with tvm.transform.PassContext(opt_level=2):
-            params = {'kernel': golden_weight}
+            params = {"kernel": golden_weight}
             graph, lib, params = relay.build(qnn_func, "llvm", params=params)
             mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
             mod.set_input("data", golden_data)
@@ -782,30 +838,33 @@ def test_tflite_anistropic_strides():
         golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
         np.testing.assert_equal(qnn_output, golden_output)
 
+
 def test_broadcast_layout():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # Test broadcast support for NHWC layout.
-        data_shape = (1, 229, 229, 3) # NHWC
-        data_dtype = 'uint8'
-        kernel_shape = (7, 7, 3, 64) # HWIO
-        kernel_dtype = 'int8'
-        _, qnn_func = get_funcs(data_shape=data_shape,
-                                data_dtype=data_dtype,
-                                kernel_shape=kernel_shape,
-                                kernel_dtype=kernel_dtype,
-                                input_zero_point=8,
-                                kernel_zero_point=3,
-                                input_scale=1.0,
-                                kernel_scale=1.0,
-                                kernel_size=(7, 7),
-                                padding=(1, 1),
-                                strides=(1, 1),
-                                dilation=(1, 1),
-                                data_layout="NHWC",
-                                kernel_layout="HWIO",
-                                out_dtype="int32")
-        func = qnn_func['main'].body
+        data_shape = (1, 229, 229, 3)  # NHWC
+        data_dtype = "uint8"
+        kernel_shape = (7, 7, 3, 64)  # HWIO
+        kernel_dtype = "int8"
+        _, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=8,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(7, 7),
+            padding=(1, 1),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
+        func = qnn_func["main"].body
         bias = relay.var("bias", shape=(64,), dtype="int32")
         bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")
 
@@ -819,141 +878,147 @@ def test_broadcast_layout():
         with tvm.transform.PassContext(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
 
+
 def test_depthwise_depth_multiplier():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
         # uint8 input, NCHW and OIHW
         # Depthwise multiplier = 1
         data_shape = (2, 4, 16, 16)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (4, 1, 3, 3)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(3, 3),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32",
-                                       groups=4)
-
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
-
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(3, 3),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+            groups=4,
+        )
+
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # Depthwise multiplier = 2
         data_shape = (10, 4, 16, 16)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (4, 2, 3, 3)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(3, 3),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NCHW",
-                                       kernel_layout="OIHW",
-                                       out_dtype="int32",
-                                       groups=4,
-                                       channels=8)
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(3, 3),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+            groups=4,
+            channels=8,
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # uint8 input, NHWC and HWOI
         # Depthwise multiplier = 1
         data_shape = (2, 16, 16, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 3, 4, 1)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(3, 3),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NHWC",
-                                       kernel_layout="HWOI",
-                                       out_dtype="int32",
-                                       groups=4)
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(3, 3),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWOI",
+            out_dtype="int32",
+            groups=4,
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
 
         # Depthwise multiplier = 2
         data_shape = (2, 16, 16, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 3, 4, 2)
-        kernel_dtype = 'uint8'
-        ref_func, qnn_func = get_funcs(data_shape=data_shape,
-                                       data_dtype=data_dtype,
-                                       kernel_shape=kernel_shape,
-                                       kernel_dtype=kernel_dtype,
-                                       input_zero_point=5,
-                                       kernel_zero_point=3,
-                                       input_scale=1.0,
-                                       kernel_scale=1.0,
-                                       kernel_size=(3, 3),
-                                       padding=(0, 0),
-                                       strides=(1, 1),
-                                       dilation=(1, 1),
-                                       data_layout="NHWC",
-                                       kernel_layout="HWOI",
-                                       out_dtype="int32",
-                                       groups=4,
-                                       channels=8)
-        verify(ref_func, qnn_func, data_shape, data_dtype,
-                kernel_shape, kernel_dtype)
+        kernel_dtype = "uint8"
+        ref_func, qnn_func = get_funcs(
+            data_shape=data_shape,
+            data_dtype=data_dtype,
+            kernel_shape=kernel_shape,
+            kernel_dtype=kernel_dtype,
+            input_zero_point=5,
+            kernel_zero_point=3,
+            input_scale=1.0,
+            kernel_scale=1.0,
+            kernel_size=(3, 3),
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWOI",
+            out_dtype="int32",
+            groups=4,
+            channels=8,
+        )
+        verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype)
+
 
 def test_per_channel_kernel_scale():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
         data_shape = (2, 1, 2, 4)
-        data_dtype = 'uint8'
+        data_dtype = "uint8"
         kernel_shape = (3, 1, 2, 2)
-        kernel_dtype = 'uint8'
-        data = relay.var("data", shape=data_shape,
-                dtype=data_dtype)
-        kernel = relay.var("kernel", shape=kernel_shape,
-                dtype=kernel_dtype)
+        kernel_dtype = "uint8"
+        data = relay.var("data", shape=data_shape, dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
         kernel_scales = [2, 2, 2]
-        kernel_scales = relay.const(np.array(kernel_scales).astype('float32'))
+        kernel_scales = relay.const(np.array(kernel_scales).astype("float32"))
         func = relay.qnn.op.conv2d(
-                data, kernel,
-                input_zero_point=relay.const(0, 'int32'),
-                kernel_zero_point=relay.const(0, 'int32'),
-                input_scale=relay.const(2.0, 'float32'),
-                kernel_scale=kernel_scales,
-                kernel_size=(2, 2),
-                channels=kernel_shape[0],
-                padding=(0, 0),
-                strides=(1, 1),
-                dilation=(1, 1),
-                data_layout="NCHW",
-                kernel_layout="OIHW",
-                out_dtype="int32")
+            data,
+            kernel,
+            input_zero_point=relay.const(0, "int32"),
+            kernel_zero_point=relay.const(0, "int32"),
+            input_scale=relay.const(2.0, "float32"),
+            kernel_scale=kernel_scales,
+            kernel_size=(2, 2),
+            channels=kernel_shape[0],
+            padding=(0, 0),
+            strides=(1, 1),
+            dilation=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+            out_dtype="int32",
+        )
 
         mod = relay.Function(relay.analysis.free_vars(func), func)
         mod = tvm.IRModule.from_expr(mod)
 
+
 if __name__ == "__main__":
     test_no_zero_point()
     test_input_zero_point()
index 0ba3210..a05ee3f 100644 (file)
@@ -33,45 +33,47 @@ def legalize_qnn_dense(attrs, inputs, types):
 
 def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype):
     config = {
-        'input_scale': input_scale,
-        'output_scale': output_scale,
-        'output_zero_point': output_zero_point,
-        'out_dtype': out_dtype
+        "input_scale": input_scale,
+        "output_scale": output_scale,
+        "output_zero_point": output_zero_point,
+        "out_dtype": out_dtype,
     }
     return config
 
 
-def make_configuration(quantized_data,
-                       quantized_kernel,
-                       dtype,
-                       input_shape,
-                       kernel_shape,
-                       input_zero_point,
-                       kernel_zero_point,
-                       input_scale,
-                       kernel_scale,
-                       units,
-                       output,
-                       out_dtype='int32',
-                       bias=None,
-                       requantize=None):
+def make_configuration(
+    quantized_data,
+    quantized_kernel,
+    dtype,
+    input_shape,
+    kernel_shape,
+    input_zero_point,
+    kernel_zero_point,
+    input_scale,
+    kernel_scale,
+    units,
+    output,
+    out_dtype="int32",
+    bias=None,
+    requantize=None,
+):
     if requantize is not None:
         assert bias is not None
     config = {
-        'quantized_data': quantized_data,
-        'quantized_kernel': quantized_kernel,
-        'dtype': dtype,
-        'input_shape': input_shape,
-        'kernel_shape': kernel_shape,
-        'input_zero_point': input_zero_point,
-        'kernel_zero_point': kernel_zero_point,
-        'input_scale': input_scale,
-        'kernel_scale': kernel_scale,
-        'units': units,
-        'output': output,
-        'out_dtype': out_dtype,
-        'bias': bias,
-        'requantize': requantize
+        "quantized_data": quantized_data,
+        "quantized_kernel": quantized_kernel,
+        "dtype": dtype,
+        "input_shape": input_shape,
+        "kernel_shape": kernel_shape,
+        "input_zero_point": input_zero_point,
+        "kernel_zero_point": kernel_zero_point,
+        "input_scale": input_scale,
+        "kernel_scale": kernel_scale,
+        "units": units,
+        "output": output,
+        "out_dtype": out_dtype,
+        "bias": bias,
+        "requantize": requantize,
     }
     return config
 
@@ -79,22 +81,56 @@ def make_configuration(quantized_data,
 def make_int_configuration(use_bias=False, requantize_output=False, per_channel=False):
     input_shape, kernel_shape, output_shape = (2, 10), (3, 10), (2, 3)
     input_zero_point, kernel_zero_point = -1, -1
-    in_dtype = 'int8'
-    out_dtype = 'int32' if not requantize_output else 'int8'
+    in_dtype = "int8"
+    out_dtype = "int32" if not requantize_output else "int8"
     units = 3
-    quantized_data_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21,
-                                  1, 3, 5, 7, 9, 11, 13, -17, 17, -21]) \
-        .astype(in_dtype) \
+    quantized_data_np = (
+        np.array([1, 3, 5, 7, 9, 11, 13, 15, -19, -21, 1, 3, 5, 7, 9, 11, 13, -17, 17, -21])
+        .astype(in_dtype)
         .reshape(input_shape)
-    quantized_kernel_np = np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19,
-                                    1, 3, 5, 7, 9, 11, 13, 15, 17, 19,
-                                    1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \
-        .astype(in_dtype) \
+    )
+    quantized_kernel_np = (
+        np.array(
+            [
+                1,
+                3,
+                5,
+                7,
+                9,
+                11,
+                13,
+                15,
+                17,
+                19,
+                1,
+                3,
+                5,
+                7,
+                9,
+                11,
+                13,
+                15,
+                17,
+                19,
+                1,
+                3,
+                5,
+                7,
+                9,
+                11,
+                13,
+                15,
+                17,
+                19,
+            ]
+        )
+        .astype(in_dtype)
         .reshape(kernel_shape)
+    )
     input_scale = 0.5
     kernel_scale = 0.5
     output_scale = 1.0
-    bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
+    bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units,)) if use_bias else None
 
     if per_channel:
         assert use_bias and requantize_output
@@ -108,61 +144,66 @@ def make_int_configuration(use_bias=False, requantize_output=False, per_channel=
     else:
         output = np.array([92, 92, 92, 228, 228, 228])
 
-    requant_params = make_requantize_params(input_scale * kernel_scale,
-                                            output_scale, -1, 'int8') if requantize_output else None
+    requant_params = (
+        make_requantize_params(input_scale * kernel_scale, output_scale, -1, "int8")
+        if requantize_output
+        else None
+    )
 
     output = output.astype(out_dtype).reshape(output_shape)
-    return make_configuration(quantized_data=quantized_data_np,
-                              quantized_kernel=quantized_kernel_np,
-                              dtype=in_dtype,
-                              input_shape=input_shape,
-                              kernel_shape=kernel_shape,
-                              input_zero_point=input_zero_point,
-                              kernel_zero_point=kernel_zero_point,
-                              input_scale=input_scale,
-                              kernel_scale=kernel_scale,
-                              units=units,
-                              output=output,
-                              bias=bias,
-                              requantize=requant_params)
+    return make_configuration(
+        quantized_data=quantized_data_np,
+        quantized_kernel=quantized_kernel_np,
+        dtype=in_dtype,
+        input_shape=input_shape,
+        kernel_shape=kernel_shape,
+        input_zero_point=input_zero_point,
+        kernel_zero_point=kernel_zero_point,
+        input_scale=input_scale,
+        kernel_scale=kernel_scale,
+        units=units,
+        output=output,
+        bias=bias,
+        requantize=requant_params,
+    )
 
 
 def qnn_dense_driver(test_configuration):
-    in_dtype = test_configuration['dtype']
-    out_dtype = test_configuration['out_dtype']
+    in_dtype = test_configuration["dtype"]
+    out_dtype = test_configuration["out_dtype"]
     quantized_data_name = "quantized_data"
     quantized_kernel_name = "quantized_kernel"
-    expected_out_dtype = test_configuration['out_dtype']
-    bias_name = 'bias'
-    quantized_data = relay.var(quantized_data_name,
-                               shape=test_configuration['input_shape'],
-                               dtype=in_dtype)
-    quantized_kernel = relay.var(quantized_kernel_name,
-                                 shape=test_configuration['kernel_shape'],
-                                 dtype=in_dtype)
+    expected_out_dtype = test_configuration["out_dtype"]
+    bias_name = "bias"
+    quantized_data = relay.var(
+        quantized_data_name, shape=test_configuration["input_shape"], dtype=in_dtype
+    )
+    quantized_kernel = relay.var(
+        quantized_kernel_name, shape=test_configuration["kernel_shape"], dtype=in_dtype
+    )
     mod = relay.qnn.op.dense(
         quantized_data,
         quantized_kernel,
-        relay.const(test_configuration['input_zero_point'], 'int32'),
-        relay.const(test_configuration['kernel_zero_point'], 'int32'),
-        relay.const(test_configuration['input_scale'], 'float32'),
-        relay.const(test_configuration['kernel_scale'], 'float32'),
-        test_configuration['units'])
+        relay.const(test_configuration["input_zero_point"], "int32"),
+        relay.const(test_configuration["kernel_zero_point"], "int32"),
+        relay.const(test_configuration["input_scale"], "float32"),
+        relay.const(test_configuration["kernel_scale"], "float32"),
+        test_configuration["units"],
+    )
     if test_configuration[bias_name] is not None:
-        bias = relay.var(bias_name,
-                         shape=test_configuration['bias'].shape,
-                         dtype=out_dtype)
+        bias = relay.var(bias_name, shape=test_configuration["bias"].shape, dtype=out_dtype)
         mod = relay.nn.bias_add(mod, bias)
-    if test_configuration['requantize'] is not None:
-        requantize_config = test_configuration['requantize']
+    if test_configuration["requantize"] is not None:
+        requantize_config = test_configuration["requantize"]
         mod = relay.qnn.op.requantize(
             mod,
-            input_scale=relay.const(requantize_config['input_scale'], 'float32'),
-            input_zero_point=relay.const(0, 'int32'),
-            output_scale=relay.const(requantize_config['output_scale'], 'float32'),
-            output_zero_point=relay.const(requantize_config['output_zero_point'], 'int32'),
-            out_dtype=requantize_config['out_dtype'])
-        expected_out_dtype = requantize_config['out_dtype']
+            input_scale=relay.const(requantize_config["input_scale"], "float32"),
+            input_zero_point=relay.const(0, "int32"),
+            output_scale=relay.const(requantize_config["output_scale"], "float32"),
+            output_zero_point=relay.const(requantize_config["output_zero_point"], "int32"),
+            out_dtype=requantize_config["out_dtype"],
+        )
+        expected_out_dtype = requantize_config["out_dtype"]
 
     mod = relay.Function(relay.analysis.free_vars(mod), mod)
     mod = tvm.IRModule.from_expr(mod)
@@ -177,38 +218,36 @@ def qnn_dense_driver(test_configuration):
         mod.set_input(**params)
         mod.run()
         res = mod.get_output(0).asnumpy()
-        np.testing.assert_equal(res, test_configuration['output'])
+        np.testing.assert_equal(res, test_configuration["output"])
         assert res.dtype == expected_out_dtype
 
 
 def test_qnn_dense_without_bias():
     with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
 
-        int32_output_without_bias_params = \
-            make_int_configuration(use_bias=False)
+        int32_output_without_bias_params = make_int_configuration(use_bias=False)
         qnn_dense_driver(int32_output_without_bias_params)
 
 
 def test_qnn_dense_with_bias():
     with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
 
-        int32_output_with_bias_params = \
-            make_int_configuration(use_bias=True)
+        int32_output_with_bias_params = make_int_configuration(use_bias=True)
         qnn_dense_driver(int32_output_with_bias_params)
 
 
 def test_qnn_dense_with_requantized_output():
     with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
 
-        int8_requantized_output_with_bias_params = \
-            make_int_configuration(use_bias=True, requantize_output=True)
+        int8_requantized_output_with_bias_params = make_int_configuration(
+            use_bias=True, requantize_output=True
+        )
         qnn_dense_driver(int8_requantized_output_with_bias_params)
 
 
 def test_per_channel_weight_scale():
     with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
-        config = make_int_configuration(use_bias=True, requantize_output=True,
-                                        per_channel=True)
+        config = make_int_configuration(use_bias=True, requantize_output=True, per_channel=True)
         qnn_dense_driver(config)
 
 
index 361d6f0..6598e2b 100644 (file)
@@ -21,14 +21,15 @@ import numpy as np
 from tvm import relay
 from tvm.contrib import graph_runtime
 
+
 def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
     shape = in_data.shape
     input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
-    input_zero_point = relay.const(quant_args['in_zero_point'], 'int32')
-    input_scale = relay.const(quant_args['in_scale'], 'float32')
-    quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
-                                               input_zero_point=input_zero_point,
-                                               axis=axis)
+    input_zero_point = relay.const(quant_args["in_zero_point"], "int32")
+    input_scale = relay.const(quant_args["in_scale"], "float32")
+    quantized_output = relay.qnn.op.dequantize(
+        input_data, input_scale=input_scale, input_zero_point=input_zero_point, axis=axis
+    )
     mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
     mod = tvm.IRModule.from_expr(mod)
     with tvm.transform.PassContext(opt_level=3):
@@ -41,46 +42,63 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, ax
         np.testing.assert_equal(res, verify_output_data)
         assert res.dtype == np.float32
 
+
 def test_uint8_to_float32():
-    data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-        .astype('uint8') \
-        .reshape((2, 5))
-    output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-        .astype('float32') \
+    data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8").reshape((2, 5))
+    output = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
+        .astype("float32")
         .reshape((2, 5))
-    quant_args = {"in_zero_point":127, "in_scale":0.5}
-    dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
-                           verify_output_data=output, axis=-1)
+    )
+    quant_args = {"in_zero_point": 127, "in_scale": 0.5}
+    dequantize_test_driver(
+        in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=-1
+    )
+
 
 def test_int8_to_float32():
-    data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
-        .astype('int8') \
+    data = (
+        np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127])
+        .astype("int8")
         .reshape((2, 5))
-    output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-        .astype('float32') \
+    )
+    output = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
+        .astype("float32")
         .reshape((2, 5))
+    )
     quant_args = {"in_zero_point": -1, "in_scale": 0.5}
-    dequantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
-                           verify_output_data=output, axis=-1)
+    dequantize_test_driver(
+        in_dtype="int8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=-1
+    )
+
 
 def test_int32_to_float32():
-    data = np.array([113, 29, -1052]).astype('int32')
-    output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
+    data = np.array([113, 29, -1052]).astype("int32")
+    output = np.array([0.6550452, 0.16810896, -6.098297]).astype("float32")
     quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
-    dequantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
-                           verify_output_data=output, axis=-1)
+    dequantize_test_driver(
+        in_dtype="int32", quant_args=quant_args, in_data=data, verify_output_data=output, axis=-1
+    )
 
 
 def test_channelwise_axis_1():
-    data = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
-                        .astype('uint8').reshape((2,5)))
-    output = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
-                         .astype('float32').reshape((2,5)))
-    quant_args = {"in_zero_point" : np.array([127, 123]).astype('int32'),
-                  "in_scale"      : np.array([0.5, 0.25]).astype('float32')}
+    data = np.transpose(
+        np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
+    )
+    output = np.transpose(
+        np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    quant_args = {
+        "in_zero_point": np.array([127, 123]).astype("int32"),
+        "in_scale": np.array([0.5, 0.25]).astype("float32"),
+    }
 
-    dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
-                           verify_output_data=output, axis=1)
+    dequantize_test_driver(
+        in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=1
+    )
 
 
 if __name__ == "__main__":
index 4fbb4e9..17ec137 100644 (file)
@@ -44,13 +44,16 @@ def test_tflite_same_io_qnn_params():
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=relay.const(lhs_scale, 'float32'),
-                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
-                         rhs_scale=relay.const(rhs_scale, 'float32'),
-                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
-                         output_scale=relay.const(output_scale, 'float32'),
-                         output_zero_point=relay.const(output_zero_point, 'int32'))
+    z = relay.qnn.op.mul(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(lhs_scale, "float32"),
+        lhs_zero_point=relay.const(lhs_zero_point, "int32"),
+        rhs_scale=relay.const(rhs_scale, "float32"),
+        rhs_zero_point=relay.const(rhs_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -74,8 +77,7 @@ def test_tflite_same_io_qnn_params():
 
         x_rec = recover(x_data, lhs_scale, lhs_zero_point)
         y_rec = recover(y_data, rhs_scale, rhs_zero_point)
-        golden = generate_golden_output(x_rec, y_rec, output_scale,
-            output_zero_point)
+        golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point)
 
         intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
         op_res = intrp.evaluate(func)(x_data, y_data)
@@ -95,13 +97,16 @@ def test_tflite_different_io_qnn_params():
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=relay.const(lhs_scale, 'float32'),
-                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
-                         rhs_scale=relay.const(rhs_scale, 'float32'),
-                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
-                         output_scale=relay.const(output_scale, 'float32'),
-                         output_zero_point=relay.const(output_zero_point, 'int32'))
+    z = relay.qnn.op.mul(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(lhs_scale, "float32"),
+        lhs_zero_point=relay.const(lhs_zero_point, "int32"),
+        rhs_scale=relay.const(rhs_scale, "float32"),
+        rhs_zero_point=relay.const(rhs_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -125,8 +130,7 @@ def test_tflite_different_io_qnn_params():
 
         x_rec = recover(x_data, lhs_scale, lhs_zero_point)
         y_rec = recover(y_data, rhs_scale, rhs_zero_point)
-        golden = generate_golden_output(x_rec, y_rec, output_scale,
-            output_zero_point)
+        golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point)
 
         intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
         op_res = intrp.evaluate(func)(x_data, y_data)
@@ -141,13 +145,16 @@ def test_saturation():
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=relay.const(lhs_scale, 'float32'),
-                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
-                         rhs_scale=relay.const(rhs_scale, 'float32'),
-                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
-                         output_scale=relay.const(output_scale, 'float32'),
-                         output_zero_point=relay.const(output_zero_point, 'int32'))
+    z = relay.qnn.op.mul(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(lhs_scale, "float32"),
+        lhs_zero_point=relay.const(lhs_zero_point, "int32"),
+        rhs_scale=relay.const(rhs_scale, "float32"),
+        rhs_zero_point=relay.const(rhs_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -160,8 +167,7 @@ def test_saturation():
     x_rec = recover(x_data, lhs_scale, lhs_zero_point)
     y_rec = recover(y_data, rhs_scale, rhs_zero_point)
 
-    golden = generate_golden_output(x_rec, y_rec, output_scale,
-        output_zero_point)
+    golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point)
 
     intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
     op_res = intrp.evaluate(func)(x_data, y_data)
@@ -172,13 +178,16 @@ def test_saturation():
     lhs_scale = rhs_scale = 0.125
     output_scale = 0.25
 
-    z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=relay.const(lhs_scale, 'float32'),
-                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
-                         rhs_scale=relay.const(rhs_scale, 'float32'),
-                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
-                         output_scale=relay.const(output_scale, 'float32'),
-                         output_zero_point=relay.const(output_zero_point, 'int32'))
+    z = relay.qnn.op.mul(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(lhs_scale, "float32"),
+        lhs_zero_point=relay.const(lhs_zero_point, "int32"),
+        rhs_scale=relay.const(rhs_scale, "float32"),
+        rhs_zero_point=relay.const(rhs_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -191,8 +200,7 @@ def test_saturation():
     x_rec = recover(x_data, lhs_scale, lhs_zero_point)
     y_rec = recover(y_data, rhs_scale, rhs_zero_point)
 
-    golden = generate_golden_output(x_rec, y_rec, output_scale,
-        output_zero_point)
+    golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point)
 
     intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
     op_res = intrp.evaluate(func)(x_data, y_data)
@@ -204,13 +212,16 @@ def test_saturation():
     rhs_scale = 0.25
     output_scale = 0.125
 
-    z = relay.qnn.op.mul(lhs=x, rhs=y,
-                         lhs_scale=relay.const(lhs_scale, 'float32'),
-                         lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
-                         rhs_scale=relay.const(rhs_scale, 'float32'),
-                         rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
-                         output_scale=relay.const(output_scale, 'float32'),
-                         output_zero_point=relay.const(output_zero_point, 'int32'))
+    z = relay.qnn.op.mul(
+        lhs=x,
+        rhs=y,
+        lhs_scale=relay.const(lhs_scale, "float32"),
+        lhs_zero_point=relay.const(lhs_zero_point, "int32"),
+        rhs_scale=relay.const(rhs_scale, "float32"),
+        rhs_zero_point=relay.const(rhs_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
 
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
@@ -223,8 +234,7 @@ def test_saturation():
     x_rec = recover(x_data, lhs_scale, lhs_zero_point)
     y_rec = recover(y_data, rhs_scale, rhs_zero_point)
 
-    golden = generate_golden_output(x_rec, y_rec, output_scale,
-        output_zero_point)
+    golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point)
 
     intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
     op_res = intrp.evaluate(func)(x_data, y_data)
index a284e8b..a22c25f 100644 (file)
@@ -21,15 +21,19 @@ import numpy as np
 from tvm import relay
 from tvm.contrib import graph_runtime
 
+
 def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data):
     shape = in_data.shape
     input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
-    output_zero_point = relay.const(quant_args['out_zero_point'])
-    output_scale = relay.const(quant_args['out_scale'])
-    quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
-                                             output_zero_point=output_zero_point,
-                                             axis=axis,
-                                             out_dtype=out_dtype)
+    output_zero_point = relay.const(quant_args["out_zero_point"])
+    output_scale = relay.const(quant_args["out_scale"])
+    quantized_output = relay.qnn.op.quantize(
+        input_data,
+        output_scale=output_scale,
+        output_zero_point=output_zero_point,
+        axis=axis,
+        out_dtype=out_dtype,
+    )
     mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
     mod = tvm.IRModule.from_expr(mod)
     with tvm.transform.PassContext(opt_level=3):
@@ -42,51 +46,91 @@ def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_
         np.testing.assert_equal(res, verify_output_data)
         assert res.dtype == out_dtype
 
+
 def test_float32_to_uint8():
-    data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-        .astype('float32') \
-        .reshape((2,5))
-    output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
-        .astype('uint8') \
-        .reshape((2,5))
-    quant_args = {"out_zero_point":np.int32(127), "out_scale": np.float32(0.5)}
-    quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='uint8',
-                         in_data=data, verify_output_data=output)
+    data = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]).astype("uint8").reshape((2, 5))
+    quant_args = {"out_zero_point": np.int32(127), "out_scale": np.float32(0.5)}
+    quantize_test_driver(
+        in_dtype="float32",
+        quant_args=quant_args,
+        axis=-1,
+        out_dtype="uint8",
+        in_data=data,
+        verify_output_data=output,
+    )
+
 
 def test_float32_to_int8():
-    data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
-        .astype('float32') \
-        .reshape((2,5))
-    output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
-        .astype('int8') \
-        .reshape((2,5))
-    quant_args = {"out_zero_point":np.int32(-1), "out_scale":np.float32(0.5)}
-    quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='int8',
-                         in_data=data, verify_output_data=output)
+    data = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    output = (
+        np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127])
+        .astype("int8")
+        .reshape((2, 5))
+    )
+    quant_args = {"out_zero_point": np.int32(-1), "out_scale": np.float32(0.5)}
+    quantize_test_driver(
+        in_dtype="float32",
+        quant_args=quant_args,
+        axis=-1,
+        out_dtype="int8",
+        in_data=data,
+        verify_output_data=output,
+    )
+
 
 def test_channelwise_axis_0():
-    data = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
-        .astype('float32') \
-        .reshape((2,5))
-    output = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
-        .astype('uint8') \
-        .reshape((2,5))
-    quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
-                  "out_scale"      : np.array([0.5, 0.25]).astype('float32')}
+    data = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    output = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
+    quant_args = {
+        "out_zero_point": np.array([127, 123]).astype("int32"),
+        "out_scale": np.array([0.5, 0.25]).astype("float32"),
+    }
+
+    quantize_test_driver(
+        in_dtype="float32",
+        quant_args=quant_args,
+        axis=0,
+        out_dtype="uint8",
+        in_data=data,
+        verify_output_data=output,
+    )
 
-    quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=0, out_dtype='uint8',
-                         in_data=data, verify_output_data=output)
 
 def test_channelwise_axis_1():
-    data = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
-                        .astype('float32').reshape((2,5)))
-    output = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
-                          .astype('uint8').reshape((2,5)))
-    quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
-                  "out_scale"      : np.array([0.5, 0.25]).astype('float32')}
+    data = np.transpose(
+        np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    output = np.transpose(
+        np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
+    )
+    quant_args = {
+        "out_zero_point": np.array([127, 123]).astype("int32"),
+        "out_scale": np.array([0.5, 0.25]).astype("float32"),
+    }
 
-    quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=1, out_dtype='uint8',
-                         in_data=data, verify_output_data=output)
+    quantize_test_driver(
+        in_dtype="float32",
+        quant_args=quant_args,
+        axis=1,
+        out_dtype="uint8",
+        in_data=data,
+        verify_output_data=output,
+    )
 
 
 if __name__ == "__main__":
index fb52b30..f152a4e 100644 (file)
@@ -23,79 +23,95 @@ from tvm.contrib import graph_runtime
 
 roundings = ["UPWARD", "TONEAREST"]
 
+
 def verify(mod, goldens):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, params = relay.build(mod, "llvm", params=None)
         golden_data, golden_output = goldens
         rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
-        rt_mod.set_input("quantized_data",golden_data)
+        rt_mod.set_input("quantized_data", golden_data)
         rt_mod.set_input(**params)
         rt_mod.run()
         res = rt_mod.get_output(0).asnumpy()
         np.testing.assert_equal(res, golden_output)
 
-def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
-        input_zero_point=0, output_zero_point=0, rounding="TONEAREST",
-        axis=0):
-    quantized_data = relay.var("quantized_data", shape=data_shape,
-            dtype=data_dtype)
+
+def get_mod(
+    data_shape,
+    data_dtype,
+    out_dtype,
+    input_scale,
+    output_scale,
+    input_zero_point=0,
+    output_zero_point=0,
+    rounding="TONEAREST",
+    axis=0,
+):
+    quantized_data = relay.var("quantized_data", shape=data_shape, dtype=data_dtype)
     if isinstance(input_scale, float):
-        input_scale_expr = relay.const(input_scale, 'float32')
+        input_scale_expr = relay.const(input_scale, "float32")
     else:
-        input_scale_expr = relay.const(np.array(input_scale).astype('float32'))
+        input_scale_expr = relay.const(np.array(input_scale).astype("float32"))
 
     if isinstance(input_zero_point, float):
-        input_zero_point_expr = relay.const(input_zero_point, 'int32')
+        input_zero_point_expr = relay.const(input_zero_point, "int32")
     else:
-        input_zero_point_expr = relay.const(np.array(input_zero_point).astype('int32'))
+        input_zero_point_expr = relay.const(np.array(input_zero_point).astype("int32"))
 
     mod = relay.qnn.op.requantize(
-            quantized_data,
-            input_scale=input_scale_expr,
-            input_zero_point=input_zero_point_expr,
-            output_scale=relay.const(output_scale, 'float32'),
-            output_zero_point=relay.const(output_zero_point, 'int32'),
-            axis=axis,
-            rounding=rounding,
-            out_dtype=out_dtype)
+        quantized_data,
+        input_scale=input_scale_expr,
+        input_zero_point=input_zero_point_expr,
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+        axis=axis,
+        rounding=rounding,
+        out_dtype=out_dtype,
+    )
 
     mod = relay.Function(relay.analysis.free_vars(mod), mod)
     mod = tvm.IRModule.from_expr(mod)
     return mod
 
+
 def test_same_scale():
     # Have same scales, everything within range
-    golden_data = np.arange(-100, 100, 1).astype('int32')
+    golden_data = np.arange(-100, 100, 1).astype("int32")
     golden_output = golden_data
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(200, ),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=0.5,
-                      output_scale=0.5,
-                      rounding=rounding)
-        assert 'right_shift' not in mod.astext()
+        mod = get_mod(
+            data_shape=(200,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=0.5,
+            output_scale=0.5,
+            rounding=rounding,
+        )
+        assert "right_shift" not in mod.astext()
         verify(mod, (golden_data, golden_output))
 
+
 def test_downscale():
     for rounding in roundings:
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype='int8',
-                      input_scale=1,
-                      output_scale=16,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=1,
+            output_scale=16,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 8 corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_data = np.arange(0, 32, 1).astype("int32")
         golden_output = np.repeat([0, 1, 2], [8, 16, 8])
         verify(mod, (golden_data, golden_output))
 
         # Try negative values
         # -8 corresponds to -0.5. For UPWARD, this is 0
-        golden_data = np.arange(0, -32, -1).astype('int32')
+        golden_data = np.arange(0, -32, -1).astype("int32")
         if rounding == "UPWARD":
             golden_output = np.repeat([0, -1, -2], [9, 16, 7])
         else:
@@ -103,124 +119,159 @@ def test_downscale():
         verify(mod, (golden_data, golden_output))
 
         # Try a different scale
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=1,
-                      output_scale=4,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=1,
+            output_scale=4,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 2I corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
-        golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
-                                  [2, 4, 4, 4, 4, 4, 4, 4, 2])
+        golden_data = np.arange(0, 32, 1).astype("int32")
+        golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2])
         verify(mod, (golden_data, golden_output))
 
         # Try negative values
         # -8 corresponds to -0.5. For UPWARD, this is 0
-        golden_data = np.arange(0, -32, -1).astype('int32')
+        golden_data = np.arange(0, -32, -1).astype("int32")
         if rounding == "UPWARD":
-            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
-                                      [3, 4, 4, 4, 4, 4, 4, 4, 1])
+            golden_output = np.repeat(
+                [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1]
+            )
         else:
-            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
-                                      [2, 4, 4, 4, 4, 4, 4, 4, 2])
+            golden_output = np.repeat(
+                [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2]
+            )
         verify(mod, (golden_data, golden_output))
 
         # Try uint8 out_dtype
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype='uint8',
-                      input_scale=1,
-                      output_scale=16,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="uint8",
+            input_scale=1,
+            output_scale=16,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 8 corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_data = np.arange(0, 32, 1).astype("int32")
         golden_output = np.repeat([0, 1, 2], [8, 16, 8])
         verify(mod, (golden_data, golden_output))
 
         # Try uint8 in_dtyope and uint8 out_dtype
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='uint8',
-                      out_dtype='uint8',
-                      input_scale=1,
-                      output_scale=16,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="uint8",
+            out_dtype="uint8",
+            input_scale=1,
+            output_scale=16,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 8 corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_data = np.arange(0, 32, 1).astype("int32")
         golden_output = np.repeat([0, 1, 2], [8, 16, 8])
         verify(mod, (golden_data, golden_output))
 
+
 def test_upscale():
     for rounding in roundings:
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=2,
-                      output_scale=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=2,
+            output_scale=1,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 8 corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_data = np.arange(0, 32, 1).astype("int32")
         golden_output = np.multiply(2, golden_data)
         verify(mod, (golden_data, golden_output))
 
         # Try negative values
         # -8 corresponds to -0.5. For UPWARD, this is 0
-        golden_data = np.arange(0, -32, -1).astype('int32')
+        golden_data = np.arange(0, -32, -1).astype("int32")
         golden_output = np.multiply(2, golden_data)
         verify(mod, (golden_data, golden_output))
 
+
 def test_saturation():
     for rounding in roundings:
-        mod = get_mod(data_shape=(16, ),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=0.5,
-                      output_scale=0.5,
-                      rounding=rounding)
-        golden_data = np.arange(0, 16, 1).astype('int32')
+        mod = get_mod(
+            data_shape=(16,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=0.5,
+            output_scale=0.5,
+            rounding=rounding,
+        )
+        golden_data = np.arange(0, 16, 1).astype("int32")
         golden_data = np.add(120, golden_data)
-        output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
-                           127, 127, 127, 127, 127, 127, 127, 127])
+        output = np.array(
+            [120, 121, 122, 123, 124, 125, 126, 127, 127, 127, 127, 127, 127, 127, 127, 127]
+        )
         golden_output = output
         verify(mod, (golden_data, golden_output))
 
         # Try negative numbers
-        golden_data = np.arange(0, -16, -1).astype('int32')
+        golden_data = np.arange(0, -16, -1).astype("int32")
         golden_data = np.add(-120, golden_data)
-        output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
-                           -128, -128, -128, -128, -128, -128, -128, -128])
+        output = np.array(
+            [
+                -120,
+                -121,
+                -122,
+                -123,
+                -124,
+                -125,
+                -126,
+                -127,
+                -128,
+                -128,
+                -128,
+                -128,
+                -128,
+                -128,
+                -128,
+                -128,
+            ]
+        )
         golden_output = output
         verify(mod, (golden_data, golden_output))
 
+
 def test_zero_point():
     # Output zero point
     for rounding in roundings:
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype='int8',
-                      input_scale=1,
-                      output_scale=16,
-                      output_zero_point=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=1,
+            output_scale=16,
+            output_zero_point=1,
+            rounding=rounding,
+        )
 
         # Try positive values
         # 8 corresponds to 0.5, resulting in 1
-        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_data = np.arange(0, 32, 1).astype("int32")
         golden_output = np.repeat([0, 1, 2], [8, 16, 8])
         golden_output = np.add(1, golden_output)
         verify(mod, (golden_data, golden_output))
 
         # Try negative values
         # -8 corresponds to -0.5. For UPWARD, this is 0
-        golden_data = np.arange(-32, -64, -1).astype('int32')
+        golden_data = np.arange(-32, -64, -1).astype("int32")
         if rounding == "UPWARD":
             golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
         else:
@@ -230,22 +281,24 @@ def test_zero_point():
 
     # Input zero point
     for rounding in roundings:
-        mod = get_mod(data_shape=(32, ),
-                      data_dtype='int32',
-                      out_dtype='int8',
-                      input_scale=1,
-                      output_scale=16,
-                      input_zero_point=16,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=1,
+            output_scale=16,
+            input_zero_point=16,
+            rounding=rounding,
+        )
 
         # Try positive values
-        golden_data = np.arange(32, 64, 1).astype('int32')
+        golden_data = np.arange(32, 64, 1).astype("int32")
         golden_output = np.repeat([2, 3, 4], [8, 16, 8])
         golden_output = np.subtract(golden_output, 1)
         verify(mod, (golden_data, golden_output))
 
         # Try negative values
-        golden_data = np.arange(-32, -64, -1).astype('int32')
+        golden_data = np.arange(-32, -64, -1).astype("int32")
         if rounding == "UPWARD":
             golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
         else:
@@ -253,77 +306,90 @@ def test_zero_point():
         golden_output = np.subtract(golden_output, 1)
         verify(mod, (golden_data, golden_output))
 
+
 def test_per_channel_same_scale():
     # Have same scales, everything within range
-    golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
+    golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2))
     golden_output = golden_data
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(5, 2),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=[0.5, 0.5],
-                      output_scale=0.5,
-                      axis=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(5, 2),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=[0.5, 0.5],
+            output_scale=0.5,
+            axis=1,
+            rounding=rounding,
+        )
         verify(mod, (golden_data, golden_output))
 
     # Change axis
-    golden_data = np.arange(-10, 10, 1).astype('int32').reshape((2,2,5))
+    golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5))
     golden_output = golden_data
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(2, 2, 5),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=[0.5, 0.5],
-                      output_scale=0.5,
-                      axis=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(2, 2, 5),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=[0.5, 0.5],
+            output_scale=0.5,
+            axis=1,
+            rounding=rounding,
+        )
         verify(mod, (golden_data, golden_output))
 
+
 def test_per_channel_different_scale():
     # Have same scales, everything within range
-    golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
+    golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2))
     golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2))
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(5, 2),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=[0.5, 0.25],
-                      output_scale=0.5,
-                      axis=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(5, 2),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=[0.5, 0.25],
+            output_scale=0.5,
+            axis=1,
+            rounding=rounding,
+        )
         verify(mod, (golden_data, golden_output))
 
     # Change axis
-    golden_data = np.arange(-20, 20, 2).astype('int32').reshape((2,2,5))
-    golden_output = np.array([-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7,
-        8, 9]).reshape((2, 2, 5))
+    golden_data = np.arange(-20, 20, 2).astype("int32").reshape((2, 2, 5))
+    golden_output = np.array(
+        [-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7, 8, 9]
+    ).reshape((2, 2, 5))
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(2, 2, 5),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=[0.5, 0.25],
-                      output_scale=0.5,
-                      axis=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(2, 2, 5),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=[0.5, 0.25],
+            output_scale=0.5,
+            axis=1,
+            rounding=rounding,
+        )
         verify(mod, (golden_data, golden_output))
 
     # Have input scale > output scale
-    golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
+    golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2))
     golden_output = np.array([-10, -2, -6, -1, -2, 0, 2, 1, 6, 2]).reshape((5, 2))
 
     for rounding in roundings:
-        mod = get_mod(data_shape=(5, 2),
-                      data_dtype='int32',
-                      out_dtype="int8",
-                      input_scale=[1.0, 0.25],
-                      output_scale=0.5,
-                      axis=1,
-                      rounding=rounding)
+        mod = get_mod(
+            data_shape=(5, 2),
+            data_dtype="int32",
+            out_dtype="int8",
+            input_scale=[1.0, 0.25],
+            output_scale=0.5,
+            axis=1,
+            rounding=rounding,
+        )
         verify(mod, (golden_data, golden_output))
 
 
index be39803..6a1501c 100644 (file)
@@ -20,27 +20,29 @@ import numpy as np
 from tvm import relay
 
 
-def qnn_subtract_driver(x_datas, y_datas, golden_outputs,
-                        scale_and_zp, data_dtype='uint8'):
+def qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp, data_dtype="uint8"):
     # all x, y and golden outputs should be of the same length
     assert len(x_datas) == len(y_datas)
     assert len(y_datas) == len(golden_outputs)
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.var("y", shape=(1, 4), dtype=data_dtype)
-    lhs_scale = relay.const(scale_and_zp['lhs_scale'], 'float32')
-    lhs_zp = relay.const(scale_and_zp['lhs_zp'], 'int32')
-    rhs_scale = relay.const(scale_and_zp['rhs_scale'], 'float32')
-    rhs_zp = relay.const(scale_and_zp['rhs_zp'], 'int32')
-    output_scale = relay.const(scale_and_zp['output_scale'], 'float32')
-    output_zp = relay.const(scale_and_zp['output_zp'], 'int32')
-    z = relay.qnn.op.subtract(lhs=x, rhs=y,
-                              lhs_scale=lhs_scale,
-                              lhs_zero_point=lhs_zp,
-                              rhs_scale=rhs_scale,
-                              rhs_zero_point=rhs_zp,
-                              output_scale=output_scale,
-                              output_zero_point=output_zp)
+    lhs_scale = relay.const(scale_and_zp["lhs_scale"], "float32")
+    lhs_zp = relay.const(scale_and_zp["lhs_zp"], "int32")
+    rhs_scale = relay.const(scale_and_zp["rhs_scale"], "float32")
+    rhs_zp = relay.const(scale_and_zp["rhs_zp"], "int32")
+    output_scale = relay.const(scale_and_zp["output_scale"], "float32")
+    output_zp = relay.const(scale_and_zp["output_zp"], "int32")
+    z = relay.qnn.op.subtract(
+        lhs=x,
+        rhs=y,
+        lhs_scale=lhs_scale,
+        lhs_zero_point=lhs_zp,
+        rhs_scale=rhs_scale,
+        rhs_zero_point=rhs_zp,
+        output_scale=output_scale,
+        output_zero_point=output_zp,
+    )
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.qnn.transform.CanonicalizeOps()(mod)
@@ -55,82 +57,104 @@ def qnn_subtract_driver(x_datas, y_datas, golden_outputs,
 
 
 def test_tflite_same_io_qnn_params():
-    scale_and_zp = {'lhs_scale': 0.00784314,
-                    'lhs_zp': 127,
-                    'rhs_scale': 0.00784314,
-                    'rhs_zp': 127,
-                    'output_scale': 0.00784314,
-                    'output_zp': 127}
-    x_datas = [np.array((140, 153, 165, 178)).reshape((1, 4)),
-               np.array((25, 153, 178, 216)).reshape((1, 4)),
-               np.array((25, 153, 216, 165)).reshape((1, 4))]
-    y_datas = [np.array((204, 178, 165, 140)).reshape((1, 4)),
-               np.array((204, 178, 191, 25)).reshape((1, 4)),
-               np.array((204, 178, 25, 191)).reshape((1, 4))]
-    golden_outputs = [np.array((63, 102, 127, 165)).reshape((1, 4)),
-                      np.array((0, 102, 114, 255)).reshape((1, 4)),
-                      np.array((0, 102, 255, 101)).reshape((1, 4))]
+    scale_and_zp = {
+        "lhs_scale": 0.00784314,
+        "lhs_zp": 127,
+        "rhs_scale": 0.00784314,
+        "rhs_zp": 127,
+        "output_scale": 0.00784314,
+        "output_zp": 127,
+    }
+    x_datas = [
+        np.array((140, 153, 165, 178)).reshape((1, 4)),
+        np.array((25, 153, 178, 216)).reshape((1, 4)),
+        np.array((25, 153, 216, 165)).reshape((1, 4)),
+    ]
+    y_datas = [
+        np.array((204, 178, 165, 140)).reshape((1, 4)),
+        np.array((204, 178, 191, 25)).reshape((1, 4)),
+        np.array((204, 178, 25, 191)).reshape((1, 4)),
+    ]
+    golden_outputs = [
+        np.array((63, 102, 127, 165)).reshape((1, 4)),
+        np.array((0, 102, 114, 255)).reshape((1, 4)),
+        np.array((0, 102, 255, 101)).reshape((1, 4)),
+    ]
     qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp)
 
 
 def test_tflite_different_io_qnn_params():
-    scale_and_zp = {'lhs_scale': 0.0156863,
-                    'lhs_zp': 127,
-                    'rhs_scale': 0.0117647,
-                    'rhs_zp': 85,
-                    'output_scale': 0.0235294,
-                    'output_zp': 128}
-    x_datas = [np.array((76, 140, 153, 172)).reshape((1, 4)),
-               np.array((133, 140, 146, 153)).reshape((1, 4)),
-               np.array((76, 140, 172, 146)).reshape((1, 4))]
-    y_datas = [np.array((136, 119, 128, 17)).reshape((1, 4)),
-               np.array((136, 119, 111, 94)).reshape((1, 4)),
-               np.array((136, 119, 17, 128)).reshape((1, 4))]
-    golden_outputs = [np.array((68, 120, 123, 192)).reshape((1, 4)),
-                      np.array((106, 120, 128, 140)).reshape((1, 4)),
-                      np.array((68, 120, 192, 119)).reshape((1, 4))]
+    scale_and_zp = {
+        "lhs_scale": 0.0156863,
+        "lhs_zp": 127,
+        "rhs_scale": 0.0117647,
+        "rhs_zp": 85,
+        "output_scale": 0.0235294,
+        "output_zp": 128,
+    }
+    x_datas = [
+        np.array((76, 140, 153, 172)).reshape((1, 4)),
+        np.array((133, 140, 146, 153)).reshape((1, 4)),
+        np.array((76, 140, 172, 146)).reshape((1, 4)),
+    ]
+    y_datas = [
+        np.array((136, 119, 128, 17)).reshape((1, 4)),
+        np.array((136, 119, 111, 94)).reshape((1, 4)),
+        np.array((136, 119, 17, 128)).reshape((1, 4)),
+    ]
+    golden_outputs = [
+        np.array((68, 120, 123, 192)).reshape((1, 4)),
+        np.array((106, 120, 128, 140)).reshape((1, 4)),
+        np.array((68, 120, 192, 119)).reshape((1, 4)),
+    ]
     qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp)
 
 
 def test_saturation():
     # Same params
-    scale_and_zp = {'lhs_scale': 0.125,
-                    'lhs_zp': 0,
-                    'rhs_scale': 0.125,
-                    'rhs_zp': 0,
-                    'output_scale': 0.125,
-                    'output_zp': 0}
+    scale_and_zp = {
+        "lhs_scale": 0.125,
+        "lhs_zp": 0,
+        "rhs_scale": 0.125,
+        "rhs_zp": 0,
+        "output_scale": 0.125,
+        "output_zp": 0,
+    }
     x_data = [np.array((255, 1, 1, 0)).reshape((1, 4))]
     y_data = [np.array((255, 255, 128, 0)).reshape((1, 4))]
     golden_output = [np.array((0, 0, 0, 0)).reshape((1, 4))]
     qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
 
     # Same params, different scale
-    scale_and_zp = {'lhs_scale': 0.125,
-                    'lhs_zp': 0,
-                    'rhs_scale': 0.125,
-                    'rhs_zp': 0,
-                    'output_scale': 0.25,
-                    'output_zp': 0}
+    scale_and_zp = {
+        "lhs_scale": 0.125,
+        "lhs_zp": 0,
+        "rhs_scale": 0.125,
+        "rhs_zp": 0,
+        "output_scale": 0.25,
+        "output_zp": 0,
+    }
     x_data = [np.array((255, 1, 200, 0)).reshape((1, 4))]
     y_data = [np.array((255, 255, 127, 0)).reshape((1, 4))]
     golden_output = [np.array((0, 0, 36, 0)).reshape((1, 4))]
     qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
 
     # All params different
-    scale_and_zp = {'lhs_scale': 0.5,
-                    'lhs_zp': 0,
-                    'rhs_scale': 0.25,
-                    'rhs_zp': 0,
-                    'output_scale': 0.125,
-                    'output_zp': 0}
+    scale_and_zp = {
+        "lhs_scale": 0.5,
+        "lhs_zp": 0,
+        "rhs_scale": 0.25,
+        "rhs_zp": 0,
+        "output_scale": 0.125,
+        "output_zp": 0,
+    }
     x_data = [np.array((255, 0, 1, 0)).reshape((1, 4))]
     y_data = [np.array((0, 128, 64, 0)).reshape((1, 4))]
     golden_output = [np.array((255, 0, 0, 0)).reshape((1, 4))]
     qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_tflite_same_io_qnn_params()
     test_tflite_different_io_qnn_params()
     test_saturation()
index 497a818..68e4b50 100644 (file)
@@ -43,14 +43,14 @@ def test_ndarray_reflection():
     # Make two `NDArrayWrapper`s that point to the same underlying array.
     np_array = np.random.uniform(size=(10, 2)).astype("float32")
     tvm_array = tvm.nd.array(np_array)
-    param_dict = {'x': tvm_array, 'y': tvm_array}
-    assert param_dict['x'].same_as(param_dict['y'])
+    param_dict = {"x": tvm_array, "y": tvm_array}
+    assert param_dict["x"].same_as(param_dict["y"])
     # Serialize then deserialize `param_dict`.
     deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict))
     # Make sure the data matches the original data and `x` and `y` contain the same data.
-    np.testing.assert_equal(deser_param_dict['x'].asnumpy(), tvm_array.asnumpy())
+    np.testing.assert_equal(deser_param_dict["x"].asnumpy(), tvm_array.asnumpy())
     # Make sure `x` and `y` contain the same data.
-    np.testing.assert_equal(deser_param_dict['x'].asnumpy(), deser_param_dict['y'].asnumpy())
+    np.testing.assert_equal(deser_param_dict["x"].asnumpy(), deser_param_dict["y"].asnumpy())
 
 
 def test_bigendian_rpc_param():
@@ -61,13 +61,13 @@ def test_bigendian_rpc_param():
         return
 
     def verify_graph_runtime(remote, target, shape, dtype):
-        x = relay.var('x')
+        x = relay.var("x")
         y = relay.const(1)
         z = relay.add(x, y)
         func = relay.Function([x], z)
 
         x_in = np.ones(shape).astype(dtype)
-        params = {'x': x_in}
+        params = {"x": x_in}
         graph, lib, params = relay.build(func, target=target, params=params)
 
         temp = util.tempdir()
index 2a1b983..7b242c4 100644 (file)
@@ -25,6 +25,7 @@ from tvm.relay.testing import run_infer_type
 import numpy as np
 import tvm.testing
 
+
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
@@ -34,15 +35,14 @@ def run_opt_pass(expr, passes):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
+
 def test_alter_op():
     """Test directly replacing an operator with a new one"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -54,11 +54,14 @@ def test_alter_op():
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            relay.multiply(weight, relay.const(2.0, "float32")),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -73,6 +76,7 @@ def test_alter_op():
 
 def test_alter_return_none():
     """Test doing nothing by returning 'None' """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         y = relay.nn.global_max_pool2d(x)
@@ -91,13 +95,14 @@ def test_alter_return_none():
         b = run_opt_pass(before(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
-    assert(called[0])
+    assert called[0]
 
 
 def test_alter_layout():
     """Test alternating the layout of a conv2d.
     The layout of broadcast operators and the weight should be changed accordingly.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         bias = relay.var("bias")
@@ -108,7 +113,7 @@ def test_alter_layout():
         y = relay.Tuple([y])[0]
         y = relay.nn.relu(y)
         y = relay.nn.max_pool2d(y, pool_size=(2, 2))
-        y = relay.cast(y, 'int32')
+        y = relay.cast(y, "int32")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(analysis.free_vars(y), y)
         return y
@@ -116,11 +121,10 @@ def test_alter_layout():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
-        new_attrs['kernel_layout'] = 'OIHW16i'
+        new_attrs["data_layout"] = "NCHW16c"
+        new_attrs["kernel_layout"] = "OIHW16i"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
         bias = relay.var("bias", shape=(64,))
@@ -128,12 +132,15 @@ def test_alter_layout():
 
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
         w = relay.layout_transform(weight, "OIHW", "OIHW16i")
-        y = relay.nn.conv2d(y, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            kernel_layout="OIHW16i",
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y,
+            w,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            kernel_layout="OIHW16i",
+            data_layout="NCHW16c",
+        )
         b = relay.expand_dims(bias, axis=1, num_newaxis=2)
         b = relay.expand_dims(b, axis=0, num_newaxis=1)
         b = relay.layout_transform(b, "NCHW", "NCHW16c")
@@ -141,7 +148,7 @@ def test_alter_layout():
 
         y = relay.nn.relu(y)
         y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c")
-        y = relay.cast(y, 'int32')
+        y = relay.cast(y, "int32")
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(analysis.free_vars(y), y)
@@ -149,16 +156,17 @@ def test_alter_layout():
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_alter_layout_lrn():
     """Test alternating the layout of a conv2d.
     The layout of broadcast operators and the weight should be changed accordingly.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         bias = relay.var("bias")
@@ -172,11 +180,10 @@ def test_alter_layout_lrn():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
-        new_attrs['kernel_layout'] = 'OIHW16i'
+        new_attrs["data_layout"] = "NCHW16c"
+        new_attrs["kernel_layout"] = "OIHW16i"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
         bias = relay.var("bias", shape=(64,))
@@ -184,12 +191,15 @@ def test_alter_layout_lrn():
 
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
         w = relay.layout_transform(weight, "OIHW", "OIHW16i")
-        y = relay.nn.conv2d(y, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            kernel_layout="OIHW16i",
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y,
+            w,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            kernel_layout="OIHW16i",
+            data_layout="NCHW16c",
+        )
         y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c")
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.nn.lrn(y)
@@ -198,32 +208,25 @@ def test_alter_layout_lrn():
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
-
 def test_alter_layout_dual_path():
     """
     Test alternating the layout with two outputs.
     One path continues to use the new layout while one path fall backs to old layout.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1))
+        y1 = relay.nn.conv2d(y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.batch_flatten(y)
         ret = relay.Tuple([y1, y2])
@@ -233,26 +236,21 @@ def test_alter_layout_dual_path():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y = relay.nn.relu(y)
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NCHW16c')
+        y1 = relay.nn.conv2d(
+            y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y1 = relay.nn.relu(y1)
         y1 = relay.layout_transform(y1, "NCHW16c", "NCHW")
         y2 = relay.layout_transform(y, "NCHW16c", "NCHW")
@@ -268,23 +266,20 @@ def test_alter_layout_dual_path():
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_alter_layout_resnet():
     """Test alternating the layout of a residual block
     This also tests the elimination of duplicated transformation.
     If a same transformation applies to a same node twice, only one transformation will be created.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y2 = relay.nn.conv2d(x, weight2,
-                             channels=32,
-                             kernel_size=(1, 1))
+        y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1))
         y2 = relay.nn.relu(y2)
         y = y + y2
         y = relay.nn.global_max_pool2d(y)
@@ -293,25 +288,19 @@ def test_alter_layout_resnet():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
         x = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y = relay.nn.relu(y)
-        y2 = relay.nn.conv2d(x, weight2,
-                             channels=32,
-                             kernel_size=(1, 1),
-                             data_layout='NCHW16c')
+        y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1), data_layout="NCHW16c")
         y2 = relay.nn.relu(y2)
         y = y + y2
         y = relay.nn.global_max_pool2d(y, layout="NCHW16c")
@@ -328,21 +317,22 @@ def test_alter_layout_resnet():
 
 def test_alter_layout_broadcast_op():
     """Test boradcast operators """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         bias = relay.var("bias", shape=(64,))
         scale = relay.var("scale", shape=(64, 1, 1))
         weight = relay.var("weight")
         y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
-        y = relay.nn.bias_add(y, bias) # test broadcasting to lhs
-        y = relay.multiply(scale, y)         # test broadcasting to rhs
+        y = relay.nn.bias_add(y, bias)  # test broadcasting to lhs
+        y = relay.multiply(scale, y)  # test broadcasting to rhs
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
@@ -356,18 +346,18 @@ def test_alter_layout_broadcast_op():
         bias = relay.layout_transform(bias, "NCHW", "NCHW16c")
         scale = relay.expand_dims(scale, 0, 1)
         scale = relay.layout_transform(scale, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout="NCHW16c")
-        y = relay.add(y, bias)          # test broadcasting to lhs
-        y = relay.multiply(scale, y)      # test broadcasting to rhs
+        y = relay.nn.conv2d(
+            x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        y = relay.add(y, bias)  # test broadcasting to lhs
+        y = relay.multiply(scale, y)  # test broadcasting to rhs
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -377,17 +367,15 @@ def test_alter_layout_broadcast_scalar_op():
     """Test alternating the layout of a conv2d.
     The layout of broadcast operators and the weight should be changed accordingly.
     """
+
     def before():
         x = relay.var("x", shape=(1, 500, 500, 64))
-        kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
+        kernel = relay.var("kernel", shape=(3, 3, 64, 64), dtype="float32")
         bias = relay.var("bias", shape=(64,))
-        multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
-        multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')
+        multiplier1 = relay.var("multiplier1", shape=(1,), dtype="float32")
+        multiplier2 = relay.var("multiplier2", shape=(1, 1), dtype="float32")
 
-        y = relay.nn.conv2d(x, kernel,
-                            data_layout='NHWC',
-                            kernel_layout="HWIO",
-                            kernel_size=(3, 3))
+        y = relay.nn.conv2d(x, kernel, data_layout="NHWC", kernel_layout="HWIO", kernel_size=(3, 3))
         y = relay.add(bias, y)
         y = relay.nn.relu(y)
 
@@ -399,24 +387,23 @@ def test_alter_layout_broadcast_scalar_op():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
         x = relay.var("x", shape=(1, 500, 500, 64))
-        kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
+        kernel = relay.var("kernel", shape=(3, 3, 64, 64), dtype="float32")
         bias = relay.var("bias", shape=(64,))
-        multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
-        multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')
+        multiplier1 = relay.var("multiplier1", shape=(1,), dtype="float32")
+        multiplier2 = relay.var("multiplier2", shape=(1, 1), dtype="float32")
 
         b = relay.expand_dims(bias, axis=0, num_newaxis=3)
         b = relay.layout_transform(b, "NHWC", "NCHW16c")
 
         y = relay.layout_transform(x, "NHWC", "NCHW16c")
-        y = relay.nn.conv2d(y, kernel,
-                            data_layout='NCHW16c',
-                            kernel_layout="HWIO",
-                            kernel_size=(3, 3))
+        y = relay.nn.conv2d(
+            y, kernel, data_layout="NCHW16c", kernel_layout="HWIO", kernel_size=(3, 3)
+        )
 
         y = relay.add(b, y)
         y = relay.nn.relu(y)
@@ -429,8 +416,7 @@ def test_alter_layout_broadcast_scalar_op():
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -440,6 +426,7 @@ def test_alter_layout_scalar():
     """Test alternating the layout of a conv2d.
     The layout of broadcast operators and the weight should be changed accordingly.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight = relay.var("weight")
@@ -451,7 +438,7 @@ def test_alter_layout_scalar():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
@@ -459,11 +446,9 @@ def test_alter_layout_scalar():
         w = relay.var("weight")
 
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, w, channels=64, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y = relay.add(y, relay.const(1.0, "float32"))
 
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
@@ -472,8 +457,7 @@ def test_alter_layout_scalar():
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -481,45 +465,35 @@ def test_alter_layout_scalar():
 
 def test_alter_layout_concatenate():
     """ NCHW, NHWC and corner case concatenate layout transform."""
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     # NCHW layout transformation.
     def before_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1))
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
+        y1 = relay.nn.conv2d(y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.concatenate([y, y1], axis=1)
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NCHW16c')
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        y1 = relay.nn.conv2d(
+            y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.concatenate([y, y1], axis=1)
         ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -535,37 +509,29 @@ def test_alter_layout_concatenate():
     # NHWC layout transformation.
     def before_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC')
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NHWC')
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
+        y = relay.nn.conv2d(
+            x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
+        )
+        y1 = relay.nn.conv2d(
+            y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
+        )
         ret = relay.concatenate([y, y1], axis=3)
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
-        weight2 = relay.var('weight2')
+        weight1 = relay.var("weight1")
+        weight2 = relay.var("weight2")
         y = relay.layout_transform(x, "NHWC", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NCHW16c')
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        y1 = relay.nn.conv2d(
+            y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.concatenate([y, y1], axis=1)
         ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -581,9 +547,10 @@ def test_alter_layout_concatenate():
 
 def test_alter_layout_nchw_upsamping_op():
     """Test upsamping operators """
+
     def before():
         x = relay.var("x", shape=(1, 32, 28, 28))
-        weight = relay.var('weight', shape=(32, 32, 3, 3))
+        weight = relay.var("weight", shape=(32, 32, 3, 3))
         y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.upsampling(y, scale_h=2, scale_w=2)
         y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
@@ -593,17 +560,18 @@ def test_alter_layout_nchw_upsamping_op():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
         x = relay.var("x", shape=(1, 32, 28, 28))
         weight = relay.var("weight")
         x = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y = relay.nn.upsampling(y, scale_h=2, scale_w=2, layout="NCHW16c")
-        y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
+        y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout="NCHW16c")
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.Function(analysis.free_vars(y), y)
         return y
@@ -619,21 +587,19 @@ def test_alter_layout_nchw_upsamping_op():
 @tvm.testing.uses_gpu
 def test_alter_layout_strided_slice():
     """Test rewriting strided_slice during alter_iop_layout"""
+
     def before():
         x = relay.var("x", shape=(1, 32, 28, 28))
-        weight = relay.var('weight', shape=(32, 32, 3, 3))
+        weight = relay.var("weight", shape=(32, 32, 3, 3))
         y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
-        y = relay.strided_slice(y,
-                                begin=[0, 16],
-                                end=[1, 33],
-                                strides=[1, 1])
+        y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1])
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW4c'
+        new_attrs["data_layout"] = "NCHW4c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
@@ -641,13 +607,11 @@ def test_alter_layout_strided_slice():
         weight = relay.var("weight", shape=(32, 32, 3, 3))
         weight = relay.layout_transform(weight, "OIHW", "OIHW4i4o")
         x = relay.layout_transform(x, "NCHW", "NCHW4c")
-        y = relay.op.nn.contrib_conv2d_nchwc(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
-                                             data_layout="NCHW4c")
+        y = relay.op.nn.contrib_conv2d_nchwc(
+            x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW4c"
+        )
 
-        y = relay.strided_slice(y,
-                                begin=[0, 4],
-                                end=[1, 21],
-                                strides=[1, 1])
+        y = relay.strided_slice(y, begin=[0, 4], end=[1, 21], strides=[1, 1])
 
         y = relay.layout_transform(y, "NCHW4c", "NCHW")
         y = relay.Function(analysis.free_vars(y), y)
@@ -660,8 +624,8 @@ def test_alter_layout_strided_slice():
     # Verify inference result
     mod_before = tvm.IRModule()
     mod_new = tvm.IRModule()
-    mod_before['main'] = a
-    mod_new['main'] = b
+    mod_before["main"] = a
+    mod_new["main"] = b
     with relay.build_config(opt_level=3):
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug", "vm"]:
@@ -671,11 +635,14 @@ def test_alter_layout_strided_slice():
                 np_weight = np.random.uniform(size=(32, 32, 3, 3)).astype("float32")
                 result_before = ex_before.evaluate()(np_data, np_weight)
                 result_new = ex_new.evaluate()(np_data, np_weight)
-                tvm.testing.assert_allclose(result_before.asnumpy(), result_new.asnumpy(), rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(
+                    result_before.asnumpy(), result_new.asnumpy(), rtol=1e-5, atol=1e-5
+                )
 
 
 def test_alter_layout_depthwise_conv2d():
     """Test depthwise_conv2d operator"""
+
     def before():
         x = relay.var("x", shape=(1, 32, 56, 56))
         w = relay.var("w", shape=(32, 1, 3, 3))
@@ -684,33 +651,42 @@ def test_alter_layout_depthwise_conv2d():
         return y
 
     from tvm import topi
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         with tvm.target.Target("llvm"):
             return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
-
     def expected():
         x = relay.var("x", shape=(1, 32, 56, 56))
         w = relay.var("w", shape=(32, 1, 3, 3))
         x = relay.layout_transform(x, "NCHW", "NCHW8c")
         w = relay.layout_transform(w, "OIHW", "OIHW1i8o")
-        y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1, 1, 1), channels=32, kernel_size=(3, 3),
-                                                    groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
-                                                    out_layout="NCHW8c")
+        y = relay.nn.contrib_depthwise_conv2d_nchwc(
+            x,
+            w,
+            padding=(1, 1, 1, 1),
+            channels=32,
+            kernel_size=(3, 3),
+            groups=32,
+            data_layout="NCHW8c",
+            kernel_layout="OIHW1i8o",
+            out_layout="NCHW8c",
+        )
         y = relay.layout_transform(y, "NCHW8c", "NCHW")
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before()
-        a = run_opt_pass(a, [transform.CanonicalizeOps(),
-                             transform.AlterOpLayout()])
+        a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert(tvm.ir.structural_equal(a, b))
+    assert tvm.ir.structural_equal(a, b)
+
 
 def test_alter_layout_prelu():
     """Test PRelu operator"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight = relay.var("weight")
@@ -723,7 +699,7 @@ def test_alter_layout_prelu():
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
     def expected():
@@ -732,11 +708,9 @@ def test_alter_layout_prelu():
         alpha = relay.var("alpha", relay.IncompleteType())
 
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, w, channels=64, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.nn.prelu(y, alpha)
         y = relay.Function(analysis.free_vars(y), y)
@@ -747,39 +721,34 @@ def test_alter_layout_prelu():
         a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert(tvm.ir.structural_equal(a, b))
+    assert tvm.ir.structural_equal(a, b)
 
 
 def test_alter_layout_pad():
     """ Check NCHW, NHWC and corner case for pad layout conversion"""
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     # Check NCHW conversion.
     def before_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
         ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -795,25 +764,21 @@ def test_alter_layout_pad():
     # Check NHWC conversion.
     def before_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC')
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(
+            x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
+        )
         ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)))
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NHWC", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
         ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -829,24 +794,19 @@ def test_alter_layout_pad():
     # Check that conversion does not happen when padding along split axis.
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.layout_transform(y, "NCHW16c", "NCHW")
         ret = relay.nn.pad(ret, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -862,35 +822,30 @@ def test_alter_layout_pad():
 
 def test_alter_layout_pool():
     """ Check NCHW, NHWC pool layout conversion"""
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     # Check NCHW conversion.
     def before_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.nn.avg_pool2d(y, pool_size=(1, 1))
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
-        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout="NCHW16c")
         ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
@@ -905,26 +860,22 @@ def test_alter_layout_pool():
     # Check NHWC conversion.
     def before_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC')
-        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NHWC')
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(
+            x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
+        )
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout="NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NHWC", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
-        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout="NCHW16c")
         ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
@@ -939,34 +890,29 @@ def test_alter_layout_pool():
 
 def test_alter_layout_sum():
     """ Check NCHW, NHWC sum layout conversion"""
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
+        new_attrs["data_layout"] = "NCHW16c"
         return relay.nn.conv2d(data, weight, **new_attrs)
 
-
     # Check NCHW conversion.
     def before_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.sum(y, axis=1, keepdims=True)
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NCHW", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.layout_transform(y, "NCHW16c", "NCHW")
         ret = relay.sum(ret, axis=[1], keepdims=True)
         y = relay.Function(analysis.free_vars(ret), ret)
@@ -982,25 +928,21 @@ def test_alter_layout_sum():
     # Check NHWC conversion.
     def before_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC')
+        weight1 = relay.var("weight1")
+        y = relay.nn.conv2d(
+            x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
+        )
         ret = relay.sum(y, axis=3, keepdims=True)
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1')
+        weight1 = relay.var("weight1")
         y = relay.layout_transform(x, "NHWC", "NCHW16c")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW16c")
+        y = relay.nn.conv2d(
+            y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
         ret = relay.layout_transform(y, "NCHW16c", "NCHW")
         ret = relay.sum(ret, axis=[1], keepdims=True)
         ret = relay.layout_transform(ret, "NCHW", "NHWC")
@@ -1017,30 +959,26 @@ def test_alter_layout_sum():
 
 def test_alter_layout_nhwc_arm():
     """ Check that AlterOplayout does not alter NHWC data layout. """
+
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         from tvm import topi
+
         with tvm.target.Target("llvm -device=arm_cpu"):
             return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
     # Check NHWC conversion.
     def before_nhwc():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight1,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64))
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x, weight1, channels=64, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO"
+        )
         y = relay.nn.relu(y)
-        y = relay.nn.avg_pool2d(y,
-                                pool_size=(1,1),
-                                layout='NHWC')
-        y = relay.nn.conv2d(y, weight2,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        y = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout="NHWC")
+        y = relay.nn.conv2d(
+            y, weight2, channels=64, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO"
+        )
         y = relay.nn.relu(y)
         y = relay.Function(analysis.free_vars(y), y)
         return y
@@ -1055,9 +993,11 @@ def test_alter_layout_nhwc_arm():
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_alter_layout_nhwc_int8_aarch64():
     """ Check that AlterOplayout does not alter NHWC data layout. """
     from tvm import autotvm
+
     expected_workload_shape = (20, 42, 4, 16)
 
     # We use Int8Fallback  to disable the fallback flag
@@ -1072,6 +1012,7 @@ def test_alter_layout_nhwc_int8_aarch64():
             cfg.cost = 0
             self.memory[key] = cfg
             return cfg
+
         def update(self, target, workload, cfg):
             key = (str(target), workload)
             assert workload[2][1] == expected_workload_shape
@@ -1080,38 +1021,48 @@ def test_alter_layout_nhwc_int8_aarch64():
 
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         from tvm import topi
+
         with tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"):
             with Int8Fallback():
-                tmp =  topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
+                tmp = topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
                 return tmp
 
     # Check NHWC conversion.
     def before_nhwc_int8():
-        x = relay.var("x", shape=(1, 56, 56, 73), dtype='int8')
-        weight = relay.var('weight1', shape=(3, 3, 73, 79), dtype='int8')
-        y = relay.nn.conv2d(x, weight,
-                            channels=79,
-                            kernel_size=(3, 3),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO',
-                            out_dtype='int32')
+        x = relay.var("x", shape=(1, 56, 56, 73), dtype="int8")
+        weight = relay.var("weight1", shape=(3, 3, 73, 79), dtype="int8")
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=79,
+            kernel_size=(3, 3),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def expected_nhwc_int8():
-        x = relay.var("x", shape=(1, 56, 56, 73), dtype='int8')
-        weight = relay.var('weight1', shape=(3, 3, 73, 79), dtype='int8')
+        x = relay.var("x", shape=(1, 56, 56, 73), dtype="int8")
+        weight = relay.var("weight1", shape=(3, 3, 73, 79), dtype="int8")
         tile_rows = 4
         tile_cols = 16
-        weight_transformed = relay.nn.contrib_conv2d_gemm_weight_transform(weight, tile_rows, tile_cols)
-        y = relay.nn.contrib_conv2d_gemm_without_weight_transform(x, weight_transformed,
-                            channels=79,
-                            kernel_size=(3, 3),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO',
-                            out_dtype='int32')
+        weight_transformed = relay.nn.contrib_conv2d_gemm_weight_transform(
+            weight, tile_rows, tile_cols
+        )
+        y = relay.nn.contrib_conv2d_gemm_without_weight_transform(
+            x,
+            weight_transformed,
+            channels=79,
+            kernel_size=(3, 3),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+            out_dtype="int32",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
+
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before_nhwc_int8()
         a = run_opt_pass(a, transform.AlterOpLayout())
@@ -1119,18 +1070,17 @@ def test_alter_layout_nhwc_int8_aarch64():
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_alter_op_with_global_var():
     """Test directly replacing an operator with a new one"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
         mod = tvm.IRModule()
-        foo = relay.GlobalVar('foo')
+        foo = relay.GlobalVar("foo")
         mod[foo] = relay.Function([x, weight], y)
         mod["main"] = relay.Function([x, weight], foo(x, weight))
         return mod
@@ -1142,14 +1092,17 @@ def test_alter_op_with_global_var():
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            relay.multiply(weight, relay.const(2.0, "float32")),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
         y = relay.nn.relu(y)
         mod = tvm.IRModule()
-        foo = relay.GlobalVar('foo')
+        foo = relay.GlobalVar("foo")
         mod[foo] = relay.Function([x, weight], y)
         mod["main"] = relay.Function([x, weight], foo(x, weight))
         return mod
@@ -1161,6 +1114,7 @@ def test_alter_op_with_global_var():
 
     assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a)
 
+
 if __name__ == "__main__":
     test_alter_op()
     test_alter_return_none()
index 46989da..b7c4349 100644 (file)
@@ -28,22 +28,22 @@ from tvm import runtime
 from tvm.contrib import util
 
 
-def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
-                 ctx=tvm.cpu(), params=None):
+def check_result(
+    mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu(), params=None
+):
     if sys.platform == "win32":
         print("Skip test on Windows for now")
         return
 
     def update_lib(lib):
-        test_dir = os.path.dirname(
-            os.path.realpath(os.path.expanduser(__file__)))
+        test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
         source_dir = os.path.join(test_dir, "..", "..", "..")
         contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
 
         kwargs = {}
         kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
         tmp_path = util.tempdir()
-        lib_name = 'lib.so'
+        lib_name = "lib.so"
         lib_path = tmp_path.relpath(lib_name)
         lib.export_library(lib_path, fcompile=False, **kwargs)
         lib = runtime.load_module(lib_path)
@@ -81,18 +81,14 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
 
 def test_extern_dnnl():
     def annotated(dtype, ishape, w1shape):
-        data = relay.var('data', shape=(ishape), dtype=dtype)
-        weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
-        depthwise_conv2d_1 = relay.nn.conv2d(data,
-                                             weight1,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
-        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                             weight1,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
+        data = relay.var("data", shape=(ishape), dtype=dtype)
+        weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
+        depthwise_conv2d_1 = relay.nn.conv2d(
+            data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
+        depthwise_conv2d_2 = relay.nn.conv2d(
+            depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
         out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
         f = relay.Function([data, weight1], out)
@@ -101,25 +97,21 @@ def test_extern_dnnl():
         return mod
 
     def expected(dtype, ishape, w1shape):
-        data = relay.var('data', shape=(ishape), dtype=dtype)
-        weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+        data = relay.var("data", shape=(ishape), dtype=dtype)
+        weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
         begin0 = relay.annotation.compiler_begin(data, "dnnl")
         begin1 = relay.annotation.compiler_begin(weight1, "dnnl")
-        depthwise_conv2d_1 = relay.nn.conv2d(begin0,
-                                             begin1,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
+        depthwise_conv2d_1 = relay.nn.conv2d(
+            begin0, begin1, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
         end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
         end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
         begin2 = relay.annotation.compiler_begin(end1, "dnnl")
         begin3 = relay.annotation.compiler_begin(end0, "dnnl")
         begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
-        depthwise_conv2d_2 = relay.nn.conv2d(begin3,
-                                             begin4,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
+        depthwise_conv2d_2 = relay.nn.conv2d(
+            begin3, begin4, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
         end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
         begin5 = relay.annotation.compiler_begin(end2, "dnnl")
         out = relay.add(begin2, begin5)
@@ -153,35 +145,34 @@ def test_extern_dnnl():
         ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
         ref_res = ref_ex.evaluate()(i_data, w1_data)
 
-        check_result(mod, {"data": i_data, "weight1": w1_data},
-                     (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+        check_result(
+            mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5
+        )
 
     test_annotate()
     test_run()
 
+
 @pytest.mark.skip(reason="fix constant node before opening this case")
 def test_extern_dnnl_mobilenet():
     if not tvm.get_global_func("relay.ext.dnnl", True):
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 3, 224, 224)
-    mod, params = relay.testing.mobilenet.get_workload(
-        batch_size=1, dtype='float32')
+    mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32")
 
     mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
     mod = transform.AnnotateTarget("dnnl")(mod)
     mod = transform.PartitionGraph()(mod)
     i_data = np.random.uniform(0, 1, ishape).astype(dtype)
 
-    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
-                                                           dtype='float32')
+    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32")
     ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
     ref_res = ref_ex.evaluate()(i_data, **params)
 
-    check_result(mod, {"data": i_data},
-                 (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+    check_result(mod, {"data": i_data}, (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
 
 
 def test_multiple_ends():
@@ -228,7 +219,7 @@ def test_type_propagation():
     target = "test_type_propagation"
 
     @tvm.ir.register_op_attr("nn.relu", "target." + target)
-    def relu(attrs, args): # pylint: disable=unused-variable
+    def relu(attrs, args):  # pylint: disable=unused-variable
         return args[0].checked_type.dtype == "float32"
 
     def before():
@@ -247,7 +238,7 @@ def test_tuple():
     target = "test_tuple"
 
     @tvm.ir.register_op_attr("nn.relu", "target." + target)
-    def relu(attrs, args): # pylint: disable=unused-variable
+    def relu(attrs, args):  # pylint: disable=unused-variable
         return True
 
     @tvm.ir.register_op_attr("concatenate", "target." + target)
@@ -255,6 +246,7 @@ def test_tuple():
         return True
 
     """Test that TupleNode is included in annotation when surrounded by supported nodes."""
+
     def before():
         x = relay.var("x", shape=(10, 5))
         y = relay.var("y", shape=(10, 5))
@@ -292,12 +284,12 @@ def test_tuple():
 
 def test_composite_function():
     def before():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
 
         # add_relu function
-        in_1 = relay.var('in_1', shape=(10, 10))
-        in_2 = relay.var('in_2', shape=(10, 10))
+        in_1 = relay.var("in_1", shape=(10, 10))
+        in_2 = relay.var("in_2", shape=(10, 10))
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
@@ -310,12 +302,12 @@ def test_composite_function():
         return mod
 
     def after():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
 
         # add_relu function
-        in_1 = relay.var('in_1', shape=(10, 10))
-        in_2 = relay.var('in_2', shape=(10, 10))
+        in_1 = relay.var("in_1", shape=(10, 10))
+        in_2 = relay.var("in_2", shape=(10, 10))
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
@@ -364,7 +356,7 @@ def test_multiple_runs():
 if __name__ == "__main__":
     test_extern_dnnl()
     test_composite_function()
-    #test_extern_dnnl_mobilenet()
+    # test_extern_dnnl_mobilenet()
     test_multiple_ends()
     test_type_propagation()
     test_tuple()
index c55120e..ff68d48 100644 (file)
@@ -25,18 +25,17 @@ from tvm.relay.expr_functor import ExprMutator
 from tvm.relay import transform
 import tvm.testing
 
+
 def _trace(module, metadata, _):
-    if metadata.name == 'ManifestAlloc':
-        pass # import pdb; pdb.set_trace()
+    if metadata.name == "ManifestAlloc":
+        pass  # import pdb; pdb.set_trace()
 
 
-def check_graph_runtime(target, ref_res, device, func, params, config,
-                        opt_level, expected_index=None):
+def check_graph_runtime(
+    target, ref_res, device, func, params, config, opt_level, expected_index=None
+):
     with tvm.transform.PassContext(opt_level=opt_level, config=config):
-        graph, lib, new_params = relay.build(
-            func,
-            target,
-            params=params)
+        graph, lib, new_params = relay.build(func, target, params=params)
         contexts = [tvm.cpu(0), tvm.context(device)]
         graph_json = json.loads(graph)
         if "device_index" in graph_json["attrs"]:
@@ -49,8 +48,7 @@ def check_graph_runtime(target, ref_res, device, func, params, config,
         tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
 
 
-def check_vm_runtime(target, ref_res, device, func, params, config,
-                     opt_level, expected_index=None):
+def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None):
     with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config):
         mod = tvm.IRModule()
         mod["main"] = func
@@ -60,6 +58,7 @@ def check_vm_runtime(target, ref_res, device, func, params, config,
         res = vm.invoke("main", **params)
         tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
+
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
@@ -84,8 +83,7 @@ def test_redundant_annotation():
         sub2 = relay.subtract(_add2, z)
 
         func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
-        func = run_opt_pass(func,
-                            transform.RewriteAnnotatedOps(ctx1.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(ctx1.device_type))
         return func
 
     def expected():
@@ -114,8 +112,7 @@ def test_annotate_expr():
         _add = relay.annotation.on_device(add, ctx1)
         sub = relay.subtract(_add, z)
         _sub = relay.annotation.on_device(sub, ctx2)
-        expr = run_opt_pass(_sub,
-                            transform.RewriteAnnotatedOps(ctx1.device_type))
+        expr = run_opt_pass(_sub, transform.RewriteAnnotatedOps(ctx1.device_type))
         return expr
 
     def expected():
@@ -143,8 +140,7 @@ def test_annotate_all():
         _sub = relay.annotation.on_device(sub, ctx2)
 
         func = relay.Function([x, y, z], _sub)
-        func = run_opt_pass(func,
-                            transform.RewriteAnnotatedOps(ctx1.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(ctx1.device_type))
         return func
 
     def expected():
@@ -169,8 +165,7 @@ def test_annotate_none():
         add = relay.add(x, y)
         sub = relay.subtract(add, z)
         func = relay.Function([x, y, z], sub)
-        func = run_opt_pass(func,
-                            transform.RewriteAnnotatedOps(ctx1.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(ctx1.device_type))
         return func
 
     def expected():
@@ -191,14 +186,14 @@ def check_annotated_graph(annotated_func, expected_func):
 
 
 def test_conv_network():
-    R""" The network is as following:
-             data1     data2
-               |         |
-             conv2d    conv2d
-                \       /
-                   add
-                    |
-                  conv2d
+    R"""The network is as following:
+    data1     data2
+      |         |
+    conv2d    conv2d
+       \       /
+          add
+           |
+         conv2d
     """
     batch_size = 1
     dshape = (batch_size, 64, 56, 56)
@@ -209,60 +204,27 @@ def test_conv_network():
     dev2 = tvm.context(2)
 
     def original():
-        conv2d_1 = relay.nn.conv2d(
-            data1,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
-        conv2d_2 = relay.nn.conv2d(
-            data2,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         add = relay.add(conv2d_1, conv2d_2)
-        conv2d_3 = relay.nn.conv2d(
-            add,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_3 = relay.nn.conv2d(add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
 
         func = relay.Function([data1, data2, weight], conv2d_3)
-        func = run_opt_pass(
-            func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
         return func
 
-
     def annotated():
-        conv2d_1 = relay.nn.conv2d(
-            data1,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         _conv2d_1 = relay.annotation.on_device(conv2d_1, dev2)
-        conv2d_2 = relay.nn.conv2d(
-            data2,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
         add = relay.add(_conv2d_1, _conv2d_2)
         _add = relay.annotation.on_device(add, dev1)
-        conv2d_3 = relay.nn.conv2d(
-            _add,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_3 = relay.nn.conv2d(_add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
 
         func = relay.Function([data1, data2, weight], _conv2d_3)
-        func = run_opt_pass(
-            func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
         return func
 
     class ScheduleConv2d(ExprMutator):
@@ -280,41 +242,26 @@ def test_conv_network():
     def annotate_with_visitor(func):
         sched = ScheduleConv2d(dev2)
         func = sched.visit(func)
-        func = run_opt_pass(
-            func, transform.RewriteAnnotatedOps(dev1.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type))
         return func
 
     def expected():
-        conv2d_1 = relay.nn.conv2d(
-            data1,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         device_copy1 = relay.device_copy(conv2d_1, dev2, dev1)
-        conv2d_2 = relay.nn.conv2d(
-            data2,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+        conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         device_copy2 = relay.device_copy(conv2d_2, dev2, dev1)
         add = relay.add(device_copy1, device_copy2)
         device_copy3 = relay.device_copy(add, dev1, dev2)
         conv2d_3 = relay.nn.conv2d(
-            device_copy3,
-            weight,
-            channels=64,
-            kernel_size=(3, 3),
-            padding=(1, 1))
+            device_copy3, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)
+        )
 
         func = relay.Function([data1, data2, weight], conv2d_3)
         return func
 
     def check_storage_and_device_types():
         func = annotated()
-        func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3),
-                                   transform.FuseOps(2)])
+        func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3), transform.FuseOps(2)])
         smap = relay.backend._backend.GraphPlanMemory(func)
         storage_ids = []
         device_types = []
@@ -359,27 +306,21 @@ def test_propogation():
     ctx1 = tvm.context(1)
     ctx2 = tvm.context(2)
 
-    expected_dev_type = {
-        'log': ctx1,
-        'log2': ctx2,
-        'log10': ctx2,
-        'add': ctx2,
-        'tan': ctx1
-    }
+    expected_dev_type = {"log": ctx1, "log2": ctx2, "log10": ctx2, "add": ctx2, "tan": ctx1}
 
     x = relay.var("x", shape=(3,))
 
     def annotated():
         log = relay.log(x)
-        _log = relay.annotation.on_device(log, expected_dev_type['log'])
+        _log = relay.annotation.on_device(log, expected_dev_type["log"])
         log2 = relay.log2(_log)
-        _log2 = relay.annotation.on_device(log2, expected_dev_type['log2'])
+        _log2 = relay.annotation.on_device(log2, expected_dev_type["log2"])
         log10 = relay.log10(_log)
-        _log10 = relay.annotation.on_device(log10, expected_dev_type['log10'])
+        _log10 = relay.annotation.on_device(log10, expected_dev_type["log10"])
         add = relay.add(_log2, _log10)
-        _add = relay.annotation.on_device(add, expected_dev_type['add'])
+        _add = relay.annotation.on_device(add, expected_dev_type["add"])
         tan = relay.tan(_add)
-        _tan = relay.annotation.on_device(tan, expected_dev_type['tan'])
+        _tan = relay.annotation.on_device(tan, expected_dev_type["tan"])
 
         func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type))
         return func
@@ -428,8 +369,8 @@ def run_fusible_network(dev, tgt):
     """
     x = relay.var("x", shape=(1, 10))
     y = relay.var("y", shape=(10, 10))
-    x_data = np.random.rand(1, 10).astype('float32')
-    y_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(1, 10).astype("float32")
+    y_data = np.random.rand(10, 10).astype("float32")
     tmp_add = x_data + y_data
     tmp_sqrt = np.sqrt(tmp_add)
     tmp_log = np.log(tmp_add)
@@ -464,8 +405,7 @@ def run_fusible_network(dev, tgt):
             _exp = relay.annotation.on_device(exp, dev_ctx)
 
             func = relay.Function([x, y], _exp)
-            func = run_opt_pass(
-                func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
+            func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
             return func
 
         def expected():
@@ -489,11 +429,13 @@ def run_fusible_network(dev, tgt):
         check_annotated_graph(annotated_func, expected_func)
         opt_level = 1
         config = {"relay.fallback_device_type": fallback_device.device_type}
-        check_graph_runtime(target, ref_res, device, annotated_func, params,
-                            config, opt_level, expected_index)
+        check_graph_runtime(
+            target, ref_res, device, annotated_func, params, config, opt_level, expected_index
+        )
         opt_level = 2
-        check_vm_runtime(target, ref_res, device, annotated_func, params,
-                         config, opt_level, expected_index)
+        check_vm_runtime(
+            target, ref_res, device, annotated_func, params, config, opt_level, expected_index
+        )
 
     def test_fuse_all(device, tgt):
         """Fuse all operators."""
@@ -515,8 +457,7 @@ def run_fusible_network(dev, tgt):
             _exp = relay.annotation.on_device(exp, dev_ctx)
 
             func = relay.Function([x, y], _exp)
-            func = run_opt_pass(
-                func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
+            func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
             return func
 
         annotated_func = annotated()
@@ -524,11 +465,9 @@ def run_fusible_network(dev, tgt):
         check_annotated_graph(annotated_func, expected_func)
         opt_level = 1
         config = {"relay.fallback_device_type": fallback_device.device_type}
-        check_graph_runtime(target, ref_res, device, annotated_func, params,
-                            config, opt_level)
+        check_graph_runtime(target, ref_res, device, annotated_func, params, config, opt_level)
         opt_level = 2
-        check_vm_runtime(target, ref_res, device, annotated_func, params,
-                         config, opt_level)
+        check_vm_runtime(target, ref_res, device, annotated_func, params, config, opt_level)
 
     def test_fallback_exp(device, tgt):
         fallback_device = tvm.context("cpu")
@@ -545,8 +484,7 @@ def run_fusible_network(dev, tgt):
             _exp = relay.annotation.on_device(exp, cpu_ctx)
 
             func = relay.Function([x, y], _exp)
-            func = run_opt_pass(
-                func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
+            func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
             return func
 
         def expected():
@@ -568,11 +506,13 @@ def run_fusible_network(dev, tgt):
         opt_level = 1
         config = {"relay.fallback_device_type": fallback_device.device_type}
         check_annotated_graph(annotated_func, expected_func)
-        check_graph_runtime(target, ref_res, device, annotated_func, params, config,
-                            opt_level, expected_index)
+        check_graph_runtime(
+            target, ref_res, device, annotated_func, params, config, opt_level, expected_index
+        )
         opt_level = 2
-        check_vm_runtime(target, ref_res, device, annotated_func, params, config,
-                         opt_level, expected_index)
+        check_vm_runtime(
+            target, ref_res, device, annotated_func, params, config, opt_level, expected_index
+        )
 
     def test_fallback_all_operators(device, tgt):
         target = {device: tgt, "cpu": "llvm"}
@@ -580,11 +520,8 @@ def run_fusible_network(dev, tgt):
         expected_func = get_func()
         check_annotated_graph(annotated_func, expected_func)
         opt_level = 2
-        check_graph_runtime(target, ref_res, device, annotated_func, params, {},
-                            opt_level)
-        check_vm_runtime(target, ref_res, device, annotated_func, params, {},
-                         opt_level)
-
+        check_graph_runtime(target, ref_res, device, annotated_func, params, {}, opt_level)
+        check_vm_runtime(target, ref_res, device, annotated_func, params, {}, opt_level)
 
     test_fuse_log_add(dev, tgt)
     test_fuse_all(dev, tgt)
@@ -593,22 +530,22 @@ def run_fusible_network(dev, tgt):
 
 
 def run_unpropagatable_graph(dev, tgt):
-    R""" The network is as following:
-            a     b  c     d
-             \   /    \   /
-              add      mul
-                \      /
-                subtract
+    R"""The network is as following:
+    a     b  c     d
+     \   /    \   /
+      add      mul
+        \      /
+        subtract
     """
 
     a = relay.var("a", shape=(10, 10))
     b = relay.var("b", shape=(10, 10))
     c = relay.var("c", shape=(10, 10))
     d = relay.var("d", shape=(10, 10))
-    a_data = np.random.rand(10, 10).astype('float32')
-    b_data = np.random.rand(10, 10).astype('float32')
-    c_data = np.random.rand(10, 10).astype('float32')
-    d_data = np.random.rand(10, 10).astype('float32')
+    a_data = np.random.rand(10, 10).astype("float32")
+    b_data = np.random.rand(10, 10).astype("float32")
+    c_data = np.random.rand(10, 10).astype("float32")
+    d_data = np.random.rand(10, 10).astype("float32")
     tmp_add = a_data + b_data
     tmp_mul = np.multiply(c_data, d_data)
     ref_res = np.subtract(tmp_add, tmp_mul)
@@ -626,8 +563,7 @@ def run_unpropagatable_graph(dev, tgt):
         sub = relay.subtract(_add, _mul)
         _sub = relay.annotation.on_device(sub, dev_ctx)
         func = relay.Function([a, b, c, d], _sub)
-        func = run_opt_pass(
-            func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
         return func
 
     def expected():
@@ -646,12 +582,12 @@ def run_unpropagatable_graph(dev, tgt):
     opt_level = 0
     config = {"relay.fallback_device_type": fallback_device.device_type}
 
-    check_graph_runtime(target, ref_res, dev, annotated_func, params, config,
-                        opt_level, expected_index)
+    check_graph_runtime(
+        target, ref_res, dev, annotated_func, params, config, opt_level, expected_index
+    )
 
     opt_level = 2
-    check_vm_runtime(target, ref_res, dev, annotated_func, params, config,
-                     opt_level)
+    check_vm_runtime(target, ref_res, dev, annotated_func, params, config, opt_level)
 
 
 @tvm.testing.requires_opencl
@@ -701,8 +637,7 @@ def test_tuple_get_item():
         split = relay.TupleWrapper(split, 3)
         sub = split[0] - split[1]
         func = relay.Function(relay.analysis.free_vars(sub), sub)
-        func = run_opt_pass(
-            func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
+        func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
         return func
 
     annotated_func = annotated()
index 3558ebc..034cb48 100644 (file)
@@ -36,14 +36,14 @@ def quantize_and_build(out):
 
     return qmod
 
+
 def test_mul_rewrite():
     """a test case where rhs of mul is not constant"""
     data = relay.var("data", shape=(1, 16, 64, 64))
     multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1)))
-    conv = relay.nn.conv2d(data, relay.var("weight"),
-                           kernel_size=(3, 3),
-                           padding=(1, 1),
-                           channels=16)
+    conv = relay.nn.conv2d(
+        data, relay.var("weight"), kernel_size=(3, 3), padding=(1, 1), channels=16
+    )
     act = relay.nn.relu(data=conv)
 
     quantize_and_build(act * multiplier)
@@ -52,14 +52,14 @@ def test_mul_rewrite():
 
     quantize_and_build(act * pool)
 
+
 def test_batch_flatten_rewrite():
 
     data = relay.var("data", shape=(1, 16, 64, 64), dtype="float32")
 
-    out = relay.nn.conv2d(data, relay.var("weight"),
-                          kernel_size=(3, 3),
-                          padding=(1, 1),
-                          channels=16)
+    out = relay.nn.conv2d(
+        data, relay.var("weight"), kernel_size=(3, 3), padding=(1, 1), channels=16
+    )
 
     out = relay.nn.batch_flatten(out)
 
@@ -67,12 +67,13 @@ def test_batch_flatten_rewrite():
 
     def _check_batch_flatten(node):
         if isinstance(node, Call):
-            if(node.op.name == "nn.batch_flatten"):
-               assert node.checked_type.dtype == "int8"
+            if node.op.name == "nn.batch_flatten":
+                assert node.checked_type.dtype == "int8"
 
     # check if batch_flatten is quantized
     relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten)
 
+
 def get_calibration_dataset(mod, input_name):
     dataset = []
     input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape]
@@ -99,9 +100,9 @@ def test_calibrate_memory_bound():
     mod, params = testing.synthetic.get_workload()
     dataset = get_calibration_dataset(mod, "data")
     import multiprocessing
+
     num_cpu = multiprocessing.cpu_count()
-    with relay.quantize.qconfig(calibrate_mode="kl_divergence",
-                                calibrate_chunk_by=num_cpu):
+    with relay.quantize.qconfig(calibrate_mode="kl_divergence", calibrate_chunk_by=num_cpu):
         relay.quantize.quantize(mod, params, dataset)
 
 
@@ -110,73 +111,78 @@ def test_calibrate_memory_bound():
 ####################################
 
 BASE_CFG = {
-  'skip_conv_layers': [],
-  'skip_dense_layers': False,
-  'dtype_input': "int8",
-  'dtype_weight': "int8",
-  'dtype_activation': "int32",
+    "skip_conv_layers": [],
+    "skip_dense_layers": False,
+    "dtype_input": "int8",
+    "dtype_weight": "int8",
+    "dtype_activation": "int32",
 }
 
+
 def gen_rand_tvm(tt, low, high):
-    if 'int' in tt.dtype:
+    if "int" in tt.dtype:
         data_np = np.random.randint(low, high, size=get_const_tuple(tt.shape), dtype=tt.dtype)
-    elif 'float' in tt.dtype:
+    elif "float" in tt.dtype:
         data_np = np.random.uniform(low, high, size=get_const_tuple(tt.shape)).astype(tt.dtype)
     else:
-        assert False, 'unknown dtype'
+        assert False, "unknown dtype"
     return tvm.nd.array(data_np, ctx=tvm.cpu(0))
 
 
 def verify_partition_fails(mod, params):
     # standard partition should always succeed
-    with relay.quantize.qconfig(**BASE_CFG, partition_conversions='enabled'):
+    with relay.quantize.qconfig(**BASE_CFG, partition_conversions="enabled"):
         partitioned_mod = relay.quantize.quantize(mod, params)
 
     try:
-        with relay.quantize.qconfig(**BASE_CFG, partition_conversions='fully_integral'):
+        with relay.quantize.qconfig(**BASE_CFG, partition_conversions="fully_integral"):
             partitioned_mod = relay.quantize.quantize(mod, params)
-        raise RuntimeError('partitioning should have failed')
+        raise RuntimeError("partitioning should have failed")
     except AssertionError:
         pass
 
 
 def verify_partition(mod, params):
-    with relay.quantize.qconfig(**BASE_CFG, paritition_conversions='disabled'):
+    with relay.quantize.qconfig(**BASE_CFG, paritition_conversions="disabled"):
         unpartitioned_mod = relay.quantize.quantize(mod, params)
-        assert len(unpartitioned_mod.get_global_vars()) == 1, \
-            'unpartitioned module should only have one function'
-    with relay.quantize.qconfig(**BASE_CFG, partition_conversions='fully_integral'):
+        assert (
+            len(unpartitioned_mod.get_global_vars()) == 1
+        ), "unpartitioned module should only have one function"
+    with relay.quantize.qconfig(**BASE_CFG, partition_conversions="fully_integral"):
         partitioned_mod = relay.quantize.quantize(mod, params)
 
     # ensure partitioned and unpartitioned results agree
-    params = [
-        gen_rand_tvm(param.type_annotation, 0, 1)
-        for param in partitioned_mod['main'].params
-    ]
+    params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in partitioned_mod["main"].params]
+
     def _eval_mod(mod):
-        vm = relay.create_executor('vm', ctx=tvm.cpu(0), target='llvm', mod=mod)
+        vm = relay.create_executor("vm", ctx=tvm.cpu(0), target="llvm", mod=mod)
         return vm.evaluate()(*params)
+
     partitioned_mod_result = _eval_mod(partitioned_mod)
     unpartitioned_mod_result = _eval_mod(unpartitioned_mod)
     tvm.testing.assert_allclose(
-        unpartitioned_mod_result.asnumpy(), partitioned_mod_result.asnumpy())
+        unpartitioned_mod_result.asnumpy(), partitioned_mod_result.asnumpy()
+    )
 
 
 def test_add_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x: Tensor[(10, 10), float32],
         %y: Tensor[(10, 10), float32]) {
       add(%x, %y)
     }
-    """)
+    """
+    )
     params = {}
     verify_partition_fails(mod, params)
 
 
 def test_conv2d_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x: Tensor[(1, 4, 16, 16), float32],
@@ -186,16 +192,16 @@ def test_conv2d_partition():
         channels=4,
         kernel_size=[3, 3])
     }
-    """)
-    weight_ty = mod['main'].params[1].checked_type
-    params = {
-        'w': gen_rand_tvm(weight_ty, 0, 1)
-    }
+    """
+    )
+    weight_ty = mod["main"].params[1].checked_type
+    params = {"w": gen_rand_tvm(weight_ty, 0, 1)}
     verify_partition(mod, params)
 
 
 def test_multiple_arg_conversions_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x1: Tensor[(1, 4, 16, 16), float32],
@@ -213,19 +219,18 @@ def test_multiple_arg_conversions_partition():
         kernel_size=[3, 3]);
       add(%0, %1)
     }
-    """)
+    """
+    )
 
-    w1_ty = mod['main'].params[1].checked_type
-    w2_ty = mod['main'].params[3].checked_type
-    params = {
-        'w1': gen_rand_tvm(w1_ty, 0, 1),
-        'w2': gen_rand_tvm(w2_ty, 0, 1)
-    }
+    w1_ty = mod["main"].params[1].checked_type
+    w2_ty = mod["main"].params[3].checked_type
+    params = {"w1": gen_rand_tvm(w1_ty, 0, 1), "w2": gen_rand_tvm(w2_ty, 0, 1)}
     verify_partition(mod, params)
 
 
 def test_unquantizable_prefix_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x: Tensor[(1, 4, 16, 16), float32],
@@ -238,18 +243,17 @@ def test_unquantizable_prefix_partition():
         channels=4,
         kernel_size=[3, 3])
     }
-    """)
-    bias_ty = mod['main'].params[1].checked_type
-    weight_ty = mod['main'].params[2].checked_type
-    params = {
-        'b': gen_rand_tvm(bias_ty, 0, 1),
-        'w': gen_rand_tvm(weight_ty, 0, 1)
-    }
+    """
+    )
+    bias_ty = mod["main"].params[1].checked_type
+    weight_ty = mod["main"].params[2].checked_type
+    params = {"b": gen_rand_tvm(bias_ty, 0, 1), "w": gen_rand_tvm(weight_ty, 0, 1)}
     verify_partition_fails(mod, params)
 
 
 def test_unquantizable_core_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x1: Tensor[(1, 4, 16, 16), float32],
@@ -267,20 +271,22 @@ def test_unquantizable_core_partition():
         channels=4,
         kernel_size=[3, 3])
     }
-    """)
-    w1_ty = mod['main'].params[1].checked_type
-    bias_ty = mod['main'].params[2].checked_type
-    w2_ty = mod['main'].params[3].checked_type
+    """
+    )
+    w1_ty = mod["main"].params[1].checked_type
+    bias_ty = mod["main"].params[2].checked_type
+    w2_ty = mod["main"].params[3].checked_type
     params = {
-        'w1': gen_rand_tvm(w1_ty, 0, 1),
-        'w2': gen_rand_tvm(w2_ty, 0, 1),
-        'b': gen_rand_tvm(bias_ty, 0, 1)
+        "w1": gen_rand_tvm(w1_ty, 0, 1),
+        "w2": gen_rand_tvm(w2_ty, 0, 1),
+        "b": gen_rand_tvm(bias_ty, 0, 1),
     }
     verify_partition_fails(mod, params)
 
 
 def test_unquantizable_suffix_partition():
-    mod = tvm.parser.parse("""
+    mod = tvm.parser.parse(
+        """
     #[version = "0.0.5"]
     def @main(
         %x: Tensor[(1, 4, 16, 16), float32],
@@ -293,13 +299,11 @@ def test_unquantizable_suffix_partition():
       // NOTE bias_add isn't currently quantizable
       nn.bias_add(%0, %b)
     }
-    """)
-    weight_ty = mod['main'].params[1].checked_type
-    bias_ty = mod['main'].params[2].checked_type
-    params = {
-        'w': gen_rand_tvm(weight_ty, 0, 1),
-        'b': gen_rand_tvm(bias_ty, 0, 1)
-    }
+    """
+    )
+    weight_ty = mod["main"].params[1].checked_type
+    bias_ty = mod["main"].params[2].checked_type
+    params = {"w": gen_rand_tvm(weight_ty, 0, 1), "b": gen_rand_tvm(bias_ty, 0, 1)}
     verify_partition_fails(mod, params)
 
 
index e13547b..321d866 100644 (file)
@@ -23,11 +23,9 @@ import tvm.relay.transform as _transform
 
 def test_canonicalize_cast():
     def before(data, conv_weight, bias1, bias2):
-        x = relay.nn.conv2d(data, conv_weight,
-                          channels=16,
-                          kernel_size=(3, 3),
-                          padding=(1, 1),
-                          out_dtype="int8")
+        x = relay.nn.conv2d(
+            data, conv_weight, channels=16, kernel_size=(3, 3), padding=(1, 1), out_dtype="int8"
+        )
         x1 = relay.cast(x, dtype="int32")
         y1 = relay.add(x1, bias1)
         y2 = relay.add(x1, bias2)
@@ -35,11 +33,9 @@ def test_canonicalize_cast():
         return relay.Function([data, conv_weight, bias1, bias2], y)
 
     def expected(data, conv_weight, bias1, bias2):
-        x = relay.nn.conv2d(data, conv_weight,
-                          channels=16,
-                          kernel_size=(3, 3),
-                          padding=(1, 1),
-                          out_dtype="int8")
+        x = relay.nn.conv2d(
+            data, conv_weight, channels=16, kernel_size=(3, 3), padding=(1, 1), out_dtype="int8"
+        )
         x1 = relay.cast(x, dtype="int32")
         x2 = relay.cast(x, dtype="int32")
         y1 = relay.add(x1, bias1)
@@ -54,8 +50,9 @@ def test_canonicalize_cast():
         bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
         y = before(data, conv_weight, bias1, bias2)
         mod = tvm.IRModule.from_expr(y)
-        seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
-                                     _transform.InferType()])
+        seq = tvm.transform.Sequential(
+            [_transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType()]
+        )
         with tvm.transform.PassContext(opt_level=3):
             mod = seq(mod)
         y = mod["main"]
@@ -69,5 +66,5 @@ def test_canonicalize_cast():
     check((1, 16, 7, 7))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_canonicalize_cast()
index 06fe13a..41c7754 100644 (file)
@@ -20,11 +20,12 @@ from tvm import relay
 from tvm.relay.analysis import check_kind
 import pytest
 
+
 def test_typevar_kind():
     # returns the same kind
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.Type)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar)
-    tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.Type)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.ShapeVar)
+    tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint)
 
     assert check_kind(tp1) == relay.TypeKind.Type
     assert check_kind(tp2) == relay.TypeKind.ShapeVar
@@ -33,9 +34,11 @@ def test_typevar_kind():
 
 def test_tuple_kind():
     # only contain type kinds
-    tp = relay.TypeVar('tp', relay.TypeKind.Type)
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
-    tf = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([]))
+    tp = relay.TypeVar("tp", relay.TypeKind.Type)
+    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
+    tf = relay.FuncType(
+        tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([])
+    )
     fields = tvm.runtime.convert([tp, tf, tt])
 
     tup_ty = relay.TupleType(fields)
@@ -44,14 +47,14 @@ def test_tuple_kind():
 
 def test_func_kind():
     # only contain type kinds
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.Type)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.Type)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.Type)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.Type)
 
     shape = tvm.runtime.convert([1, 2, 3])
-    dtype = 'float32'
+    dtype = "float32"
     tensor_type = relay.TensorType(shape, dtype)
 
-    tr = relay.TypeRelation(None, tvm.runtime.convert([tensor_type, tp1]) , 1, None)
+    tr = relay.TypeRelation(None, tvm.runtime.convert([tensor_type, tp1]), 1, None)
 
     type_params = tvm.runtime.convert([tp1, tp2])
     type_constraints = tvm.runtime.convert([tr])
@@ -64,8 +67,10 @@ def test_func_kind():
 
 def test_ref_kind():
     # only contain type kinds
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
-    ft = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([]))
+    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
+    ft = relay.FuncType(
+        tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([])
+    )
 
     rt1 = relay.RefType(tt)
     assert check_kind(rt1) == relay.TypeKind.Type
@@ -77,9 +82,11 @@ def test_ref_kind():
 
 def test_relation_kind():
     # only have type kinds for arguments
-    tp = relay.TypeVar('tp', relay.TypeKind.Type)
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
-    tf = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([]))
+    tp = relay.TypeVar("tp", relay.TypeKind.Type)
+    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
+    tf = relay.FuncType(
+        tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([])
+    )
     args = tvm.runtime.convert([tf, tt, tp])
 
     tr = relay.TypeRelation(None, args, 2, None)
@@ -87,15 +94,15 @@ def test_relation_kind():
 
 
 def test_global_typevar_kind():
-    v1 = relay.GlobalTypeVar('gtv1', relay.TypeKind.AdtHandle)
-    v2 = relay.GlobalTypeVar('gtv2', relay.TypeKind.Type)
+    v1 = relay.GlobalTypeVar("gtv1", relay.TypeKind.AdtHandle)
+    v2 = relay.GlobalTypeVar("gtv2", relay.TypeKind.Type)
 
     assert check_kind(v1) == relay.TypeKind.AdtHandle
     assert check_kind(v2) == relay.TypeKind.Type
 
 
 def test_typecall_kind():
-    gtv = relay.GlobalTypeVar('gtv')
+    gtv = relay.GlobalTypeVar("gtv")
 
     mod = tvm.IRModule()
     data = relay.TypeData(gtv, [], [])
@@ -104,7 +111,7 @@ def test_typecall_kind():
     assert check_kind(empty_call, mod) == relay.TypeKind.Type
 
     new_mod = tvm.IRModule()
-    tv = relay.TypeVar('tv')
+    tv = relay.TypeVar("tv")
     new_data = relay.TypeData(gtv, [tv], [])
     new_mod[gtv] = new_data
     call = relay.TypeCall(gtv, [relay.TupleType([])])
@@ -113,9 +120,9 @@ def test_typecall_kind():
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_tuple_kind():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType)
-    tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.BaseType)
+    tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint)
     fields = tvm.runtime.convert([tp1, tp2, tp3])
 
     tup_ty = relay.TupleType(fields)
@@ -124,9 +131,9 @@ def test_invalid_tuple_kind():
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_func_kind():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType)
-    tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.BaseType)
+    tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint)
 
     type_params = tvm.runtime.convert([tp1, tp2, tp3])
     type_constraints = tvm.runtime.convert([])
@@ -139,16 +146,16 @@ def test_invalid_func_kind():
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_ref_kind():
-    tp = relay.TypeVar('tp', relay.TypeKind.ShapeVar)
+    tp = relay.TypeVar("tp", relay.TypeKind.ShapeVar)
     rt = relay.RefType(tp)
     check_kind(rt)
 
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_invalid_relation_kind():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType)
-    tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.BaseType)
+    tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint)
     args = tvm.runtime.convert([tp1, tp2, tp3])
 
     func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
@@ -159,7 +166,7 @@ def test_invalid_relation_kind():
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_typecall_invalid_callee():
     # global type var must be an ADT handle
-    gtv = relay.GlobalTypeVar('v1', relay.TypeKind.Type)
+    gtv = relay.GlobalTypeVar("v1", relay.TypeKind.Type)
     check_kind(relay.TypeCall(gtv, []))
 
 
@@ -167,7 +174,7 @@ def test_typecall_invalid_callee():
 def test_typecall_invalid_args():
     # args must all be type kind
     mod = tvm.IRModule()
-    gtv = relay.GlobalTypeVar('v1')
+    gtv = relay.GlobalTypeVar("v1")
     data = relay.TypeData(gtv, [], [])
     mod[gtv] = data
 
@@ -177,8 +184,8 @@ def test_typecall_invalid_args():
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_typecall_invalid_num_args():
     mod = tvm.IRModule()
-    gtv = relay.GlobalTypeVar('v1')
-    tv = relay.TypeVar('tv')
+    gtv = relay.GlobalTypeVar("v1")
+    tv = relay.TypeVar("tv")
     data = relay.TypeData(gtv, [tv], [])
     mod[gtv] = data
     check_kind(relay.TypeCall(gtv, []))
@@ -186,51 +193,64 @@ def test_typecall_invalid_num_args():
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_func_with_invalid_ret_type():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.Type)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar)
-    tf = relay.FuncType(tvm.runtime.convert([tp1]), tp2, tvm.runtime.convert([tp1, tp2]), tvm.runtime.convert([]))
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.Type)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.ShapeVar)
+    tf = relay.FuncType(
+        tvm.runtime.convert([tp1]), tp2, tvm.runtime.convert([tp1, tp2]), tvm.runtime.convert([])
+    )
 
     check_kind(tf)
 
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_func_with_invalid_arg_types():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.Type)
-    tf = relay.FuncType(tvm.runtime.convert([tp1]), tp2, tvm.runtime.convert([tp1, tp2]), tvm.runtime.convert([]))
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.Type)
+    tf = relay.FuncType(
+        tvm.runtime.convert([tp1]), tp2, tvm.runtime.convert([tp1, tp2]), tvm.runtime.convert([])
+    )
 
     check_kind(tf)
 
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_func_with_invalid_tuple():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
 
     ret_type = relay.TupleType(tvm.runtime.convert([tp1, tp1, tp1]))
 
-    tf = relay.FuncType(tvm.runtime.convert([]), ret_type, tvm.runtime.convert([tp1]), tvm.runtime.convert([]))
+    tf = relay.FuncType(
+        tvm.runtime.convert([]), ret_type, tvm.runtime.convert([tp1]), tvm.runtime.convert([])
+    )
     check_kind(tf)
 
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_func_with_invalid_relation():
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.Type)
-    tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar)
-    tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.Type)
+    tp2 = relay.TypeVar("tp2", relay.TypeKind.ShapeVar)
+    tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint)
 
     func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
     tr = relay.TypeRelation(func, tvm.runtime.convert([tp2, tp3]), 1, None)
 
-    tf = relay.FuncType(tvm.runtime.convert([tp1]), tp1, tvm.runtime.convert([tp1, tp2, tp3]), tvm.runtime.convert([tr]))
+    tf = relay.FuncType(
+        tvm.runtime.convert([tp1]),
+        tp1,
+        tvm.runtime.convert([tp1, tp2, tp3]),
+        tvm.runtime.convert([tr]),
+    )
     check_kind(tf)
 
 
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_tuple_with_invalid_func():
-    tensor_type = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
+    tensor_type = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
 
-    tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar)
-    tf = relay.FuncType(tvm.runtime.convert([]), tp1, tvm.runtime.convert([tp1]), tvm.runtime.convert([]))
+    tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar)
+    tf = relay.FuncType(
+        tvm.runtime.convert([]), tp1, tvm.runtime.convert([tp1]), tvm.runtime.convert([])
+    )
 
     tup_ty = relay.TupleType(tvm.runtime.convert([tensor_type, tf]))
     check_kind(tup_ty)
index edede97..84fa40a 100644 (file)
@@ -28,8 +28,10 @@ def run_opt_pass(expr, opt_pass):
     mod = opt_pass(mod)
     return mod["main"]
 
+
 def test_combine_parallel_batch_matmul():
     """Simple testcase."""
+
     def before(x, w1, w2, w3):
         args = [x, w1, w2, w3]
         y1 = relay.nn.batch_matmul(x, w1)
@@ -46,21 +48,15 @@ def test_combine_parallel_batch_matmul():
         args = [x, w1, w2, w3]
         w = relay.concatenate((w1, w2, w3), axis=1)
         y = relay.nn.batch_matmul(x, w)
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0, 0],
-                                 end=[-1, -1, s1],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, 0, s1],
-                                 end=[-1, -1, s2],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
-        y3 = relay.strided_slice(y,
-                                 begin=[0, 0, s1+s2],
-                                 end=[-1, -1, s3],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=[0, 0, 0], end=[-1, -1, s1], strides=[1, 1, 1], slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=[0, 0, s1], end=[-1, -1, s2], strides=[1, 1, 1], slice_mode="size"
+        )
+        y3 = relay.strided_slice(
+            y, begin=[0, 0, s1 + s2], end=[-1, -1, s3], strides=[1, 1, 1], slice_mode="size"
+        )
         y = relay.Tuple((y1, y2, y3))
         return relay.Function(args, y)
 
@@ -71,8 +67,7 @@ def test_combine_parallel_batch_matmul():
         w3 = relay.var("w3", shape=(b, j, k))
 
         y_before = before(x, w1, w2, w3)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelBatchMatmul(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelBatchMatmul(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
@@ -80,8 +75,10 @@ def test_combine_parallel_batch_matmul():
     check(2, 3, 5, 4)
     check(1, 100, 200, 300)
 
+
 def test_combine_parallel_batch_matmul_biasadd():
     """Simple testcase with bias"""
+
     def before(x, w1, w2, w3, b1, b2, b3):
         args = [x, w1, w2, w3, b1, b2, b3]
         y1 = relay.nn.batch_matmul(x, w1)
@@ -103,21 +100,15 @@ def test_combine_parallel_batch_matmul_biasadd():
         b = relay.concatenate((b1, b2, b3), axis=-1)
         y = relay.nn.batch_matmul(x, w)
         y = relay.add(y, b)
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0, 0],
-                                 end=[-1, -1, s1],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, 0, s1],
-                                 end=[-1, -1, s2],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
-        y3 = relay.strided_slice(y,
-                                 begin=[0, 0, s1+s2],
-                                 end=[-1, -1, s3],
-                                 strides=[1, 1, 1],
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=[0, 0, 0], end=[-1, -1, s1], strides=[1, 1, 1], slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=[0, 0, s1], end=[-1, -1, s2], strides=[1, 1, 1], slice_mode="size"
+        )
+        y3 = relay.strided_slice(
+            y, begin=[0, 0, s1 + s2], end=[-1, -1, s3], strides=[1, 1, 1], slice_mode="size"
+        )
         y = relay.Tuple((y1, y2, y3))
         return relay.Function(args, y)
 
@@ -131,8 +122,7 @@ def test_combine_parallel_batch_matmul_biasadd():
         b3 = relay.var("b3", shape=(j,))
 
         y_before = before(x, w1, w2, w3, b1, b2, b3)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelBatchMatmul(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelBatchMatmul(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3, b1, b2, b3)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
index f48cdd6..fbc836c 100644 (file)
@@ -24,6 +24,7 @@ def run_combine_parallel(expr, min_num_branches=3):
     mod = transform.CombineParallelConv2D(min_num_branches)(mod)
     return mod["main"]
 
+
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
@@ -33,6 +34,7 @@ def run_opt_pass(expr, opt_pass):
 
 def test_combine_parallel_conv2d():
     """Simple testcase."""
+
     def before(x, w1, w2, w3, w4):
         args = [x, w1, w2, w3, w4]
         y1 = relay.nn.conv2d(x, w1)
@@ -49,22 +51,20 @@ def test_combine_parallel_conv2d():
         args = [x, w1, w2, w3, w4]
         w = relay.concatenate((w1, w2, w4), axis=0)
         y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0],
-                                 end=[-1, channels1],
-                                 strides=[1, 1],
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, channels1],
-                                 end=[-1, channels2],
-                                 strides=[1, 1],
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=[0, 0], end=[-1, channels1], strides=[1, 1], slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=[0, channels1], end=[-1, channels2], strides=[1, 1], slice_mode="size"
+        )
         y3 = relay.nn.conv2d(x, w3)
-        y4 = relay.strided_slice(y,
-                                 begin=[0, channels1 + channels2],
-                                 end=[-1, channels4],
-                                 strides=[1, 1],
-                                 slice_mode="size")
+        y4 = relay.strided_slice(
+            y,
+            begin=[0, channels1 + channels2],
+            end=[-1, channels4],
+            strides=[1, 1],
+            slice_mode="size",
+        )
         y5 = relay.nn.max_pool2d(x)
         y = relay.Tuple((y1, y2, y3, y4, y5))
         return relay.Function(args, y)
@@ -78,8 +78,7 @@ def test_combine_parallel_conv2d():
         w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))
 
         y_before = before(x, w1, w2, w3, w4)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelConv2D(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
@@ -90,6 +89,7 @@ def test_combine_parallel_conv2d():
 
 def test_combine_parallel_conv2d_scale_relu():
     """Testcase of combining conv2d + scale + relu"""
+
     def before(x, w1, w2, scale1, scale2, bias):
         args = [x, w1, w2, scale1, scale2, bias]
         y1 = relay.nn.conv2d(x, w1)
@@ -109,16 +109,12 @@ def test_combine_parallel_conv2d_scale_relu():
         y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
         y = relay.multiply(y, scale)
         y = relay.nn.relu(y)
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0],
-                                 end=[-1, channels1],
-                                 strides=[1, 1],
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, channels1],
-                                 end=[-1, channels2],
-                                 strides=[1, 1],
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=[0, 0], end=[-1, channels1], strides=[1, 1], slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=[0, channels1], end=[-1, channels2], strides=[1, 1], slice_mode="size"
+        )
         y2 = relay.add(y2, bias)
         y = relay.Tuple((y1, y2))
         return relay.Function(args, y)
@@ -132,8 +128,7 @@ def test_combine_parallel_conv2d_scale_relu():
         scale2 = relay.var("scale2", shape=(channels2, 1, 1))
         bias = relay.var("bias", shape=(channels2, 1, 1))
         y_before = before(x, w1, w2, scale1, scale2, bias)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelConv2D(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
@@ -143,6 +138,7 @@ def test_combine_parallel_conv2d_scale_relu():
 
 def test_combine_parallel_conv2d_scale():
     """Testcase of un-combinable scale"""
+
     def before(x, w1, w2, scale1, scale2):
         args = [x, w1, w2, scale1, scale2]
         y1 = relay.nn.conv2d(x, w1)
@@ -156,16 +152,12 @@ def test_combine_parallel_conv2d_scale():
         args = [x, w1, w2, scale1, scale2]
         w = relay.concatenate((w1, w2), axis=0)
         y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0],
-                                 end=[-1, channels1],
-                                 strides=[1, 1],
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, channels1],
-                                 end=[-1, channels2],
-                                 strides=[1, 1],
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=[0, 0], end=[-1, channels1], strides=[1, 1], slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=[0, channels1], end=[-1, channels2], strides=[1, 1], slice_mode="size"
+        )
         y1 = relay.multiply(y1, scale1)
         y2 = relay.multiply(y2, scale2)
         y = relay.Tuple((y1, y2))
@@ -179,8 +171,7 @@ def test_combine_parallel_conv2d_scale():
         scale1 = relay.var("scale1", shape=(1,))
         scale2 = relay.var("scale2", shape=(1,))
         y_before = before(x, w1, w2, scale1, scale2)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelConv2D(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
@@ -203,17 +194,13 @@ def test_combine_parallel_conv2d_multiple_blocks():
         y = x
         for i in range(repeat):
             w_concat = relay.concatenate((w, w), axis=0)
-            y = relay.nn.conv2d(y, w_concat, channels=channels*2)
-            y1 = relay.strided_slice(y,
-                                     begin=[0, 0],
-                                     end=[-1, channels],
-                                     strides=[1, 1],
-                                     slice_mode="size")
-            y2 = relay.strided_slice(y,
-                                     begin=[0, channels],
-                                     end=[-1, channels],
-                                     strides=[1, 1],
-                                     slice_mode="size")
+            y = relay.nn.conv2d(y, w_concat, channels=channels * 2)
+            y1 = relay.strided_slice(
+                y, begin=[0, 0], end=[-1, channels], strides=[1, 1], slice_mode="size"
+            )
+            y2 = relay.strided_slice(
+                y, begin=[0, channels], end=[-1, channels], strides=[1, 1], slice_mode="size"
+            )
             y = relay.concatenate((y1, y2), axis=1)
         return relay.Function(args, y)
 
@@ -223,8 +210,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
         out_c = in_c // 2
         w = relay.var("w", shape=(out_c, in_c, 1, 1))
         y_before = before(x, w, repeat)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelConv2D(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w, out_c, repeat)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
index 3c5cd9d..7cf8867 100644 (file)
@@ -25,6 +25,7 @@ def run_combine_parallel(expr, min_num_branches=3, to_batch=True):
     mod = transform.CombineParallelDense(min_num_branches, to_batch)(mod)
     return mod["main"]
 
+
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
@@ -34,6 +35,7 @@ def run_opt_pass(expr, opt_pass):
 
 def test_combine_parallel_dense():
     """Simple testcase. One dense cannot be combined due to shape mismatch"""
+
     def before(x, w1, w2, w3, w4):
         args = [x, w1, w2, w3, w4]
         y1 = relay.nn.dense(x, w1)
@@ -64,15 +66,14 @@ def test_combine_parallel_dense():
         return relay.Function(args, y)
 
     def check(i, j, k):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(j, k))
         w3 = relay.var("w3", shape=(j + 1, k))
         w4 = relay.var("w4", shape=(j, k))
 
         y_before = before(x, w1, w2, w3, w4)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelDense(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3, w4)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
@@ -83,6 +84,7 @@ def test_combine_parallel_dense():
 
 def test_combine_parallel_dense_biasadd():
     """Testcase of combining dense + 1d biasadd"""
+
     def before(x, w1, w2, b1, b2):
         args = [x, w1, w2, b1, b2]
         y1 = relay.nn.dense(x, w1)
@@ -111,7 +113,7 @@ def test_combine_parallel_dense_biasadd():
         return relay.Function(args, y)
 
     def check(i, j, k, is_2d_bias):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(j, k))
 
@@ -123,8 +125,7 @@ def test_combine_parallel_dense_biasadd():
             b2 = relay.var("b2", shape=(j,))
 
         y_before = before(x, w1, w2, b1, b2)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelDense(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
@@ -134,8 +135,10 @@ def test_combine_parallel_dense_biasadd():
     check(3, 5, 4, True)
     check(100, 200, 300, True)
 
+
 def test_combine_parallel_dense_biasadd_scale_reshape():
     """Testcase of combining dense + 1d biasadd + multiply with non-fused reshape"""
+
     def before(x, w1, w2, b1, b2, scale1, scale2, newshape):
         args = [x, w1, w2, b1, b2, scale1, scale2]
         y1 = relay.nn.dense(x, w1)
@@ -171,7 +174,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
         return relay.Function(args, y)
 
     def check(i, j, k, scale1, scale2, newshape):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(j, k))
         b1 = relay.var("b1", shape=(j,))
@@ -180,8 +183,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
         scale2 = relay.var("scale2", shape=(1,))
 
         y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape)
-        y = run_opt_pass(y_before,
-                         transform.CombineParallelDense(min_num_branches=2))
+        y = run_opt_pass(y_before, transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
@@ -192,6 +194,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
 
 def test_combine_parallel_dense_flat():
     """Simple testcase. All matmul of different output dim can be combined"""
+
     def before(x, w1, w2, w3):
         args = [x, w1, w2, w3]
         y1 = relay.nn.dense(x, w1)
@@ -205,30 +208,24 @@ def test_combine_parallel_dense_flat():
         w_stacked = relay.concatenate((w1, w2, w3), axis=0)
         y = relay.nn.dense(x, w_stacked, units=6 * j)
         strides = [1, 1]
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0],
-                                 end=[-1, j],
-                                 strides=strides, slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, j],
-                                 end=[-1, 2 * j],
-                                 strides=strides, slice_mode="size")
-        y3 = relay.strided_slice(y,
-                                 begin=[0, 3 * j],
-                                 end=[-1, 3 * j],
-                                 strides=strides, slice_mode="size")
+        y1 = relay.strided_slice(y, begin=[0, 0], end=[-1, j], strides=strides, slice_mode="size")
+        y2 = relay.strided_slice(
+            y, begin=[0, j], end=[-1, 2 * j], strides=strides, slice_mode="size"
+        )
+        y3 = relay.strided_slice(
+            y, begin=[0, 3 * j], end=[-1, 3 * j], strides=strides, slice_mode="size"
+        )
         y = relay.Tuple((y1, y2, y3))
         return relay.Function(args, y)
 
     def check(i, j, k):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(2 * j, k))
         w3 = relay.var("w3", shape=(3 * j, k))
 
         y_before = before(x, w1, w2, w3)
-        combine_pass = transform.CombineParallelDense(min_num_branches=3,
-                                                      to_batch=False) 
+        combine_pass = transform.CombineParallelDense(min_num_branches=3, to_batch=False)
         y = run_opt_pass(y_before, combine_pass)
         y_expected = expected(x, w1, w2, w3, j)
         y_expected = run_opt_pass(y_expected, transform.InferType())
@@ -240,6 +237,7 @@ def test_combine_parallel_dense_flat():
 
 def test_combine_parallel_dense_flat_biasadd():
     """Testcase of combining dense + 1d biasadd with different out dims"""
+
     def before(x, w1, w2, b1, b2):
         args = [x, w1, w2, b1, b2]
         y1 = relay.nn.dense(x, w1)
@@ -267,28 +265,23 @@ def test_combine_parallel_dense_flat_biasadd():
         begin = [0 for _ in range(n_out_dims - 1)]
         end = [-1 for _ in range(n_out_dims - 1)]
         strides = [1 for _ in range(n_out_dims)]
-        y1 = relay.strided_slice(y,
-                                 begin=begin + [0],
-                                 end=end + [j],
-                                 strides=strides,
-                                 slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=begin + [j],
-                                 end=end + [2 * j],
-                                 strides=strides,
-                                 slice_mode="size")
+        y1 = relay.strided_slice(
+            y, begin=begin + [0], end=end + [j], strides=strides, slice_mode="size"
+        )
+        y2 = relay.strided_slice(
+            y, begin=begin + [j], end=end + [2 * j], strides=strides, slice_mode="size"
+        )
         return relay.Function(args, relay.Tuple((y1, y2)))
 
     def check(i, j, k, bias_shape1, bias_shape2):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(2 * j, k))
         b1 = relay.var("b1", shape=bias_shape1)
         b2 = relay.var("b2", shape=bias_shape2)
-        
+
         y_before = before(x, w1, w2, b1, b2)
-        combine_pass = transform.CombineParallelDense(min_num_branches=2,
-                                                      to_batch=False) 
+        combine_pass = transform.CombineParallelDense(min_num_branches=2, to_batch=False)
         y = run_opt_pass(y_before, combine_pass)
         y_expected = expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
@@ -308,10 +301,12 @@ def test_combine_parallel_dense_flat_biasadd():
     check(3, 5, 4, (9, 3, 5), (9, 3, 1))
     check(3, 5, 4, (9, 3, 1), (9, 3, 10))
 
+
 def test_combine_parallel_dense_flat_biasadd_scale_reshape():
     """Testcase of combining dense with different out dims
-       following bias add, scale, reshape ops
+    following bias add, scale, reshape ops
     """
+
     def before(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2):
         args = [x, w1, w2, b1, b2, scale1, scale2]
         y1 = relay.nn.dense(x, w1)
@@ -328,7 +323,7 @@ def test_combine_parallel_dense_flat_biasadd_scale_reshape():
     def expected(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2, j):
         args = [x, w1, w2, b1, b2, scale1, scale2]
         w_stacked = relay.concatenate((w1, w2), axis=0)
-        y = relay.nn.dense(x, w_stacked, units=3*j)
+        y = relay.nn.dense(x, w_stacked, units=3 * j)
         b = relay.concatenate((b1, b2), axis=0)
         y = relay.add(y, b)
         scale1 = relay.repeat(scale1, j, 0)
@@ -336,21 +331,17 @@ def test_combine_parallel_dense_flat_biasadd_scale_reshape():
         scale = relay.concatenate((scale1, scale2), axis=0)
         y = relay.multiply(y, scale)
         strides = [1, 1]
-        y1 = relay.strided_slice(y,
-                                 begin=[0, 0],
-                                 end=[-1, j],
-                                 strides=strides, slice_mode="size")
-        y2 = relay.strided_slice(y,
-                                 begin=[0, j],
-                                 end=[-1, 2 * j],
-                                 strides=strides, slice_mode="size")
+        y1 = relay.strided_slice(y, begin=[0, 0], end=[-1, j], strides=strides, slice_mode="size")
+        y2 = relay.strided_slice(
+            y, begin=[0, j], end=[-1, 2 * j], strides=strides, slice_mode="size"
+        )
         y1 = relay.reshape(y1, newshape=newshape1)
         y2 = relay.reshape(y2, newshape=newshape2)
         y = relay.Tuple((y1, y2))
         return relay.Function(args, y)
 
     def check(i, j, k, scale1, scale2, newshape1, newshape2):
-        x =  relay.var("x", shape=(i, k))
+        x = relay.var("x", shape=(i, k))
         w1 = relay.var("w1", shape=(j, k))
         w2 = relay.var("w2", shape=(2 * j, k))
         b1 = relay.var("b1", shape=(j,))
@@ -358,13 +349,10 @@ def test_combine_parallel_dense_flat_biasadd_scale_reshape():
         scale1 = relay.var("scale1", shape=(1,))
         scale2 = relay.var("scale2", shape=(1,))
 
-        y_before = before(x, w1, w2, b1, b2, scale1, scale2,
-                          newshape1, newshape2)
-        combine_pass = transform.CombineParallelDense(min_num_branches=2,
-                                                      to_batch=False)
+        y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2)
+        combine_pass = transform.CombineParallelDense(min_num_branches=2, to_batch=False)
         y = run_opt_pass(y_before, combine_pass)
-        y_expected = expected(x, w1, w2, b1, b2, scale1, scale2,
-                              newshape1, newshape2, j)
+        y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2, j)
         y_expected = run_opt_pass(y_expected, transform.InferType())
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
 
index e71cfdc..e4771a0 100644 (file)
@@ -36,11 +36,8 @@ def run_opt_pass(expr, passes):
 def test_no_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -49,7 +46,7 @@ def test_no_convert_layout():
         return before()
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -58,33 +55,33 @@ def test_no_convert_layout():
 def test_conv_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight = relay.var('weight', shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight = relay.var('weight', shape=(3, 3, 64, 64))
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y = relay.layout_transform(y, 'NCHW', 'NHWC')
+        y = relay.layout_transform(y, "NCHW", "NHWC")
         y = relay.Function(relay.analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -93,35 +90,41 @@ def test_conv_convert_layout():
 def test_conv_nhwc_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NCHW',
-                            kernel_layout='OIHW')
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        x = relay.layout_transform(x, 'NCHW', 'NHWC')
-        weight = relay.layout_transform(weight, 'OIHW', 'HWIO')
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NHWC",
-                            kernel_layout="HWIO")
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        x = relay.layout_transform(x, "NCHW", "NHWC")
+        weight = relay.layout_transform(weight, "OIHW", "HWIO")
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.relu(y)
-        y = relay.layout_transform(y, 'NHWC', 'NCHW')
+        y = relay.layout_transform(y, "NHWC", "NCHW")
         y = relay.Function(relay.analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -130,33 +133,33 @@ def test_conv_nhwc_convert_layout():
 def test_conv_transpose_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight = relay.var('weight', shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d_transpose(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d_transpose(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight = relay.var('weight', shape=(3, 3, 64, 64))
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
-        y = relay.nn.conv2d_transpose(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        y = relay.nn.conv2d_transpose(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y = relay.layout_transform(y, 'NCHW', 'NHWC')
+        y = relay.layout_transform(y, "NCHW", "NHWC")
         y = relay.Function(relay.analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d_transpose': ['NCHW', 'OIHW']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d_transpose": ["NCHW", "OIHW"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -167,14 +170,21 @@ def test_conv_bias_pool_convert_layout():
         x = relay.var("x", shape=(1, 56, 56, 64))
         bias = relay.var("bias", shape=(64,))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout='NHWC', kernel_layout='HWIO')
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.bias_add(y, bias, axis=3)
         # a useless tuple, which will be eliminated
         y = relay.Tuple([y])[0]
         y = relay.nn.relu(y)
-        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout='NHWC')
-        y = relay.cast(y, 'int32')
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC")
+        y = relay.cast(y, "int32")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(analysis.free_vars(y), y)
         return y
@@ -183,25 +193,25 @@ def test_conv_bias_pool_convert_layout():
         x = relay.var("x", shape=(1, 56, 56, 64))
         bias = relay.var("bias", shape=(64,))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
         y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
 
         bias = relay.expand_dims(bias, axis=0, num_newaxis=3)
-        bias = relay.layout_transform(bias, 'NHWC', 'NCHW')
+        bias = relay.layout_transform(bias, "NHWC", "NCHW")
         y = relay.add(y, bias)
         # a useless tuple, which will be eliminated
         y = relay.Tuple([y])[0]
         y = relay.nn.relu(y)
         y = relay.nn.max_pool2d(y, pool_size=(2, 2))
-        y = relay.cast(y, 'int32')
-        y = relay.layout_transform(y, 'NCHW', 'NHWC')
+        y = relay.cast(y, "int32")
+        y = relay.layout_transform(y, "NCHW", "NHWC")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -210,46 +220,46 @@ def test_conv_bias_pool_convert_layout():
 def test_conv_concat_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight1,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=64,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NHWC',
-                             kernel_layout='HWIO')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64))
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y1 = relay.nn.conv2d(
+            y,
+            weight2,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         ret = relay.concatenate([y, y1], axis=3)
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
-        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
-        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64))
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64))
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
         y = relay.layout_transform(x, "NHWC", "NCHW")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=64,
-                             kernel_size=(3, 3),
-                             padding=(1, 1))
+        y = relay.nn.conv2d(y, weight1, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        y1 = relay.nn.conv2d(y, weight2, channels=64, kernel_size=(3, 3), padding=(1, 1))
         ret = relay.concatenate([y, y1], axis=1)
         ret = relay.layout_transform(ret, "NCHW", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -258,21 +268,27 @@ def test_conv_concat_convert_layout():
 def test_dual_path_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-        weight2 = relay.var('weight2', shape=(3, 3, 32, 32))
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(3, 3, 32, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.relu(y)
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout='NHWC',
-                             kernel_layout='HWIO')
+        y1 = relay.nn.conv2d(
+            y,
+            weight2,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.batch_flatten(y)
         ret = relay.Tuple([y1, y2])
@@ -281,20 +297,14 @@ def test_dual_path_convert_layout():
 
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-        weight2 = relay.var('weight2', shape=(3, 3, 32, 32))
-        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
-        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(3, 3, 32, 32))
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
         y = relay.layout_transform(x, "NHWC", "NCHW")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        y = relay.nn.conv2d(y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y1 = relay.nn.conv2d(y, weight2,
-                             channels=32,
-                             kernel_size=(3, 3),
-                             padding=(1, 1))
+        y1 = relay.nn.conv2d(y, weight2, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y1 = relay.nn.relu(y1)
         y1 = relay.layout_transform(y1, "NCHW", "NHWC")
         y2 = relay.layout_transform(y, "NCHW", "NHWC")
@@ -304,7 +314,7 @@ def test_dual_path_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -313,28 +323,34 @@ def test_dual_path_convert_layout():
 def test_bn_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         gamma = relay.var("gamma")
         beta = relay.var("beta")
         mean = relay.var("mean")
         variance = relay.var("variance")
-        y, _, _ = relay.nn.batch_norm(y , gamma, beta, mean, variance, axis=3)
+        y, _, _ = relay.nn.batch_norm(y, gamma, beta, mean, variance, axis=3)
         return relay.Function(analysis.free_vars(y), y)
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
 
     # Check that there is only 1 NHWC to NCHW transform.
     has_lt = list()
-    find_op = lambda x : \
-            has_lt.append(isinstance(x, tvm.relay.expr.Call) and x.op.name == "layout_transform" \
-            and x.attrs.src_layout == 'NCHW' and x.attrs.dst_layout == 'NHWC')
+    find_op = lambda x: has_lt.append(
+        isinstance(x, tvm.relay.expr.Call)
+        and x.op.name == "layout_transform"
+        and x.attrs.src_layout == "NCHW"
+        and x.attrs.dst_layout == "NHWC"
+    )
     relay.analysis.post_order_visit(a, find_op)
     has_lt = list(filter(lambda x: x, has_lt))
     assert len(has_lt) == 1
@@ -343,40 +359,36 @@ def test_bn_convert_layout():
 def test_resnet_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-        weight2 = relay.var('weight2', shape=(1, 1, 64, 32))
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO')
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(1, 1, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.nn.relu(y)
-        y2 = relay.nn.conv2d(x, weight2,
-                             channels=32,
-                             kernel_size=(1, 1),
-                             data_layout='NHWC',
-                             kernel_layout='HWIO')
+        y2 = relay.nn.conv2d(
+            x, weight2, channels=32, kernel_size=(1, 1), data_layout="NHWC", kernel_layout="HWIO"
+        )
         y2 = relay.nn.relu(y2)
         y = y + y2
-        y = relay.nn.global_max_pool2d(y, layout='NHWC')
+        y = relay.nn.global_max_pool2d(y, layout="NHWC")
         return relay.Function(analysis.free_vars(y), y)
 
     def expected():
-        x = relay.var("x", shape=(1,56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
-        weight2 = relay.var('weight2', shape=(1, 1, 64, 32))
-        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
-        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(1, 1, 64, 32))
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
         x = relay.layout_transform(x, "NHWC", "NCHW")
-        y = relay.nn.conv2d(x, weight1,
-                            channels=32,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
-        y2 = relay.nn.conv2d(x, weight2,
-                             channels=32,
-                             kernel_size=(1, 1))
+        y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1))
         y2 = relay.nn.relu(y2)
         y = y + y2
         y = relay.nn.global_max_pool2d(y)
@@ -384,7 +396,7 @@ def test_resnet_convert_layout():
         return relay.Function(analysis.free_vars(y), y)
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -394,8 +406,15 @@ def test_scalar_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout='NHWC', kernel_layout='HWIO')
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.add(y, relay.const(1, "float32"))
         y = relay.Function(analysis.free_vars(y), y)
         return y
@@ -403,12 +422,9 @@ def test_scalar_convert_layout():
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
         w = relay.var("weight", shape=(3, 3, 64, 64))
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        w = relay.layout_transform(w, 'HWIO', 'OIHW')
-        y = relay.nn.conv2d(x, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        w = relay.layout_transform(w, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.add(y, relay.const(1.0, "float32"))
 
         y = relay.layout_transform(y, "NCHW", "NHWC")
@@ -416,7 +432,7 @@ def test_scalar_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -424,11 +440,19 @@ def test_scalar_convert_layout():
 
 def test_conv_bn_convert_layout():
     """ Check that layout transforms are propagated through bn. """
+
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout='NHWC', kernel_layout='HWIO')
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
 
         dtype = "float32"
         beta = relay.var("beta", relay.TensorType((64,), dtype))
@@ -444,12 +468,9 @@ def test_conv_bn_convert_layout():
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
         w = relay.var("weight", shape=(3, 3, 64, 64))
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        w = relay.layout_transform(w, 'HWIO', 'OIHW')
-        y = relay.nn.conv2d(x, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        w = relay.layout_transform(w, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
 
         dtype = "float32"
         beta = relay.var("beta", relay.TensorType((64,), dtype))
@@ -464,7 +485,7 @@ def test_conv_bn_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -472,55 +493,65 @@ def test_conv_bn_convert_layout():
 
 def test_qnn_conv_requantize_convert_layout():
     def before():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight = relay.var('weight', shape=(3, 3, 64, 64), dtype='int8')
-        y = relay.qnn.op.conv2d(x, weight,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1),
-                                data_layout='NHWC',
-                                kernel_layout='HWIO')
-        y = relay.qnn.op.requantize(y,
-                                    relay.const(1, 'float32'),
-                                    relay.const(1, 'int32'),
-                                    relay.const(1, 'float32'),
-                                    relay.const(1, 'int32'),
-                                    out_dtype='int32')
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
+        y = relay.qnn.op.conv2d(
+            x,
+            weight,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.qnn.op.requantize(
+            y,
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            out_dtype="int32",
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
 
     def expected():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight = relay.var('weight', shape=(3, 3, 64, 64), dtype='int8')
-        x = relay.layout_transform(x, 'NHWC', 'NCHW')
-        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
-        y = relay.qnn.op.conv2d(x, weight,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1))
-        y = relay.qnn.op.requantize(y,
-                                    relay.const(1, 'float32'),
-                                    relay.const(1, 'int32'),
-                                    relay.const(1, 'float32'),
-                                    relay.const(1, 'int32'),
-                                    axis=1,
-                                    out_dtype='int32')
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        y = relay.qnn.op.conv2d(
+            x,
+            weight,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
+        y = relay.qnn.op.requantize(
+            y,
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            axis=1,
+            out_dtype="int32",
+        )
         y = relay.nn.relu(y)
-        y = relay.layout_transform(y, 'NCHW', 'NHWC')
+        y = relay.layout_transform(y, "NCHW", "NHWC")
         y = relay.Function(relay.analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -528,77 +559,93 @@ def test_qnn_conv_requantize_convert_layout():
 
 def test_qnn_conv_concat_convert_layout():
     def before():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64), dtype='int8')
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64), dtype='int8')
-        y = relay.qnn.op.conv2d(x, weight1,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1),
-                                data_layout='NHWC',
-                                kernel_layout='HWIO')
-        y1 = relay.qnn.op.conv2d(y, weight2,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1),
-                                data_layout='NHWC',
-                                kernel_layout='HWIO')
-        y = relay.cast(y, 'int8')
-        y1 = relay.cast(y, 'int8')
-        ret = relay.qnn.op.concatenate([y, y1],
-                                       [relay.const(1, 'float32'), relay.const(1, 'float32')],
-                                       [relay.const(1, 'int32'), relay.const(1, 'int32')],
-                                       relay.const(1, 'float32'),
-                                       relay.const(1, 'int32'),
-                                       axis=3)
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64), dtype="int8")
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64), dtype="int8")
+        y = relay.qnn.op.conv2d(
+            x,
+            weight1,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y1 = relay.qnn.op.conv2d(
+            y,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.cast(y, "int8")
+        y1 = relay.cast(y, "int8")
+        ret = relay.qnn.op.concatenate(
+            [y, y1],
+            [relay.const(1, "float32"), relay.const(1, "float32")],
+            [relay.const(1, "int32"), relay.const(1, "int32")],
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            axis=3,
+        )
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64), dtype='int8')
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64), dtype='int8')
-        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
-        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64), dtype="int8")
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64), dtype="int8")
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
         y = relay.layout_transform(x, "NHWC", "NCHW")
-        y = relay.qnn.op.conv2d(y, weight1,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1))
-        y1 = relay.qnn.op.conv2d(y, weight2,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1))
-        y = relay.cast(y, 'int8')
-        y1 = relay.cast(y, 'int8')
-        ret = relay.qnn.op.concatenate([y, y1],
-                                      [relay.const(1, 'float32'), relay.const(1, 'float32')],
-                                      [relay.const(1, 'int32'), relay.const(1, 'int32')],
-                                      relay.const(1, 'float32'),
-                                      relay.const(1, 'int32'),
-                                      axis=1)
+        y = relay.qnn.op.conv2d(
+            y,
+            weight1,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
+        y1 = relay.qnn.op.conv2d(
+            y,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
+        y = relay.cast(y, "int8")
+        y1 = relay.cast(y, "int8")
+        ret = relay.qnn.op.concatenate(
+            [y, y1],
+            [relay.const(1, "float32"), relay.const(1, "float32")],
+            [relay.const(1, "int32"), relay.const(1, "int32")],
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            axis=1,
+        )
         ret = relay.layout_transform(ret, "NCHW", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -606,79 +653,97 @@ def test_qnn_conv_concat_convert_layout():
 
 def test_qnn_conv_add_convert_layout():
     def before():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64), dtype='int8')
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64), dtype='int8')
-        y = relay.qnn.op.conv2d(x, weight1,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1),
-                                data_layout='NHWC',
-                                kernel_layout='HWIO')
-        y1 = relay.qnn.op.conv2d(y, weight2,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1),
-                                data_layout='NHWC',
-                                kernel_layout='HWIO')
-        y = relay.cast(y, 'int8')
-        y1 = relay.cast(y, 'int8')
-        ret = relay.qnn.op.add(y, y1,
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'),
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'),
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'))
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64), dtype="int8")
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64), dtype="int8")
+        y = relay.qnn.op.conv2d(
+            x,
+            weight1,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y1 = relay.qnn.op.conv2d(
+            y,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.cast(y, "int8")
+        y1 = relay.cast(y, "int8")
+        ret = relay.qnn.op.add(
+            y,
+            y1,
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+        )
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     def expected():
-        x = relay.var("x", shape=(1, 56, 56, 64), dtype='int8')
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64), dtype='int8')
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64), dtype='int8')
-        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
-        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
+        x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 64), dtype="int8")
+        weight2 = relay.var("weight2", shape=(3, 3, 64, 64), dtype="int8")
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
         y = relay.layout_transform(x, "NHWC", "NCHW")
-        y = relay.qnn.op.conv2d(y, weight1,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1))
-        y1 = relay.qnn.op.conv2d(y, weight2,
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'int32'),
-                                relay.const(1, 'float32'),
-                                relay.const(1, 'float32'),
-                                channels=64,
-                                kernel_size=(3, 3),
-                                padding=(1, 1))
-        y = relay.cast(y, 'int8')
-        y1 = relay.cast(y, 'int8')
-        ret = relay.qnn.op.add(y, y1,
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'),
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'),
-                               relay.const(1, 'float32'),
-                               relay.const(1, 'int32'))
+        y = relay.qnn.op.conv2d(
+            y,
+            weight1,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
+        y1 = relay.qnn.op.conv2d(
+            y,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
+        y = relay.cast(y, "int8")
+        y1 = relay.cast(y, "int8")
+        ret = relay.qnn.op.add(
+            y,
+            y1,
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "int32"),
+        )
         ret = relay.layout_transform(ret, "NCHW", "NHWC")
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -686,29 +751,40 @@ def test_qnn_conv_add_convert_layout():
 
 def test_conv_convert_kernel_layout():
     """ Check that convolution kernel layout is correctly transformed. """
+
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout='NHWC', kernel_layout='HWIO')
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 56, 56, 64))
         w = relay.var("weight", shape=(3, 3, 64, 64))
-        w = relay.layout_transform(w, 'HWIO', 'OHWI')
-        y = relay.nn.conv2d(x, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='OHWI')
+        w = relay.layout_transform(w, "HWIO", "OHWI")
+        y = relay.nn.conv2d(
+            x,
+            w,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="OHWI",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'OHWI']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -716,113 +792,145 @@ def test_conv_convert_kernel_layout():
 
 def test_default_keyword():
     """ Check that the default keyword selects correct TVM default layout. """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight = relay.var("weight", shape=(64, 3, 3, 64))
-        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
-                            data_layout='NCHW', kernel_layout='OHWI')
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OHWI",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
         w = relay.var("weight", shape=(64, 3, 3, 64))
-        w = relay.layout_transform(w, 'OHWI', 'OIHW')
-        y = relay.nn.conv2d(x, w,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout='NCHW',
-                            kernel_layout='OIHW')
+        w = relay.layout_transform(w, "OHWI", "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            w,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_different_ops_convert_layout():
-    """ Check convert layout correctly supports converting the layout of
+    """Check convert layout correctly supports converting the layout of
     different ops in the same graph.
     """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight1 = relay.var("weight1", shape=(64, 3, 3, 64))
-        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8')
+        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype="int8")
         weight3 = relay.var("weight3", shape=(64, 3, 3, 64))
-        out = relay.nn.conv2d(x, weight1,
-                              channels=64,
-                              kernel_size=(3, 3),
-                              padding=(1, 1),
-                              data_layout='NCHW',
-                              kernel_layout='OHWI')
-        out = relay.cast(out, 'int8')
-        out = relay.qnn.op.conv2d(out, weight2,
-                                  relay.const(1, 'int32'),
-                                  relay.const(1, 'int32'),
-                                  relay.const(1, 'float32'),
-                                  relay.const(1, 'float32'),
-                                  channels=64,
-                                  kernel_size=(3, 3),
-                                  padding=(1, 1),
-                                  data_layout='NCHW',
-                                  kernel_layout='OHWI')
-        out = relay.cast(out, 'float32')
-        out = relay.nn.conv2d_transpose(out, weight3,
-                              channels=64,
-                              kernel_size=(3, 3),
-                              padding=(1, 1),
-                              data_layout='NCHW',
-                              kernel_layout='OHWI')
+        out = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OHWI",
+        )
+        out = relay.cast(out, "int8")
+        out = relay.qnn.op.conv2d(
+            out,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OHWI",
+        )
+        out = relay.cast(out, "float32")
+        out = relay.nn.conv2d_transpose(
+            out,
+            weight3,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OHWI",
+        )
         out = relay.Function(analysis.free_vars(out), out)
         return out
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight1 = relay.var("weight1", shape=(64, 3, 3, 64))
-        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8')
+        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype="int8")
         weight3 = relay.var("weight3", shape=(64, 3, 3, 64))
-        x = relay.layout_transform(x, 'NCHW', 'NHWC')
-        weight1 = relay.layout_transform(weight1, 'OHWI', 'HWIO')
-        out = relay.nn.conv2d(x, weight1,
-                              channels=64,
-                              kernel_size=(3, 3),
-                              padding=(1, 1),
-                              data_layout='NHWC',
-                              kernel_layout='HWIO')
-        out = relay.cast(out, 'int8')
-        out = relay.layout_transform(out, 'NHWC', 'NCHW')
-        weight2 = relay.layout_transform(weight2, 'OHWI', 'OIHW')
-        out = relay.qnn.op.conv2d(out, weight2,
-                                  relay.const(1, 'int32'),
-                                  relay.const(1, 'int32'),
-                                  relay.const(1, 'float32'),
-                                  relay.const(1, 'float32'),
-                                  channels=64,
-                                  kernel_size=(3, 3),
-                                  padding=(1, 1),
-                                  data_layout='NCHW',
-                                  kernel_layout='OIHW')
-        out = relay.cast(out, 'float32')
-        out = relay.layout_transform(out, 'NCHW', 'NHWC')
-        weight3 = relay.layout_transform(weight3, 'OHWI', 'HWIO')
-        out = relay.nn.conv2d_transpose(out, weight3,
-                              channels=64,
-                              kernel_size=(3, 3),
-                              padding=(1, 1),
-                              data_layout='NHWC',
-                              kernel_layout='HWIO')
-        out = relay.layout_transform(out, 'NHWC', 'NCHW')
+        x = relay.layout_transform(x, "NCHW", "NHWC")
+        weight1 = relay.layout_transform(weight1, "OHWI", "HWIO")
+        out = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        out = relay.cast(out, "int8")
+        out = relay.layout_transform(out, "NHWC", "NCHW")
+        weight2 = relay.layout_transform(weight2, "OHWI", "OIHW")
+        out = relay.qnn.op.conv2d(
+            out,
+            weight2,
+            relay.const(1, "int32"),
+            relay.const(1, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        out = relay.cast(out, "float32")
+        out = relay.layout_transform(out, "NCHW", "NHWC")
+        weight3 = relay.layout_transform(weight3, "OHWI", "HWIO")
+        out = relay.nn.conv2d_transpose(
+            out,
+            weight3,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        out = relay.layout_transform(out, "NHWC", "NCHW")
         out = relay.Function(analysis.free_vars(out), out)
         return out
 
     a = before()
-    desired_layouts = {'nn.conv2d': ['NHWC', 'HWIO'],
-                       'qnn.conv2d': ['NCHW', 'OIHW'],
-                       'nn.conv2d_transpose': ['NHWC', 'HWIO'],}
+    desired_layouts = {
+        "nn.conv2d": ["NHWC", "HWIO"],
+        "qnn.conv2d": ["NCHW", "OIHW"],
+        "nn.conv2d_transpose": ["NHWC", "HWIO"],
+    }
     a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
     b = run_opt_pass(expected(), transform.InferType())
 
index 35fd444..6da6c3e 100644 (file)
@@ -24,6 +24,7 @@ from tvm.relay.testing import inception_v3
 
 import pytest
 
+
 class env:
     def __init__(self):
         self.shape = tvm.runtime.convert([1, 2, 3])
@@ -66,6 +67,7 @@ def test_used_let():
     expected = relay.Let(e.c, e.one, e.c + e.c)
     assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
 
+
 def test_inline():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
     orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
@@ -82,13 +84,13 @@ def use_f(func):
     f = relay.Var("f")
     n = relay.Var("n", e.int32)
     data = relay.Var("data", e.float32)
-    funcbody = relay.If(equal(n, relay.const(0)),
-                        data,
-                        relay.Call(f, [subtract(n, relay.const(1)),
-                                       log(data)]))
+    funcbody = relay.If(
+        equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1)), log(data)])
+    )
     value = relay.Function([n, data], funcbody, e.float32, [])
     return relay.Let(f, value, func(f))
 
+
 # make sure we dont infinite loop
 def test_recursion():
     """
@@ -107,6 +109,7 @@ def test_recursion():
     orig = run_opt_pass(orig, transform.InferType())
     tvm.ir.assert_structural_equal(dced, orig)
 
+
 def test_recursion_dead():
     x = relay.Let(e.a, e.one, e.three)
     dced_f = lambda f: x
@@ -115,15 +118,14 @@ def test_recursion_dead():
 
 
 def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
-                        transform.DeadCodeElimination())
+    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
     assert tvm.ir.structural_equal(dced, add(e.three, e.two))
 
 
 def test_tuple_get_item():
     tt = relay.TupleType([e.float32, e.float32])
-    t = relay.Var('t', tt)
-    a = relay.Var('a')
+    t = relay.Var("t", tt)
+    a = relay.Var("a")
     g = relay.TupleGetItem(t, 0)
     dced = run_opt_pass(g, transform.DeadCodeElimination())
     assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
@@ -134,7 +136,7 @@ def test_tuple_get_item():
 
 @pytest.mark.timeout(timeout=10, method="thread")
 def test_complexity():
-    g = inception_v3.get_net(1, 1000, (3, 299, 299), 'float32')
+    g = inception_v3.get_net(1, 1000, (3, 299, 299), "float32")
     run_opt_pass(g, transform.DeadCodeElimination())
 
 
index ac54ebf..4097eb7 100644 (file)
@@ -25,108 +25,113 @@ from tvm.relay.testing import Prelude
 
 # determine if type t is a FuncType or has a nested FuncType
 def has_func_type(t):
-  class FuncTypeVisitor(TypeVisitor):
-    def __init__(self):
-      super().__init__()
-      self.has_func = False
+    class FuncTypeVisitor(TypeVisitor):
+        def __init__(self):
+            super().__init__()
+            self.has_func = False
 
-    def visit_func_type(self, ftt):
-      self.has_func = True
+        def visit_func_type(self, ftt):
+            self.has_func = True
+
+    ftvisitor = FuncTypeVisitor()
+    ftvisitor.visit(t)
+    return ftvisitor.has_func
 
-  ftvisitor = FuncTypeVisitor()
-  ftvisitor.visit(t)
-  return ftvisitor.has_func
 
 # determine whether a program has any higher order functions
 # a higher order function is defined as one that:
 # - has function type arguments
 # - returns a function
 def assert_no_higher_order_functions(expr, mod):
-  class CheckFirstOrderVisitor(ExprVisitor):
-    def __init__(self, mod):
-      super().__init__()
-      self.mod = mod
-      self.hof = []
-      self.visited_gv = set()
-    
-    def visit_call(self, call):
-      is_higher_order = False
-      # check return type
-      if (has_func_type(call.checked_type)):
-        is_higher_order = True
-      # check argument types
-      for a in call.args:
-        if (has_func_type(a.checked_type)):
-          is_higher_order = True
-      # if it is higher order, save it for debugging later
-      if is_higher_order:
-        self.hof.append(call)
-      super().visit_call(call)
-
-    def visit_global_var(self, gv):
-      # visit global vars to visit entire program
-      if gv not in self.visited_gv:
-        self.visited_gv.add(gv)
-        self.visit(self.mod[gv])
-
-  mod = transform.InferType()(mod)
-  check_fo_visitor = CheckFirstOrderVisitor(mod)
-  check_fo_visitor.visit(expr)
-
-  nl = '\n--------\n'
-  errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions:
+    class CheckFirstOrderVisitor(ExprVisitor):
+        def __init__(self, mod):
+            super().__init__()
+            self.mod = mod
+            self.hof = []
+            self.visited_gv = set()
+
+        def visit_call(self, call):
+            is_higher_order = False
+            # check return type
+            if has_func_type(call.checked_type):
+                is_higher_order = True
+            # check argument types
+            for a in call.args:
+                if has_func_type(a.checked_type):
+                    is_higher_order = True
+            # if it is higher order, save it for debugging later
+            if is_higher_order:
+                self.hof.append(call)
+            super().visit_call(call)
+
+        def visit_global_var(self, gv):
+            # visit global vars to visit entire program
+            if gv not in self.visited_gv:
+                self.visited_gv.add(gv)
+                self.visit(self.mod[gv])
+
+    mod = transform.InferType()(mod)
+    check_fo_visitor = CheckFirstOrderVisitor(mod)
+    check_fo_visitor.visit(expr)
+
+    nl = "\n--------\n"
+    errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions:
   {nl.join(expr.astext() for expr in check_fo_visitor.hof)}"""
 
-  assert len(check_fo_visitor.hof) == 0, errmsg
+    assert len(check_fo_visitor.hof) == 0, errmsg
+
 
 # assert that a program is defunctionalized and returns
 # defunctionalized module
 # assumes program starts from mod['main']
 def defunctionalized(mod):
-  mod = transform.InferType()(mod)
-  mod['main'] = transform.Defunctionalization(mod['main'], mod)
-  mod = transform.InferType()(mod)
-  assert_no_higher_order_functions(mod['main'], mod)
+    mod = transform.InferType()(mod)
+    mod["main"] = transform.Defunctionalization(mod["main"], mod)
+    mod = transform.InferType()(mod)
+    assert_no_higher_order_functions(mod["main"], mod)
+
+    return mod
 
-  return mod
 
 # adt list to python list
 def to_list(mod, l):
-  list = mod.get_global_type_var('List')
-  list_adt = mod[list]
-  cons = list_adt.constructors[0]
-  nil = list_adt.constructors[1]
-
-  assert isinstance(l, ConstructorValue)
-  val = l
-  ret = []
-  while True:
-      if val.tag == cons.tag:
-          ret.append(val.fields[0].asnumpy())
-          val = val.fields[1]
-      else:
-          assert val.tag == nil.tag
-          break
-  return ret
+    list = mod.get_global_type_var("List")
+    list_adt = mod[list]
+    cons = list_adt.constructors[0]
+    nil = list_adt.constructors[1]
+
+    assert isinstance(l, ConstructorValue)
+    val = l
+    ret = []
+    while True:
+        if val.tag == cons.tag:
+            ret.append(val.fields[0].asnumpy())
+            val = val.fields[1]
+        else:
+            assert val.tag == nil.tag
+            break
+    return ret
+
 
 # list to adt list
 def to_adt_list(mod, arr):
-  expr = mod['main']
-  l = mod.get_global_type_var('List')
-  list_adt = mod[l]
-  cons = list_adt.constructors[0]
-  nil = list_adt.constructors[1]
-
-  li = nil()
-  for a in arr:
-    li = cons(relay.const(a), li)
-  ex = relay.create_executor(mod=mod)
-  adt = ex.evaluate(li)
-  mod['main'] = expr
-  return adt
+    expr = mod["main"]
+    l = mod.get_global_type_var("List")
+    list_adt = mod[l]
+    cons = list_adt.constructors[0]
+    nil = list_adt.constructors[1]
+
+    li = nil()
+    for a in arr:
+        li = cons(relay.const(a), li)
+    ex = relay.create_executor(mod=mod)
+    adt = ex.evaluate(li)
+    mod["main"] = expr
+    return adt
+
 
 def test_simple():
-  code = """
+    code = """
 #[version = "0.0.5"]
 def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B {
   %f(%xs)
@@ -138,22 +143,22 @@ def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] {
   @simple(%0, %l)
 }
 """
-  mod = tvm.parser.fromtext(code)
-  defunc_mod = defunctionalized(mod)
+    mod = tvm.parser.fromtext(code)
+    defunc_mod = defunctionalized(mod)
+
+    input = np.random.rand(5, 5).astype("float32")
 
-  input = np.random.rand(5,5).astype('float32')
+    ex = relay.create_executor("debug", mod=mod)
+    defunc_ex = relay.create_executor("debug", mod=defunc_mod)
 
-  ex = relay.create_executor('debug', mod=mod)
-  defunc_ex = relay.create_executor('debug', mod=defunc_mod)
+    out = ex.evaluate()(input)
+    defunc_out = defunc_ex.evaluate()(input)
 
-  out = ex.evaluate()(input)
-  defunc_out = defunc_ex.evaluate()(input)
+    np.testing.assert_equal(out.asnumpy(), defunc_out.asnumpy())
 
-  np.testing.assert_equal(out.asnumpy(), defunc_out.asnumpy())
-  
 
 def test_global_recursion():
-  code = """
+    code = """
 #[version = "0.0.5"]
 type List[A] {
   Cons(A, List[A]),
@@ -172,22 +177,23 @@ def @main(%l: List[float32]) -> List[float32] {
   @map(@id, %l)
 }
 """
-  mod = tvm.parser.fromtext(code)
-  defunc_mod = defunctionalized(mod)
+    mod = tvm.parser.fromtext(code)
+    defunc_mod = defunctionalized(mod)
+
+    input = np.random.rand(10).astype("float32")
 
-  input = np.random.rand(10).astype('float32')
-  
-  ex = relay.create_executor('debug', mod=mod)
-  defunc_ex = relay.create_executor('debug', mod=defunc_mod)
+    ex = relay.create_executor("debug", mod=mod)
+    defunc_ex = relay.create_executor("debug", mod=defunc_mod)
 
-  out = ex.evaluate(mod['main'])(to_adt_list(mod, input))
-  defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input))
+    out = ex.evaluate(mod["main"])(to_adt_list(mod, input))
+    defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input))
+
+    np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out))
 
-  np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out))
 
 def test_recursive_datatype():
-  # CPS will create recursive datatype
-  code = """
+    # CPS will create recursive datatype
+    code = """
 #[version = "0.0.5"]
 type List[A] {
   Cons(A, List[A]),
@@ -209,18 +215,19 @@ def @main(%l: List[int32]) -> int32 {
   @sum(@id, %l)
 }
 """
-  mod = tvm.parser.fromtext(code)
-  defunc_mod = defunctionalized(mod)
+    mod = tvm.parser.fromtext(code)
+    defunc_mod = defunctionalized(mod)
+
+    input = np.random.randint(1, 100, 10)
 
-  input = np.random.randint(1, 100, 10)
+    ex = relay.create_executor("debug", mod=mod)
+    defunc_ex = relay.create_executor("debug", mod=defunc_mod)
 
-  ex = relay.create_executor('debug', mod=mod)
-  defunc_ex = relay.create_executor('debug', mod=defunc_mod)
+    out = ex.evaluate(mod["main"])(to_adt_list(mod, input))
+    defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input))
 
-  out = ex.evaluate(mod['main'])(to_adt_list(mod, input))
-  defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input))
+    tvm.testing.assert_allclose(out.asnumpy(), defunc_out.asnumpy())
 
-  tvm.testing.assert_allclose(out.asnumpy(), defunc_out.asnumpy())
 
 if __name__ == "__main__":
-  pytest.main([__file__])
\ No newline at end of file
+    pytest.main([__file__])
index 210dfc8..ba3d279 100644 (file)
@@ -51,8 +51,7 @@ def test_dynamic_to_static_reshape():
         y = relay.var("y", relay.TensorType(newshape, "float32"))
         z = relay.reshape(x, relay.shape_of(y))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -77,8 +76,7 @@ def test_dynamic_to_static_double_reshape():
         z = relay.reshape(x, relay.shape_of(y))
         z = relay.reshape(z, relay.shape_of(x))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -104,8 +102,7 @@ def test_dynamic_to_static_quad_reshape():
         z3 = relay.reshape(z2, relay.shape_of(z1))
         z4 = relay.reshape(z3, relay.shape_of(z2))
         func = run_infer_type(relay.Function([x, y], z4))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -128,8 +125,7 @@ def test_dynamic_to_static_tile():
         y = relay.var("y", relay.TensorType(reps, "float32"))
         z = relay.tile(x, relay.shape_of(y))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -174,14 +170,14 @@ def test_dynamic_to_static_topk():
                 np_values[i, :] = np_data[i, np_indices[i, :]]
         np_indices = np_indices.astype(dtype)
 
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
         zz = func2.body
         assert isinstance(zz, relay.Call)
         assert zz.op == relay.op.get("topk")
 
         for target, ctx in tvm.testing.enabled_targets():
-            if "llvm" not in target: continue
+            if "llvm" not in target:
+                continue
             for kind in ["graph", "vm", "debug"]:
                 mod = tvm.ir.IRModule.from_expr(func2)
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
@@ -210,8 +206,7 @@ def test_dynamic_to_static_broadcast_to():
         z = relay.broadcast_to(x, shape=relay.shape_of(y))
 
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -235,8 +230,9 @@ def test_dynamic_to_static_zeros_ones():
             y = op(relay.shape_of(x), dtype)
 
             func = run_infer_type(relay.Function([x], y))
-            func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                                 transform.InferType())
+            func2 = run_opt_pass(
+                run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()
+            )
 
             zz = func2.body
             assert isinstance(zz, relay.Constant)
@@ -246,8 +242,8 @@ def test_dynamic_to_static_zeros_ones():
             ref_res = ref(x_data.shape)
             verify_func(func2, [x_data], ref_res)
 
-    verify_ones_zeros((1, 2, 3), 'int64')
-    verify_ones_zeros((9, 8, 3, 4), 'float32')
+    verify_ones_zeros((1, 2, 3), "int64")
+    verify_ones_zeros((9, 8, 3, 4), "float32")
 
 
 @tvm.testing.uses_gpu
@@ -261,12 +257,12 @@ def test_dynamic_to_static_resize():
         x = relay.var("x", relay.TensorType(shape, "float32"))
         size_var = relay.const(np.array(size).astype("float32"))
         coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners"
-        z = relay.image.resize(x, size_var, layout, method,
-                              coordinate_transformation_mode=coord_trans)
+        z = relay.image.resize(
+            x, size_var, layout, method, coordinate_transformation_mode=coord_trans
+        )
 
         func = run_infer_type(relay.Function([x], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -295,8 +291,7 @@ def test_dynamic_to_static_one_hot():
         out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype)
         func = relay.Function([indices], out)
 
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
-                             transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -306,18 +301,19 @@ def test_dynamic_to_static_one_hot():
         out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
         verify_func(func2, [indices_np], out_np)
 
-    _verify((3, ), 3, 1, 0, -1, "int32")
-    _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+    _verify((3,), 3, 1, 0, -1, "int32")
+    _verify((3,), 3, 1.0, 0.0, -1, "float32")
     _verify((2, 2), 5, 2, -2, 0, "int32")
     _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
     _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+
 @tvm.testing.uses_gpu
 def test_dynamic_to_static_full():
     def verify_full(fill_value, fill_shape, dtype):
         x = relay.var("x", relay.scalar_type(dtype))
-        y = relay.var("y", relay.TensorType(fill_shape, 'int64'))
+        y = relay.var("y", relay.TensorType(fill_shape, "int64"))
         z = relay.full(x, relay.shape_of(y), dtype)
 
         func = run_infer_type(relay.Function([x, y], z))
@@ -328,11 +324,12 @@ def test_dynamic_to_static_full():
         assert zz.op == relay.op.get("full")
 
         ref_res = np.full(fill_shape, fill_value).astype(dtype)
-        y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
+        y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype("int64")
         verify_func(func2, [fill_value, y_data], ref_res)
 
-    verify_full(4, (1, 2, 3, 4), 'int32')
-    verify_full(4.0, (1, 2, 8, 10), 'float32')
+    verify_full(4, (1, 2, 3, 4), "int32")
+    verify_full(4.0, (1, 2, 8, 10), "float32")
+
 
 def test_dynamic_to_static_upsampling():
     def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype):
@@ -352,8 +349,9 @@ def test_dynamic_to_static_upsampling():
         ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW")
         verify_func(func2, [x_data], ref_res)
 
-    verify_upsampling((1, 16, 32, 32), 2, 2, 'int8')
-    verify_upsampling((1, 16, 32, 32), 4, 4, 'int32')
+    verify_upsampling((1, 16, 32, 32), 2, 2, "int8")
+    verify_upsampling((1, 16, 32, 32), 4, 4, "int32")
+
 
 def test_dynamic_to_static_upsampling3d():
     def verify_upsampling3d(data_shape, scale_d_val, scale_h_val, scale_w_val, dtype):
@@ -372,12 +370,15 @@ def test_dynamic_to_static_upsampling3d():
         assert zz.op == relay.op.get("nn.upsampling3d")
 
         x_data = np.random.uniform(size=data_shape).astype(dtype)
-        ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW")
+        ref_res = tvm.topi.testing.upsampling3d_python(
+            x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW"
+        )
         verify_func(func2, [x_data], ref_res)
 
-    verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, 'int8')
-    verify_upsampling3d((5, 7, 8, 10, 32), 3, 2, 2, 'int8')
-    verify_upsampling3d((1, 4, 2, 5, 3), 5, 4, 3, 'int32')
+    verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, "int8")
+    verify_upsampling3d((5, 7, 8, 10, 32), 3, 2, 2, "int8")
+    verify_upsampling3d((1, 4, 2, 5, 3), 5, 4, 3, "int32")
+
 
 def test_dynamic_to_static_pad():
     def verify_pad(data_shape, pad_width, pad_val, dtype):
@@ -390,7 +391,9 @@ def test_dynamic_to_static_pad():
         assert zz.op == relay.op.get("nn.pad")
 
         x_data = np.random.uniform(size=data_shape).astype(dtype)
-        ref_res = np.pad(x_data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * len(data_shape))
+        ref_res = np.pad(
+            x_data, pad_width, "constant", constant_values=(((pad_val,) * 2),) * len(data_shape)
+        )
         verify_func(func2, [x_data], ref_res)
 
     verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
@@ -398,8 +401,7 @@ def test_dynamic_to_static_pad():
 
 
 def test_dynamic_to_static_strided_slice():
-    def verify(dshape, begin, end, strides, output, slice_mode="end",
-               test_ref=True, dtype="int32"):
+    def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
         x = relay.var("x", relay.TensorType(dshape, "float32"))
         ndim = len(dshape)
         begin = begin if begin else [0] * ndim
@@ -412,27 +414,18 @@ def test_dynamic_to_static_strided_slice():
 
         # target numpy result
         x_data = np.random.uniform(size=dshape).astype("float32")
-        ref_res = tvm.topi.testing.strided_slice_python(
-            x_data, begin, end, strides, slice_mode)
+        ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
         data = [x_data, np.array(begin), np.array(end)]
-        
+
         begin = relay.const(begin, dtype=dtype)
         end = relay.const(end, dtype=dtype)
 
-        
         if strides:
             data.append(np.array(strides))
             strides = relay.const(strides, dtype=dtype)
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    strides=strides,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
         else:
-            z = relay.strided_slice(x,
-                                    begin=begin,
-                                    end=end,
-                                    slice_mode=slice_mode)
+            z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
         func = relay.Function([x], z)
 
         func = run_infer_type(func)
@@ -442,8 +435,14 @@ def test_dynamic_to_static_strided_slice():
         verify_func(func2, [x_data], ref_res)
 
     verify((1, 3, 10, 10), [0, 0, 0, 0], [1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64")
-    verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3],
-           [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64")
+    verify(
+        (1, 224, 224, 3),
+        [0, 20, 20, 0],
+        [1, 140, 140, 3],
+        [1, 1, 1, 1],
+        (1, 120, 120, 3),
+        dtype="int64",
+    )
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
     verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
@@ -452,10 +451,10 @@ def test_dynamic_to_static_strided_slice():
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
     verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
-    verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1],
-           (2, 4, 3), slice_mode="size", test_ref=False)
-    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1],
-           (2, 2, 3), slice_mode="size", test_ref=True)
+    verify(
+        (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
+    )
+    verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
 
 
 if __name__ == "__main__":
index 45d21a4..ac519a9 100644 (file)
@@ -76,7 +76,7 @@ def test_callback():
         return run_opt_pass(f, transform.InferType())
 
     def fskip(expr):
-        if isinstance(expr, relay.expr.Call) and expr.op.name == 'add':
+        if isinstance(expr, relay.expr.Call) and expr.op.name == "add":
             return True
         return False
 
@@ -84,13 +84,14 @@ def test_callback():
     z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
     assert tvm.ir.structural_equal(z, expected())
 
+
 def test_tuple_get_time():
     def before():
-        x = relay.var('x', shape=(1, 16, 1, 1))
-        var = relay.var('var', shape=(16,))
-        mean = relay.var('mean', shape=(16,))
-        beta = relay.var('beta', shape=(16,))
-        gamma = relay.var('gamma', shape=(16,))
+        x = relay.var("x", shape=(1, 16, 1, 1))
+        var = relay.var("var", shape=(16,))
+        mean = relay.var("mean", shape=(16,))
+        beta = relay.var("beta", shape=(16,))
+        gamma = relay.var("gamma", shape=(16,))
         BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
         T1 = BN[0]
         T2 = BN[0]
@@ -99,11 +100,11 @@ def test_tuple_get_time():
         return f
 
     def expected():
-        x = relay.var('x', shape=(1, 16, 1, 1))
-        var = relay.var('var', shape=(16,))
-        mean = relay.var('mean', shape=(16,))
-        beta = relay.var('beta', shape=(16,))
-        gamma = relay.var('gamma', shape=(16,))
+        x = relay.var("x", shape=(1, 16, 1, 1))
+        var = relay.var("var", shape=(16,))
+        mean = relay.var("mean", shape=(16,))
+        beta = relay.var("beta", shape=(16,))
+        gamma = relay.var("gamma", shape=(16,))
         BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
         T1 = BN[0]
         add = T1 + T1
@@ -114,6 +115,7 @@ def test_tuple_get_time():
     z = run_opt_pass(z, transform.EliminateCommonSubexpr())
     assert tvm.ir.structural_equal(z, expected())
 
+
 if __name__ == "__main__":
     test_simple()
     test_callback()
index 05c5f03..bbb2c2b 100644 (file)
@@ -23,8 +23,10 @@ from tvm import te
 from tvm import relay
 import tvm.relay.transform as _transform
 
+
 def test_eta_expand_global_var():
-    mod = tvm.parser.fromtext(r"""
+    mod = tvm.parser.fromtext(
+        r"""
         #[version = "0.0.5"]
         def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
             %x
@@ -32,11 +34,13 @@ def test_eta_expand_global_var():
         def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] {
             @aux
         }
-    """)
+    """
+    )
     seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
     with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
-    expected = tvm.parser.fromtext(r"""
+    expected = tvm.parser.fromtext(
+        r"""
         #[version = "0.0.5"]
         def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
             %x
@@ -46,13 +50,14 @@ def test_eta_expand_global_var():
                 @aux(%x)
             }
         }
-    """)
-    tvm.ir.assert_structural_equal(mod['main'], expected['main'],
-                                   map_free_vars=True)
+    """
+    )
+    tvm.ir.assert_structural_equal(mod["main"], expected["main"], map_free_vars=True)
 
 
 def test_eta_expand_constructor():
-    mod = tvm.parser.fromtext(r"""
+    mod = tvm.parser.fromtext(
+        r"""
         #[version = "0.0.5"]
         type List[A] {
             Cons(A, List[A]),
@@ -61,11 +66,13 @@ def test_eta_expand_constructor():
         def @main[A]() -> fn(A, List[A]) -> List[A] {
             Cons
         }
-    """)
+    """
+    )
     seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
     with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
-    expected = tvm.parser.fromtext(r"""
+    expected = tvm.parser.fromtext(
+        r"""
         #[version = "0.0.5"]
         type List[A] {
             Cons(A, List[A]),
@@ -76,11 +83,11 @@ def test_eta_expand_constructor():
                 Cons(%x, %xs)
             }
         }
-    """)
-    tvm.ir.assert_structural_equal(mod['main'], expected['main'],
-                                   map_free_vars=True)
+    """
+    )
+    tvm.ir.assert_structural_equal(mod["main"], expected["main"], map_free_vars=True)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_eta_expand_global_var()
     test_eta_expand_constructor()
index da5eaf4..bb3fb84 100644 (file)
@@ -19,6 +19,7 @@ from tvm.ir import IRModule
 from tvm import relay
 from tvm.relay.transform import FastMath
 
+
 def test_exp():
     x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
     y = relay.exp(x)
@@ -29,10 +30,11 @@ def test_exp():
     assert "fast_exp" in fast_mod.astext()
 
     # Check that FastMath option works for relay.build.
-    with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']):
-        fast_mod = relay.optimize(mod, target='llvm', params=None)
+    with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
+        fast_mod = relay.optimize(mod, target="llvm", params=None)
     assert "fast_exp" in fast_mod[0].astext()
 
+
 def test_tanh():
     x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
     y = relay.tanh(x)
@@ -43,10 +45,11 @@ def test_tanh():
     assert "fast_tanh" in fast_mod.astext()
 
     # Check that FastMath option works for relay.build.
-    with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']):
-        fast_mod = relay.optimize(mod, target='llvm', params=None)
+    with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
+        fast_mod = relay.optimize(mod, target="llvm", params=None)
     assert "fast_tanh" in fast_mod[0].astext()
 
+
 def test_erf():
     x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
     y = relay.erf(x)
@@ -57,10 +60,11 @@ def test_erf():
     assert "fast_erf" in fast_mod.astext()
 
     # Check that FastMath option works for relay.build.
-    with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']):
-        fast_mod = relay.optimize(mod, target='llvm', params=None)
+    with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
+        fast_mod = relay.optimize(mod, target="llvm", params=None)
     assert "fast_erf" in fast_mod[0].astext()
 
+
 if __name__ == "__main__":
     test_exp()
     test_tanh()
index a496e1d..b3ea422 100644 (file)
@@ -54,6 +54,7 @@ def test_concatenate_const():
 def test_fold_const():
     c_data = np.array([1, 2, 3]).astype("float32")
     t = relay.TensorType([1, 2, 3], "float32")
+
     def before():
         c = relay.const(c_data)
         x = relay.var("x", t)
@@ -80,6 +81,7 @@ def test_fold_const():
 def test_fold_let():
     c_data = np.array(1).astype("float32")
     t = relay.TensorType([1], "float32")
+
     def before():
         sb = relay.ScopeBuilder()
         x = relay.var("x", t)
@@ -92,7 +94,7 @@ def test_fold_let():
     def expected():
         sb = relay.ScopeBuilder()
         x = relay.var("x", t)
-        c_folded = (c_data + c_data)
+        c_folded = c_data + c_data
         t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
         sb.ret(t3)
         return relay.Function([x], sb.get())
@@ -105,6 +107,7 @@ def test_fold_let():
 def test_fold_tuple():
     c_data = np.array(1).astype("float32")
     t = relay.TensorType([1], "float32")
+
     def before():
         c = relay.const(c_data)
         x = relay.var("x", t)
@@ -145,6 +148,7 @@ def test_fold_concat():
 
 def test_fold_shape_of():
     c_shape = (8, 9, 10)
+
     def before(dtype):
         x = relay.var("x", shape=c_shape, dtype="float32")
         y = relay.var("y", shape=c_shape, dtype="float32")
@@ -166,6 +170,7 @@ def test_fold_shape_of():
 
 def test_fold_ndarray_size():
     c_shape = (8, 9, 10)
+
     def before(dtype):
         x = relay.var("x", shape=c_shape, dtype="float32")
         y = relay.var("y", shape=c_shape, dtype="float32")
@@ -187,8 +192,9 @@ def test_fold_ndarray_size():
 
 def test_fold_full():
     c_shape = (8, 9, 10)
+
     def before():
-        dtype = 'float32'
+        dtype = "float32"
         return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
 
     def expected():
@@ -205,17 +211,20 @@ def test_fold_batch_norm():
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.const(np.zeros((16, 3, 3, 3)))
         bias = relay.const(np.zeros((16, 1, 1)))
-        conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
-                               channels=16, padding=(1, 1))
+        conv = relay.nn.conv2d(
+            data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
+        )
         add = relay.add(conv, bias)
         return relay.Function(relay.analysis.free_vars(add), add)
 
-    remove_bn_pass = tvm.transform.Sequential([
-        relay.transform.InferType(),
-        relay.transform.SimplifyInference(),
-        relay.transform.FoldConstant(),
-        relay.transform.FoldScaleAxis(),
-    ])
+    remove_bn_pass = tvm.transform.Sequential(
+        [
+            relay.transform.InferType(),
+            relay.transform.SimplifyInference(),
+            relay.transform.FoldConstant(),
+            relay.transform.FoldScaleAxis(),
+        ]
+    )
 
     data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
     weight = relay.var("weight")
@@ -224,10 +233,11 @@ def test_fold_batch_norm():
     bn_mmean = relay.var("bn_mean")
     bn_mvar = relay.var("bn_var")
 
-    conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
-                           channels=16, padding=(1, 1))
-    bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta,
-                                    bn_mmean, bn_mvar)
+    conv = relay.nn.conv2d(
+        data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
+    )
+    bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+
     def initializer(_, param):
         param = np.zeros(param.shape)
 
index 8aecf3f..421c6c5 100644 (file)
@@ -21,8 +21,9 @@ from tvm import te
 from tvm import relay
 from tvm.relay import transform
 
+
 def _get_positive_scale(size):
-    return np.random.uniform(0.5, 1, size=size).astype('float32')
+    return np.random.uniform(0.5, 1, size=size).astype("float32")
 
 
 def run_opt_pass(expr, opt_pass):
@@ -35,17 +36,21 @@ def run_opt_pass(expr, opt_pass):
 
 def test_fold_fwd_simple():
     """Simple testcase."""
+
     def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         args = [x, conv_weight, in_bias]
         x = relay.multiply(x, in_scale)
         x = relay.nn.relu(x)
         x = relay.add(x, in_bias)
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
 
         return relay.Function(args, y)
 
@@ -53,43 +58,54 @@ def test_fold_fwd_simple():
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, in_bias]
         if blocking:
-            squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3])
+            squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3])
             x = relay.nn.relu(x)
-            in_bias = relay.divide(in_bias, 
-                relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0]))) #NCHWc
+            in_bias = relay.divide(
+                in_bias,
+                relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0])),
+            )  # NCHWc
             x = relay.add(x, in_bias)
-            conv_weight = relay.multiply(conv_weight,
-                relay.reshape(squeezed_scale, (1, in_channels//2, 1, 1, 2, 1))) #OIHWio
+            conv_weight = relay.multiply(
+                conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 2, 1))
+            )  # OIHWio
         else:
-            squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
+            squeezed_scale = relay.squeeze(in_scale, axis=[1, 2])
             x = relay.nn.relu(x)
-            in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+            in_bias = relay.divide(
+                in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
+            )
             x = relay.add(x, in_bias)
             conv_weight = relay.multiply(
-                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
-
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW")
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
+            )
+
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         return relay.Function(args, y)
 
     def check(shape, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
             in_channels = shape[1] * shape[4]
             in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
-            in_scale = relay.const(_get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0])))
+            in_scale = relay.const(
+                _get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0]))
+            )
         else:
             in_channels = shape[1]
             in_bias = relay.var("in_bias", shape=(in_channels, 1, 1))
             in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
         y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)
@@ -101,27 +117,35 @@ def test_fold_fwd_simple():
     check((2, 4, 10, 10), 2, None)
     check((2, 2, 10, 10, 2), 8, (2, 4))
 
+
 def test_fold_fwd_dual_path():
     """scale axis being consumed by two consumers"""
+
     def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         args = [x, conv_weight, in_bias]
         x = relay.multiply(in_scale, x)
         x = relay.nn.relu(x)
         x = relay.subtract(x, in_bias)
-        y1 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             groups=channels,
-                             padding=(1, 1))
-        y2 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             groups=channels,
-                             padding=(1, 1))
+        y1 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            groups=channels,
+            padding=(1, 1),
+        )
+        y2 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            groups=channels,
+            padding=(1, 1),
+        )
         z = relay.add(y1, y2)
         return relay.Function(args, z)
 
@@ -129,56 +153,70 @@ def test_fold_fwd_dual_path():
         args = [x, conv_weight, in_bias]
         x = relay.nn.relu(x)
         if blocking:
-            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], blocking[0])) #NHWCc
+            _in_scale = relay.reshape(
+                in_scale, (1, 1, 1, channels // blocking[0], blocking[0])
+            )  # NHWCc
         else:
             _in_scale = in_scale
         in_bias = relay.divide(in_bias, _in_scale)
         x = relay.subtract(x, in_bias)
         if blocking:
-            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio
-        y1 = relay.nn.conv2d(x,
-                             relay.multiply(conv_weight, _in_scale),
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             groups=channels,
-                             padding=(1, 1))
+            _in_scale = relay.reshape(
+                in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
+            )  # HWIOio
+        y1 = relay.nn.conv2d(
+            x,
+            relay.multiply(conv_weight, _in_scale),
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            groups=channels,
+            padding=(1, 1),
+        )
         if blocking:
-            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio
-        y2 = relay.nn.conv2d(x,
-                             relay.multiply(conv_weight, _in_scale),
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             groups=channels,
-                             padding=(1, 1))
+            _in_scale = relay.reshape(
+                in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
+            )  # HWIOio
+        y2 = relay.nn.conv2d(
+            x,
+            relay.multiply(conv_weight, _in_scale),
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            groups=channels,
+            padding=(1, 1),
+        )
         z = relay.add(y1, y2)
         return relay.Function(args, z)
 
     def check(dshape, channels, blocking):
-        x =  relay.var("x", shape=dshape)
+        x = relay.var("x", shape=dshape)
         if blocking:
             in_channels = dshape[3] * dshape[4]
-            wshape = (3, 3, 1, channels//blocking[1], 1, blocking[1]) # HWIOio
+            wshape = (3, 3, 1, channels // blocking[1], 1, blocking[1])  # HWIOio
             weight = relay.var("weight", shape=wshape)
-            in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0]))
-            in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0])))
+            in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0]))
+            in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0])))
         else:
             in_channels = dshape[-1]
-            wshape = (3, 3, 1, channels) # HWIO
+            wshape = (3, 3, 1, channels)  # HWIO
             weight = relay.var("weight", shape=wshape)
             in_bias = relay.var("in_bias", shape=(in_channels,))
-            in_scale = relay.const(_get_positive_scale(in_channels,))
-        
+            in_scale = relay.const(
+                _get_positive_scale(
+                    in_channels,
+                )
+            )
+
         # test depthwise
         assert in_channels == channels
 
         y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
@@ -187,26 +225,31 @@ def test_fold_fwd_dual_path():
     check((2, 4, 10, 3), 3, None)
     check((2, 4, 10, 2, 2), 4, (2, 2))
 
+
 def test_fold_fwd_fail():
     """testcase where we canont fold"""
+
     def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         x = relay.multiply(x, in_scale)
         xx = relay.nn.leaky_relu(x, alpha=0.1)
-        y1 = relay.nn.conv2d(xx, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             padding=(1, 1))
+        y1 = relay.nn.conv2d(
+            xx,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            padding=(1, 1),
+        )
         z = relay.add(y1, x)
         return relay.Function(relay.analysis.free_vars(z), z)
 
     def check(shape, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         if blocking:
             in_channels = shape[3] * shape[4]
-            in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0]))
-            in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0])))
+            in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0]))
+            in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0])))
         else:
             in_channels = shape[-1]
             in_bias = relay.var("in_bias", shape=(in_channels,))
@@ -220,24 +263,29 @@ def test_fold_fwd_fail():
         assert tvm.ir.structural_equal(y1, y1_folded)
 
     check((2, 11, 10, 4), 4, None)
-    check((2, 11, 10, 2, 2), 4, (2,2))
+    check((2, 11, 10, 2, 2), 4, (2, 2))
+
 
 def test_fold_fwd_relu_fail():
     """testcase where we canont fold because scale can not pass relu"""
+
     def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         x = relay.multiply(x, in_scale)
         xx = relay.nn.relu(x)
-        y1 = relay.nn.conv2d(xx, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
-                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
-                             padding=(1, 1))
+        y1 = relay.nn.conv2d(
+            xx,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+            kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
+            padding=(1, 1),
+        )
         z = relay.add(y1, x)
         return relay.Function(relay.analysis.free_vars(z), z)
 
     def check(shape, channels, blocking, in_scale):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
             in_channels = shape[3] * shape[4]
@@ -257,50 +305,56 @@ def test_fold_fwd_relu_fail():
     in_scale = relay.const(-_get_positive_scale((4,)))
     check((2, 11, 10, 4), 4, None, in_scale)
 
-    in_scale = relay.var("in_scale", shape=(1,1,1,2,2))
+    in_scale = relay.var("in_scale", shape=(1, 1, 1, 2, 2))
     check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
-    in_scale = relay.const(-_get_positive_scale((1,1,1,2,2)))
+    in_scale = relay.const(-_get_positive_scale((1, 1, 1, 2, 2)))
     check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
 
 
-
-
 def test_fold_fwd_negative_scale():
     """Testcase of folding negative scale"""
+
     def before(x, conv_weight, in_scale, channels, blocking):
         args = [x, conv_weight]
         x = relay.multiply(x, in_scale)
-        y = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         return relay.Function(args, y)
 
     def expected(x, conv_weight, in_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight]
         if blocking:
-            squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3])
+            squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3])
             conv_weight = relay.multiply(
-                conv_weight , relay.reshape(squeezed_scale, (1, in_channels//4, 1, 1, 4, 1)))
-            #blocking by "i" in OIHWio
+                conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 4, 1, 1, 4, 1))
+            )
+            # blocking by "i" in OIHWio
         else:
-            squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
+            squeezed_scale = relay.squeeze(in_scale, axis=[1, 2])
             conv_weight = relay.multiply(
-                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
-        y = relay.nn.conv2d(x,
-                             conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW")
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
+            )
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         return relay.Function(args, y)
 
     def check(shape, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         if blocking:
             in_channels = shape[1] * shape[4]
             in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4])))
@@ -310,7 +364,7 @@ def test_fold_fwd_negative_scale():
         weight = relay.var("weight")
         y1 = before(x, weight, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking)
@@ -320,24 +374,29 @@ def test_fold_fwd_negative_scale():
     check((2, 4, 10, 10), 4, None)
     check((2, 2, 10, 10, 2), 8, (2, 2))
 
+
 def test_fold_bwd_simple():
     """Simple testcase."""
+
     def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
         if blocking:
-            out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1]))
+            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1]))
         else:
             out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y = relay.add(y, out_bias)
         y = relay.nn.relu(y)
         if blocking:
-            out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1]))
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
@@ -345,44 +404,53 @@ def test_fold_bwd_simple():
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, out_bias]
         if blocking:
-            out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1]))
-            out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1]))
+            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1]))
             squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3])
             conv_weight = relay.multiply(
-                conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+                conv_weight,
+                relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])),
+            )
         else:
             out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
-            squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
             conv_weight = relay.multiply(
-                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
-
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+            )
+
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         if blocking:
-            out_bias = relay.multiply(out_bias,
-                                  relay.reshape(squeezed_scale, (1, channels//blocking[1], 1, 1, blocking[1])))
+            out_bias = relay.multiply(
+                out_bias,
+                relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, blocking[1])),
+            )
         else:
-            out_bias = relay.multiply(out_bias,
-                                  relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+            out_bias = relay.multiply(
+                out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
+            )
         y = relay.add(y, out_bias)
         y = relay.nn.relu(y)
         return relay.Function(args, y)
 
     def check(shape, in_channels, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
         if blocking:
             out_scale = relay.const(_get_positive_scale((channels,)))
         else:
-            out_scale = relay.const(_get_positive_scale((channels,1, 1)))
+            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
         y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
@@ -395,21 +463,28 @@ def test_fold_bwd_simple():
 
 def test_fold_bwd_dual_path():
     """Dual path testcase."""
+
     def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        y1 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y1 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y1 = relay.nn.relu(y1)
-        y2 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y2 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         y = relay.multiply(y, out_scale)
@@ -420,46 +495,59 @@ def test_fold_bwd_dual_path():
         args = [x, conv_weight, out_bias]
         if not blocking:
             out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
-        squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+        squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
+
         def fold_conv_weight():
             if blocking:
                 return relay.multiply(
-                    conv_weight ,
-                    relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+                    conv_weight,
+                    relay.reshape(
+                        squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])
+                    ),
+                )
             else:
                 return relay.multiply(
-                    conv_weight ,
-                    relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
-        y1 = relay.nn.conv2d(x, fold_conv_weight(),
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+                    conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+                )
+
+        y1 = relay.nn.conv2d(
+            x,
+            fold_conv_weight(),
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y1 = relay.nn.relu(y1)
-        y2 = relay.nn.conv2d(x, fold_conv_weight(),
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y2 = relay.nn.conv2d(
+            x,
+            fold_conv_weight(),
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
     def check(shape, in_channels, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
             out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
-            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+            out_scale = relay.const(
+                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
+            )
         else:
             out_bias = relay.var("out_bias", shape=(channels,))
             out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
 
         y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
@@ -469,33 +557,43 @@ def test_fold_bwd_dual_path():
     check((2, 4, 10, 10), 4, 8, None)
     check((2, 2, 10, 10, 2), 4, 8, (2, 2))
 
+
 def test_fold_bwd_dual_consumer():
     def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        y0 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y0 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y0 = relay.multiply(y0, out_scale)
         y0 = relay.nn.relu(y0)
 
-        y1 = relay.nn.conv2d(y0, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y1 = relay.nn.conv2d(
+            y0,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y1 = relay.multiply(y1, out_scale)
         y1 = relay.nn.relu(y1)
 
-        y2 = relay.nn.conv2d(y0, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y2 = relay.nn.conv2d(
+            y0,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y2 = relay.multiply(y2, out_scale)
         y2 = relay.nn.relu(y2)
 
@@ -505,53 +603,69 @@ def test_fold_bwd_dual_consumer():
     def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, out_bias]
+
         def fold_conv_weight():
-            squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
             if blocking:
                 return relay.multiply(
-                    conv_weight ,
-                    relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+                    conv_weight,
+                    relay.reshape(
+                        squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])
+                    ),
+                )
             else:
                 return relay.multiply(
-                    conv_weight ,
-                    relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
-        y0 = relay.nn.conv2d(x, fold_conv_weight(),
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+                    conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+                )
+
+        y0 = relay.nn.conv2d(
+            x,
+            fold_conv_weight(),
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y0 = relay.nn.relu(y0)
-        y1 = relay.nn.conv2d(y0, fold_conv_weight(),
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y1 = relay.nn.conv2d(
+            y0,
+            fold_conv_weight(),
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y1 = relay.nn.relu(y1)
-        y2 = relay.nn.conv2d(y0, fold_conv_weight(),
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y2 = relay.nn.conv2d(
+            y0,
+            fold_conv_weight(),
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
     def check(shape, in_channels, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
             out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
-            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+            out_scale = relay.const(
+                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
+            )
         else:
             out_bias = relay.var("out_bias", shape=(channels,))
             out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
 
         y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
@@ -561,24 +675,32 @@ def test_fold_bwd_dual_consumer():
     check((2, 4, 10, 10), 4, 4, None)
     check((2, 2, 10, 10, 2), 4, 4, (2, 2))
 
+
 def test_fold_bwd_fail():
     """Dual path testcase."""
+
     def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        y1 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y1 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y1 = relay.nn.relu(y1)
-        y2 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
-                             out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW")
+        y2 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+            out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW",
+        )
         # fold will fail because the axis from two path
         # differs from each other.
         y2 = relay.nn.relu(y2)
@@ -588,12 +710,15 @@ def test_fold_bwd_fail():
 
     def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        y1 = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y1 = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y2 = relay.nn.relu(y1)
         # fold will fail because y1 is referred also by y2
         y1 = relay.multiply(y1, out_scale)
@@ -601,11 +726,13 @@ def test_fold_bwd_fail():
         return relay.Function(args, y)
 
     def check(shape, in_channels, channels, blocking, fbefore):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
             out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
-            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+            out_scale = relay.const(
+                _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
+            )
         else:
             out_bias = relay.var("out_bias", shape=(channels, 1, 1))
             out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
@@ -622,19 +749,23 @@ def test_fold_bwd_fail():
 
 def test_fold_bwd_relu_fail():
     """testcase where we canont fold because scale can not pass relu"""
+
     def before(x, conv_weight, out_scale, channels, blocking):
-        y = relay.nn.conv2d(x, conv_weight,
-                             channels=channels,
-                             kernel_size=(3, 3),
-                             padding=(1, 1),
-                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y = relay.nn.relu(y)
         y = relay.multiply(x, out_scale)
         return relay.Function(relay.analysis.free_vars(y), y)
 
     def check(shape, channels, blocking, out_scale):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         in_channels = shape[1]
         weight = relay.var("weight")
         y1 = before(x, weight, out_scale, channels, blocking)
@@ -649,20 +780,26 @@ def test_fold_bwd_relu_fail():
 
     out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2))
     check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
-    out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype("float32")
+    out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype(
+        "float32"
+    )
     check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
 
 
 def test_fold_bwd_negative_scale():
     """Testcase of folding negative scale"""
+
     def before(x, conv_weight, out_scale, channels, blocking):
         args = [x, conv_weight]
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
@@ -670,31 +807,39 @@ def test_fold_bwd_negative_scale():
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight]
         if blocking:
-            squeezed_scale = relay.squeeze(out_scale, axis=[0,2,3])
+            squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3])
             conv_weight = relay.multiply(
-                conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+                conv_weight,
+                relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])),
+            )
         else:
-            squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
             conv_weight = relay.multiply(
-                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
-        y = relay.nn.conv2d(x, conv_weight,
-                            channels=channels,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
-                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+            )
+        y = relay.nn.conv2d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+        )
         return relay.Function(args, y)
 
     def check(shape, channels, blocking):
-        x =  relay.var("x", shape=shape)
+        x = relay.var("x", shape=shape)
         weight = relay.var("weight")
         if blocking:
-            out_scale = relay.const(-_get_positive_scale((1,channels//blocking[1], 1, 1, blocking[1])))
+            out_scale = relay.const(
+                -_get_positive_scale((1, channels // blocking[1], 1, 1, blocking[1]))
+            )
         else:
             out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
         y1 = before(x, weight, out_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
-        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_scale, channels, blocking)
@@ -704,6 +849,7 @@ def test_fold_bwd_negative_scale():
     check((2, 4, 10, 10), 8, None)
     check((2, 2, 10, 10, 2), 8, (2, 2))
 
+
 if __name__ == "__main__":
     test_fold_fwd_simple()
     test_fold_fwd_dual_path()
index df30eb4..1d9cfb2 100644 (file)
@@ -24,6 +24,7 @@ import tvm.testing
 
 def test_fuse_simple():
     """Simple testcase."""
+
     def before():
         x = relay.var("x", shape=(10, 20))
         y = relay.add(x, relay.const(1, "float32"))
@@ -51,25 +52,17 @@ def test_fuse_simple():
 
 def test_conv2d_fuse():
     """Test fusion case of conv2d"""
+
     def before(dshape):
         x = relay.var("x", shape=dshape)
         x = relay.add(x, relay.const(1, "float32"))
-        y = relay.nn.conv2d(x, relay.var("w1"),
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            channels=16)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=16)
         # this is the next dominator.
         y1 = relay.add(relay.const(1, "float32"), y)
         y = relay.add(y, y1)
         # second path
-        z2 = relay.nn.conv2d(y, relay.var("w2"),
-                             kernel_size=(1, 1),
-                             padding=(0,0),
-                             channels=16)
-        z3 = relay.nn.conv2d(y, relay.var("w3"),
-                             kernel_size=(3, 3),
-                             padding=(1,1),
-                             channels=16)
+        z2 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(1, 1), padding=(0, 0), channels=16)
+        z3 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=16)
         # add can only be fused to z1
         z = relay.add(z2, z3)
         return relay.Function(relay.analysis.free_vars(z), z)
@@ -84,10 +77,7 @@ def test_conv2d_fuse():
         # segment 1
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
-        y = relay.nn.conv2d(x, w,
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            channels=16)
+        y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16)
         y1 = relay.add(relay.const(1, "float32"), y)
         y = relay.add(y, y1)
         f1 = relay.Function([x, w], y)
@@ -96,10 +86,7 @@ def test_conv2d_fuse():
         # segment 2
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
-        z2 = relay.nn.conv2d(x, w,
-                             kernel_size=(3, 3),
-                             padding=(1,1),
-                             channels=16)
+        z2 = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16)
         f2 = relay.Function([x, w], z2)
         f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
@@ -107,10 +94,7 @@ def test_conv2d_fuse():
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
         offset = relay.var("p2", shape=dshape)
-        z3 = relay.nn.conv2d(x, w,
-                             kernel_size=(1, 1),
-                             padding=(0, 0),
-                             channels=16)
+        z3 = relay.nn.conv2d(x, w, kernel_size=(1, 1), padding=(0, 0), channels=16)
         z3 = relay.add(z3, offset)
         f3 = relay.Function([x, w, offset], z3)
         f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
@@ -148,7 +132,7 @@ def test_concatenate():
         f0 = relay.Function([x], pooled)
         f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
+        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2))
         p1 = relay.var("p1", shape=dshape)
         upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
         concat = relay.concatenate((upsampled, p1), axis=1)
@@ -187,7 +171,7 @@ def test_tuple_root():
         f0 = relay.Function([x], pooled)
         f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
+        p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2))
         upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
         f1 = relay.Function([p0], upsampled)
         f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
@@ -241,34 +225,31 @@ def test_stop_fusion():
 
 def test_fuse_myia_regression():
     def before(dshape, dtype):
-        x = relay.var('x', shape=dshape, dtype=dtype)
-        y = relay.var('y', shape=dshape, dtype=dtype)
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        y = relay.var("y", shape=dshape, dtype=dtype)
         sb = relay.ScopeBuilder()
         with sb.if_scope(relay.op.greater(x, y)):
             sb.ret(relay.Function([], x))
         with sb.else_scope():
             sb.ret(relay.Function([], y))
-        return relay.Function([x, y],
-            relay.Call(sb.get(), []))
+        return relay.Function([x, y], relay.Call(sb.get(), []))
 
     def expected(dshape, dtype):
-        x = relay.var('x', shape=dshape, dtype=dtype)
-        y = relay.var('y', shape=dshape, dtype=dtype)
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        y = relay.var("y", shape=dshape, dtype=dtype)
         sb = relay.ScopeBuilder()
-        p1 = relay.var('p1', shape=dshape, dtype=dtype)
-        p2 = relay.var('p2', shape=dshape, dtype=dtype)
-        fused_gt = relay.Function([p1, p2],
-            relay.op.greater(p1, p2))
+        p1 = relay.var("p1", shape=dshape, dtype=dtype)
+        p2 = relay.var("p2", shape=dshape, dtype=dtype)
+        fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2))
         fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         with sb.if_scope(fused_gt(x, y)):
             sb.ret(relay.Function([], x))
         with sb.else_scope():
             sb.ret(relay.Function([], y))
-        return relay.Function([x, y],
-            relay.Call(sb.get(), []))
+        return relay.Function([x, y], relay.Call(sb.get(), []))
 
     dshape = ()
-    dtype = 'int64'
+    dtype = "int64"
     f = before(dshape, dtype)
     zz = run_opt_pass(f, transform.FuseOps())
     after = run_opt_pass(expected(dshape, dtype), transform.InferType())
@@ -353,6 +334,7 @@ def test_tuple_get_root():
 fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
 fuse2 = relay.transform.FuseOps(fuse_opt_level=2)
 
+
 def test_tuple_intermediate():
     def before(x):
         inj = relay.squeeze(x)
@@ -378,7 +360,7 @@ def test_tuple_intermediate():
     orig = before(x)
     fuse0(tvm.IRModule.from_expr(orig))
     m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, 'llvm')
+    relay.build(m, "llvm")
     after = run_opt_pass(expected(x), transform.InferType())
     assert tvm.ir.structural_equal(m["main"], after)
 
@@ -413,13 +395,13 @@ def test_tuple_consecutive():
         f0 = relay.Function([p0], concat)
         f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
+        p01 = relay.var("p01", shape=(1, dshape[1] * 9, dshape[2], dshape[3]))
         pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
         out = relay.add(pooled, relay.const(1, "float32"))
         f1 = relay.Function([p01], out)
         f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
+        p02 = relay.var("p02", shape=(1, dshape[1] * 9, dshape[2] // 2, dshape[3] // 2))
         out = relay.add(p02, relay.const(1, "float32"))
         f2 = relay.Function([p02], out)
         f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
@@ -436,17 +418,14 @@ def test_tuple_consecutive():
     orig = before(x)
     fuse0(tvm.IRModule.from_expr(orig))
     m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, 'llvm')
+    relay.build(m, "llvm")
     after = run_opt_pass(expected(dshape), transform.InferType())
     assert tvm.ir.structural_equal(m["main"], after)
 
 
 def test_inception_like():
     def conv(data):
-        y = relay.nn.conv2d(data, relay.var("w"),
-                            kernel_size=(3, 3),
-                            padding=(1, 1),
-                            channels=16)
+        y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16)
         return relay.nn.relu(data=y)
 
     def inception_like(data):
@@ -477,7 +456,7 @@ def test_inception_like():
         f_concat1 = relay.Function([p02, p12], concat1)
         f_concat1 = f_concat1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
+        dshape2 = (dshape[0], dshape[1] * 2, dshape[2], dshape[3])
 
         p03 = relay.var("p03", shape=dshape2)
         c = conv(p03)
@@ -509,13 +488,14 @@ def test_inception_like():
     orig = before(dshape)
     fuse0(tvm.IRModule.from_expr(orig))
     m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, 'llvm')
+    relay.build(m, "llvm")
     after = run_opt_pass(expected(dshape), transform.InferType())
     assert tvm.ir.structural_equal(m["main"], after)
 
 
 def test_fuse_parallel_injective():
     """Test fusing parallel injective ops to an elemwise op."""
+
     def before():
         x = relay.var("x", shape=(10, 20))
         y = relay.add(x, relay.const(1, "float32"))
@@ -547,6 +527,7 @@ def test_fuse_parallel_injective():
 
 def test_immutable():
     """Verify the fusion pass won't change original module."""
+
     def before():
         x = relay.var("x", shape=(10, 20))
         y = relay.add(x, relay.const(1, "float32"))
@@ -586,8 +567,10 @@ def test_split():
     mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
     mod = transform.FuseOps()(mod)
 
+
 def test_fuse_max():
     """Test the constraint of number of nodes in op fusion."""
+
     def before(n):
         x = relay.var("x", shape=(10, 20))
         y = x
@@ -607,7 +590,7 @@ def test_fuse_max():
         xx = relay.var("pp", shape=(10, 20))
         yy = xx
         # it is assumed that there are two fused functions
-        for i in range(n-max_fused_ops):
+        for i in range(n - max_fused_ops):
             yy = relay.exp(yy)
         f2 = relay.Function([xx], yy)
         f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
@@ -637,23 +620,20 @@ def test_fuse_take():
     """Test fusion case involving concat and take"""
 
     def before():
-        shape = (tvm.tir.const(10, "int64"),
-                 tvm.tir.const(1, "int64"))
+        shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
         x = relay.var("x", shape=shape)
-        concat = relay.concatenate([x,x], axis=-1)
+        concat = relay.concatenate([x, x], axis=-1)
         out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
         return relay.Function(relay.analysis.free_vars(out), out)
 
     def expected():
-        shape1 = (tvm.tir.const(10, "int64"),
-                  tvm.tir.const(1, "int64"))
+        shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
         shape2 = (tvm.tir.const(1, "int64"),)
         x = relay.var("x", shape=shape1)
         p0 = relay.var("p0", shape=shape1)
-        p1 = relay.var("p1", shape=shape2,
-                             dtype="int64")
+        p1 = relay.var("p1", shape=shape2, dtype="int64")
         c = relay.const([0], dtype="int64")
-        concat = relay.concatenate([p0,p0], axis=-1)
+        concat = relay.concatenate([p0, p0], axis=-1)
         out = relay.op.take(concat, indices=p1)
 
         f0 = relay.Function([p0, p1], out)
@@ -664,7 +644,7 @@ def test_fuse_take():
 
     orig = before()
     m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, 'llvm')
+    relay.build(m, "llvm")
     after = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(m["main"], after)
 
@@ -673,23 +653,20 @@ def test_fuse_gather_nd():
     """Test fusion case involving concat and gather_nd"""
 
     def before():
-        shape = (tvm.tir.const(10, "int64"),
-                 tvm.tir.const(1, "int64"))
+        shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
         x = relay.var("x", shape=shape)
-        concat = relay.concatenate([x,x], axis=-1)
-        out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
+        concat = relay.concatenate([x, x], axis=-1)
+        out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64"))
         return relay.Function(relay.analysis.free_vars(out), out)
 
     def expected():
-        shape1 = (tvm.tir.const(10, "int64"),
-                  tvm.tir.const(1, "int64"))
-        shape2 = (tvm.tir.const(2, "int64"),
-                  tvm.tir.const(2, "int64"))
+        shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
+        shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64"))
         x = relay.var("x", shape=shape1)
         p0 = relay.var("p0", shape=shape1)
         p1 = relay.var("p1", shape=shape2, dtype="int64")
-        c = relay.const([[0,1],[1,0]], dtype="int64")
-        concat = relay.concatenate([p0,p0], axis=-1)
+        c = relay.const([[0, 1], [1, 0]], dtype="int64")
+        concat = relay.concatenate([p0, p0], axis=-1)
         out = relay.gather_nd(concat, indices=p1)
 
         f0 = relay.Function([p0, p1], out)
@@ -700,7 +677,7 @@ def test_fuse_gather_nd():
 
     orig = before()
     m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, 'llvm')
+    relay.build(m, "llvm")
     after = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(m["main"], after)
 
index b239ef4..5d79205 100644 (file)
@@ -26,13 +26,20 @@ from tvm.relay.analysis import free_vars, free_type_vars
 from tvm.relay import create_executor, transform
 from tvm.relay.transform import gradient
 from tvm.relay.prelude import Prelude
-from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand, count_ops
+from tvm.relay.testing import (
+    add_nat_definitions,
+    make_nat_expr,
+    run_infer_type,
+    check_grad,
+    rand,
+    count_ops,
+)
 import tvm.relay.op as op
 
 
 def test_fo_id():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x)
@@ -45,9 +52,10 @@ def test_fo_id():
     tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
     tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
+
 def test_id():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x)
@@ -63,7 +71,7 @@ def test_id():
 
 def test_relu():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], op.nn.relu(x))
@@ -75,7 +83,7 @@ def test_relu():
 
 def test_add():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x + x)
@@ -91,7 +99,7 @@ def test_add():
 
 def test_check_grad():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     y = relay.var("y", t)
@@ -102,7 +110,7 @@ def test_check_grad():
 def test_temp_add():
     scope = relay.ScopeBuilder()
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     y = scope.let("y", x + x)
@@ -120,7 +128,7 @@ def test_temp_add():
 
 def test_sub():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x - x)
@@ -137,7 +145,7 @@ def test_sub():
 def test_broadcast_add():
     shape1 = (3, 4, 1)
     shape2 = (1, 5)
-    dtype = 'float32'
+    dtype = "float32"
     x_nd = rand(dtype, *shape1)
     y_nd = rand(dtype, *shape2)
     x_np = x_nd.asnumpy()
@@ -150,22 +158,28 @@ def test_broadcast_add():
     func = relay.Function([x, y], x + y)
     func = run_infer_type(func)
     full_func = run_infer_type(gradient(func))
-    assert full_func.checked_type == relay.FuncType([t1, t2],
-                                                    relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
-                                                                     relay.TupleType([t1, t2])]))
+    assert full_func.checked_type == relay.FuncType(
+        [t1, t2],
+        relay.TupleType(
+            [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])]
+        ),
+    )
     ex = create_executor()
     forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
     tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
-    tvm.testing.assert_allclose(grad_x.asnumpy(),
-                                np.ones_like(expected_forward).sum(axis=2, keepdims=True))
-    tvm.testing.assert_allclose(grad_y.asnumpy(),
-                                np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
+    tvm.testing.assert_allclose(
+        grad_x.asnumpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True)
+    )
+    tvm.testing.assert_allclose(
+        grad_y.asnumpy(),
+        np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0),
+    )
 
 
 def test_broadcast_subtract():
     shape1 = (3, 4, 1)
     shape2 = (1, 5)
-    dtype = 'float32'
+    dtype = "float32"
     x_nd = rand(dtype, *shape1)
     y_nd = rand(dtype, *shape2)
     x_np = x_nd.asnumpy()
@@ -178,40 +192,55 @@ def test_broadcast_subtract():
     func = relay.Function([x, y], x - y)
     func = run_infer_type(func)
     full_func = run_infer_type(gradient(func))
-    assert full_func.checked_type == relay.FuncType([t1, t2],
-                                                    relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
-                                                                     relay.TupleType([t1, t2])]))
+    assert full_func.checked_type == relay.FuncType(
+        [t1, t2],
+        relay.TupleType(
+            [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])]
+        ),
+    )
     ex = create_executor()
     forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
     tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
-    tvm.testing.assert_allclose(grad_x.asnumpy(),
-                                np.ones_like(expected_forward).sum(axis=2, keepdims=True))
-    tvm.testing.assert_allclose(grad_y.asnumpy(),
-                                -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
+    tvm.testing.assert_allclose(
+        grad_x.asnumpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True)
+    )
+    tvm.testing.assert_allclose(
+        grad_y.asnumpy(),
+        -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0),
+    )
 
 
 def _test_tuple(mode):
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     y = relay.var("y", t)
     z = relay.var("z", t)
     if mode == "higher_order":
         tup = relay.Var("tup")
-        func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]),
-                                                   relay.TupleGetItem(tup, 0) +
-                                                   relay.TupleGetItem(tup, 1) -
-                                                   relay.TupleGetItem(tup, 2)))
+        func = relay.Function(
+            [x, y, z],
+            relay.Let(
+                tup,
+                relay.Tuple([x, y, z]),
+                relay.TupleGetItem(tup, 0)
+                + relay.TupleGetItem(tup, 1)
+                - relay.TupleGetItem(tup, 2),
+            ),
+        )
     else:
         # first order does not do let.
         tup = relay.Tuple([x, y, z])
-        func = relay.Function([x, y, z], relay.TupleGetItem(tup, 0) +
-                                         relay.TupleGetItem(tup, 1) -
-                                         relay.TupleGetItem(tup, 2))
+        func = relay.Function(
+            [x, y, z],
+            relay.TupleGetItem(tup, 0) + relay.TupleGetItem(tup, 1) - relay.TupleGetItem(tup, 2),
+        )
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func, mode=mode))
-    assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
+    assert back_func.checked_type == relay.FuncType(
+        [t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])])
+    )
     x_nd = rand(dtype, *shape)
     y_nd = rand(dtype, *shape)
     z_nd = rand(dtype, *shape)
@@ -230,15 +259,17 @@ def _test_tuple(mode):
 def test_tuple():
     _test_tuple("higher_order")
 
+
 def test_tuple_first_order():
     _test_tuple("first_order")
 
+
 def test_pow():
     mod = tvm.IRModule()
     p = Prelude(mod)
     add_nat_definitions(p)
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     double = relay.Function([x], x + x)
@@ -258,7 +289,7 @@ def test_pow():
 
 def test_ref():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     r = relay.Var("r")
@@ -279,14 +310,16 @@ def test_ref():
 
 def test_square_second_order():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     func = relay.Function([x], x * x)
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
     y = relay.var("y", t)
-    back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
+    back_func_adjusted = relay.Function(
+        [y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0)
+    )
     back_func_adjusted = run_infer_type(back_func_adjusted)
     back_back_func = run_infer_type(gradient(back_func_adjusted))
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
@@ -300,19 +333,19 @@ def test_square_second_order():
 def test_if():
     x = relay.var("x", shape=(1, 16, 64, 64))
     y = relay.var("y", shape=(1, 16, 64, 64))
-    cond = relay.var("cond", shape=(), dtype='uint1')
+    cond = relay.var("cond", shape=(), dtype="uint1")
     net = relay.If(cond, x, y)
     net = relay.log(net)
     func = relay.Function(free_vars(net), net)
     func = run_infer_type(func)
-    net = gradient(func, mode='higher_order')
+    net = gradient(func, mode="higher_order")
     net = run_infer_type(net)
 
 
 def test_grad_tuple():
     scope = relay.ScopeBuilder()
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     y = scope.let("y", x + x)
@@ -320,7 +353,9 @@ def test_grad_tuple():
     func = relay.Function([x], scope.get())
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
-    assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])]))
+    assert back_func.checked_type == relay.FuncType(
+        [t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])])
+    )
     ex = create_executor()
     x = rand(dtype, *shape)
     (forward_four, forward_two), (grad,) = ex.evaluate(back_func)(x)
@@ -331,7 +366,7 @@ def test_grad_tuple():
 
 def test_concat():
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     rt = relay.TensorType((10, 20), dtype)
     x = relay.var("x", t)
@@ -339,37 +374,39 @@ def test_concat():
     func = relay.Function([x], y)
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
-    tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
+    tvm.ir.assert_structural_equal(
+        back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])]))
+    )
     # no value validation as concatenate has dummy gradient right now.
 
 
 def test_no_duplication():
-    x = tvm.relay.Var('x', type_annotation=tvm.relay.TensorType([12, 12]))
-    y = tvm.relay.Var('y', type_annotation=tvm.relay.TensorType([12, 12]))
+    x = tvm.relay.Var("x", type_annotation=tvm.relay.TensorType([12, 12]))
+    y = tvm.relay.Var("y", type_annotation=tvm.relay.TensorType([12, 12]))
     xy = tvm.relay.nn.dense(x, y)
 
     m = tvm.relay.sum(xy, keepdims=True)
     s = tvm.relay.sum(xy - m)
-    fn = tvm.relay.Function([x,y], s)
+    fn = tvm.relay.Function([x, y], s)
     fn = run_infer_type(fn)
-    gr = tvm.relay.transform.gradient(fn, mode='first_order')
+    gr = tvm.relay.transform.gradient(fn, mode="first_order")
 
     counts = count_ops(gr)
-    assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)"
+    assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)"
 
 
 def test_global_function():
     m = tvm.IRModule()
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
-    x = relay.Var('x', t)
-    d = GlobalVar('double')
+    x = relay.Var("x", t)
+    d = GlobalVar("double")
     m[d] = relay.Function([x], x + x)
-    y = relay.Var('y', t)
-    q = GlobalVar('q')
+    y = relay.Var("y", t)
+    q = GlobalVar("q")
     m[q] = relay.Function([y], d(d(y)))
-    g = GlobalVar('grad')
+    g = GlobalVar("grad")
     m[g] = tvm.relay.transform.gradient(q, m)
     back_func = m[g]
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
index 3b41f07..aea3a38 100644 (file)
@@ -21,21 +21,19 @@ from tvm import relay
 
 def get_recursive_count_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
     sb = relay.ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
         sb.ret(i)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+        one_less = relay.subtract(i, relay.const(1, dtype="int32"))
         rec_call = relay.Call(sum_up, [one_less])
         sb.ret(relay.add(rec_call, i))
-    func = relay.Function([i],
-                          sb.get(),
-                          ret_type=relay.TensorType([], 'int32'))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
     func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
     mod[sum_up] = func
-    iarg = relay.var('i', shape=[], dtype='int32')
+    iarg = relay.var("i", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     return mod, sum_up
 
@@ -273,14 +271,14 @@ def test_recursive_call_with_global():
     def get_mod():
         mod = tvm.IRModule({})
 
-        x = relay.var('x', shape=[], dtype='int32')
+        x = relay.var("x", shape=[], dtype="int32")
         fn0 = relay.Function([x], x)
         fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         gx = relay.GlobalVar("gx")
         mod[gx] = fn0
 
-        sum_up = relay.GlobalVar('sum_up')
-        i = relay.var('i', shape=[], dtype='int32')
+        sum_up = relay.GlobalVar("sum_up")
+        i = relay.var("i", shape=[], dtype="int32")
         sb = relay.ScopeBuilder()
         with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
             sb.ret(i)
@@ -289,33 +287,29 @@ def test_recursive_call_with_global():
             global_call = gx(i)
             rec_call = relay.Call(sum_up, [one_less]) + global_call
             sb.ret(relay.add(rec_call, i))
-        func = relay.Function([i],
-                              sb.get(),
-                              ret_type=relay.TensorType([], "int32"))
+        func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         mod[sum_up] = func
-        iarg = relay.var("i", shape=[], dtype='int32')
+        iarg = relay.var("i", shape=[], dtype="int32")
         mod["main"] = relay.Function([iarg], sum_up(iarg))
         return mod
 
     def expected():
         mod = tvm.IRModule({})
 
-        sum_up = relay.GlobalVar('sum_up')
-        i = relay.var('i', shape=[], dtype='int32')
+        sum_up = relay.GlobalVar("sum_up")
+        i = relay.var("i", shape=[], dtype="int32")
         sb = relay.ScopeBuilder()
-        with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+        with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
             sb.ret(i)
         with sb.else_scope():
-            one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+            one_less = relay.subtract(i, relay.const(1, dtype="int32"))
             rec_call = relay.Call(sum_up, [one_less]) + i
             sb.ret(relay.add(rec_call, i))
-        func = relay.Function([i],
-                              sb.get(),
-                              ret_type=relay.TensorType([], 'int32'))
+        func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         mod[sum_up] = func
-        iarg = relay.var('i', shape=[], dtype='int32')
+        iarg = relay.var("i", shape=[], dtype="int32")
         mod["main"] = relay.Function([iarg], sum_up(iarg))
         return mod
 
@@ -326,7 +320,7 @@ def test_recursive_call_with_global():
 
 def test_recursive_called():
     mod, sum_up = get_recursive_count_loop()
-    iarg = relay.var('i', shape=[], dtype='int32')
+    iarg = relay.var("i", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     ref_mod = mod
     mod = relay.transform.Inline()(mod)
@@ -510,12 +504,12 @@ def test_inline_globalvar_without_args():
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        g1 = relay.GlobalVar('g1')
-        g2 = relay.GlobalVar('g2')
+        g1 = relay.GlobalVar("g1")
+        g2 = relay.GlobalVar("g2")
         mod[g1] = fn1
         mod[g2] = fn2
-        p = relay.var('p', 'bool')
-        mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
+        p = relay.var("p", "bool")
+        mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
         return mod
 
     def expected():
@@ -524,9 +518,8 @@ def test_inline_globalvar_without_args():
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        p = relay.var('p', 'bool')
-        mod['main'] = relay.Function([p], relay.Call(
-            relay.If(p, fn1, fn2), []))
+        p = relay.var("p", "bool")
+        mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), []))
         return mod
 
     mod = get_mod()
@@ -543,12 +536,12 @@ def test_inline_globalvar_without_args_extern_compiler():
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         fn2 = fn2.with_attr("Compiler", "b")
-        g1 = relay.GlobalVar('g1')
-        g2 = relay.GlobalVar('g2')
+        g1 = relay.GlobalVar("g1")
+        g2 = relay.GlobalVar("g2")
         mod[g1] = fn1
         mod[g2] = fn2
-        p = relay.var('p', 'bool')
-        mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
+        p = relay.var("p", "bool")
+        mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
         return mod
 
     def expected():
@@ -559,9 +552,8 @@ def test_inline_globalvar_without_args_extern_compiler():
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         fn2 = fn2.with_attr("Compiler", "b")
-        p = relay.var('p', 'bool')
-        mod['main'] = relay.Function([p], relay.Call(
-            relay.If(p, fn1, fn2), []))
+        p = relay.var("p", "bool")
+        mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), []))
         return mod
 
     mod = get_mod()
@@ -833,5 +825,5 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
     assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     pytest.main()
index e388878..b19aebd 100644 (file)
@@ -22,25 +22,27 @@ from tvm import te
 from tvm import relay
 from tvm.relay import transform
 
+
 def test_basic():
     mod = tvm.IRModule()
-    x2 = relay.var('x2', shape=(10, 5))
-    y2 = relay.var('y2', shape=(1, 5))
+    x2 = relay.var("x2", shape=(10, 5))
+    y2 = relay.var("y2", shape=(1, 5))
     level2_func = relay.Function([x2, y2], relay.op.add(x2, y2))
 
-    x1 = relay.var('x1', shape=(10, 5))
-    y1 = relay.var('y1', shape=(1, 5))
+    x1 = relay.var("x1", shape=(10, 5))
+    y1 = relay.var("y1", shape=(1, 5))
     level1_func = relay.Function([x1, y1], level2_func(x1, y1))
 
     mod["main"] = level1_func
     new_mod = transform.LambdaLift()(mod)
     assert len(new_mod.functions) == 2
 
+
 def test_closure():
     mod = tvm.IRModule()
 
-    x = relay.var('x', shape=(2,))
-    y = relay.var('y', shape=(2,))
+    x = relay.var("x", shape=(2,))
+    y = relay.var("y", shape=(2,))
     inner_func = relay.Function([x], x + y)
     outer_func = relay.Function([y], inner_func)
     clo = outer_func(relay.ones(shape=(2,), dtype="float32"))
@@ -49,25 +51,28 @@ def test_closure():
     new_mod = transform.LambdaLift()(mod)
     assert len(new_mod.functions) == 3
 
+
 def test_recursive():
     mod = tvm.IRModule()
 
-    x = relay.var('x', shape=(2,))
-    i = relay.var('i', shape=(), dtype='int32')
-    s = relay.var('s', shape=(2,))
-    cond = i < relay.const(10, dtype='int32')
+    x = relay.var("x", shape=(2,))
+    i = relay.var("i", shape=(), dtype="int32")
+    s = relay.var("s", shape=(2,))
+    cond = i < relay.const(10, dtype="int32")
 
-    loop = relay.var('while_loop')
+    loop = relay.var("while_loop")
     sb = relay.scope_builder.ScopeBuilder()
     with sb.if_scope(cond):
-        ii = i + relay.const(1, dtype='int32')
+        ii = i + relay.const(1, dtype="int32")
         ss = s + x
         sb.ret(loop(ii, ss))
     with sb.else_scope():
         sb.ret(s)
     func = relay.Function([i, s], sb.get())
 
-    ret = relay.Let(loop, func, loop(relay.const(0, dtype='int32'), relay.zeros(shape=(2,), dtype='float32')))
+    ret = relay.Let(
+        loop, func, loop(relay.const(0, dtype="int32"), relay.zeros(shape=(2,), dtype="float32"))
+    )
     mod["main"] = relay.Function([x], ret)
 
     new_mod = transform.LambdaLift()(mod)
@@ -76,4 +81,3 @@ def test_recursive():
 
 if __name__ == "__main__":
     pytest.main()
-
index 4a09e4e..403f88d 100644 (file)
@@ -24,391 +24,413 @@ import tvm.testing
 from tvm.testing import assert_allclose
 import pytest
 
+
 def test_tc():
-  """Simple testcase, check that transformation typechecks."""
-  mod = tvm.IRModule()
+    """Simple testcase, check that transformation typechecks."""
+    mod = tvm.IRModule()
+
+    shape = (20, 20)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (20, 20)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    x1 = relay.var("x1", t)
+    x2 = relay.var("x2", t)
+    # f(x1,x2) = (x1-x2)*x2
+    y = relay.Function([x1, x2], (x1 - x2) * x2)
 
-  x1 = relay.var("x1", t)
-  x2 = relay.var("x2", t)
-  # f(x1,x2) = (x1-x2)*x2
-  y = relay.Function([x1, x2], (x1 - x2) * x2)
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
+    # function input/output types should remain the same
+    assert mod["main"].checked_type == relay.FuncType([t, t], t)
 
-  # function input/output types should remain the same
-  assert mod["main"].checked_type == relay.FuncType([t, t], t)
 
 def test_add():
-  """Simple add testcase. Check types and semantic equivalence."""
-  mod = tvm.IRModule()
+    """Simple add testcase. Check types and semantic equivalence."""
+    mod = tvm.IRModule()
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  # f(x) = x+x
-  y = relay.Function([x], x+x)
+    x = relay.var("x", t)
+    # f(x) = x+x
+    y = relay.Function([x], x + x)
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    assert mod["main"].checked_type == relay.FuncType([t], t)
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy())
 
 def test_add_tuple():
-  """Add elements of tuple. Check types and semantic equivalence."""
-  mod = tvm.IRModule()
+    """Add elements of tuple. Check types and semantic equivalence."""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    tensor_type = relay.TensorType(shape, dtype)
+    t = relay.TupleType([tensor_type, tensor_type])
 
-  shape = (10, 10)
-  dtype = 'float32'
-  tensor_type = relay.TensorType(shape, dtype)
-  t = relay.TupleType([tensor_type, tensor_type])
+    x = relay.var("x", t)
+    # f((x1,x2)) = x1 + x2
+    y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1))
 
-  x = relay.var("x", t)
-  # f((x1,x2)) = x1 + x2
-  y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1))
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
+    y = mod["main"]
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
-  y = mod["main"]
+    assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
 
-  assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
+    ex = create_executor(mod=mod)
+    x = (rand(dtype, *shape), rand(dtype, *shape))
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = (rand(dtype, *shape), rand(dtype, *shape))
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy())
 
 def test_mult():
-  """Simple multiplication testcase. Check types and semantic equivalence."""
-  mod = tvm.IRModule()
+    """Simple multiplication testcase. Check types and semantic equivalence."""
+    mod = tvm.IRModule()
 
-  shape = (15, 15)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (15, 15)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  # f(x) = x*x
-  y = relay.Function([x], x * x)
+    x = relay.var("x", t)
+    # f(x) = x*x
+    y = relay.Function([x], x * x)
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    assert mod["main"].checked_type == relay.FuncType([t], t)
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy())
 
 def test_ret_tuple():
-  """Test tuple return type. Check types and semantic equivalence."""
-  mod = tvm.IRModule()
+    """Test tuple return type. Check types and semantic equivalence."""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    # f(x) = (x,x)
+    func = relay.Function([x], relay.Tuple([x, x * relay.const(2.0)]))
+    func = run_infer_type(func)
 
-  x = relay.var("x", t)
-  # f(x) = (x,x)
-  func = relay.Function([x], relay.Tuple([x,x * relay.const(2.0)]))
-  func = run_infer_type(func)
+    mod["main"] = func
+    mod = transform.LazyGradientInit()(mod)
+    func = mod["main"]
 
-  mod["main"] = func
-  mod = transform.LazyGradientInit()(mod)
-  func = mod["main"]
+    assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t]))
 
-  assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t]))
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(func)(x)
+    assert_allclose(y[0].asnumpy(), x.asnumpy())
+    assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0)
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(func)(x)
-  assert_allclose(y[0].asnumpy(), x.asnumpy())
-  assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0)
 
 def test_add_broadcast():
-  """Test adding matrices of different size. Check types and semantic equivalence."""
-  mod = tvm.IRModule()
+    """Test adding matrices of different size. Check types and semantic equivalence."""
+    mod = tvm.IRModule()
 
-  shape1 = (3, 4, 1)
-  shape2 = (1, 5)
-  dtype = 'float32'
-  t1 = relay.TensorType(shape1, dtype)
-  t2 = relay.TensorType(shape2, dtype)
+    shape1 = (3, 4, 1)
+    shape2 = (1, 5)
+    dtype = "float32"
+    t1 = relay.TensorType(shape1, dtype)
+    t2 = relay.TensorType(shape2, dtype)
 
-  x1 = relay.var("x1", t1)
-  x2 = relay.var("x2", t2)
-  func = relay.Function([x1,x2], x1 + x2)
-  func = run_infer_type(func)
+    x1 = relay.var("x1", t1)
+    x2 = relay.var("x2", t2)
+    func = relay.Function([x1, x2], x1 + x2)
+    func = run_infer_type(func)
 
-  mod["main"] = func
-  mod = transform.LazyGradientInit()(mod)
-  func = mod["main"]
+    mod["main"] = func
+    mod = transform.LazyGradientInit()(mod)
+    func = mod["main"]
 
-  x1_np = rand(dtype, *shape1).asnumpy()
-  x2_np = rand(dtype, *shape2).asnumpy()
-  expected_forward = x1_np + x2_np
+    x1_np = rand(dtype, *shape1).asnumpy()
+    x2_np = rand(dtype, *shape2).asnumpy()
+    expected_forward = x1_np + x2_np
 
-  expected_forward_type = relay.TensorType(expected_forward.shape, dtype)
-  assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type)
+    expected_forward_type = relay.TensorType(expected_forward.shape, dtype)
+    assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type)
 
-  ex = create_executor(mod=mod)
-  forward = ex.evaluate(func)(x1_np, x2_np)
+    ex = create_executor(mod=mod)
+    forward = ex.evaluate(func)(x1_np, x2_np)
+
+    assert_allclose(forward.asnumpy(), expected_forward)
 
-  assert_allclose(forward.asnumpy(), expected_forward)
 
 def test_reverse_ad_identity():
-  """Simple test with reverse mode ad."""
-  # of f(x) = x
-  mod = tvm.IRModule()
+    """Simple test with reverse mode ad."""
+    # of f(x) = x
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
 
-  x = relay.var("x", t)
+    func = relay.Function([x], x)
+    func = run_infer_type(func)
+    back_func = transform.gradient(func)
+    back_func = run_infer_type(back_func)
 
-  func = relay.Function([x], x)
-  func = run_infer_type(func)
-  back_func = transform.gradient(func)
-  back_func = run_infer_type(back_func)
+    mod["main"] = back_func
+    mod = transform.LazyGradientInit()(mod)
+    back_func = mod["main"]
 
-  mod["main"] = back_func
-  mod = transform.LazyGradientInit()(mod)
-  back_func = mod["main"]
+    assert mod["main"].checked_type == relay.FuncType(
+        [t], relay.TupleType([t, relay.TupleType([t])])
+    )
 
-  assert mod["main"].checked_type == relay.FuncType([t],
-                                                    relay.TupleType([t, relay.TupleType([t])]))
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    (forward), (grad,) = ex.evaluate(back_func)(x)
+    assert_allclose(forward.asnumpy(), x.asnumpy())
+    assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  (forward), (grad,) = ex.evaluate(back_func)(x)
-  assert_allclose(forward.asnumpy(), x.asnumpy())
-  assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
 def test_multivar_reverse_ad():
-  """Simple test with multivariate reverse mode ad."""
-  mod = tvm.IRModule()
+    """Simple test with multivariate reverse mode ad."""
+    mod = tvm.IRModule()
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  y = relay.var("y", t)
+    x = relay.var("x", t)
+    y = relay.var("y", t)
 
-  func = relay.Function([x, y],  (x * y) * relay.const(np.ones(shape, dtype)))
-  func = run_infer_type(func)
-  back_func = transform.gradient(func)
-  back_func = run_infer_type(back_func)
+    func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype)))
+    func = run_infer_type(func)
+    back_func = transform.gradient(func)
+    back_func = run_infer_type(back_func)
 
-  mod["main"] = back_func
-  mod = transform.LazyGradientInit()(mod)
-  back_func = mod["main"]
+    mod["main"] = back_func
+    mod = transform.LazyGradientInit()(mod)
+    back_func = mod["main"]
 
-  assert mod["main"].checked_type == relay.FuncType([t, t],
-                                                    relay.TupleType([t, relay.TupleType([t, t])]))
+    assert mod["main"].checked_type == relay.FuncType(
+        [t, t], relay.TupleType([t, relay.TupleType([t, t])])
+    )
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = rand(dtype, *shape)
+    (forward), (grad_x, grad_y,) = ex.evaluate(
+        back_func
+    )(x, y)
+    assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
+    assert_allclose(grad_x.asnumpy(), y.asnumpy())
+    assert_allclose(grad_y.asnumpy(), x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = rand(dtype, *shape)
-  (forward), (grad_x, grad_y, ) = ex.evaluate(back_func)(x, y)
-  assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
-  assert_allclose(grad_x.asnumpy(), y.asnumpy())
-  assert_allclose(grad_y.asnumpy(), x.asnumpy())
 
 def test_partial_eval():
-  """Test transformation following reverse mode ad and PartialEval"""
-  mod = tvm.IRModule()
+    """Test transformation following reverse mode ad and PartialEval"""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    func = relay.Function([], relay.const(np.ones(shape, dtype)))
+    func = run_infer_type(func)
+    back_func = transform.gradient(func)
+    back_func = run_infer_type(back_func)
 
-  func = relay.Function([], relay.const(np.ones(shape, dtype)))
-  func = run_infer_type(func)
-  back_func = transform.gradient(func)
-  back_func = run_infer_type(back_func)
+    mod["main"] = back_func
+    back_func = mod["main"]
 
-  mod["main"] = back_func
-  back_func = mod["main"]
+    transform.PartialEvaluate()(mod)
 
-  transform.PartialEvaluate()(mod)
 
 def test_after_partial_eval():
-  """Test transformation following reverse mode ad and PartialEval"""
-  mod = tvm.IRModule()
+    """Test transformation following reverse mode ad and PartialEval"""
+    mod = tvm.IRModule()
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  y = relay.var("y", t)
+    x = relay.var("x", t)
+    y = relay.var("y", t)
 
-  func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype)))
-  func = run_infer_type(func)
-  back_func = transform.gradient(func)
-  back_func = run_infer_type(back_func)
+    func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype)))
+    func = run_infer_type(func)
+    back_func = transform.gradient(func)
+    back_func = run_infer_type(back_func)
 
-  mod["main"] = back_func
-  back_func = mod["main"]
+    mod["main"] = back_func
+    back_func = mod["main"]
 
-  seq = tvm.transform.Sequential([
-    transform.PartialEvaluate(),
-    transform.LazyGradientInit(),
-    transform.DeadCodeElimination()
-  ])
+    seq = tvm.transform.Sequential(
+        [transform.PartialEvaluate(), transform.LazyGradientInit(), transform.DeadCodeElimination()]
+    )
 
-  mod = seq(mod)
+    mod = seq(mod)
 
-  assert mod["main"].checked_type == relay.FuncType([t, t],
-                                                    relay.TupleType([t, relay.TupleType([t, t])]))
+    assert mod["main"].checked_type == relay.FuncType(
+        [t, t], relay.TupleType([t, relay.TupleType([t, t])])
+    )
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = rand(dtype, *shape)
+    (forward), (grad_x, grad_y,) = ex.evaluate(
+        back_func
+    )(x, y)
+    assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
+    assert_allclose(grad_x.asnumpy(), y.asnumpy())
+    assert_allclose(grad_y.asnumpy(), x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = rand(dtype, *shape)
-  (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y)
-  assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
-  assert_allclose(grad_x.asnumpy(), y.asnumpy())
-  assert_allclose(grad_y.asnumpy(), x.asnumpy())
 
 def test_before_partial_eval():
-  """Test transformation before PartialEval"""
-  mod = tvm.IRModule()
-
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
-
-  x = relay.var("x", t)
-  y = relay.var("y", t)
-
-  func = relay.Function([x, y], x * y)
-  func = run_infer_type(func)
-  back_func = transform.gradient(func)
-  back_func = run_infer_type(back_func)
-
-  mod["main"] = back_func
-  seq = tvm.transform.Sequential([
-    transform.LazyGradientInit(),
-    transform.PartialEvaluate(),
-    transform.DeadCodeElimination()
-  ])
-  mod = seq(mod)
-  back_func = mod["main"]
-
-  assert mod["main"].checked_type == relay.FuncType([t, t],
-                                                    relay.TupleType([t, relay.TupleType([t, t])]))
-
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = rand(dtype, *shape)
-  (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y)
-  assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
-  assert_allclose(grad_x.asnumpy(), y.asnumpy())
-  assert_allclose(grad_y.asnumpy(), x.asnumpy())
+    """Test transformation before PartialEval"""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
+
+    x = relay.var("x", t)
+    y = relay.var("y", t)
+
+    func = relay.Function([x, y], x * y)
+    func = run_infer_type(func)
+    back_func = transform.gradient(func)
+    back_func = run_infer_type(back_func)
+
+    mod["main"] = back_func
+    seq = tvm.transform.Sequential(
+        [transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination()]
+    )
+    mod = seq(mod)
+    back_func = mod["main"]
+
+    assert mod["main"].checked_type == relay.FuncType(
+        [t, t], relay.TupleType([t, relay.TupleType([t, t])])
+    )
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = rand(dtype, *shape)
+    (forward), (grad_x, grad_y,) = ex.evaluate(
+        back_func
+    )(x, y)
+    assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
+    assert_allclose(grad_x.asnumpy(), y.asnumpy())
+    assert_allclose(grad_y.asnumpy(), x.asnumpy())
+
 
 def test_zeros():
-  """Simple test using "zeros" op"""
-  mod = tvm.IRModule()
+    """Simple test using "zeros" op"""
+    mod = tvm.IRModule()
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  y = relay.Function([x], x + relay.zeros(shape, dtype))
+    x = relay.var("x", t)
+    y = relay.Function([x], x + relay.zeros(shape, dtype))
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    assert mod["main"].checked_type == relay.FuncType([t], t)
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy())
 
 def test_ones():
-  """Simple test using "ones" op"""
-  mod = tvm.IRModule()
+    """Simple test using "ones" op"""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    y = relay.Function([x], x + relay.ones(shape, dtype))
 
-  x = relay.var("x", t)
-  y = relay.Function([x], x + relay.ones(shape, dtype))
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    assert mod["main"].checked_type == relay.FuncType([t], t)
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
 
 def test_zeros_like():
-  """Simple test using "zeros_like" op"""
-  mod = tvm.IRModule()
+    """Simple test using "zeros_like" op"""
+    mod = tvm.IRModule()
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  x = relay.var("x", t)
-  y = relay.Function([x], x + relay.zeros_like(x))
+    x = relay.var("x", t)
+    y = relay.Function([x], x + relay.zeros_like(x))
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    assert mod["main"].checked_type == relay.FuncType([t], t)
+
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy())
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy())
 
 def test_ones_like():
-  """Simple test using "ones_like" op"""
-  mod = tvm.IRModule()
+    """Simple test using "ones_like" op"""
+    mod = tvm.IRModule()
+
+    shape = (10, 10)
+    dtype = "float32"
+    t = relay.TensorType(shape, dtype)
 
-  shape = (10, 10)
-  dtype = 'float32'
-  t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    y = relay.Function([x], x + relay.ones_like(x))
 
-  x = relay.var("x", t)
-  y = relay.Function([x], x + relay.ones_like(x))
+    mod["main"] = y
+    mod = transform.LazyGradientInit()(mod)
+    y = mod["main"]
 
-  mod["main"] = y
-  mod = transform.LazyGradientInit()(mod)
-  y = mod["main"]
+    assert mod["main"].checked_type == relay.FuncType([t], t)
 
-  assert mod["main"].checked_type == relay.FuncType([t], t)
+    ex = create_executor(mod=mod)
+    x = rand(dtype, *shape)
+    y = ex.evaluate(y)(x)
+    assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
 
-  ex = create_executor(mod=mod)
-  x = rand(dtype, *shape)
-  y = ex.evaluate(y)(x)
-  assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
 
 if __name__ == "__main__":
-  pytest.main([__file__])
+    pytest.main([__file__])
index 0882149..0d14f66 100644 (file)
@@ -34,15 +34,14 @@ def run_opt_pass(expr, passes):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
+
 def test_legalize():
     """Test directly replacing an operator with a new one"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -54,11 +53,14 @@ def test_legalize():
 
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            relay.multiply(weight, relay.const(2.0, "float32")),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -70,8 +72,10 @@ def test_legalize():
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_legalize_none():
     """Test doing nothing by returning 'None' """
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         y = relay.nn.global_max_pool2d(x)
@@ -90,17 +94,16 @@ def test_legalize_none():
         b = run_opt_pass(before(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
-    assert(called[0])
+    assert called[0]
+
 
 def test_legalize_multiple_ops():
     """Test directly replacing an operator with a new one"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, weight,
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
         return y
@@ -115,14 +118,16 @@ def test_legalize_multiple_ops():
         add = relay.add(tvm.relay.const(0, "float32"), data)
         return relay.nn.relu(add)
 
-
     def expected():
         x = relay.var("x", shape=(1, 64, 56, 56))
-        weight = relay.var('weight', shape=(64, 64, 3, 3))
-        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
-                            channels=64,
-                            kernel_size=(3, 3),
-                            padding=(1, 1))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            relay.multiply(weight, relay.const(2.0, "float32")),
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+        )
         y = relay.add(tvm.relay.const(0, "float32"), y)
         y = relay.nn.relu(y)
         y = relay.Function([x, weight], y)
@@ -139,6 +144,7 @@ def test_legalize_multiple_ops():
 
 def test_legalize_multi_input():
     """Test directly replacing an operator with a new one"""
+
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))
         y = relay.var("y", shape=(1, 64, 56, 20))
@@ -164,7 +170,6 @@ def test_legalize_multi_input():
         func = relay.Function([x, y, z], func)
         return func
 
-
     with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
         a = before()
         a = run_opt_pass(a, transform.Legalize())
index d490ac7..b3a062f 100644 (file)
@@ -39,13 +39,13 @@ def test_gemm():
     data1 = relay.var("data1", shape=dshape1)
     data2 = relay.var("data2", shape=dshape2)
     gemm = relay.nn.dense(data1, data2)
-    func = relay.Function([data1, data2],
-                            relay.Tuple(tvm.runtime.convert([gemm])))
+    func = relay.Function([data1, data2], relay.Tuple(tvm.runtime.convert([gemm])))
     func = run_opt_pass(func, transform.InferType())
     compute_count = analysis.get_total_mac_number(func)
     expect_count = n * m * k
     assert compute_count == expect_count
 
+
 def test_conv():
     batch_size = 1
     input_channel = 3
@@ -62,52 +62,40 @@ def test_conv():
     weight = relay.var("weight", shape=(output_channel, input_channel, kh, kw))
     data = relay.var("data", shape=dshape)
     conv2d = relay.nn.conv2d(
-        data,
-        weight,
-        channels=output_channel,
-        kernel_size=(kh, kw),
-        padding=(h_padding, w_padding))
+        data, weight, channels=output_channel, kernel_size=(kh, kw), padding=(h_padding, w_padding)
+    )
     func = relay.Function([data, weight], relay.Tuple(tvm.runtime.convert([conv2d])))
     func = run_opt_pass(func, transform.InferType())
     compute_count = analysis.get_total_mac_number(func)
     expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw
     assert compute_count == expect_count
 
+
 def test_simple_network():
     batch_size = 1
     dshape = (batch_size, 64, 56, 56)
     weight_conv = relay.var("weight_conv", shape=(64, 64, 3, 3))
     data1 = relay.var("data1", shape=dshape)
     data2 = relay.var("data2", shape=dshape)
-    weight_dense = relay.var("weight_dense", shape=(1, 56*56*64))
-
-    conv2d_1 = relay.nn.conv2d(
-        data1,
-        weight_conv,
-        channels=64,
-        kernel_size=(3, 3),
-        padding=(1, 1))
-    conv2d_2 = relay.nn.conv2d(
-        data2,
-        weight_conv,
-        channels=64,
-        kernel_size=(3, 3),
-        padding=(1, 1))
+    weight_dense = relay.var("weight_dense", shape=(1, 56 * 56 * 64))
+
+    conv2d_1 = relay.nn.conv2d(data1, weight_conv, channels=64, kernel_size=(3, 3), padding=(1, 1))
+    conv2d_2 = relay.nn.conv2d(data2, weight_conv, channels=64, kernel_size=(3, 3), padding=(1, 1))
     add = relay.add(conv2d_1, conv2d_2)
     flattened = relay.nn.batch_flatten(add)
-    dense_1 = relay.nn.dense(
-        flattened,
-        weight_dense)
+    dense_1 = relay.nn.dense(flattened, weight_dense)
 
-    func = relay.Function([data1, data2, weight_conv, weight_dense],
-                            relay.Tuple(tvm.runtime.convert([conv2d_1, conv2d_2,
-                                                    dense_1, add, flattened])))
+    func = relay.Function(
+        [data1, data2, weight_conv, weight_dense],
+        relay.Tuple(tvm.runtime.convert([conv2d_1, conv2d_2, dense_1, add, flattened])),
+    )
     # alter the CONV 2D data layout to test
     func = run_opt_pass(func, transform.AlterOpLayout())
     compute_count = analysis.get_total_mac_number(func)
     expect_count = 231411712
     assert compute_count == expect_count
 
+
 def test_depthwise_conv2d():
     batch_size = 1
     dshape = (batch_size, 64, 56, 56)
@@ -115,25 +103,20 @@ def test_depthwise_conv2d():
     data1 = relay.var("data1", shape=dshape)
     data2 = relay.var("data2", shape=dshape)
     depthwise_conv2d_1 = relay.nn.conv2d(
-        data1,
-        weight_conv,
-        kernel_size=(3, 3),
-        padding=(1, 1),
-        groups=64)
+        data1, weight_conv, kernel_size=(3, 3), padding=(1, 1), groups=64
+    )
     depthwise_conv2d_2 = relay.nn.conv2d(
-        data2,
-        weight_conv,
-        kernel_size=(3, 3),
-        padding=(1, 1),
-        groups=64)
+        data2, weight_conv, kernel_size=(3, 3), padding=(1, 1), groups=64
+    )
     add = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
-    func = relay.Function([data1, data2, weight_conv],
-                            relay.Tuple(tvm.runtime.convert([depthwise_conv2d_1,
-                                                    depthwise_conv2d_2,
-                                                    add])))
+    func = relay.Function(
+        [data1, data2, weight_conv],
+        relay.Tuple(tvm.runtime.convert([depthwise_conv2d_1, depthwise_conv2d_2, add])),
+    )
     func = run_opt_pass(func, transform.InferType())
     compute_count = analysis.get_total_mac_number(func)
-    assert compute_count == 2 * np.prod(dshape) * 3*3
+    assert compute_count == 2 * np.prod(dshape) * 3 * 3
+
 
 def test_conv_2d_transpose():
     batch_size = 1
@@ -151,18 +134,15 @@ def test_conv_2d_transpose():
     weight = relay.var("weight", shape=(input_channel, output_channel, kh, kw))
     data = relay.var("data", shape=dshape)
     conv2d_transpose = relay.nn.conv2d_transpose(
-        data,
-        weight,
-        channels=output_channel,
-        kernel_size=(kh, kw),
-        padding=(h_padding, w_padding))
-    func = relay.Function([data, weight],
-                            relay.Tuple(tvm.runtime.convert([conv2d_transpose])))
+        data, weight, channels=output_channel, kernel_size=(kh, kw), padding=(h_padding, w_padding)
+    )
+    func = relay.Function([data, weight], relay.Tuple(tvm.runtime.convert([conv2d_transpose])))
     func = run_opt_pass(func, transform.InferType())
     compute_count = analysis.get_total_mac_number(func)
     expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw
     assert compute_count == expect_count
 
+
 if __name__ == "__main__":
     test_conv()
     test_gemm()
index 65ebf76..c34ca07 100644 (file)
@@ -69,21 +69,18 @@ def update_func(func):
 
         def visit_function(self, fn):
             new_body = self.visit(fn.body)
-            return Function(
-                list(fn.params), new_body, fn.ret_type, fn.type_params,
-                fn.attrs)
+            return Function(list(fn.params), new_body, fn.ret_type, fn.type_params, fn.attrs)
 
     double_value = DoubleValues()
     return double_value.visit(func)
 
 
-class OptTester():
+class OptTester:
     """A helper class for testing the pass manager."""
 
     def __init__(self, mod):
         if not isinstance(mod, tvm.IRModule):
-            raise TypeError("mod is expected to be the type of "
-                            "tvm.IRModule")
+            raise TypeError("mod is expected to be the type of " "tvm.IRModule")
         self.mod = mod
 
     def analysis(self):
@@ -105,7 +102,7 @@ class OptTester():
         raise TypeError("Found not supported node type.")
 
 
-def get_rand(shape, dtype='float32'):
+def get_rand(shape, dtype="float32"):
     return tvm.nd.array(np.random.rand(*shape).astype(dtype))
 
 
@@ -118,7 +115,7 @@ def check_func(func, ref_func):
 @tvm.testing.uses_gpu
 def test_module_pass():
     shape = (5, 10)
-    dtype = 'float32'
+    dtype = "float32"
     tp = relay.TensorType(shape, dtype)
     x = relay.var("x", tp)
     y = relay.var("y", tp)
@@ -145,6 +142,7 @@ def test_module_pass():
     def test_pass_registration_no_decorator():
         def direct_transform(expr, ctx):
             return opt_tester.transform(expr, ctx)
+
         mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
         assert isinstance(mod_pass, tvm.transform.ModulePass)
         pass_info = mod_pass.info
@@ -197,6 +195,7 @@ def test_function_class_pass():
     @relay.transform.function_pass(opt_level=1)
     class TestReplaceFunc:
         """Simple test function to replace one argument to another."""
+
         def __init__(self, new_func):
             self.new_func = new_func
 
@@ -218,8 +217,8 @@ def test_function_class_pass():
 
 @tvm.testing.uses_gpu
 def test_function_pass():
-    shape = (10, )
-    dtype = 'float32'
+    shape = (10,)
+    dtype = "float32"
     tp = relay.TensorType(shape, dtype)
     x = relay.var("x", tp)
     v_log = relay.GlobalVar("myLog")
@@ -249,6 +248,7 @@ def test_function_pass():
     def test_pass_registration_no_decorator():
         def direct_transform(expr, ctx):
             return opt_tester.transform(expr, ctx)
+
         mod_pass = _transform.function_pass(direct_transform, opt_level=0)
         assert isinstance(mod_pass, _transform.FunctionPass)
         pass_info = mod_pass.info
@@ -291,6 +291,7 @@ def test_module_class_pass():
     @tvm.transform.module_pass(opt_level=1)
     class TestPipeline:
         """Simple test function to replace one argument to another."""
+
         def __init__(self, new_mod, replace):
             self.new_mod = new_mod
             self.replace = replace
@@ -319,8 +320,8 @@ def test_pass_info():
 
 @tvm.testing.uses_gpu
 def test_sequential_pass():
-    shape = (10, )
-    dtype = 'float32'
+    shape = (10,)
+    dtype = "float32"
     tp = relay.TensorType(shape, dtype)
     x = relay.var("x", tp)
     y = relay.var("y", tp)
@@ -338,9 +339,7 @@ def test_sequential_pass():
         return ref_log
 
     def get_ref_sub():
-        ref_sub = relay.Function([x, y],
-                                 relay.subtract(
-                                     relay.add(x, x), relay.add(y, y)))
+        ref_sub = relay.Function([x, y], relay.subtract(relay.add(x, x), relay.add(y, y)))
         return ref_sub
 
     def get_ref_abs():
@@ -467,6 +466,7 @@ def test_sequential_with_scoping():
     shape = (1, 2, 3)
     c_data = np.array(shape).astype("float32")
     tp = relay.TensorType(shape, "float32")
+
     def before():
         c = relay.const(c_data)
         x = relay.var("x", tp)
@@ -486,12 +486,14 @@ def test_sequential_with_scoping():
         z1 = relay.add(z, z)
         return relay.Function([x], z1)
 
-    seq = tvm.transform.Sequential([
-        relay.transform.InferType(),
-        relay.transform.FoldConstant(),
-        relay.transform.EliminateCommonSubexpr(),
-        relay.transform.AlterOpLayout()
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            relay.transform.InferType(),
+            relay.transform.FoldConstant(),
+            relay.transform.EliminateCommonSubexpr(),
+            relay.transform.AlterOpLayout(),
+        ]
+    )
 
     mod = tvm.IRModule({"main": before()})
     with tvm.transform.PassContext(opt_level=3):
@@ -511,12 +513,14 @@ def test_print_ir(capfd):
     y = relay.multiply(y, relay.const(2, "float32"))
     func = relay.Function([x], y)
 
-    seq = tvm.transform.Sequential([
-        relay.transform.InferType(),
-        relay.transform.FoldConstant(),
-        tvm.transform.PrintIR(),
-        relay.transform.DeadCodeElimination()
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            relay.transform.InferType(),
+            relay.transform.FoldConstant(),
+            tvm.transform.PrintIR(),
+            relay.transform.DeadCodeElimination(),
+        ]
+    )
 
     mod = tvm.IRModule({"main": func})
     with tvm.transform.PassContext(opt_level=3):
@@ -527,8 +531,10 @@ def test_print_ir(capfd):
     assert "PrintIR" in out
     assert "multiply" in out
 
+
 __TRACE_COUNTER__ = 0
 
+
 def _tracer(module, info, is_before):
     global __TRACE_COUNTER__
     if bool(is_before):
@@ -544,11 +550,13 @@ def test_print_debug_callback():
     y = relay.multiply(y, relay.const(2, "float32"))
     func = relay.Function([x], y)
 
-    seq = tvm.transform.Sequential([
-        relay.transform.InferType(),
-        relay.transform.FoldConstant(),
-        relay.transform.DeadCodeElimination()
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            relay.transform.InferType(),
+            relay.transform.FoldConstant(),
+            relay.transform.DeadCodeElimination(),
+        ]
+    )
 
     assert __TRACE_COUNTER__ == 0
     mod = tvm.IRModule({"main": func})
index 7d7db35..5dee052 100644 (file)
@@ -38,8 +38,9 @@ def test_diamond_graph_fanouts():
     Note that we can't just merge the three supported operators together,
     otherwise both subgraphs would depend on the other.
     """
+
     def diamond_graph_fanouts():
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         cb_1 = compiler_begin(data, "test")
         O_1 = relay.abs(cb_1)
         ce_1 = compiler_end(O_1, "test")
@@ -49,7 +50,6 @@ def test_diamond_graph_fanouts():
         O_2 = relay.nn.relu(cb_2)
         ce_3 = compiler_end(O_2, "test")
 
-
         X = relay.tanh(cb_3)
         ce_4 = compiler_end(X, "default")
 
@@ -62,7 +62,7 @@ def test_diamond_graph_fanouts():
         return diamond
 
     def expected():
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         cb_1 = compiler_begin(data, "test")
         O_1 = relay.abs(cb_1)
         ce_2 = compiler_end(O_1, "test")
@@ -92,17 +92,18 @@ def test_example_graph():
     See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
     Blue nodes are adds (target: test), red nodes are subtracts (target: default).
     """
+
     def annotated():
-        in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
-        in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
-        in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
-        in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
-        in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
-        in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
-        in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
-        in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
-        in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
-        in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
+        in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
+        in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
+        in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
+        in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
+        in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
+        in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
+        in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
+        in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
+        in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
+        in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")
 
         begin0 = compiler_begin(in_1, "test")
         begin1 = compiler_begin(in_2, "test")
@@ -154,16 +155,16 @@ def test_example_graph():
         return mod
 
     def expected():
-        in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
-        in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
-        in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
-        in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
-        in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
-        in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
-        in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
-        in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
-        in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
-        in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
+        in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
+        in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
+        in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
+        in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
+        in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
+        in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
+        in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
+        in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
+        in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
+        in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")
 
         begin0 = compiler_begin(in_1, "test")
         begin1 = compiler_begin(in_2, "test")
index aef6ab5..1ec8f97 100644 (file)
@@ -68,10 +68,10 @@ codegen function.
 def make_add_sub_mul_pattern():
     r"""Create a pattern to match the following graph.
 
-        add  sub
-         \   /
-          \ /
-          mul
+    add  sub
+     \   /
+      \ /
+      mul
     """
     x = wildcard()
     y = wildcard()
@@ -81,48 +81,48 @@ def make_add_sub_mul_pattern():
 def make_add_relu_pattern():
     r"""Create a pattern to match the following graph.
 
-        add
-         |
-       relu
+     add
+      |
+    relu
     """
     add_node = wildcard() + wildcard()
-    r = is_op('nn.relu')(add_node)
+    r = is_op("nn.relu")(add_node)
     return r
 
 
 def make_conv_bias_relu_pattern():
     r"""Create a pattern to match the following graph.
 
-       conv2d
-         |
-      bias_add
-         |
-       relu
+     conv2d
+       |
+    bias_add
+       |
+     relu
     """
     x = wildcard()
     y = wildcard()
     z = wildcard()
-    conv_node = is_op('nn.conv2d')(x, y)
-    bias_node = is_op('nn.bias_add')(conv_node, z)
-    r = is_op('nn.relu')(bias_node)
+    conv_node = is_op("nn.conv2d")(x, y)
+    bias_node = is_op("nn.bias_add")(conv_node, z)
+    r = is_op("nn.relu")(bias_node)
     return r
 
 
 def make_pattern_with_optional():
     r"""Create a pattern to match the following graph. Note that relu is optinal.
 
-       conv2d
-         |
-      bias_add
-         |
-       (relu)
+     conv2d
+       |
+    bias_add
+       |
+     (relu)
     """
     x = wildcard()
     y = wildcard()
     z = wildcard()
-    conv_node = is_op('nn.conv2d')(x, y)
-    bias_node = is_op('nn.bias_add')(conv_node, z)
-    r = bias_node.optional(lambda x: is_op('nn.relu')(x))
+    conv_node = is_op("nn.conv2d")(x, y)
+    bias_node = is_op("nn.bias_add")(conv_node, z)
+    r = bias_node.optional(lambda x: is_op("nn.relu")(x))
     return r
 
 
@@ -140,11 +140,12 @@ def make_add_add_add_pattern():
     """
     x = wildcard()
     y = wildcard()
-    add_node = is_op('add')(x, y)
-    add_node_1 = is_op('add')(x, add_node)
-    r = is_op('add')(add_node_1, add_node)
+    add_node = is_op("add")(x, y)
+    add_node_1 = is_op("add")(x, add_node)
+    r = is_op("add")(add_node_1, add_node)
     return r
 
+
 def make_bn_relu_pattern():
     r"""Create a pattern to match the following graph.
 
@@ -159,19 +160,25 @@ def make_bn_relu_pattern():
     beta = wildcard()
     moving_mean = wildcard()
     moving_var = wildcard()
-    bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
+    bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var)
     tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
-    r = is_op('nn.relu')(tuple_get_item_node)
+    r = is_op("nn.relu")(tuple_get_item_node)
     return r
 
+
 def check_result(pattern_table, graph, expected_graph, import_prelude=False):
     """Utility function to check merge composite results."""
-    result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude)
-    assert not relay.analysis.free_vars(result), \
-        "Found free vars in the result graph: {0}".format(str(result))
+    result = run_opt_pass(
+        graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
+    )
+    assert not relay.analysis.free_vars(result), "Found free vars in the result graph: {0}".format(
+        str(result)
+    )
     expected = run_opt_pass(expected_graph, relay.transform.InferType())
-    assert tvm.ir.structural_equal(result, expected, map_free_vars=True), \
-        "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected))
+    assert tvm.ir.structural_equal(
+        result, expected, map_free_vars=True
+    ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected))
+
 
 def test_simple_merge():
     r"""Test composite function is correctly produced from simple graph.
@@ -186,24 +193,22 @@ def test_simple_merge():
        relu
 
     """
-    pattern_table = [
-        ("add_relu", make_add_relu_pattern())
-    ]
+    pattern_table = [("add_relu", make_add_relu_pattern())]
 
     def before():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
         add_node = relay.add(a, b)
         r = relay.nn.relu(add_node)
         return relay.Function([a, b], r)
 
     def expected():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
 
         # add_relu function
-        in_1 = relay.var('in_1', shape=(10, 10))
-        in_2 = relay.var('in_2', shape=(10, 10))
+        in_1 = relay.var("in_1", shape=(10, 10))
+        in_2 = relay.var("in_2", shape=(10, 10))
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
@@ -241,14 +246,12 @@ def test_branch_merge():
         relu
     """
 
-    pattern_table = [
-        ("add_sub_mul", make_add_sub_mul_pattern())
-    ]
+    pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
 
     def before():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
-        c = relay.var('c', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
+        c = relay.var("c", shape=(10, 10))
         add_node = relay.add(a, b)
         sub_node = relay.subtract(a, b)
         mul_node = relay.multiply(add_node, sub_node)
@@ -259,13 +262,13 @@ def test_branch_merge():
         return relay.Function([a, b, c], r)
 
     def expected():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
-        c = relay.var('c', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
+        c = relay.var("c", shape=(10, 10))
 
         # add_sub_mul function
-        in_1 = relay.var('in_1', shape=(10, 10))
-        in_2 = relay.var('in_2', shape=(10, 10))
+        in_1 = relay.var("in_1", shape=(10, 10))
+        in_2 = relay.var("in_2", shape=(10, 10))
         add_node = relay.add(in_1, in_2)
         sub_node = relay.subtract(in_1, in_2)
         mul_node = relay.multiply(add_node, sub_node)
@@ -274,8 +277,8 @@ def test_branch_merge():
         add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
 
         # add_sub_mul1 function
-        in_3 = relay.var('in_3', shape=(10, 10))
-        in_4 = relay.var('in_4', shape=(10, 10))
+        in_3 = relay.var("in_3", shape=(10, 10))
+        in_4 = relay.var("in_4", shape=(10, 10))
         add_node_1 = relay.add(in_3, in_4)
         sub_node_1 = relay.subtract(in_3, in_4)
         mul_node_1 = relay.multiply(add_node_1, sub_node_1)
@@ -310,13 +313,11 @@ def test_reuse_call_merge():
           add
 
     """
-    pattern_table = [
-        ("add_add_add", make_add_add_add_pattern())
-    ]
+    pattern_table = [("add_add_add", make_add_add_add_pattern())]
 
     def before():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
         sub_node = relay.subtract(a, b)
 
         # pattern
@@ -327,12 +328,12 @@ def test_reuse_call_merge():
         return relay.Function([a, b], r)
 
     def expected():
-        a = relay.var('a', shape=(10, 10))
-        b = relay.var('b', shape=(10, 10))
+        a = relay.var("a", shape=(10, 10))
+        b = relay.var("b", shape=(10, 10))
 
         # add_relu_add function
-        in_1 = relay.var('in_1', shape=(10, 10))
-        in_2 = relay.var('in_2', shape=(10, 10))
+        in_1 = relay.var("in_1", shape=(10, 10))
+        in_2 = relay.var("in_2", shape=(10, 10))
         add_node = relay.add(in_1, in_2)
         add_node_1 = relay.add(in_1, add_node)
         add_node_2 = relay.add(add_node_1, add_node)
@@ -374,21 +375,19 @@ def test_multiple_patterns():
     """
     pattern_table = [
         ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
-        ("add_relu", make_add_relu_pattern())
+        ("add_relu", make_add_relu_pattern()),
     ]
 
     def before():
-        data = relay.var('data', shape=(1, 512, 28, 28))
-        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
-        bias = relay.var('bias', shape=(256,))
-        a = relay.var('a', shape=(1, 256, 28, 28))
-        b = relay.var('b', shape=(1, 256, 28, 28))
-
-        conv_node = relay.nn.conv2d(data,
-                                    kernel,
-                                    kernel_size=(1, 1),
-                                    padding=(0, 0),
-                                    strides=(1, 1))
+        data = relay.var("data", shape=(1, 512, 28, 28))
+        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
+        bias = relay.var("bias", shape=(256,))
+        a = relay.var("a", shape=(1, 256, 28, 28))
+        b = relay.var("b", shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(
+            data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)
+        )
 
         bias_node = relay.nn.bias_add(conv_node, bias)
         relu_node = relay.nn.relu(bias_node)
@@ -398,34 +397,30 @@ def test_multiple_patterns():
         return relay.Function([data, kernel, bias, a, b], r)
 
     def expected():
-        data = relay.var('data', shape=(1, 512, 28, 28))
-        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
-        bias = relay.var('bias', shape=(256,))
-        a = relay.var('a', shape=(1, 256, 28, 28))
-        b = relay.var('b', shape=(1, 256, 28, 28))
+        data = relay.var("data", shape=(1, 512, 28, 28))
+        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
+        bias = relay.var("bias", shape=(256,))
+        a = relay.var("a", shape=(1, 256, 28, 28))
+        b = relay.var("b", shape=(1, 256, 28, 28))
 
         # conv_bias_relu function
-        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
-        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
-        in_3 = relay.var('in_3', shape=(256,))
+        in_1 = relay.var("in_1", shape=(1, 512, 28, 28))
+        in_2 = relay.var("in_2", shape=(256, 512, 1, 1))
+        in_3 = relay.var("in_3", shape=(256,))
 
-        conv_node = relay.nn.conv2d(in_1,
-                                    in_2,
-                                    kernel_size=(1, 1),
-                                    padding=(0, 0),
-                                    strides=(1, 1))
+        conv_node = relay.nn.conv2d(in_1, in_2, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1))
 
         bias_node = relay.nn.bias_add(conv_node, in_3)
         r = relay.nn.relu(bias_node)
         conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
-        conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
-                                                          "conv2d_bias_relu")
-        conv_bias_add_relu = conv_bias_add_relu.with_attr("PartitionedFromPattern",
-                                                          "nn.conv2d_nn.bias_add_nn.relu_")
+        conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", "conv2d_bias_relu")
+        conv_bias_add_relu = conv_bias_add_relu.with_attr(
+            "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
+        )
 
         # add_relu function
-        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
-        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        in_4 = relay.var("in_4", shape=(1, 256, 28, 28))
+        in_5 = relay.var("in_5", shape=(1, 256, 28, 28))
         add_node = relay.add(in_4, in_5)
         r = relay.nn.relu(add_node)
         add_relu = relay.Function([in_4, in_5], r)
@@ -462,11 +457,11 @@ def test_optional_pattern():
     pattern_table = [("layer", make_pattern_with_optional())]
 
     def before():
-        x = relay.var('x', shape=(1, 3, 7, 7))
-        w1 = relay.var('w', shape=(3, 3, 1, 1))
-        b1 = relay.var('b', shape=(3, ))
-        w2 = relay.var('w', shape=(3, 3, 1, 1))
-        b2 = relay.var('b', shape=(3, ))
+        x = relay.var("x", shape=(1, 3, 7, 7))
+        w1 = relay.var("w", shape=(3, 3, 1, 1))
+        b1 = relay.var("b", shape=(3,))
+        w2 = relay.var("w", shape=(3, 3, 1, 1))
+        b2 = relay.var("b", shape=(3,))
         conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1))
         bias = relay.nn.bias_add(conv, b1)
         relu = relay.nn.relu(bias)
@@ -476,9 +471,9 @@ def test_optional_pattern():
 
     def expected():
         # Matched composite function A
-        x = relay.var('x')
-        w = relay.var('w')
-        b = relay.var('b')
+        x = relay.var("x")
+        w = relay.var("w")
+        b = relay.var("b")
         conv = relay.nn.conv2d(x, w, kernel_size=(1, 1))
         bias = relay.nn.bias_add(conv, b)
         relu = relay.nn.relu(bias)
@@ -487,9 +482,9 @@ def test_optional_pattern():
         func1 = func1.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
 
         # Matched composite function B
-        x = relay.var('x')
-        w = relay.var('w')
-        b = relay.var('b')
+        x = relay.var("x")
+        w = relay.var("w")
+        b = relay.var("b")
         conv = relay.nn.conv2d(x, w, kernel_size=(1, 1))
         bias = relay.nn.bias_add(conv, b)
         func2 = relay.Function([x, w, b], bias)
@@ -497,11 +492,11 @@ def test_optional_pattern():
         func2 = func2.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_")
 
         # Main function
-        x = relay.var('x', shape=(1, 3, 7, 7))
-        w1 = relay.var('w', shape=(3, 3, 1, 1))
-        b1 = relay.var('b', shape=(3, ))
-        w2 = relay.var('w', shape=(3, 3, 1, 1))
-        b2 = relay.var('b', shape=(3, ))
+        x = relay.var("x", shape=(1, 3, 7, 7))
+        w1 = relay.var("w", shape=(3, 3, 1, 1))
+        b1 = relay.var("b", shape=(3,))
+        w2 = relay.var("w", shape=(3, 3, 1, 1))
+        b2 = relay.var("b", shape=(3,))
         out1 = func1(x, w1, b1)
         out2 = func2(out1, w2, b2)
         return relay.Function([x, w1, w2, b1, b2], out2)
@@ -529,69 +524,69 @@ def test_merge_order():
     def pattern_A():
         x = wildcard()
         y = wildcard()
-        out = is_op('add')(x, y)
-        out = is_op('abs')(out)
-        out = is_op('nn.relu')(out)
+        out = is_op("add")(x, y)
+        out = is_op("abs")(out)
+        out = is_op("nn.relu")(out)
         return out
 
     def pattern_B():
         x = wildcard()
         y = wildcard()
-        out = is_op('add')(x, y)
-        out = is_op('abs')(out)
+        out = is_op("add")(x, y)
+        out = is_op("abs")(out)
         return out
 
     def pattern_C():
         x = wildcard()
-        out = is_op('abs')(x)
-        out = is_op('nn.relu')(out)
+        out = is_op("abs")(x)
+        out = is_op("nn.relu")(out)
         return out
 
     def before():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
         out = relay.add(input_1, input_2)
         out = relay.abs(out)
         out = relay.nn.relu(out)
         return relay.Function([input_1, input_2], out)
 
     def after_A_priority():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        x = relay.var('x')
-        y = relay.var('y')
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
+        x = relay.var("x")
+        y = relay.var("y")
         out = relay.add(x, y)
         out = relay.abs(out)
         out = relay.nn.relu(out)
         merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.with_attr('Composite', 'A')
-        merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_nn.relu_')
+        merged_func = merged_func.with_attr("Composite", "A")
+        merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_nn.relu_")
         ret = relay.Call(merged_func, [input_1, input_2])
         return relay.Function([input_1, input_2], ret)
 
     def after_B_priority():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        x = relay.var('x')
-        y = relay.var('y')
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
+        x = relay.var("x")
+        y = relay.var("y")
         out = relay.add(x, y)
         out = relay.abs(out)
         merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.with_attr('Composite', 'B')
-        merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_')
+        merged_func = merged_func.with_attr("Composite", "B")
+        merged_func = merged_func.with_attr("PartitionedFromPattern", "add_abs_")
         out = relay.Call(merged_func, [input_1, input_2])
         ret = relay.nn.relu(out)
         return relay.Function([input_1, input_2], ret)
 
     def after_C_priority():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        x = relay.var('x')
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
+        x = relay.var("x")
         out = relay.abs(x)
         out = relay.nn.relu(out)
         merged_func = relay.Function([x], out)
-        merged_func = merged_func.with_attr('Composite', 'C')
-        merged_func = merged_func.with_attr('PartitionedFromPattern', 'abs_nn.relu_')
+        merged_func = merged_func.with_attr("Composite", "C")
+        merged_func = merged_func.with_attr("PartitionedFromPattern", "abs_nn.relu_")
         out = relay.add(input_1, input_2)
         ret = relay.Call(merged_func, [out])
         return relay.Function([input_1, input_2], ret)
@@ -630,8 +625,8 @@ def test_parallel_merge():
     consume the same input variables, input_1 and input_2."""
 
     def before():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
         branch_1_add = relay.add(input_1, input_2)
         branch_1_sub = relay.subtract(input_1, input_2)
         branch_1 = relay.multiply(branch_1_add, branch_1_sub)
@@ -642,28 +637,26 @@ def test_parallel_merge():
         return relay.Function([input_1, input_2], out)
 
     def expected():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        x = relay.var('x')
-        y = relay.var('y')
+        input_1 = relay.var("input_1", shape=(10, 10))
+        input_2 = relay.var("input_2", shape=(10, 10))
+        x = relay.var("x")
+        y = relay.var("y")
         branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
         func_1 = relay.Function([x, y], branch_1)
-        func_1 = func_1.with_attr('Composite', "add_sub_mul")
-        func_1 = func_1.with_attr('PartitionedFromPattern', "add_subtract_multiply_")
+        func_1 = func_1.with_attr("Composite", "add_sub_mul")
+        func_1 = func_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
         call_1 = relay.Call(func_1, [input_1, input_2])
-        x1 = relay.var('x1')
-        y1 = relay.var('y1')
+        x1 = relay.var("x1")
+        y1 = relay.var("y1")
         branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
         func_2 = relay.Function([x1, y1], branch_2)
-        func_2 = func_2.with_attr('Composite', "add_sub_mul")
-        func_2 = func_2.with_attr('PartitionedFromPattern', "add_subtract_multiply_")
+        func_2 = func_2.with_attr("Composite", "add_sub_mul")
+        func_2 = func_2.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
         call_2 = relay.Call(func_2, [input_1, input_2])
         out = relay.multiply(call_1, call_2)
         return relay.Function([input_1, input_2], out)
 
-    pattern_table = [
-        ("add_sub_mul", make_add_sub_mul_pattern())
-    ]
+    pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
     check_result(pattern_table, before(), expected())
 
 
@@ -707,7 +700,7 @@ def test_multiple_input_subgraphs():
 
     def before():
         before_funcs = {}
-        inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
+        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)]
         add_relu_1 = relay.add(inputs[0], inputs[1])
         add_relu_1 = relay.nn.relu(add_relu_1)
         add_relu_2 = relay.add(inputs[2], inputs[3])
@@ -719,53 +712,53 @@ def test_multiple_input_subgraphs():
         add = relay.add(add_relu_1, add_relu_2)
         sub = relay.subtract(add_relu_3, add_relu_4)
         out = relay.multiply(add, sub)
-        before_funcs['B'] = relay.Function(inputs, out)
+        before_funcs["B"] = relay.Function(inputs, out)
         sub = relay.subtract(add_relu_1, add_relu_2)
         out = relay.multiply(add, sub)
-        before_funcs['A'] = relay.Function(inputs[:4], out)
+        before_funcs["A"] = relay.Function(inputs[:4], out)
         return before_funcs
 
     def after_A():
-        inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(4)]
-        x = relay.var('x')
-        y = relay.var('y')
+        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(4)]
+        x = relay.var("x")
+        y = relay.var("y")
         add_relu_1 = relay.add(x, y)
         add_relu_1 = relay.nn.relu(add_relu_1)
         add_relu_1 = relay.Function([x, y], add_relu_1)
-        add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu')
-        add_relu_1 = add_relu_1.with_attr('PartitionedFromPattern', 'add_nn.relu_')
+        add_relu_1 = add_relu_1.with_attr("Composite", "add_relu")
+        add_relu_1 = add_relu_1.with_attr("PartitionedFromPattern", "add_nn.relu_")
         add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
-        x1 = relay.var('x1')
-        y1 = relay.var('y1')
+        x1 = relay.var("x1")
+        y1 = relay.var("y1")
         add_relu_2 = relay.add(x1, y1)
         add_relu_2 = relay.nn.relu(add_relu_2)
         add_relu_2 = relay.Function([x1, y1], add_relu_2)
-        add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu')
-        add_relu_2 = add_relu_2.with_attr('PartitionedFromPattern', 'add_nn.relu_')
+        add_relu_2 = add_relu_2.with_attr("Composite", "add_relu")
+        add_relu_2 = add_relu_2.with_attr("PartitionedFromPattern", "add_nn.relu_")
         add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
-        x2 = relay.var('x2')
-        y2 = relay.var('y2')
+        x2 = relay.var("x2")
+        y2 = relay.var("y2")
         add = relay.add(x2, y2)
         sub = relay.subtract(x2, y2)
         add_sub_mul = relay.multiply(add, sub)
         add_sub_mul = relay.Function([x2, y2], add_sub_mul)
-        add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul')
-        add_sub_mul = add_sub_mul.with_attr('PartitionedFromPattern', 'add_subtract_multiply_')
+        add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
+        add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_")
         add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
         return relay.Function(inputs, add_sub_mul_call)
 
     def after_B():
-        inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
+        inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)]
         add_relu_calls = []
         for i in range(4):
-            x = relay.var('x' + str(i))
-            y = relay.var('x' + str(i))
+            x = relay.var("x" + str(i))
+            y = relay.var("x" + str(i))
             add_relu = relay.add(x, y)
             add_relu = relay.nn.relu(add_relu)
             add_relu = relay.Function([x, y], add_relu)
-            add_relu = add_relu.with_attr('Composite', 'add_relu')
-            add_relu = add_relu.with_attr('PartitionedFromPattern', 'add_nn.relu_')
-            add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
+            add_relu = add_relu.with_attr("Composite", "add_relu")
+            add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_")
+            add_relu_call = relay.Call(add_relu, [inputs[i * 2], inputs[i * 2 + 1]])
             add_relu_calls.append(add_relu_call)
 
         add = relay.add(add_relu_calls[0], add_relu_calls[1])
@@ -775,20 +768,18 @@ def test_multiple_input_subgraphs():
 
     pattern_table = [
         ("add_sub_mul", make_add_sub_mul_pattern()),
-        ("add_relu", make_add_relu_pattern())
+        ("add_relu", make_add_relu_pattern()),
     ]
-    check_result(pattern_table, before()['A'], after_A())
-    check_result(pattern_table, before()['B'], after_B())
+    check_result(pattern_table, before()["A"], after_A())
+    check_result(pattern_table, before()["B"], after_B())
 
 
 def test_tuple_get_item_merge():
     """Test composite function can be merged from pattern containing TupleGetItem nodes."""
-    pattern_table = [
-        ("bn_relu", make_bn_relu_pattern())
-    ]
+    pattern_table = [("bn_relu", make_bn_relu_pattern())]
 
     def before():
-        x = relay.var('x', shape=(1, 8))
+        x = relay.var("x", shape=(1, 8))
         gamma = relay.var("gamma", shape=(8,))
         beta = relay.var("beta", shape=(8,))
         moving_mean = relay.var("moving_mean", shape=(8,))
@@ -799,25 +790,26 @@ def test_tuple_get_item_merge():
         return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
 
     def expected():
-        x = relay.var('x', shape=(1, 8))
+        x = relay.var("x", shape=(1, 8))
         beta = relay.var("beta", shape=(8,))
         gamma = relay.var("gamma", shape=(8,))
         moving_mean = relay.var("moving_mean", shape=(8,))
         moving_var = relay.var("moving_var", shape=(8,))
 
         # bn_relu function
-        in_1 = relay.var('x1', shape=(1, 8))
-        in_2 = relay.var('gamma1', shape=(8,))
-        in_3 = relay.var('beta1', shape=(8,))
-        in_4 = relay.var('moving_mean1', shape=(8,))
-        in_5 = relay.var('moving_var1', shape=(8,))
+        in_1 = relay.var("x1", shape=(1, 8))
+        in_2 = relay.var("gamma1", shape=(8,))
+        in_3 = relay.var("beta1", shape=(8,))
+        in_4 = relay.var("moving_mean1", shape=(8,))
+        in_5 = relay.var("moving_var1", shape=(8,))
         bn_node = relay.nn.batch_norm(in_1, in_2, in_3, in_4, in_5)
         tuple_get_item_node = bn_node[0]
         relu_node = relay.nn.relu(tuple_get_item_node)
         bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
         bn_relu = bn_relu.with_attr("Composite", "bn_relu")
-        bn_relu = bn_relu.with_attr("PartitionedFromPattern",
-                                    "nn.batch_norm_TupleGetItem0_nn.relu_")
+        bn_relu = bn_relu.with_attr(
+            "PartitionedFromPattern", "nn.batch_norm_TupleGetItem0_nn.relu_"
+        )
 
         # merged function
         r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
@@ -828,14 +820,10 @@ def test_tuple_get_item_merge():
 
 def test_pattern_with_check():
     def before():
-        x = relay.var('x', shape=(1, 10, 10, 10))
-        w = relay.var('w', shape=(10, 10, 3, 3))
-        b = relay.var('b', shape=(8,))
-        conv = relay.nn.conv2d(x,
-                               w,
-                               kernel_size=(3, 3),
-                               kernel_layout="OIHW",
-                               data_layout="NHWC")
+        x = relay.var("x", shape=(1, 10, 10, 10))
+        w = relay.var("w", shape=(10, 10, 3, 3))
+        b = relay.var("b", shape=(8,))
+        conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
         bias = relay.nn.bias_add(conv, b)
         relu = relay.nn.relu(bias)
         return relay.Function([x, w, b], relu)
@@ -849,9 +837,9 @@ def test_pattern_with_check():
         return conv.attrs.data_layout == "NCHW"
 
     def expected():
-        x = relay.var('x')
-        w = relay.var('w')
-        b = relay.var('b')
+        x = relay.var("x")
+        w = relay.var("w")
+        b = relay.var("b")
         conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
         bias = relay.nn.bias_add(conv, b)
         relu = relay.nn.relu(bias)
@@ -859,19 +847,15 @@ def test_pattern_with_check():
         func = func.with_attr("Composite", "conv_bias_relu")
         func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
 
-        x = relay.var('x', shape=(1, 10, 10, 10))
-        w = relay.var('w', shape=(10, 10, 3, 3))
-        b = relay.var('b', shape=(8,))
+        x = relay.var("x", shape=(1, 10, 10, 10))
+        w = relay.var("w", shape=(10, 10, 3, 3))
+        b = relay.var("b", shape=(8,))
         return relay.Function([x, w, b], func(x, w, b))
 
-    pattern_table_false = [
-        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)
-    ]
+    pattern_table_false = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)]
     check_result(pattern_table_false, before(), before())
 
-    pattern_table_true = [
-        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)
-    ]
+    pattern_table_true = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)]
     check_result(pattern_table_true, before(), expected())
 
 
@@ -887,19 +871,17 @@ def test_diamond_not_merge():
                       |  /
                       mul
     """
+
     def get_pattern():
         conv = make_conv_bias_relu_pattern()
-        clip = is_op('clip')(conv, wildcard(), wildcard())
-        return is_op('multiply')(conv, clip)
+        clip = is_op("clip")(conv, wildcard(), wildcard())
+        return is_op("multiply")(conv, clip)
 
     def get_net():
-        data = relay.var('data', shape=(1, 512, 28, 28))
-        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
-        conv = relay.nn.conv2d(data, kernel,
-                               kernel_size=(1, 1),
-                               padding=(0, 0),
-                               strides=(1, 1))
-        bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,)))
+        data = relay.var("data", shape=(1, 512, 28, 28))
+        kernel = relay.var("kernel", shape=(256, 512, 1, 1))
+        conv = relay.nn.conv2d(data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1))
+        bias = relay.nn.bias_add(conv, relay.var("bias", shape=(256,)))
         relu = relay.nn.relu(bias)
         add = relay.op.add(relu, relay.const(1.0))
         clip2 = relay.op.clip(add, 0, 255)
@@ -913,28 +895,27 @@ def test_diamond_not_merge():
 
 def test_type_check():
     """Test that we can query tensor types in the 'check' function."""
+
     def before():
-        x = relay.var('x', shape=(1, 10, 10, 10))
-        w = relay.var('w', shape=(10, 10, 3, 3))
-        b = relay.var('b', shape=(8,))
+        x = relay.var("x", shape=(1, 10, 10, 10))
+        w = relay.var("w", shape=(10, 10, 3, 3))
+        b = relay.var("b", shape=(8,))
         add = relay.op.add(x, x)
         relu = relay.nn.relu(add)
-        conv = relay.nn.conv2d(relu,
-                               w,
-                               kernel_size=(3, 3),
-                               kernel_layout="OIHW",
-                               data_layout="NHWC")
+        conv = relay.nn.conv2d(
+            relu, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC"
+        )
         bias = relay.nn.bias_add(conv, b)
         relu2 = relay.nn.relu(bias)
         return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType())
 
     def expected_false():
-        x = relay.var('x', shape=(1, 10, 10, 10))
-        w = relay.var('w', shape=(10, 10, 3, 3))
-        b = relay.var('b', shape=(8, ))
+        x = relay.var("x", shape=(1, 10, 10, 10))
+        w = relay.var("w", shape=(10, 10, 3, 3))
+        b = relay.var("b", shape=(8,))
 
-        x0 = relay.var('x')
-        y0 = relay.var('y')
+        x0 = relay.var("x")
+        y0 = relay.var("y")
 
         add = relay.op.add(y0, y0)
         relu = relay.nn.relu(add)
@@ -943,18 +924,20 @@ def test_type_check():
         func = func.with_attr("Composite", "add_relu")
         call = relay.Call(func, [x, x])
 
-        conv = relay.nn.conv2d(call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
+        conv = relay.nn.conv2d(
+            call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC"
+        )
         bias = relay.nn.bias_add(conv, b)
         relu2 = relay.nn.relu(bias)
         return relay.Function([x, w, b], relu2)
 
     def expected_true():
-        x = relay.var('x', shape=(1, 10, 10, 10))
-        w = relay.var('w', shape=(10, 10, 3, 3))
-        b = relay.var('b', shape=(8, ))
+        x = relay.var("x", shape=(1, 10, 10, 10))
+        w = relay.var("w", shape=(10, 10, 3, 3))
+        b = relay.var("b", shape=(8,))
 
-        x0 = relay.var('x')
-        y0 = relay.var('y')
+        x0 = relay.var("x")
+        y0 = relay.var("y")
 
         add = relay.op.add(y0, y0)
         relu = relay.nn.relu(add)
@@ -963,9 +946,9 @@ def test_type_check():
         func = func.with_attr("Composite", "add_relu")
         call = relay.Call(func, [x, x])
 
-        x2 = relay.var('x')
-        w1 = relay.var('w')
-        b1 = relay.var('b')
+        x2 = relay.var("x")
+        w1 = relay.var("w")
+        b1 = relay.var("b")
         conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
         bias = relay.nn.bias_add(conv, b1)
         relu2 = relay.nn.relu(bias)
@@ -987,13 +970,13 @@ def test_type_check():
 
     pattern_table_false = [
         ("add_relu", make_add_relu_pattern()),
-        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false)
+        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false),
     ]
     check_result(pattern_table_false, before(), expected_false())
 
     pattern_table_true = [
         ("add_relu", make_add_relu_pattern()),
-        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true)
+        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true),
     ]
     check_result(pattern_table_true, before(), expected_true())
 
index 95805d2..f2368fc 100644 (file)
@@ -28,6 +28,7 @@ from tvm.relay import GlobalVar, Call
 from tvm.relay.transform import gradient
 from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
 
+
 def check_eval(expr, expected_result, mod=None, rtol=1e-07):
     ctx = tvm.context("llvm", 0)
     intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
@@ -41,19 +42,17 @@ def run_opt_pass(expr, passes):
     mod = tvm.IRModule.from_expr(expr)
     seq = tvm.transform.Sequential(passes)
     with tvm.transform.PassContext(opt_level=3):
-       mod = seq(mod)
+        mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
 def tipe(expr):
-    return run_opt_pass(expr, [transform.PartialEvaluate(),
-                               transform.InferType()])
+    return run_opt_pass(expr, [transform.PartialEvaluate(), transform.InferType()])
 
 
 def dcpe(expr, mod=None, grad=False):
-    passes = [transform.PartialEvaluate(),
-              transform.DeadCodeElimination(inline_once=True)]
+    passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]
     if grad:
         expr = gradient(run_infer_type(expr))
     if mod:
@@ -339,5 +338,5 @@ def test_tuple_match():
     tvm.ir.assert_structural_equal(dcpe(x), const(2))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     pytest.main([__file__])
index 58bb16d..75218c5 100644 (file)
@@ -46,20 +46,20 @@ class WhiteListAnnotator:
     def transform_function(self, func, mod, ctx):
 
         annotator = self
+
         class Annotator(tvm.relay.ExprMutator):
             def visit_call(self, call):
                 op_name = call.op.name
                 if op_name in annotator.op_list:
                     new_args = []
                     for arg in call.args:
-                        ann = compiler_begin(super().visit(arg),
-                                             annotator.compiler)
+                        ann = compiler_begin(super().visit(arg), annotator.compiler)
                         new_args.append(ann)
-                    new_call = relay.Call(call.op, new_args, call.attrs,
-                                          call.type_args)
+                    new_call = relay.Call(call.op, new_args, call.attrs, call.type_args)
                     return compiler_end(new_call, annotator.compiler)
                 else:
                     return super().visit_call(call)
+
         return Annotator().visit(func)
 
 
@@ -155,14 +155,14 @@ class MobileNetAnnotator(ExprMutator):
 
     def visit_call(self, call):
 
-        if call.op.name == 'nn.global_avg_pool2d':
+        if call.op.name == "nn.global_avg_pool2d":
             self.compiler_open = True
         compiler_open = self.compiler_open
 
         params = []
         for arg in call.args:
             param = super().visit(arg)
-            if call.op.name == 'nn.global_avg_pool2d':
+            if call.op.name == "nn.global_avg_pool2d":
                 param = compiler_end(param, self.compiler)
             if compiler_open and isinstance(param, relay.expr.Var):
                 param = compiler_begin(param, self.compiler)
@@ -172,8 +172,9 @@ class MobileNetAnnotator(ExprMutator):
         return new_call
 
 
-def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
-                 ctx=tvm.cpu(), params=None):
+def check_result(
+    mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu(), params=None
+):
     if sys.platform == "win32":
         print("Skip test on Windows for now")
         return
@@ -186,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         kwargs = {}
         kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
         tmp_path = util.tempdir()
-        lib_name = 'lib.so'
+        lib_name = "lib.so"
         lib_path = tmp_path.relpath(lib_name)
         lib.export_library(lib_path, fcompile=False, **kwargs)
         lib = runtime.load_module(lib_path)
@@ -232,15 +233,15 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
 
 
 def test_multi_node_compiler():
-    x = relay.var('x', shape=(10, 10))
-    w0 = relay.var('w0', shape=(10, 10))
-    w1 = relay.var('w1', shape=(10, 10))
-    w2 = relay.var('w2', shape=(10, 10))
-    w3 = relay.var('w3', shape=(10, 10))
-    w4 = relay.var('w4', shape=(10, 10))
-    w5 = relay.var('w5', shape=(10, 10))
-    w6 = relay.var('w6', shape=(10, 10))
-    w7 = relay.var('w7', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
+    w0 = relay.var("w0", shape=(10, 10))
+    w1 = relay.var("w1", shape=(10, 10))
+    w2 = relay.var("w2", shape=(10, 10))
+    w3 = relay.var("w3", shape=(10, 10))
+    w4 = relay.var("w4", shape=(10, 10))
+    w5 = relay.var("w5", shape=(10, 10))
+    w6 = relay.var("w6", shape=(10, 10))
+    w7 = relay.var("w7", shape=(10, 10))
 
     # C compiler
     # FIXME: We generate two compilers for this case but they should be merged to one
@@ -265,19 +266,26 @@ def test_multi_node_compiler():
     mod = transform.PartitionGraph()(mod)
     mod = transform.InferType()(mod)
 
-    x_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
     w_data = []
     for _ in range(8):
-        w_data.append(np.random.rand(10, 10).astype('float32'))
+        w_data.append(np.random.rand(10, 10).astype("float32"))
 
     map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
     map_inputs["x"] = x_data
     check_result(
-        mod, map_inputs, (30, 10),
-        np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
-                        ((x_data + w_data[3]) - w_data[4]) * w_data[5],
-                        x_data + w_data[6] - w_data[7]),
-                       axis=0))
+        mod,
+        map_inputs,
+        (30, 10),
+        np.concatenate(
+            (
+                ((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                x_data + w_data[6] - w_data[7],
+            ),
+            axis=0,
+        ),
+    )
 
 
 def test_extern_ccompiler_single_op():
@@ -292,14 +300,15 @@ def test_extern_ccompiler_single_op():
                         new_args.append(ann)
                     new_call = relay.Call(call.op, new_args)
                     return compiler_end(new_call, "ccompiler")
+
             return Annotator().visit(func)
 
-    x = relay.var('x', shape=(8, 8))
-    y = relay.var('y', shape=(8, 8))
+    x = relay.var("x", shape=(8, 8))
+    y = relay.var("y", shape=(8, 8))
     z = x + y
     f = relay.Function([x, y], z)
-    x_data = np.random.rand(8, 8).astype('float32')
-    y_data = np.random.rand(8, 8).astype('float32')
+    x_data = np.random.rand(8, 8).astype("float32")
+    y_data = np.random.rand(8, 8).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     mod = MyAnnotator()(mod)
@@ -336,8 +345,7 @@ def test_extern_ccompiler_default_ops():
         exp = relay.exp(p0)
         concat = relay.concatenate([log, exp], axis=0)
         fused_func = relay.Function([p0], concat)
-        fused_func = fused_func.with_attr("Primitive",
-                                          tvm.tir.IntImm("int32", 1))
+        fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         fused_call = relay.Call(fused_func, [add_call])
         main = relay.Function([x, y], fused_call)
         mod["main"] = main
@@ -359,21 +367,21 @@ def test_extern_ccompiler_default_ops():
     expected_mod = expected()
     assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
 
-    x_data = np.random.rand(8, 8).astype('float32')
-    y_data = np.random.rand(8, 8).astype('float32')
+    x_data = np.random.rand(8, 8).astype("float32")
+    y_data = np.random.rand(8, 8).astype("float32")
     np_add = x_data + y_data
     res = np.concatenate([np.log(np_add), np.exp(np_add)])
     check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
 
 
 def test_extern_ccompiler():
-    x = relay.var('x', shape=(2, 2))
-    y = relay.var('y', shape=(2, 2))
+    x = relay.var("x", shape=(2, 2))
+    y = relay.var("y", shape=(2, 2))
     z = x + x
     p = y * y
     f = relay.Function([x, y], p - z)
-    x_data = np.random.rand(2, 2).astype('float32')
-    y_data = np.random.rand(2, 2).astype('float32')
+    x_data = np.random.rand(2, 2).astype("float32")
+    y_data = np.random.rand(2, 2).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
@@ -387,23 +395,19 @@ def test_extern_dnnl():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 32, 14, 14)
     w1shape = (32, 1, 3, 3)
 
     def expected():
         data0 = relay.var("data", shape=(ishape), dtype=dtype)
         input0 = relay.var("input", shape=(w1shape), dtype=dtype)
-        depthwise_conv2d_1 = relay.nn.conv2d(data0,
-                                             input0,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
-        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                             input0,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
+        depthwise_conv2d_1 = relay.nn.conv2d(
+            data0, input0, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
+        depthwise_conv2d_2 = relay.nn.conv2d(
+            depthwise_conv2d_1, input0, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
         out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
         func = relay.Function([data0, input0], out)
@@ -422,16 +426,12 @@ def test_extern_dnnl():
     def get_func():
         data = relay.var("data", shape=(ishape), dtype=dtype)
         weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
-        depthwise_conv2d_1 = relay.nn.conv2d(data,
-                                             weight1,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
-        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                             weight1,
-                                             kernel_size=(3, 3),
-                                             padding=(1, 1),
-                                             groups=32)
+        depthwise_conv2d_1 = relay.nn.conv2d(
+            data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
+        depthwise_conv2d_2 = relay.nn.conv2d(
+            depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
+        )
         out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
         return relay.Function([data, weight1], out)
@@ -450,8 +450,9 @@ def test_extern_dnnl():
 
     ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
     ref_res = ref_ex.evaluate()(i_data, w1_data)
-    check_result(mod, {"data": i_data, "weight1": w1_data},
-                 (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+    check_result(
+        mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5
+    )
 
 
 def test_extern_dnnl_mobilenet():
@@ -459,9 +460,9 @@ def test_extern_dnnl_mobilenet():
         print("skip because DNNL codegen is not available")
         return
 
-    dtype = 'float32'
+    dtype = "float32"
     ishape = (1, 3, 224, 224)
-    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype='float32')
+    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32")
     mod = transform.AnnotateTarget(["dnnl"])(ref_mod)
     mod = transform.MergeCompilerRegions()(mod)
     mod = transform.PartitionGraph()(mod)
@@ -478,34 +479,33 @@ def test_function_lifting():
     def partition():
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mmean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_mvar = relay.var("bn_var", relay.TensorType((16,), "float32"))
 
         conv = relay.nn.conv2d(
-            data=data,
-            weight=weight,
-            kernel_size=(3, 3),
-            channels=16,
-            padding=(1, 1))
-        bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
-                                        bn_mvar)
-
-        func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mmean,
-                               bn_mvar], bn_output.astuple())
+            data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
+        )
+        bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+
+        func = relay.Function(
+            [data, weight, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple()
+        )
         mod = tvm.IRModule()
         mod["main"] = func
         op_list = ["nn.batch_norm", "nn.conv2d"]
         mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
 
-        opt_pass = tvm.transform.Sequential([
-            transform.InferType(),
-            transform.PartitionGraph(),
-            transform.SimplifyInference(),
-            transform.FoldConstant(),
-            transform.AlterOpLayout(),
-        ])
+        opt_pass = tvm.transform.Sequential(
+            [
+                transform.InferType(),
+                transform.PartitionGraph(),
+                transform.SimplifyInference(),
+                transform.FoldConstant(),
+                transform.AlterOpLayout(),
+            ]
+        )
 
         with tvm.transform.PassContext(opt_level=3):
             mod = opt_pass(mod)
@@ -514,17 +514,15 @@ def test_function_lifting():
 
     def expected():
         # function for batch_norm
-        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
-                                                    "float32"))
+        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224), "float32"))
         mod = tvm.IRModule()
-        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
-        bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
-        bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta1", relay.TensorType((16,), "float32"))
+        bn_mmean = relay.var("bn_mean1", relay.TensorType((16,), "float32"))
+        bn_mvar = relay.var("bn_var1", relay.TensorType((16,), "float32"))
 
         bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
-        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
-                               bn.astuple())
+        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple())
         func0 = set_func_attr(func0, "test_compiler", "test_compiler_2")
         gv0 = relay.GlobalVar("test_compiler_2")
         mod[gv0] = func0
@@ -533,11 +531,8 @@ def test_function_lifting():
         data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32"))
         weight1 = relay.var("weight1", relay.TensorType((16, 3, 3, 3), "float32"))
         conv = relay.nn.conv2d(
-            data=data1,
-            weight=weight1,
-            kernel_size=(3, 3),
-            channels=16,
-            padding=(1, 1))
+            data=data1, weight=weight1, kernel_size=(3, 3), channels=16, padding=(1, 1)
+        )
         func1 = relay.Function([data1, weight1], conv)
         func1 = set_func_attr(func1, "test_compiler", "test_compiler_0")
         gv1 = relay.GlobalVar("test_compiler_0")
@@ -546,15 +541,16 @@ def test_function_lifting():
         # main function
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta0 = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_mvar0 = relay.var("bn_var", relay.TensorType((16,), "float32"))
 
         call1 = gv1(data, weight)
         call0 = gv0(call1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
-        mod["main"] = relay.Function([data, weight, bn_gamma0, bn_beta0, bn_mmean0,
-                                      bn_mvar0], call0)
+        mod["main"] = relay.Function(
+            [data, weight, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0], call0
+        )
         mod = transform.InferType()(mod)
         return mod
 
@@ -566,29 +562,29 @@ def test_function_lifting():
 def test_function_lifting_inline():
     def partition():
         data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
-        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mmean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_mvar = relay.var("bn_var", relay.TensorType((16,), "float32"))
 
-        bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
-                                        bn_mvar)
+        bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
 
-        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean,
-                               bn_mvar], bn_output.astuple())
+        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple())
         mod = tvm.IRModule()
         mod["main"] = func
         op_list = ["nn.batch_norm", "nn.conv2d"]
         mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
 
-        opt_pass = tvm.transform.Sequential([
-            transform.InferType(),
-            transform.PartitionGraph(),
-            transform.SimplifyInference(),
-            transform.FoldConstant(),
-            transform.AlterOpLayout(),
-            transform.Inline(),
-        ])
+        opt_pass = tvm.transform.Sequential(
+            [
+                transform.InferType(),
+                transform.PartitionGraph(),
+                transform.SimplifyInference(),
+                transform.FoldConstant(),
+                transform.AlterOpLayout(),
+                transform.Inline(),
+            ]
+        )
 
         with tvm.transform.PassContext(opt_level=3):
             mod = opt_pass(mod)
@@ -597,29 +593,26 @@ def test_function_lifting_inline():
 
     def expected():
         # function for batch_norm
-        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
-                                                    "float32"))
+        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224), "float32"))
         mod = tvm.IRModule()
-        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
-        bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
-        bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta1", relay.TensorType((16,), "float32"))
+        bn_mmean = relay.var("bn_mean1", relay.TensorType((16,), "float32"))
+        bn_mvar = relay.var("bn_var1", relay.TensorType((16,), "float32"))
 
         bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
-        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
-                               bn.astuple())
+        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple())
         func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
 
         # main function
         data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
-        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta0 = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_mvar0 = relay.var("bn_var", relay.TensorType((16,), "float32"))
 
         call0 = func0(data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
-        mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0,
-                                      bn_mvar0], call0)
+        mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0], call0)
         mod = transform.InferType()(mod)
         return mod
 
@@ -662,52 +655,46 @@ def test_constant_propagation():
     expected_mod = expected()
     assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
 
-    y_data = np.random.rand(8, 8).astype('float32')
+    y_data = np.random.rand(8, 8).astype("float32")
     np_add = ones + y_data
     check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))
 
 
 def test_multiple_outputs():
-
     def create_graph():
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
-
-        data_cb = compiler_begin(data, 'test_target')
-        weight_cb = compiler_begin(weight, 'test_target')
-        bn_gamma_cb = compiler_begin(bn_gamma, 'test_target')
-        bn_beta_cb = compiler_begin(bn_beta, 'test_target')
-        bn_mean_cb = compiler_begin(bn_mean, 'test_target')
-        bn_var_cb = compiler_begin(bn_var, 'test_target')
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_var = relay.var("bn_var", relay.TensorType((16,), "float32"))
+
+        data_cb = compiler_begin(data, "test_target")
+        weight_cb = compiler_begin(weight, "test_target")
+        bn_gamma_cb = compiler_begin(bn_gamma, "test_target")
+        bn_beta_cb = compiler_begin(bn_beta, "test_target")
+        bn_mean_cb = compiler_begin(bn_mean, "test_target")
+        bn_var_cb = compiler_begin(bn_var, "test_target")
 
         conv_o = relay.nn.conv2d(
-            data=data_cb,
-            weight=weight_cb,
-            kernel_size=(3, 3),
-            channels=16,
-            padding=(1, 1))
+            data=data_cb, weight=weight_cb, kernel_size=(3, 3), channels=16, padding=(1, 1)
+        )
 
-        bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb,
-                                   bn_var_cb)
+        bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb, bn_var_cb)
 
         relu_o = relay.nn.relu(bn_o[0])
-        relu_o_ce = compiler_end(relu_o, 'test_target')
+        relu_o_ce = compiler_end(relu_o, "test_target")
 
         bn_omean = bn_o[1]
-        rebn_omean_ce = compiler_end(bn_omean, 'test_target')
+        rebn_omean_ce = compiler_end(bn_omean, "test_target")
         bn_ovar = bn_o[2]
-        bn_ovar_ce = compiler_end(bn_ovar, 'test_target')
+        bn_ovar_ce = compiler_end(bn_ovar, "test_target")
 
         dummy_mean_abs = relay.abs(rebn_omean_ce)
         dummy_ovar_abs = relay.abs(bn_ovar_ce)
-        dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs,dummy_ovar_abs))
+        dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs, dummy_ovar_abs))
 
-        func = relay.Function([data, weight, bn_gamma, bn_beta,
-                               bn_mean, bn_var], dummy_tuple)
+        func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], dummy_tuple)
         return func
 
     def expected():
@@ -716,26 +703,21 @@ def test_multiple_outputs():
         # function 0
         data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
-        bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
-        bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("test_target_0_i3", relay.TensorType((16,), "float32"))
+        bn_mean = relay.var("test_target_0_i4", relay.TensorType((16,), "float32"))
+        bn_var = relay.var("test_target_0_i5", relay.TensorType((16,), "float32"))
 
         conv_o = relay.nn.conv2d(
-            data=data,
-            weight=weight,
-            kernel_size=(3, 3),
-            channels=16,
-            padding=(1, 1))
+            data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
+        )
 
-        bn_o = relay.nn.batch_norm(conv_o, bn_gamma, bn_beta, bn_mean,
-                                   bn_var)
+        bn_o = relay.nn.batch_norm(conv_o, bn_gamma, bn_beta, bn_mean, bn_var)
 
         relu_o = relay.nn.relu(bn_o[0])
         tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
 
-        func0 = relay.Function([data, weight, bn_gamma, bn_beta,
-                                bn_mean, bn_var], tuple_o)
+        func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o)
         func0 = set_func_attr(func0, "test_target", "test_target_0")
         gv0 = relay.GlobalVar("test_target_0")
         mod[gv0] = func0
@@ -743,10 +725,10 @@ def test_multiple_outputs():
         # body
         data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
         weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
-        bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
-        bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16,), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16,), "float32"))
+        bn_mean = relay.var("bn_mean", relay.TensorType((16,), "float32"))
+        bn_var = relay.var("bn_var", relay.TensorType((16,), "float32"))
 
         f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
         f0_relu_o = relay.TupleGetItem(f0_o, 0)
@@ -757,8 +739,7 @@ def test_multiple_outputs():
         f0_var_abs = relay.abs(f0_var_o)
         main_tuple = relay.Tuple((f0_relu_o, f0_mean_abs, f0_var_abs))
 
-        func = relay.Function([data, weight, bn_gamma,
-                               bn_beta, bn_mean, bn_var], main_tuple)
+        func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], main_tuple)
         mod["main"] = func
         return mod
 
@@ -771,20 +752,20 @@ def test_multiple_outputs():
 
 def test_mixed_single_multiple_outputs():
     def create_graph():
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
 
-        cb_1 = compiler_begin(data, 'test_target')
+        cb_1 = compiler_begin(data, "test_target")
         O_1 = relay.abs(cb_1)
-        ce_2 = compiler_end(O_1, 'test_target')
+        ce_2 = compiler_end(O_1, "test_target")
         O_2 = relay.nn.relu(O_1)
-        ce_3 = compiler_end(O_2, 'test_target')
+        ce_3 = compiler_end(O_2, "test_target")
 
         X = relay.tanh(ce_2)
 
-        cb_3 = compiler_begin(ce_3, 'test_target')
-        cb_4 = compiler_begin(X, 'test_target')
+        cb_3 = compiler_begin(ce_3, "test_target")
+        cb_4 = compiler_begin(X, "test_target")
         O_3 = relay.add(cb_3, cb_4)
-        ce_4 = compiler_end(O_3, 'test_target')
+        ce_4 = compiler_end(O_3, "test_target")
 
         func = relay.Function([data], ce_4)
         return func
@@ -793,7 +774,7 @@ def test_mixed_single_multiple_outputs():
         mod = tvm.IRModule()
 
         # function 1
-        f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10))
+        f1_cb1 = relay.var("test_target_0_i0", shape=(10, 10))
         f1_O_1 = relay.abs(f1_cb1)
         f1_O_2 = relay.nn.relu(f1_O_1)
         f1_out = relay.Tuple((f1_O_2, f1_O_1))
@@ -803,8 +784,8 @@ def test_mixed_single_multiple_outputs():
         mod[gv1] = func1
 
         # function 0
-        f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10))
-        f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10))
+        f2_cb3 = relay.var("test_target_1_i0", shape=(10, 10))
+        f2_cb4 = relay.var("test_target_1_i1", shape=(10, 10))
         f2_O_3 = relay.add(f2_cb3, f2_cb4)
         func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
         func0 = set_func_attr(func0, "test_target", "test_target_1")
@@ -812,7 +793,7 @@ def test_mixed_single_multiple_outputs():
         mod[gv0] = func0
 
         # body
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         tuple_out = gv1(data)
         ce_2 = relay.TupleGetItem(tuple_out, 1)
         ce_3 = relay.TupleGetItem(tuple_out, 0)
@@ -836,19 +817,18 @@ def test_dnnl_fuse():
     dnnl_patterns = get_pattern_table("dnnl")
     conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns
 
-    def get_blocks(prefix, data, in_channel, out_channel,
-                   include_bn=True, include_sigmoid=False):
+    def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False):
         weight = relay.var(prefix + "weight")
         bn_gamma = relay.var(prefix + "bn_gamma")
         bn_beta = relay.var(prefix + "bn_beta")
         bn_mmean = relay.var(prefix + "bn_mean")
         bn_mvar = relay.var(prefix + "bn_var")
 
-        layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
-                                channels=out_channel, padding=(1, 1))
+        layer = relay.nn.conv2d(
+            data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
+        )
         if include_bn:
-            bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
-                                            bn_mmean, bn_mvar)
+            bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
             layer = bn_output[0]
         if include_sigmoid:
             # dummy layer to prevent pattern detection
@@ -867,29 +847,31 @@ def test_dnnl_fuse():
         # This is required for constant folding
         mod["main"] = bind_params_by_name(mod["main"], params)
 
-        remove_bn_pass = tvm.transform.Sequential([
-            transform.InferType(),
-            transform.SimplifyInference(),
-            transform.FoldConstant(),
-            transform.FoldScaleAxis(),
-        ])
-        composite_partition = tvm.transform.Sequential([
-            remove_bn_pass,
-            transform.MergeComposite(pattern_table),
-            transform.AnnotateTarget("dnnl"),
-            transform.PartitionGraph()
-        ])
-
-        with tvm.transform.PassContext(opt_level=3,
-                                       disabled_pass=["AlterOpLayout"]):
+        remove_bn_pass = tvm.transform.Sequential(
+            [
+                transform.InferType(),
+                transform.SimplifyInference(),
+                transform.FoldConstant(),
+                transform.FoldScaleAxis(),
+            ]
+        )
+        composite_partition = tvm.transform.Sequential(
+            [
+                remove_bn_pass,
+                transform.MergeComposite(pattern_table),
+                transform.AnnotateTarget("dnnl"),
+                transform.PartitionGraph(),
+            ]
+        )
+
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             return composite_partition(mod)
 
-    def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
-                            num_expected_partition):
+    def test_detect_pattern(pattern_table, include_bn, include_sigmoid, num_expected_partition):
         net = get_net(include_bn, include_sigmoid)
         mod, params = tvm.relay.testing.create_workload(net)
         mod = get_partitoned_mod(mod, params, pattern_table)
-        assert(len(mod.functions) - 1 == num_expected_partition)  # -1 for main
+        assert len(mod.functions) - 1 == num_expected_partition  # -1 for main
 
     def test_partition():
         # conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
@@ -909,7 +891,7 @@ def test_dnnl_fuse():
         mod, params = relay.testing.mobilenet.get_workload()
         mod = get_partitoned_mod(mod, params, dnnl_patterns)
         # 27 fused conv + bn + relu and one dense
-        assert(len(mod.functions) - 1 == 28)  # -1 for main
+        assert len(mod.functions) - 1 == 28  # -1 for main
 
     def test_exec(mod, params, ref_mod, ref_params, out_shape):
         ishape = (1, 3, 224, 224)
@@ -920,8 +902,7 @@ def test_dnnl_fuse():
 
         mod = get_partitoned_mod(mod, params, dnnl_patterns)
 
-        check_result(mod, {"data": i_data},
-                     out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
+        check_result(mod, {"data": i_data}, out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
 
     test_partition()
     test_partition_mobilenet()
@@ -1026,15 +1007,16 @@ def test_multiple_use_of_an_output():
     test_same_output_region()
     test_different_output_region()
 
+
 def test_duplicate_outputs():
     target = "test_duplicate_outputs"
 
     @tvm.ir.register_op_attr("abs", "target." + target)
-    def abs(attrs, args): # pylint: disable=unused-variable
+    def abs(attrs, args):  # pylint: disable=unused-variable
         return True
 
     def create_graph():
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         x = relay.abs(data)
         out_1 = relay.nn.relu(x)
         out_2 = relay.tanh(x)
@@ -1047,19 +1029,19 @@ def test_duplicate_outputs():
         mod = tvm.IRModule()
 
         # function 0
-        f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
+        f0_i0 = relay.var(target + "_0_i0", shape=(10, 10))
         f0_o0 = relay.abs(f0_i0)
         func0 = relay.Function([f0_i0], f0_o0)
 
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler", target)
-        func0 = func0.with_attr("global_symbol", target+"_0")
-        gv0 = relay.GlobalVar(target+"_0")
+        func0 = func0.with_attr("global_symbol", target + "_0")
+        gv0 = relay.GlobalVar(target + "_0")
         mod[gv0] = func0
 
         # body
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         function_out = gv0(data)
         out_1 = relay.nn.relu(function_out)
         out_2 = relay.tanh(function_out)
@@ -1072,29 +1054,32 @@ def test_duplicate_outputs():
     mod = tvm.IRModule()
     mod["main"] = create_graph()
 
-    seq = tvm.transform.Sequential([
-        transform.AnnotateTarget(target),
-        transform.MergeCompilerRegions(),
-        transform.PartitionGraph(),
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            transform.AnnotateTarget(target),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
 
     ref_mod = expected()
     partitioned = seq(mod)
     assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
+
 def test_duplicate_merge_and_tuplegetitem():
     target = "test_duplicate_merge_and_tuplegetitem"
 
     @tvm.ir.register_op_attr("nn.batch_norm", "target." + target)
-    def batch_norm(attrs, args): # pylint: disable=unused-variable
+    def batch_norm(attrs, args):  # pylint: disable=unused-variable
         return True
 
     @tvm.ir.register_op_attr("nn.relu", "target." + target)
-    def relu(attrs, args): # pylint: disable=unused-variable
+    def relu(attrs, args):  # pylint: disable=unused-variable
         return True
 
     def create_graph():
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         bn_gamma = relay.var("bn_gamma")
         bn_beta = relay.var("bn_beta")
         bn_mmean = relay.var("bn_mean")
@@ -1131,7 +1116,7 @@ def test_duplicate_merge_and_tuplegetitem():
         mod[gv0] = func0
 
         # body
-        data = relay.var('data', shape=(10, 10))
+        data = relay.var("data", shape=(10, 10))
         bn_gamma = relay.var("bn_gamma")
         bn_beta = relay.var("bn_beta")
         bn_mmean = relay.var("bn_mean")
@@ -1149,44 +1134,51 @@ def test_duplicate_merge_and_tuplegetitem():
     mod = tvm.IRModule()
     mod["main"] = create_graph()
 
-    seq = tvm.transform.Sequential([
-        transform.AnnotateTarget(target),
-        transform.MergeCompilerRegions(),
-        transform.PartitionGraph(),
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            transform.AnnotateTarget(target),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
 
     ref_mod = expected()
     partitioned = seq(mod)
     assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
+
 def test_constant_tuples():
     @tvm.ir.register_op_attr("qnn.concatenate", "target.const_tuples")
     def add(attrs, args):  # pylint: disable=unused-variable
         return True
 
     def create_graph():
-        a = relay.var('a', shape=(10, 10), dtype="uint8")
-        b = relay.var('b', shape=(10, 10), dtype="uint8")
+        a = relay.var("a", shape=(10, 10), dtype="uint8")
+        b = relay.var("b", shape=(10, 10), dtype="uint8")
         a1 = relay.abs(a)
 
         zeroi = relay.const(1, "int32")
         zerof = relay.const(0, "float32")
-        con = relay.qnn.op.concatenate((a1, b),
-                                       input_scales=(zerof, zerof),
-                                       input_zero_points=(zeroi, zeroi),
-                                       output_scale=zerof,
-                                       output_zero_point=zeroi,
-                                       axis=1)
+        con = relay.qnn.op.concatenate(
+            (a1, b),
+            input_scales=(zerof, zerof),
+            input_zero_points=(zeroi, zeroi),
+            output_scale=zerof,
+            output_zero_point=zeroi,
+            axis=1,
+        )
 
         f = relay.Function([a, b], con)
         mod = tvm.IRModule.from_expr(f)
         return mod
 
-    seq = tvm.transform.Sequential([
-        transform.AnnotateTarget("const_tuples"),
-        transform.MergeCompilerRegions(),
-        transform.PartitionGraph(),
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            transform.AnnotateTarget("const_tuples"),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
 
     partitioned = seq(create_graph())
     concat = partitioned["const_tuples_0"].body
@@ -1195,22 +1187,23 @@ def test_constant_tuples():
     assert type(concat.args[3]) == relay.Constant
     assert type(concat.args[4]) == relay.Constant
 
+
 def test_flatten_tuple_output():
     target = "test_flatten_tuple_output"
 
     @tvm.ir.register_op_attr("split", "target." + target)
-    def split(attrs, args): # pylint: disable=unused-variable
+    def split(attrs, args):  # pylint: disable=unused-variable
         return True
 
     @tvm.ir.register_op_attr("abs", "target." + target)
-    def abs(attrs, args): # pylint: disable=unused-variable
+    def abs(attrs, args):  # pylint: disable=unused-variable
         return True
 
     def create_graph():
-        a = relay.var('a', shape=(10, 10), dtype="uint8")
+        a = relay.var("a", shape=(10, 10), dtype="uint8")
 
         a_split = relay.split(a, 2)
-        a_split_0 = relay.TupleGetItem(a_split.astuple(),0)
+        a_split_0 = relay.TupleGetItem(a_split.astuple(), 0)
         a_split_0_abs = relay.abs(a_split_0)
 
         a_con = relay.concatenate(a_split, 0)
@@ -1241,49 +1234,56 @@ def test_flatten_tuple_output():
         gv0 = relay.GlobalVar(target + "_0")
         mod[gv0] = func0
 
-        #body
-        data = relay.var('a', shape=(10, 10), dtype="uint8")
+        # body
+        data = relay.var("a", shape=(10, 10), dtype="uint8")
         f_out = gv0(data)
         f_out_0 = relay.TupleGetItem(f_out, 0)
         f_out_1 = relay.TupleGetItem(f_out, 1)
         tuple = relay.Tuple((f_out_0, f_out_1))
-        concat = relay.concatenate(tuple,0)
+        concat = relay.concatenate(tuple, 0)
         f_out_2 = relay.TupleGetItem(f_out, 2)
         relu = relay.nn.relu(f_out_2)
         ret_tuple = relay.Tuple((concat, relu))
         mod["main"] = relay.Function([data], ret_tuple)
         return mod
 
-    seq = tvm.transform.Sequential([
-        transform.AnnotateTarget(target),
-        transform.MergeCompilerRegions(),
-        transform.PartitionGraph(),
-    ])
+    seq = tvm.transform.Sequential(
+        [
+            transform.AnnotateTarget(target),
+            transform.MergeCompilerRegions(),
+            transform.PartitionGraph(),
+        ]
+    )
 
     partitioned = seq(create_graph())
     assert tvm.ir.structural_equal(partitioned, expected(), map_free_vars=True)
 
+
 def test_tuple_output_exec():
     """Test C codegen and runtime for a subgraph with a tuple output"""
-    a = relay.var('a', shape=(10, 10), dtype='float32')
-    b = relay.var('b', shape=(10, 10), dtype='float32')
-    ba = relay.annotation.compiler_begin(a, 'ccompiler')
-    bb = relay.annotation.compiler_begin(b, 'ccompiler')
+    a = relay.var("a", shape=(10, 10), dtype="float32")
+    b = relay.var("b", shape=(10, 10), dtype="float32")
+    ba = relay.annotation.compiler_begin(a, "ccompiler")
+    bb = relay.annotation.compiler_begin(b, "ccompiler")
     add = relay.add(ba, bb)
     sub = relay.subtract(ba, bb)
     out = relay.Tuple((add, sub))
-    eout = relay.annotation.compiler_end(out, 'ccompiler')
-    func=relay.Function([a, b], eout)
+    eout = relay.annotation.compiler_end(out, "ccompiler")
+    func = relay.Function([a, b], eout)
     mod = tvm.IRModule()
     mod["main"] = func
     mod = transform.PartitionGraph()(mod)
 
-    a_data = np.random.rand(10, 10).astype('float32')
-    b_data = np.random.rand(10, 10).astype('float32')
+    a_data = np.random.rand(10, 10).astype("float32")
+    b_data = np.random.rand(10, 10).astype("float32")
+
+    check_result(
+        mod,
+        {"a": a_data, "b": b_data},
+        [(10, 10), (10, 10)],
+        [(a_data + b_data), (a_data - b_data)],
+    )
 
-    check_result(mod, {'a': a_data, 'b': b_data},
-                 [(10, 10), (10, 10)],
-                 [(a_data + b_data), (a_data - b_data)])
 
 def test_extern_opt():
     def Optimize(mod):
@@ -1291,13 +1291,13 @@ def test_extern_opt():
 
     tvm.register_func("relay.ext.test_target.optimize", Optimize)
 
-    x = relay.var('x', shape=(2, 2))
-    y0 = relay.var('y0', shape=(2, 2))
-    y1 = relay.var('y1', shape=(2, 2))
-    yy0 = relay.annotation.compiler_begin(y0, 'test_target')
-    yy1 = relay.annotation.compiler_begin(y1, 'test_target')
+    x = relay.var("x", shape=(2, 2))
+    y0 = relay.var("y0", shape=(2, 2))
+    y1 = relay.var("y1", shape=(2, 2))
+    yy0 = relay.annotation.compiler_begin(y0, "test_target")
+    yy1 = relay.annotation.compiler_begin(y1, "test_target")
     z = yy0 + yy1
-    end = relay.annotation.compiler_end(z, 'test_target')
+    end = relay.annotation.compiler_end(z, "test_target")
     f = relay.Function([x, y0, y1], end * x)
     c = np.ones(shape=(2, 2), dtype="float32")
     f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)})
@@ -1313,8 +1313,8 @@ def test_extern_opt():
     assert isinstance(t0.body, relay.Constant)
     expected = np.empty([2, 2])
     expected.fill(2)
-    tvm.testing.assert_allclose(t0.body.data.asnumpy(), expected, rtol=1e-5,
-                                atol=1e-5)
+    tvm.testing.assert_allclose(t0.body.data.asnumpy(), expected, rtol=1e-5, atol=1e-5)
+
 
 if __name__ == "__main__":
     test_multi_node_compiler()
index cf72107..37da3ab 100644 (file)
@@ -24,15 +24,16 @@ from tvm.contrib import graph_runtime
 from tvm.relay import transform, analysis
 from tvm.relay.testing.temp_op_attr import TempOpAttr
 
+
 def alpha_equal(x, y):
     """
     Wrapper around alpha equality which ensures that
     the hash function respects equality.
     """
-    x = x['main']
-    y = y['main']
-    return tvm.ir.structural_equal(x, y) and \
-            tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y)
+    x = x["main"]
+    y = y["main"]
+    return tvm.ir.structural_equal(x, y) and tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y)
+
 
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
@@ -43,39 +44,47 @@ def run_opt_pass(expr, passes):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
+
 def test_qnn_legalize():
     """Test directly replacing an operator with a new one"""
+
     def before():
-        x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
-        y = relay.qnn.op.requantize(x,
-                                    input_scale=relay.const(1, 'float32'),
-                                    input_zero_point=relay.const(0, 'int32'),
-                                    output_scale=relay.const(1, 'float32'),
-                                    output_zero_point=relay.const(0, 'int32'),
-                                    out_dtype='int8')
+        x = relay.var("x", shape=(1, 64, 56, 56), dtype="int8")
+        y = relay.qnn.op.requantize(
+            x,
+            input_scale=relay.const(1, "float32"),
+            input_zero_point=relay.const(0, "int32"),
+            output_scale=relay.const(1, "float32"),
+            output_zero_point=relay.const(0, "int32"),
+            out_dtype="int8",
+        )
         y = relay.Function([x], y)
         return y
 
     def legalize_qnn_requantize(attrs, inputs, types):
         data = inputs[0]
-        data = relay.add(relay.const(0, 'int8'), data)
-        y = relay.qnn.op.requantize(data,
-                                    input_scale=relay.const(1, 'float32'),
-                                    input_zero_point=relay.const(0, 'int32'),
-                                    output_scale=relay.const(1, 'float32'),
-                                    output_zero_point=relay.const(0, 'int32'),
-                                    out_dtype='int8')
+        data = relay.add(relay.const(0, "int8"), data)
+        y = relay.qnn.op.requantize(
+            data,
+            input_scale=relay.const(1, "float32"),
+            input_zero_point=relay.const(0, "int32"),
+            output_scale=relay.const(1, "float32"),
+            output_zero_point=relay.const(0, "int32"),
+            out_dtype="int8",
+        )
         return y
 
     def expected():
-        x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
-        y = relay.add(relay.const(0, 'int8'), x)
-        z = relay.qnn.op.requantize(y,
-                                    input_scale=relay.const(1, 'float32'),
-                                    input_zero_point=relay.const(0, 'int32'),
-                                    output_scale=relay.const(1, 'float32'),
-                                    output_zero_point=relay.const(0, 'int32'),
-                                    out_dtype='int8')
+        x = relay.var("x", shape=(1, 64, 56, 56), dtype="int8")
+        y = relay.add(relay.const(0, "int8"), x)
+        z = relay.qnn.op.requantize(
+            y,
+            input_scale=relay.const(1, "float32"),
+            input_zero_point=relay.const(0, "int32"),
+            output_scale=relay.const(1, "float32"),
+            output_zero_point=relay.const(0, "int32"),
+            out_dtype="int8",
+        )
         z = relay.Function([x], z)
         return z
 
@@ -98,42 +107,44 @@ def test_qnn_legalize_qnn_conv2d():
     def _get_mod(data_dtype, kernel_dtype):
         data_shape = (1, 64, 256, 256)
         kernel_shape = (128, 64, 3, 3)
-        data = relay.var("data", shape=data_shape,
-                dtype=data_dtype)
-        kernel = relay.var("kernel", shape=kernel_shape,
-                dtype=kernel_dtype)
+        data = relay.var("data", shape=data_shape, dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
         func = relay.qnn.op.conv2d(
-                data, kernel,
-                input_zero_point=relay.const(1, 'int32'),
-                kernel_zero_point=relay.const(1, 'int32'),
-                input_scale=relay.const(1.0, 'float32'),
-                kernel_scale=relay.const(1.0, 'float32'),
-                kernel_size=(3, 3),
-                channels=kernel_shape[0],
-                strides=(1, 1),
-                dilation=(1, 1),
-                out_dtype='int32',
-                data_layout='NCHW',
-                kernel_layout='OIHW')
+            data,
+            kernel,
+            input_zero_point=relay.const(1, "int32"),
+            kernel_zero_point=relay.const(1, "int32"),
+            input_scale=relay.const(1.0, "float32"),
+            kernel_scale=relay.const(1.0, "float32"),
+            kernel_size=(3, 3),
+            channels=kernel_shape[0],
+            strides=(1, 1),
+            dilation=(1, 1),
+            out_dtype="int32",
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
 
         mod = relay.Function(relay.analysis.free_vars(func), func)
         mod = tvm.IRModule.from_expr(mod)
         return mod
 
     # Check uint8 x uint8 and int8 x int8 transformation
-    for dtype in ('uint8', 'int8'):
+    for dtype in ("uint8", "int8"):
         mod = _get_mod(dtype, dtype)
 
         #############################################################
         # Check transformations for platforms with fast Int8 support.
         #############################################################
         # Check that Intel VNNI gets picked up.
-        with tvm.target.Target('llvm -mcpu=skylake-avx512'):
+        with tvm.target.Target("llvm -mcpu=skylake-avx512"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
 
         # Since same dtype, there should not be any transformation
-        with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        with tvm.target.Target(
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
+        ):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
             assert tvm.ir.structural_equal(mod, legalized_mod)
 
@@ -141,86 +152,90 @@ def test_qnn_legalize_qnn_conv2d():
         # Check transformations for platforms without fast Int8 support.
         ################################################################
         # Older Intel versions.
-        with tvm.target.Target('llvm'):
+        with tvm.target.Target("llvm"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
         # Older ARM vesions.
-        with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu'):
+        with tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     # Check uint8 x int8 transformation
-    mod = _get_mod('uint8', 'int8')
+    mod = _get_mod("uint8", "int8")
     #############################################################
     # Check transformations for platforms with fast Int8 support.
     #############################################################
     # Check no transformation for Intel VNNI.
-    with tvm.target.Target('llvm -mcpu=skylake-avx512'):
+    with tvm.target.Target("llvm -mcpu=skylake-avx512"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
         assert tvm.ir.structural_equal(mod, legalized_mod)
 
     # ARM - so check that transformation has happened.
-    with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+    with tvm.target.Target(
+        "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
+    ):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
 
     ################################################################
     # Check transformations for platforms without fast Int8 support.
     ################################################################
     # Older Intel versions.
-    with tvm.target.Target('llvm'):
+    with tvm.target.Target("llvm"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     # Older ARM vesions.
-    with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu'):
+    with tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     ###########################################
     # Check transformations for CUDA platforms.
     ###########################################
-    with tvm.target.Target('cuda'):
+    with tvm.target.Target("cuda"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" in legalized_mod.astext()
 
 
 def test_qnn_legalize_qnn_dense():
     def _get_mod(data_dtype, kernel_dtype):
         data_shape = (10, 3)
         kernel_shape = (20, 3)
-        data = relay.var("data", shape=data_shape,
-                dtype=data_dtype)
-        kernel = relay.var("kernel", shape=kernel_shape,
-                dtype=kernel_dtype)
+        data = relay.var("data", shape=data_shape, dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
         func = relay.qnn.op.dense(
-                data, kernel,
-                input_zero_point=relay.const(1, 'int32'),
-                kernel_zero_point=relay.const(1, 'int32'),
-                input_scale=relay.const(1, 'float32'),
-                kernel_scale=relay.const(1, 'float32'),
-                units=kernel_shape[0],
-                out_dtype='int32')
+            data,
+            kernel,
+            input_zero_point=relay.const(1, "int32"),
+            kernel_zero_point=relay.const(1, "int32"),
+            input_scale=relay.const(1, "float32"),
+            kernel_scale=relay.const(1, "float32"),
+            units=kernel_shape[0],
+            out_dtype="int32",
+        )
 
         mod = relay.Function(relay.analysis.free_vars(func), func)
         mod = tvm.IRModule.from_expr(mod)
         return mod
 
     # Check uint8 x uint8 and int8 x int8 transformation
-    for dtype in ('uint8', 'int8'):
+    for dtype in ("uint8", "int8"):
         mod = _get_mod(dtype, dtype)
 
         #############################################################
         # Check transformations for platforms with fast Int8 support.
         #############################################################
         # Check that Intel VNNI gets picked up.
-        with tvm.target.Target('llvm -mcpu=skylake-avx512'):
+        with tvm.target.Target("llvm -mcpu=skylake-avx512"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
 
         # Since same dtype, there should not be any transformation
-        with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        with tvm.target.Target(
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
+        ):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
             assert tvm.ir.structural_equal(mod, legalized_mod)
 
@@ -228,49 +243,51 @@ def test_qnn_legalize_qnn_dense():
         # Check transformations for platforms without fast Int8 support.
         ################################################################
         # Older Intel versions.
-        with tvm.target.Target('llvm'):
+        with tvm.target.Target("llvm"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
         # Older ARM vesions.
-        with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu'):
+        with tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+            assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     # Check uint8 x int8 transformation
-    mod = _get_mod('uint8', 'int8')
+    mod = _get_mod("uint8", "int8")
     #############################################################
     # Check transformations for platforms with fast Int8 support.
     #############################################################
     # Check no transformation for Intel VNNI.
-    with tvm.target.Target('llvm -mcpu=skylake-avx512'):
+    with tvm.target.Target("llvm -mcpu=skylake-avx512"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
         assert tvm.ir.structural_equal(mod, legalized_mod)
 
     # ARM - so check that transformation has happened.
-    with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+    with tvm.target.Target(
+        "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
+    ):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
 
     ################################################################
     # Check transformations for platforms without fast Int8 support.
     ################################################################
     # Older Intel versions.
-    with tvm.target.Target('llvm'):
+    with tvm.target.Target("llvm"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     # Older ARM vesions.
-    with tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu'):
+    with tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
     ###########################################
     # Check transformations for CUDA platforms.
     ###########################################
-    with tvm.target.Target('cuda'):
+    with tvm.target.Target("cuda"):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
+        assert "cast" in legalized_mod.astext() and "qnn" in legalized_mod.astext()
 
 
 if __name__ == "__main__":
index 43b54e9..271dc8e 100644 (file)
@@ -29,7 +29,7 @@ def test_remove_all_prelude_functions():
     mod["main"] = relay.Function([x], x)
     mod = relay.transform.RemoveUnusedFunctions()(mod)
     l = set([x[0].name_hint for x in mod.functions.items()])
-    assert l == set(['main'])
+    assert l == set(["main"])
 
 
 def test_remove_all_prelude_functions_but_referenced_functions():
@@ -37,13 +37,13 @@ def test_remove_all_prelude_functions_but_referenced_functions():
     p = Prelude(mod)
     x = relay.var("x", shape=(1, 16))
     id_func = relay.Function([x], x)
-    id_name = relay.GlobalVar('id_func')
+    id_name = relay.GlobalVar("id_func")
     mod[id_name] = id_func
 
     mod["main"] = relay.Function([x], id_name(x))
     mod = relay.transform.RemoveUnusedFunctions()(mod)
     l = set([x[0].name_hint for x in mod.functions.items()])
-    assert l == set(['id_func', 'main'])
+    assert l == set(["id_func", "main"])
 
 
 def test_keep_only_referenced_prelude_functions():
@@ -56,7 +56,7 @@ def test_keep_only_referenced_prelude_functions():
     mod["main"] = relay.Function([], body)
     mod = relay.transform.RemoveUnusedFunctions()(mod)
     l = set([x[0].name_hint for x in mod.functions.items()])
-    assert l == set(['tl', 'hd', 'main'])
+    assert l == set(["tl", "hd", "main"])
 
 
 def test_multiple_entry_functions():
@@ -70,29 +70,29 @@ def test_multiple_entry_functions():
 
     x = relay.var("x", shape=(1, 16))
     id_func = relay.Function([x], x)
-    id_name = relay.GlobalVar('id_func')
+    id_name = relay.GlobalVar("id_func")
     mod[id_name] = id_func
     mod["main2"] = relay.Function([x], id_name(x))
-    mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod)
+    mod = relay.transform.RemoveUnusedFunctions(["main1", "main2"])(mod)
     l = set([x[0].name_hint for x in mod.functions.items()])
-    assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
+    assert l == set(["tl", "hd", "main2", "id_func", "main1"])
 
 
 def test_globalvar_as_call_arg():
     mod = tvm.IRModule()
     p = Prelude(mod)
-    tensor_array = p.get_var('tensor_array', 'int32')
-    tensor1 = p.get_var('tensor1', 'int32')
-    write = p.get_var('tensor_array_write', 'int32')
-    stack = p.get_var('tensor_array_stack', 'int32')
-    v = relay.var('v')
+    tensor_array = p.get_var("tensor_array", "int32")
+    tensor1 = p.get_var("tensor1", "int32")
+    write = p.get_var("tensor_array_write", "int32")
+    stack = p.get_var("tensor_array_stack", "int32")
+    v = relay.var("v")
     init_tensor_array = tensor_array(relay.const(3))
     tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
     tensor_array2 = stack(tensor_array1)
     mod["main"] = relay.Function([v], tensor_array2)
     mod = relay.transform.RemoveUnusedFunctions()(mod)
     l = set([x[0].name_hint for x in mod.functions.items()])
-    assert 'tensor_array_int32' in l
+    assert "tensor_array_int32" in l
 
 
 def test_call_globalvar_without_args():
@@ -100,18 +100,19 @@ def test_call_globalvar_without_args():
         mod = tvm.IRModule({})
         fn1 = relay.Function([], relay.const(1))
         fn2 = relay.Function([], relay.const(2))
-        g1 = relay.GlobalVar('g1')
-        g2 = relay.GlobalVar('g2')
+        g1 = relay.GlobalVar("g1")
+        g2 = relay.GlobalVar("g2")
         mod[g1] = fn1
         mod[g2] = fn2
-        p = relay.var('p', 'bool')
-        mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
+        p = relay.var("p", "bool")
+        mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
         return mod
+
     mod = get_mod()
     ref_mod = get_mod()
     mod = relay.transform.RemoveUnusedFunctions()(mod)
     assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     pytest.main()
index e934c11..b57abc6 100644 (file)
@@ -19,6 +19,7 @@ from tvm import relay
 from tvm.relay import transform
 from tvm.relay.testing import run_opt_pass
 
+
 def test_simplify_reshape():
     def before():
         x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
@@ -37,7 +38,7 @@ def test_simplify_reshape():
         return relay.Function([x, w], y)
 
     def symbolic():
-        b = tvm.te.size_var('b')
+        b = tvm.te.size_var("b")
         x = relay.var("x", shape=(b, 16, 16, 16), dtype="float32")
         w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
         y = relay.nn.conv2d(x, w, padding=(1, 1))
@@ -56,5 +57,6 @@ def test_simplify_reshape():
     after = run_opt_pass(symbolic(), transform.InferType())
     assert tvm.ir.structural_equal(zz, after)
 
+
 if __name__ == "__main__":
     test_simplify_reshape()
index 3a8c90b..f557e2c 100644 (file)
@@ -18,14 +18,14 @@ from tvm.ir import IRModule, structural_equal
 from tvm import relay as rly
 from tvm.relay.transform import SimplifyInference
 
-def test_simplify_batchnorm(dtype='float32'):
-    def simple_bn(x, gamma, beta, moving_mean, moving_var,
-                  axis=1, epsilon=1e-5, shape=None):
+
+def test_simplify_batchnorm(dtype="float32"):
+    def simple_bn(x, gamma, beta, moving_mean, moving_var, axis=1, epsilon=1e-5, shape=None):
         # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
-        scale = rly.multiply(rly.const(1, dtype) /
-                rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma)
-        shift = rly.add(
-            rly.multiply(rly.negative(moving_mean), scale), beta)
+        scale = rly.multiply(
+            rly.const(1, dtype) / rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma
+        )
+        shift = rly.add(rly.multiply(rly.negative(moving_mean), scale), beta)
         num_newaxis = len(shape) - (axis + 1)
         if num_newaxis:
             scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
@@ -44,12 +44,26 @@ def test_simplify_batchnorm(dtype='float32'):
         y1, y2 = x, x
 
         for _ in range(nstep):
-            y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
-                gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
+            y1, _, _ = rly.nn.batch_norm(
+                y1 + rly.const(1, dtype),
+                gamma,
+                beta,
+                moving_mean,
+                moving_var,
+                epsilon=eps,
+                axis=axis,
+            )
             y1 = rly.nn.dropout(y1)
-            y2 = simple_bn(y2 + rly.const(1, dtype),
-                           gamma, beta, moving_mean, moving_var,
-                           epsilon=eps, axis=axis, shape=ttype1.shape)
+            y2 = simple_bn(
+                y2 + rly.const(1, dtype),
+                gamma,
+                beta,
+                moving_mean,
+                moving_var,
+                epsilon=eps,
+                axis=axis,
+                shape=ttype1.shape,
+            )
 
         mod = IRModule.from_expr(y1)
         simplify = SimplifyInference()
@@ -64,5 +78,5 @@ def test_simplify_batchnorm(dtype='float32'):
 
 
 if __name__ == "__main__":
-    test_simplify_batchnorm(dtype='float32')
-    test_simplify_batchnorm(dtype='float16')
+    test_simplify_batchnorm(dtype="float32")
+    test_simplify_batchnorm(dtype="float16")
index 5a63db7..7f5d6a4 100644 (file)
@@ -30,7 +30,7 @@ def run_opt_pass(expr, passes):
     mod = tvm.IRModule.from_expr(expr)
     seq = tvm.transform.Sequential(passes)
     with tvm.transform.PassContext(opt_level=3):
-       mod = seq(mod)
+        mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
@@ -64,11 +64,11 @@ def test_order():
     val = x + y * z
     check_eval(val, 7.0)
     anf = run_opt_pass(val, [transform.ToANormalForm(), transform.InferType()])
-    a = relay.Var('a', relay.IncompleteType())
-    b = relay.Var('b', relay.IncompleteType())
-    c = relay.Var('c', relay.IncompleteType())
-    d = relay.Var('d', relay.IncompleteType())
-    e = relay.Var('e', relay.IncompleteType())
+    a = relay.Var("a", relay.IncompleteType())
+    b = relay.Var("b", relay.IncompleteType())
+    c = relay.Var("c", relay.IncompleteType())
+    d = relay.Var("d", relay.IncompleteType())
+    e = relay.Var("e", relay.IncompleteType())
     expected_output = e
     expected_output = relay.Let(e, a + d, expected_output)
     expected_output = relay.Let(d, b * c, expected_output)
@@ -83,10 +83,10 @@ def test_if():
     cond = relay.const(True)
     x = relay.If(cond, relay.const(2), relay.const(3))
     anf = run_opt_pass(x, [transform.ToANormalForm(), transform.InferType()])
-    a = relay.Var('a', relay.IncompleteType())
-    b = relay.Var('b', relay.IncompleteType())
-    c = relay.Var('c', relay.IncompleteType())
-    d = relay.Var('d', relay.IncompleteType())
+    a = relay.Var("a", relay.IncompleteType())
+    b = relay.Var("b", relay.IncompleteType())
+    c = relay.Var("c", relay.IncompleteType())
+    d = relay.Var("d", relay.IncompleteType())
     true_branch = relay.Let(a, relay.const(2), a)
     false_branch = relay.Let(b, relay.const(3), b)
     expected_output = relay.If(c, true_branch, false_branch)
@@ -112,27 +112,27 @@ def test_recursion():
        f(5);
     """
     mod = tvm.IRModule()
-    i64 = relay.TensorType((), 'int64')
+    i64 = relay.TensorType((), "int64")
     f = relay.GlobalVar("f")
     n = relay.Var("n", i64)
-    m = n * relay.const(2, 'int64')
-    funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')),
-                        m,
-                        m + f(n - relay.const(1, 'int64')))
+    m = n * relay.const(2, "int64")
+    funcbody = relay.If(
+        relay.equal(n, relay.const(0, "int64")), m, m + f(n - relay.const(1, "int64"))
+    )
     value = relay.Function([n], funcbody, i64, [])
     mod[f] = value
-    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
     old_f = mod[f]
     mod = transform.ToANormalForm()(mod)
     f = mod[f]
-    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
 
 
 def test_ref():
-    i = relay.Var('i')
-    iv = relay.Var('iv')
-    u = relay.Var('u')
-    uv = relay.Var('uv')
+    i = relay.Var("i")
+    iv = relay.Var("iv")
+    u = relay.Var("u")
+    uv = relay.Var("uv")
     body = relay.add(iv, uv)
     body = relay.Let(uv, relay.RefRead(i), body)
     body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
@@ -167,7 +167,7 @@ def test_nat_add():
 def test_let():
     x = relay.Var("x")
     y = relay.Var("y")
-    d = relay.const(4.0, 'float32')
+    d = relay.const(4.0, "float32")
     body = relay.Let(y, x, x + y)
     body = relay.Let(x, d, body)
     check_eval(body, 8)
@@ -176,10 +176,10 @@ def test_let():
 
 
 def test_function():
-    t = relay.TensorType((), 'float32')
+    t = relay.TensorType((), "float32")
     x = relay.Var("x", t)
     f = relay.Function([x], x + x)
-    d = relay.const(4.0, 'float32')
+    d = relay.const(4.0, "float32")
     anf_f = run_opt_pass(f, transform.ToANormalForm())
     assert isinstance(anf_f, relay.Function)
     check_eval(f(d), 8)
@@ -189,17 +189,17 @@ def test_function():
 def test_gradient_if():
     x = relay.var("a", shape=(1, 16))
     y = relay.var("y", shape=(1, 16))
-    cond = relay.var("cond", shape=(), dtype='uint1')
+    cond = relay.var("cond", shape=(), dtype="uint1")
     net = relay.If(cond, x, x)
     net = relay.add(x, net)
-    net = relay.Function([cond,x,y], net)
+    net = relay.Function([cond, x, y], net)
     mod = tvm.IRModule.from_expr(net)
     mod = relay.transform.ToANormalForm()(mod)
-    mod["main"] = relay.transform.gradient(mod["main"], mode='higher_order')
+    mod["main"] = relay.transform.gradient(mod["main"], mode="higher_order")
     mod = relay.transform.ToANormalForm()(mod)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_explicit_bound()
     test_order()
     test_if()
index 05c6544..dafd1d1 100644 (file)
@@ -32,7 +32,7 @@ def run_opt_pass(expr, passes):
     mod = tvm.IRModule.from_expr(expr)
     seq = tvm.transform.Sequential(passes)
     with tvm.transform.PassContext(opt_level=3):
-       mod = seq(mod)
+        mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
@@ -64,14 +64,15 @@ def test_no_explicit_bind():
     check_eval(bblock(), 8.0)
     check_basic_block_normal_form(bblock)
 
+
 def test_top_level_nested_if():
-    x = relay.var('x', shape=(), dtype='bool')
-    y = relay.var('y', shape=(), dtype='float32')
-    z = relay.var('z', shape=(), dtype='float32')
+    x = relay.var("x", shape=(), dtype="bool")
+    y = relay.var("y", shape=(), dtype="float32")
+    z = relay.var("z", shape=(), dtype="float32")
     cond_t = relay.const(True)
     cond_f = relay.const(False)
-    one = relay.const(1, dtype='float32')
-    three = relay.const(3, dtype='float32')
+    one = relay.const(1, dtype="float32")
+    three = relay.const(3, dtype="float32")
     y2 = relay.add(y, y)
     z2 = relay.add(z, z)
     true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
@@ -97,16 +98,17 @@ def test_top_level_nested_if():
       }
     }
     """
+
     def expected():
-        x = relay.var('x', shape=(), dtype='bool')
-        y = relay.var('y', shape=(), dtype='float32')
-        z = relay.var('z', shape=(), dtype='float32')
+        x = relay.var("x", shape=(), dtype="bool")
+        y = relay.var("y", shape=(), dtype="float32")
+        z = relay.var("z", shape=(), dtype="float32")
         cond_t = relay.const(True)
         cond_f = relay.const(False)
-        one = relay.const(1, dtype='float32')
-        three = relay.const(3, dtype='float32')
-        y2 = relay.var('y2')
-        z2 = relay.var('z2')
+        one = relay.const(1, dtype="float32")
+        three = relay.const(3, dtype="float32")
+        y2 = relay.var("y2")
+        z2 = relay.var("z2")
         true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
         true_branch = relay.Let(y2, relay.add(y, y), true_branch)
         false_branch = relay.If(cond_f, z2, one)
@@ -138,14 +140,15 @@ def test_top_level_nested_if():
     expected_output = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
 
+
 def test_nested_if():
-    x = relay.var('x', shape=(), dtype='bool')
-    y = relay.var('y', shape=(), dtype='float32')
+    x = relay.var("x", shape=(), dtype="bool")
+    y = relay.var("y", shape=(), dtype="float32")
     cond_t = relay.const(True)
     cond_f = relay.const(False)
-    one = relay.const(1, dtype='float32')
-    two = relay.const(2, dtype='float32')
-    three = relay.const(3, dtype='float32')
+    one = relay.const(1, dtype="float32")
+    two = relay.const(2, dtype="float32")
+    three = relay.const(3, dtype="float32")
     y2 = relay.add(y, y)
     true_branch = relay.If(cond_t, y2, relay.add(three, y2))
     false_branch = relay.If(cond_f, two, one)
@@ -168,15 +171,16 @@ def test_nested_if():
       }
     }
     """
+
     def expected():
-        x = relay.var('x', shape=(), dtype='bool')
-        y = relay.var('y', shape=(), dtype='float32')
+        x = relay.var("x", shape=(), dtype="bool")
+        y = relay.var("y", shape=(), dtype="float32")
         cond_t = relay.const(True)
         cond_f = relay.const(False)
-        one = relay.const(1, dtype='float32')
-        two = relay.const(2, dtype='float32')
-        three = relay.const(3, dtype='float32')
-        y2 = relay.var('y2')
+        one = relay.const(1, dtype="float32")
+        two = relay.const(2, dtype="float32")
+        three = relay.const(3, dtype="float32")
+        y2 = relay.var("y2")
         true_branch = relay.If(cond_t, y2, relay.add(three, y2))
         true_branch = relay.Let(y2, relay.add(y, y), true_branch)
         false_branch = relay.If(cond_f, two, one)
@@ -223,27 +227,28 @@ def test_recursion():
        f(5);
     """
     mod = tvm.IRModule()
-    i64 = relay.TensorType((), 'int64')
+    i64 = relay.TensorType((), "int64")
     f = relay.GlobalVar("f")
     n = relay.Var("n", i64)
-    m = n * relay.const(2, 'int64')
-    cond = relay.equal(n, relay.const(0, 'int64'))
-    false_branch = m + f(n - relay.const(1, 'int64'))
+    m = n * relay.const(2, "int64")
+    cond = relay.equal(n, relay.const(0, "int64"))
+    false_branch = m + f(n - relay.const(1, "int64"))
     funcbody = relay.If(cond, m, false_branch)
     value = relay.Function([n], funcbody, i64, [])
     mod[f] = value
-    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
     old_f = mod[f]
     mod = transform.ToBasicBlockNormalForm()(mod)
     f = mod[f]
-    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
     check_basic_block_normal_form(f)
 
+
 def test_ref():
-    i = relay.Var('i')
-    iv = relay.Var('iv')
-    u = relay.Var('u')
-    uv = relay.Var('uv')
+    i = relay.Var("i")
+    iv = relay.Var("iv")
+    u = relay.Var("u")
+    uv = relay.Var("uv")
     body = relay.add(iv, uv)
     body = relay.Let(uv, relay.RefRead(i), body)
     body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
@@ -276,10 +281,11 @@ def test_nat_add():
     assert not Feature.fLet in detect_feature(mod[add])
     check_basic_block_normal_form(opt_expr)
 
+
 def test_let():
     def test_let1():
         x = relay.Var("x")
-        c = relay.const(4.0, 'float32')
+        c = relay.const(4.0, "float32")
         body = relay.Let(x, c, x)
         body = run_opt_pass(body, transform.InferType())
         """
@@ -289,20 +295,20 @@ def test_let():
         opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
         assert tvm.ir.structural_equal(body, opt_body)
         check_basic_block_normal_form(opt_body)
-        
+
     def test_let1_1():
         x = relay.Var("y")
-        d = relay.const(4.0, 'float32')
-        body = relay.Let(x, d, relay.add(x,x))
+        d = relay.const(4.0, "float32")
+        body = relay.Let(x, d, relay.add(x, x))
         body = run_opt_pass(body, transform.InferType())
         opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
         assert tvm.ir.structural_equal(body, opt_body)
         check_basic_block_normal_form(opt_body)
-    
+
     def test_let2():
         x = relay.Var("x")
         y = relay.Var("y")
-        d = relay.const(4.0, 'float32')
+        d = relay.const(4.0, "float32")
         body = relay.Let(y, x, x)
         body = relay.Let(x, d, body)
         body = run_opt_pass(body, transform.InferType())
@@ -311,7 +317,7 @@ def test_let():
         def expected():
             x = relay.Var("x")
             y = relay.Var("y")
-            d = relay.const(4.0, 'float32')
+            d = relay.const(4.0, "float32")
             body = relay.Let(y, x, y)
             body = relay.Let(x, d, body)
             return body
@@ -325,8 +331,8 @@ def test_let():
         x = relay.Var("x")
         y = relay.Var("y")
         z = relay.Var("z")
-        c = relay.const(3.0, 'float32')
-        d = relay.const(4.0, 'float32')
+        c = relay.const(3.0, "float32")
+        d = relay.const(4.0, "float32")
         body = relay.Let(z, x + y, x + z)
         body = relay.Let(x, d, body)
         body = relay.Let(y, c, body)
@@ -340,31 +346,34 @@ def test_let():
     test_let2()
     test_let3()
 
+
 def test_function():
-    t = relay.TensorType((), 'float32')
+    t = relay.TensorType((), "float32")
     x = relay.Var("x", t)
     f = relay.Function([x], x + x)
-    d = relay.const(4.0, 'float32')
+    d = relay.const(4.0, "float32")
     bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
     assert isinstance(bblock, relay.Function)
     check_eval(f(d), 8)
     check_eval(bblock(d), 8)
     check_basic_block_normal_form(bblock)
 
+
 def test_gradient_if():
     x = relay.var("a", shape=(1, 16))
     y = relay.var("y", shape=(1, 16))
-    cond = relay.var("cond", shape=(), dtype='uint1')
+    cond = relay.var("cond", shape=(), dtype="uint1")
     net = relay.If(cond, x, x)
     net = relay.add(x, net)
-    net = relay.Function([cond,x,y], net)
+    net = relay.Function([cond, x, y], net)
     mod = tvm.IRModule.from_expr(net)
     mod = relay.transform.ToBasicBlockNormalForm()(mod)
-    net_grad = relay.transform.gradient(mod["main"], mode='higher_order')
+    net_grad = relay.transform.gradient(mod["main"], mode="higher_order")
     mod["main"] = net_grad
     mod_grad = relay.transform.ToBasicBlockNormalForm()(mod)
-    check_basic_block_normal_form(mod_grad['main'])
-    check_basic_block_normal_form(mod['main'])
+    check_basic_block_normal_form(mod_grad["main"])
+    check_basic_block_normal_form(mod["main"])
+
 
 def test_if():
     def if_expr(x):
@@ -378,8 +387,8 @@ def test_if():
           multiply(%1, 1f)
         }
         """
-        one = relay.const(1, dtype='float32')
-        two = relay.const(2, dtype='float32')
+        one = relay.const(1, dtype="float32")
+        two = relay.const(2, dtype="float32")
         v1 = relay.add(x, one)
         v2 = relay.equal(x, two)
         true_branch = relay.multiply(v1, two)
@@ -398,9 +407,9 @@ def test_if():
           multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */
         }
         """
-        one = relay.const(1, dtype='float32')
-        two = relay.const(2, dtype='float32')
-        v1 = relay.var('v1')
+        one = relay.const(1, dtype="float32")
+        two = relay.const(2, dtype="float32")
+        v1 = relay.var("v1")
         v2 = relay.equal(x, two)
         true_branch = relay.multiply(v1, two)
         false_branch = relay.multiply(v1, one)
@@ -408,7 +417,7 @@ def test_if():
         body = relay.Let(v1, relay.add(x, one), body)
         return body
 
-    x = relay.var('x', shape=(), dtype='float32')
+    x = relay.var("x", shape=(), dtype="float32")
     body = if_expr(x)
     expected_body = expected_if_expr(x)
     bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
@@ -423,13 +432,14 @@ def test_if():
     assert tvm.ir.structural_equal(bblock, expected_bblock)
     check_basic_block_normal_form(bblock)
 
+
 def test_higher_order_return():
-    x = relay.var('x', shape=(1,), dtype='float32')#, a)
-    y = relay.var('y', shape=(1,), dtype='float32')#, a)
-    z = relay.var('z', shape=(1,), dtype='float32')#, a)
+    x = relay.var("x", shape=(1,), dtype="float32")  # , a)
+    y = relay.var("y", shape=(1,), dtype="float32")  # , a)
+    z = relay.var("z", shape=(1,), dtype="float32")  # , a)
     x2 = relay.add(x, x)
-    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
-    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+    func_a = relay.Function([y], relay.add(x2, y))  # , a, [a])
+    func_b = relay.Function([z], relay.add(x2, z))  # , a, [a])
     body = relay.Tuple([func_a, func_b])
     body = relay.Function([x], body)
     """
@@ -450,13 +460,13 @@ def test_higher_order_return():
 
 
 def test_higher_order_nested():
-    x = relay.var('x', dtype='float32', shape=(1,))
-    s = relay.var('s', dtype='float32', shape=(1,))
+    x = relay.var("x", dtype="float32", shape=(1,))
+    s = relay.var("s", dtype="float32", shape=(1,))
     shared = relay.add(s, s)
     func_true = relay.Function([x], relay.add(x, shared))
-    choice_t = relay.FuncType([], relay.scalar_type('bool'))
-    f = relay.Var('f', choice_t)
-    z = relay.Var('z')
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
+    z = relay.Var("z")
     body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
     top = relay.Function([f, s], body)
     """
@@ -478,5 +488,6 @@ def test_higher_order_nested():
     bblock = run_opt_pass(top, transform.ToBasicBlockNormalForm())
     check_basic_block_normal_form(bblock)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     pytest.main([__file__])
index 6edf185..85e3a62 100644 (file)
@@ -46,7 +46,7 @@ def test_recursion():
     p = Prelude(mod)
     add_nat_definitions(p)
     shape = (10, 10)
-    dtype = 'float32'
+    dtype = "float32"
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     double = relay.Function([x], x + x)
@@ -71,22 +71,29 @@ def test_cps_pe():
         x = run_infer_type(x)
         y = un_cps(x)
         y = run_infer_type(y)
-        x = run_opt_pass(x, tvm.transform.Sequential(
-            [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
+        x = run_opt_pass(
+            x,
+            tvm.transform.Sequential(
+                [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]
+            ),
+        )
         assert Feature.fRefCreate not in detect_feature(x)
-    unit = relay.Function([], relay.const(0., dtype='float32'))
+
+    unit = relay.Function([], relay.const(0.0, dtype="float32"))
     f_ref = relay.Var("f_ref")
 
-    one = relay.const(1., dtype='float32')
-    two = relay.const(2., dtype='float32')
-    cond = relay.var(shape=(), dtype='uint1', name_hint='cond')
+    one = relay.const(1.0, dtype="float32")
+    two = relay.const(2.0, dtype="float32")
+    cond = relay.var(shape=(), dtype="uint1", name_hint="cond")
     true_branch = relay.RefWrite(f_ref, relay.Function([], one))
     false_branch = relay.RefWrite(f_ref, relay.Function([], two))
     if_expr = relay.If(cond, true_branch, false_branch)
 
-    stmt = relay.Let(f_ref, relay.RefCreate(unit),
-                     relay.Let(relay.Var("x"), if_expr,
-                               relay.Call(relay.RefRead(f_ref), [])))
+    stmt = relay.Let(
+        f_ref,
+        relay.RefCreate(unit),
+        relay.Let(relay.Var("x"), if_expr, relay.Call(relay.RefRead(f_ref), [])),
+    )
 
     F = relay.Function([cond], stmt)
     destroy_ref(F)
@@ -99,15 +106,15 @@ def test_cps_pe():
     x = relay.var("x", shape=(1, 16))
     y = relay.var("y", shape=(1, 16))
     z = relay.var("z", shape=(1, 16))
-    cond = relay.var("cond", shape=(), dtype='uint1')
+    cond = relay.var("cond", shape=(), dtype="uint1")
     H = relay.If(cond, x, y)
     H = relay.add(H, z)
-    H = relay.Function([cond,x,y,z], H)
+    H = relay.Function([cond, x, y, z], H)
     H = run_infer_type(H)
     H = relay.transform.gradient(H)
     destroy_ref(H)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_recursion()
     test_cps_pe()
index 9488622..88d6829 100644 (file)
@@ -41,9 +41,9 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
 
 
 def test_implicit_share():
-    x = relay.Var('x')
-    y = relay.Var('y')
-    z = relay.Var('z')
+    x = relay.Var("x")
+    y = relay.Var("y")
+    z = relay.Var("z")
     body = relay.Let(z, op.add(y, y), op.add(z, z))
     body = relay.Let(y, op.add(x, x), body)
     f = relay.Function([], relay.Let(x, relay.const(1), body))
@@ -55,9 +55,9 @@ def test_implicit_share():
 
 
 def test_round_trip():
-    x = relay.Var('x')
-    y = relay.Var('y')
-    z = relay.Var('z')
+    x = relay.Var("x")
+    y = relay.Var("y")
+    z = relay.Var("z")
     body = relay.Let(z, op.add(y, y), op.add(z, z))
     body = relay.Let(y, op.add(x, x), body)
     f = relay.Function([], relay.Let(x, relay.const(1), body))
@@ -69,6 +69,7 @@ def test_round_trip():
     check_eval(g, [], 8.0)
     check_eval(h, [], 8.0)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_implicit_share()
     test_round_trip()
index 07193e1..a410347 100644 (file)
@@ -22,9 +22,10 @@ from tvm.relay.prelude import Prelude
 from tvm.relay.analysis import unmatched_cases
 import pytest
 
+
 def test_empty_match_block():
     # empty match block will not match anything, so it should return a wildcard pattern
-    v = relay.Var('v')
+    v = relay.Var("v")
     match = relay.Match(v, [])
 
     unmatched = unmatched_cases(match)
@@ -34,46 +35,50 @@ def test_empty_match_block():
 
 def test_trivial_matches():
     # a match clause with a wildcard will match anything
-    v = relay.Var('v')
-    match = relay.Match(v, [
-        relay.Clause(relay.PatternWildcard(), v)
-    ])
+    v = relay.Var("v")
+    match = relay.Match(v, [relay.Clause(relay.PatternWildcard(), v)])
     assert len(unmatched_cases(match)) == 0
 
     # same with a pattern var
-    w = relay.Var('w')
-    match = relay.Match(v, [
-        relay.Clause(relay.PatternVar(w), w)
-    ])
+    w = relay.Var("w")
+    match = relay.Match(v, [relay.Clause(relay.PatternVar(w), w)])
     assert len(unmatched_cases(match)) == 0
 
 
 def test_single_constructor_adt():
     mod = tvm.IRModule()
-    box = relay.GlobalTypeVar('box')
-    a = relay.TypeVar('a')
-    box_ctor = relay.Constructor('box', [a], box)
+    box = relay.GlobalTypeVar("box")
+    a = relay.TypeVar("a")
+    box_ctor = relay.Constructor("box", [a], box)
     box_data = relay.TypeData(box, [a], [box_ctor])
     mod[box] = box_data
 
-    v = relay.Var('v')
-    match = relay.Match(v, [
-        relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)
-    ])
+    v = relay.Var("v")
+    match = relay.Match(
+        v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)]
+    )
 
     # with one constructor, having one pattern constructor case is exhaustive
     assert len(unmatched_cases(match, mod)) == 0
 
     # this will be so if we nest the constructors too
-    nested_pattern = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                box_ctor,
-                [relay.PatternConstructor(box_ctor,
-                                          [relay.PatternConstructor(
-                                              box_ctor,
-                                              [relay.PatternWildcard()])])]), v)
-    ])
+    nested_pattern = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(
+                    box_ctor,
+                    [
+                        relay.PatternConstructor(
+                            box_ctor,
+                            [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()])],
+                        )
+                    ],
+                ),
+                v,
+            )
+        ],
+    )
     assert len(unmatched_cases(nested_pattern, mod)) == 0
 
 
@@ -81,14 +86,24 @@ def test_too_specific_match():
     mod = tvm.IRModule()
     p = Prelude(mod)
 
-    v = relay.Var('v')
-    match = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternWildcard(),
-                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                                           relay.PatternWildcard()])]), v)
-    ])
+    v = relay.Var("v")
+    match = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternWildcard(),
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                        ),
+                    ],
+                ),
+                v,
+            )
+        ],
+    )
 
     unmatched = unmatched_cases(match, mod)
 
@@ -107,14 +122,24 @@ def test_too_specific_match():
     assert nil_found and single_length_found
 
     # if we add a wildcard, this should work
-    new_match = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternWildcard(),
-                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                                           relay.PatternWildcard()])]), v),
-        relay.Clause(relay.PatternWildcard(), v)
-    ])
+    new_match = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternWildcard(),
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                        ),
+                    ],
+                ),
+                v,
+            ),
+            relay.Clause(relay.PatternWildcard(), v),
+        ],
+    )
     assert len(unmatched_cases(new_match, mod)) == 0
 
 
@@ -122,29 +147,47 @@ def test_multiple_constructor_clauses():
     mod = tvm.IRModule()
     p = Prelude(mod)
 
-    v = relay.Var('v')
-    match = relay.Match(v, [
-        # list of length exactly 1
-        relay.Clause(
-            relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                              relay.PatternConstructor(p.nil, [])]), v),
-        # list of length exactly 2
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternWildcard(),
-                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                                           relay.PatternConstructor(p.nil, [])
-                         ])]), v),
-        # empty list
-        relay.Clause(
-            relay.PatternConstructor(p.nil, []), v),
-        # list of length 2 or more
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternWildcard(),
-                         relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                                           relay.PatternWildcard()])]), v)
-    ])
+    v = relay.Var("v")
+    match = relay.Match(
+        v,
+        [
+            # list of length exactly 1
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]
+                ),
+                v,
+            ),
+            # list of length exactly 2
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternWildcard(),
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]
+                        ),
+                    ],
+                ),
+                v,
+            ),
+            # empty list
+            relay.Clause(relay.PatternConstructor(p.nil, []), v),
+            # list of length 2 or more
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternWildcard(),
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                        ),
+                    ],
+                ),
+                v,
+            ),
+        ],
+    )
     assert len(unmatched_cases(match, mod)) == 0
 
 
@@ -152,28 +195,40 @@ def test_missing_in_the_middle():
     mod = tvm.IRModule()
     p = Prelude(mod)
 
-    v = relay.Var('v')
-    match = relay.Match(v, [
-        # list of length exactly 1
-        relay.Clause(
-            relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                              relay.PatternConstructor(p.nil, [])]), v),
-        # empty list
-        relay.Clause(
-            relay.PatternConstructor(p.nil, []), v),
-        # list of length 3 or more
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternWildcard(),
-                         relay.PatternConstructor(
-                             p.cons,
-                             [relay.PatternWildcard(),
-                              relay.PatternConstructor(
-                                  p.cons,
-                                  [relay.PatternWildcard(),
-                                   relay.PatternWildcard()])])]),
-            v)
-    ])
+    v = relay.Var("v")
+    match = relay.Match(
+        v,
+        [
+            # list of length exactly 1
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]
+                ),
+                v,
+            ),
+            # empty list
+            relay.Clause(relay.PatternConstructor(p.nil, []), v),
+            # list of length 3 or more
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternWildcard(),
+                        relay.PatternConstructor(
+                            p.cons,
+                            [
+                                relay.PatternWildcard(),
+                                relay.PatternConstructor(
+                                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                                ),
+                            ],
+                        ),
+                    ],
+                ),
+                v,
+            ),
+        ],
+    )
 
     # fails to match a list of length exactly two
     unmatched = unmatched_cases(match, mod)
@@ -188,22 +243,31 @@ def test_missing_in_the_middle():
 
 def test_mixed_adt_constructors():
     mod = tvm.IRModule()
-    box = relay.GlobalTypeVar('box')
-    a = relay.TypeVar('a')
-    box_ctor = relay.Constructor('box', [a], box)
+    box = relay.GlobalTypeVar("box")
+    a = relay.TypeVar("a")
+    box_ctor = relay.Constructor("box", [a], box)
     box_data = relay.TypeData(box, [a], [box_ctor])
     mod[box] = box_data
 
     p = Prelude(mod)
 
-    v = relay.Var('v')
-    box_of_lists_inc = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                box_ctor,
-                [relay.PatternConstructor(p.cons, [
-                    relay.PatternWildcard(), relay.PatternWildcard()])]), v)
-    ])
+    v = relay.Var("v")
+    box_of_lists_inc = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(
+                    box_ctor,
+                    [
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                        )
+                    ],
+                ),
+                v,
+            )
+        ],
+    )
 
     # will fail to match a box containing an empty list
     unmatched = unmatched_cases(box_of_lists_inc, mod)
@@ -212,23 +276,42 @@ def test_mixed_adt_constructors():
     assert unmatched[0].constructor == box_ctor
     assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil
 
-    box_of_lists_comp = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                box_ctor, [relay.PatternConstructor(p.nil, [])]), v),
-        relay.Clause(
-            relay.PatternConstructor(
-                box_ctor, [relay.PatternConstructor(p.cons, [
-                    relay.PatternWildcard(), relay.PatternWildcard()])]), v)
-    ])
+    box_of_lists_comp = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(box_ctor, [relay.PatternConstructor(p.nil, [])]), v
+            ),
+            relay.Clause(
+                relay.PatternConstructor(
+                    box_ctor,
+                    [
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                        )
+                    ],
+                ),
+                v,
+            ),
+        ],
+    )
     assert len(unmatched_cases(box_of_lists_comp, mod)) == 0
 
-    list_of_boxes_inc = relay.Match(v, [
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                         relay.PatternWildcard()]), v)
-    ])
+    list_of_boxes_inc = relay.Match(
+        v,
+        [
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                        relay.PatternWildcard(),
+                    ],
+                ),
+                v,
+            )
+        ],
+    )
 
     # fails to match empty list of boxes
     unmatched = unmatched_cases(list_of_boxes_inc, mod)
@@ -236,36 +319,73 @@ def test_mixed_adt_constructors():
     assert isinstance(unmatched[0], relay.PatternConstructor)
     assert unmatched[0].constructor == p.nil
 
-    list_of_boxes_comp = relay.Match(v, [
-        # exactly one box
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                         relay.PatternConstructor(p.nil, [])]), v),
-        # exactly two boxes
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                         relay.PatternConstructor(p.cons, [
-                             relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                             relay.PatternConstructor(p.nil, [])
-                         ])]), v),
-        # exactly three boxes
-        relay.Clause(
-            relay.PatternConstructor(
-                p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                         relay.PatternConstructor(p.cons, [
-                             relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                             relay.PatternConstructor(p.cons, [
-                                 relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
-                                 relay.PatternConstructor(p.nil, [])
-                             ])])]), v),
-        # one or more boxes
-        relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
-                                                       relay.PatternWildcard()]), v),
-        # no boxes
-        relay.Clause(relay.PatternConstructor(p.nil, []), v)
-    ])
+    list_of_boxes_comp = relay.Match(
+        v,
+        [
+            # exactly one box
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                        relay.PatternConstructor(p.nil, []),
+                    ],
+                ),
+                v,
+            ),
+            # exactly two boxes
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                        relay.PatternConstructor(
+                            p.cons,
+                            [
+                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                                relay.PatternConstructor(p.nil, []),
+                            ],
+                        ),
+                    ],
+                ),
+                v,
+            ),
+            # exactly three boxes
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons,
+                    [
+                        relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                        relay.PatternConstructor(
+                            p.cons,
+                            [
+                                relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
+                                relay.PatternConstructor(
+                                    p.cons,
+                                    [
+                                        relay.PatternConstructor(
+                                            box_ctor, [relay.PatternWildcard()]
+                                        ),
+                                        relay.PatternConstructor(p.nil, []),
+                                    ],
+                                ),
+                            ],
+                        ),
+                    ],
+                ),
+                v,
+            ),
+            # one or more boxes
+            relay.Clause(
+                relay.PatternConstructor(
+                    p.cons, [relay.PatternWildcard(), relay.PatternWildcard()]
+                ),
+                v,
+            ),
+            # no boxes
+            relay.Clause(relay.PatternConstructor(p.nil, []), v),
+        ],
+    )
     assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
 
 
@@ -297,5 +417,6 @@ def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
     tvm.parser.fromtext(code)
     # fromtext parse the module, then checked it (which include strictness checking).
 
+
 if __name__ == "__main__":
     pytest.main([__file__])
index 1aad74b..a5a0e50 100644 (file)
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.analysis import (free_vars, free_type_vars,
-                                bound_vars, bound_type_vars,
-                                all_vars, all_type_vars)
+from tvm.relay.analysis import (
+    free_vars,
+    free_type_vars,
+    bound_vars,
+    bound_type_vars,
+    all_vars,
+    all_type_vars,
+)
+
 
 def assert_vars_match(actual, expected):
     assert len(actual) == len(expected)
@@ -43,7 +49,7 @@ def test_free_vars():
 
 
 def test_free_vars_tuple():
-    t = relay.Var('t')
+    t = relay.Var("t")
     fv = free_vars(relay.Tuple([t, t]))
     assert len(fv) == 1
     assert fv[0] == t
@@ -86,26 +92,30 @@ def test_match_vars():
     mod = tvm.IRModule()
     p = relay.prelude.Prelude(mod)
 
-    x = relay.Var('x')
-    y = relay.Var('y')
-    z = relay.Var('z')
-
-    match1 = relay.Match(p.nil(), [
-        relay.Clause(relay.PatternConstructor(p.nil), z),
-        relay.Clause(relay.PatternConstructor(p.cons,
-                                              [relay.PatternVar(x),
-                                               relay.PatternVar(y)]),
-                     p.cons(x, y))
-    ])
-
-    match2 = relay.Match(p.nil(), [
-        relay.Clause(relay.PatternConstructor(p.cons, [
-            relay.PatternWildcard(),
-            relay.PatternVar(x)
-        ]),
-                     y),
-        relay.Clause(relay.PatternWildcard(), z)
-    ])
+    x = relay.Var("x")
+    y = relay.Var("y")
+    z = relay.Var("z")
+
+    match1 = relay.Match(
+        p.nil(),
+        [
+            relay.Clause(relay.PatternConstructor(p.nil), z),
+            relay.Clause(
+                relay.PatternConstructor(p.cons, [relay.PatternVar(x), relay.PatternVar(y)]),
+                p.cons(x, y),
+            ),
+        ],
+    )
+
+    match2 = relay.Match(
+        p.nil(),
+        [
+            relay.Clause(
+                relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternVar(x)]), y
+            ),
+            relay.Clause(relay.PatternWildcard(), z),
+        ],
+    )
 
     assert_vars_match(bound_vars(match1), [x, y])
     assert_vars_match(free_vars(match1), [z])
index f6b1b24..4eefa71 100644 (file)
@@ -28,15 +28,15 @@ from tvm.relay.backend.interpreter import RefValue, ConstructorValue
 def seq(*exprs):
     ret = exprs[0]
     for expr in exprs[1:]:
-        ret = relay.Let(relay.var('_'), ret, expr)
+        ret = relay.Let(relay.var("_"), ret, expr)
     return ret
 
 
 # creates a dummy ADT for testing
 def init_box_adt(mod):
-    box = relay.GlobalTypeVar('box')
-    a = relay.TypeVar('a')
-    box_ctor = relay.Constructor('box', [a], box)
+    box = relay.GlobalTypeVar("box")
+    a = relay.TypeVar("a")
+    box_ctor = relay.Constructor("box", [a], box)
     mod[box] = relay.TypeData(box, [a], [box_ctor])
     return (box, box_ctor)
 
@@ -81,13 +81,9 @@ def test_create_tensor():
 
 
 def test_create_nested_tuple():
-    relay_tup = relay.Tuple([
-        relay.const(1), relay.const(2),
-        relay.Tuple([
-            relay.const(3),
-            relay.const(4)
-        ])
-    ])
+    relay_tup = relay.Tuple(
+        [relay.const(1), relay.const(2), relay.Tuple([relay.const(3), relay.const(4)])]
+    )
     tup_val = run_as_python(relay_tup)
     assert_adt_len(tup_val, 3)
     for i in range(2):
@@ -98,13 +94,9 @@ def test_create_nested_tuple():
 
 
 def test_tuple_get_item():
-    relay_tup = relay.Tuple([
-        relay.const(1), relay.const(2),
-        relay.Tuple([
-            relay.const(3),
-            relay.const(4)
-        ])
-    ])
+    relay_tup = relay.Tuple(
+        [relay.const(1), relay.const(2), relay.Tuple([relay.const(3), relay.const(4)])]
+    )
     for i in range(2):
         index = relay.TupleGetItem(relay_tup, i)
         val = run_as_python(index)
@@ -117,7 +109,7 @@ def test_tuple_get_item():
 
 
 def test_create_let():
-    v = relay.Var('v')
+    v = relay.Var("v")
     let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v]))
     tup_val = run_as_python(let)
     assert_adt_len(tup_val, 2)
@@ -133,7 +125,7 @@ def test_create_ref():
 
 
 def test_ref_read():
-    v = relay.Var('v')
+    v = relay.Var("v")
     assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v))
     read_val = run_as_python(assign)
     assert_adt_len(read_val, 0)
@@ -141,21 +133,30 @@ def test_ref_read():
 
 def test_ref_write():
     # check that the result of a ref write is an empty tuple
-    v = relay.Var('v')
-    initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])),
-                              relay.RefWrite(v, relay.Tuple([relay.const(2)])))
+    v = relay.Var("v")
+    initial_write = relay.Let(
+        v,
+        relay.RefCreate(relay.Tuple([relay.const(1)])),
+        relay.RefWrite(v, relay.Tuple([relay.const(2)])),
+    )
     write_val = run_as_python(initial_write)
     assert_adt_len(write_val, 0)
 
     # now ensure that the value, once written, can be read back
     # (we read the value before and after mutation)
-    w = relay.Var('w')
+    w = relay.Var("w")
     read_after_write = relay.Let(
-        v, relay.RefCreate(relay.Tuple([relay.const(1)])),
+        v,
+        relay.RefCreate(relay.Tuple([relay.const(1)])),
         relay.Let(
-            w, relay.RefCreate(relay.RefRead(v)),
-            seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])),
-                relay.Tuple([relay.RefRead(w), relay.RefRead(v)]))))
+            w,
+            relay.RefCreate(relay.RefRead(v)),
+            seq(
+                relay.RefWrite(v, relay.Tuple([relay.const(2)])),
+                relay.Tuple([relay.RefRead(w), relay.RefRead(v)]),
+            ),
+        ),
+    )
     read_val = run_as_python(read_after_write)
     assert_adt_len(read_val, 2)
     assert_adt_len(read_val[0], 1)
@@ -169,14 +170,16 @@ def test_if():
     true_cond = relay.const(True)
     false_cond = relay.const(False)
 
-    v  = relay.Var('v')
+    v = relay.Var("v")
     true_branch = seq(relay.RefWrite(v, relay.const(1)), relay.RefRead(v))
     false_branch = seq(relay.RefWrite(v, relay.const(2)), relay.RefRead(v))
 
-    true_expr = relay.Let(v, relay.RefCreate(relay.const(0)),
-                          relay.If(true_cond, true_branch, false_branch))
-    false_expr = relay.Let(v, relay.RefCreate(relay.const(0)),
-                           relay.If(false_cond, true_branch, false_branch))
+    true_expr = relay.Let(
+        v, relay.RefCreate(relay.const(0)), relay.If(true_cond, true_branch, false_branch)
+    )
+    false_expr = relay.Let(
+        v, relay.RefCreate(relay.const(0)), relay.If(false_cond, true_branch, false_branch)
+    )
 
     true_val = run_as_python(true_expr)
     assert_tensor_value(true_val, 1)
@@ -186,9 +189,9 @@ def test_if():
 
 
 def test_local_function():
-    v = relay.Var('v')
+    v = relay.Var("v")
     ident = relay.Function([v], v)
-    f = relay.Var('f')
+    f = relay.Var("f")
     call1 = relay.Let(f, ident, f(relay.Tuple([])))
     call2 = relay.Let(f, ident, f(relay.const(2)))
 
@@ -201,9 +204,9 @@ def test_local_function():
 
 def test_global_function():
     mod = tvm.IRModule()
-    ident = relay.GlobalVar('ident')
-    a = relay.TypeVar('a')
-    v = relay.Var('v', a)
+    ident = relay.GlobalVar("ident")
+    a = relay.TypeVar("a")
+    v = relay.Var("v", a)
     mod[ident] = relay.Function([v], v, a, [a])
 
     call1 = ident(relay.const(1))
@@ -238,12 +241,12 @@ def test_constructor():
 def test_match_wildcard():
     mod = tvm.IRModule()
     box, box_ctor = init_box_adt(mod)
-    v = relay.Var('v')
+    v = relay.Var("v")
     match = relay.Let(
-        v, box_ctor(relay.Tuple([])),
-        relay.Match(v, [
-            relay.Clause(relay.PatternWildcard(), relay.const(1))
-        ]))
+        v,
+        box_ctor(relay.Tuple([])),
+        relay.Match(v, [relay.Clause(relay.PatternWildcard(), relay.const(1))]),
+    )
 
     match_val = run_as_python(match, mod)
     assert_tensor_value(match_val, 1)
@@ -252,13 +255,11 @@ def test_match_wildcard():
 def test_match_var():
     mod = tvm.IRModule()
     box, box_ctor = init_box_adt(mod)
-    v = relay.Var('v')
-    w = relay.Var('w')
+    v = relay.Var("v")
+    w = relay.Var("w")
     match = relay.Let(
-        v, box_ctor(relay.const(1)),
-        relay.Match(v, [
-            relay.Clause(relay.PatternVar(w), w)
-        ]))
+        v, box_ctor(relay.const(1)), relay.Match(v, [relay.Clause(relay.PatternVar(w), w)])
+    )
 
     match_val = run_as_python(match, mod)
     assert_constructor_value(match_val, box_ctor, 1)
@@ -268,13 +269,15 @@ def test_match_var():
 def test_match_pattern():
     mod = tvm.IRModule()
     box, box_ctor = init_box_adt(mod)
-    v = relay.Var('v')
-    w = relay.Var('w')
+    v = relay.Var("v")
+    w = relay.Var("w")
     match = relay.Let(
-        v, box_ctor(relay.const(1)),
-        relay.Match(v, [
-            relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]), w)
-        ]))
+        v,
+        box_ctor(relay.const(1)),
+        relay.Match(
+            v, [relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]), w)]
+        ),
+    )
     match_val = run_as_python(match, mod)
     assert_tensor_value(match_val, 1)
 
@@ -282,36 +285,49 @@ def test_match_pattern():
 def test_nested_match_pattern():
     mod = tvm.IRModule()
     box, box_ctor = init_box_adt(mod)
-    v = relay.Var('v')
-    w = relay.Var('w')
+    v = relay.Var("v")
+    w = relay.Var("w")
     match = relay.Let(
-        v, box_ctor(box_ctor(relay.const(2))),
-        relay.Match(v, [
-            relay.Clause(
-                relay.PatternConstructor(
-                    box_ctor, [
-                        relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])
-                    ]),
-                w)]))
+        v,
+        box_ctor(box_ctor(relay.const(2))),
+        relay.Match(
+            v,
+            [
+                relay.Clause(
+                    relay.PatternConstructor(
+                        box_ctor, [relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])]
+                    ),
+                    w,
+                )
+            ],
+        ),
+    )
     match_val = run_as_python(match, mod)
     assert_tensor_value(match_val, 2)
 
+
 def test_match_order():
     mod = tvm.IRModule()
     box, box_ctor = init_box_adt(mod)
-    v = relay.Var('v')
-    w = relay.Var('w')
+    v = relay.Var("v")
+    w = relay.Var("w")
     # wildcard pattern goes first
     match = relay.Let(
-        v, box_ctor(box_ctor(relay.const(2))),
-        relay.Match(v, [
-            relay.Clause(relay.PatternWildcard(), relay.const(1)),
-            relay.Clause(
-                relay.PatternConstructor(
-                    box_ctor, [
-                        relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])
-                    ]),
-                w)]))
+        v,
+        box_ctor(box_ctor(relay.const(2))),
+        relay.Match(
+            v,
+            [
+                relay.Clause(relay.PatternWildcard(), relay.const(1)),
+                relay.Clause(
+                    relay.PatternConstructor(
+                        box_ctor, [relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])]
+                    ),
+                    w,
+                ),
+            ],
+        ),
+    )
     match_val = run_as_python(match, mod)
     assert_tensor_value(match_val, 1)
 
@@ -320,21 +336,31 @@ def test_local_recursion():
     mod = tvm.IRModule()
     p = Prelude(mod)
 
-    v = relay.Var('v')
-    h = relay.Var('h')
-    t = relay.Var('t')
-    f = relay.Var('f')
+    v = relay.Var("v")
+    h = relay.Var("h")
+    t = relay.Var("t")
+    f = relay.Var("f")
 
     # just returns the same list
-    let = relay.Let(f, relay.Function([v], relay.Match(v, [
-        relay.Clause(relay.PatternConstructor(p.cons,
-                                              [relay.PatternVar(h), relay.PatternVar(t)]),
-                     p.cons(h, f(t))),
-        relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
-    ])),
-                    f(p.cons(relay.const(1),
-                             p.cons(relay.const(2),
-                                    p.cons(relay.const(3), p.nil())))))
+    let = relay.Let(
+        f,
+        relay.Function(
+            [v],
+            relay.Match(
+                v,
+                [
+                    relay.Clause(
+                        relay.PatternConstructor(
+                            p.cons, [relay.PatternVar(h), relay.PatternVar(t)]
+                        ),
+                        p.cons(h, f(t)),
+                    ),
+                    relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()),
+                ],
+            ),
+        ),
+        f(p.cons(relay.const(1), p.cons(relay.const(2), p.cons(relay.const(3), p.nil())))),
+    )
 
     val = run_as_python(let, mod)
     assert_constructor_value(val, p.cons, 2)
@@ -349,18 +375,27 @@ def test_local_recursion():
 def test_global_recursion():
     mod = tvm.IRModule()
     p = Prelude(mod)
-    copy = relay.GlobalVar('copy')
+    copy = relay.GlobalVar("copy")
     # same as above: it copies the given list
-    a = relay.TypeVar('a')
-    v = relay.Var('v', p.l(a))
-    h = relay.Var('h')
-    t = relay.Var('t')
-    copy_def = relay.Function([v], relay.Match(v, [
-        relay.Clause(relay.PatternConstructor(p.cons,
-                                              [relay.PatternVar(h), relay.PatternVar(t)]),
-                     p.cons(h, copy(t))),
-        relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
-    ]), p.l(a), [a])
+    a = relay.TypeVar("a")
+    v = relay.Var("v", p.l(a))
+    h = relay.Var("h")
+    t = relay.Var("t")
+    copy_def = relay.Function(
+        [v],
+        relay.Match(
+            v,
+            [
+                relay.Clause(
+                    relay.PatternConstructor(p.cons, [relay.PatternVar(h), relay.PatternVar(t)]),
+                    p.cons(h, copy(t)),
+                ),
+                relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()),
+            ],
+        ),
+        p.l(a),
+        [a],
+    )
     mod[copy] = copy_def
 
     call1 = copy_def(p.cons(relay.const(1), p.cons(relay.const(2), p.nil())))
@@ -380,20 +415,23 @@ def test_global_recursion():
 
 def test_higher_order_call():
     # test with anon func
-    h = relay.Var('h')
-    f = relay.Var('f')
-    x = relay.Var('x')
-    ho_anon = relay.Let(h, relay.Function([f], f(relay.Tuple([]))),
-                        h(relay.Function([x], relay.const(1))))
+    h = relay.Var("h")
+    f = relay.Var("f")
+    x = relay.Var("x")
+    ho_anon = relay.Let(
+        h, relay.Function([f], f(relay.Tuple([]))), h(relay.Function([x], relay.const(1)))
+    )
 
     anon_val = run_as_python(ho_anon)
     assert_tensor_value(anon_val, 1)
 
     # test with named func
-    g = relay.Var('g')
-    ho_named = relay.Let(h, relay.Function([f], f(relay.Tuple([]))),
-                         relay.Let(g, relay.Function([x], relay.const(2)),
-                           h(g)))
+    g = relay.Var("g")
+    ho_named = relay.Let(
+        h,
+        relay.Function([f], f(relay.Tuple([]))),
+        relay.Let(g, relay.Function([x], relay.const(2)), h(g)),
+    )
     named_val = run_as_python(ho_named)
     assert_tensor_value(named_val, 2)
 
@@ -404,19 +442,25 @@ def test_match_effect_exactly_once():
 
     # the list should be of length 1!
     # Unless we mistakenly execute the data clause more than once
-    r = relay.Var('r')
+    r = relay.Var("r")
     data = seq(relay.RefWrite(r, p.cons(relay.Tuple([]), relay.RefRead(r))), relay.RefRead(r))
     match = relay.Let(
-        r, relay.RefCreate(p.nil()),
-        relay.Match(data, [
-            relay.Clause(relay.PatternConstructor(p.nil, []), relay.const(0)),
-            relay.Clause(
-                relay.PatternConstructor(
-                    p.cons,
-                    [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]),
-                relay.const(1)),
-            relay.Clause(relay.PatternWildcard(), relay.const(2))
-        ]))
+        r,
+        relay.RefCreate(p.nil()),
+        relay.Match(
+            data,
+            [
+                relay.Clause(relay.PatternConstructor(p.nil, []), relay.const(0)),
+                relay.Clause(
+                    relay.PatternConstructor(
+                        p.cons, [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]
+                    ),
+                    relay.const(1),
+                ),
+                relay.Clause(relay.PatternWildcard(), relay.const(2)),
+            ],
+        ),
+    )
 
     match_val = run_as_python(match, mod)
     assert_tensor_value(match_val, 1)
@@ -426,17 +470,21 @@ def test_arbitrary_let_nesting():
     # something that is tricky to do in Python but comes naturally in Relay
     mod = tvm.IRModule()
     p = Prelude(mod)
-    x = relay.Var('x')
-    r = relay.Var('r')
-    y = relay.Var('y')
-    z = relay.Var('z')
-    expr = relay.Tuple([
-        relay.Let(x, relay.Tuple([relay.const(1), relay.const(2)]),
-                  relay.TupleGetItem(x, 1)),
-        relay.Let(r, relay.RefCreate(relay.const(1)),
-                  seq(relay.RefWrite(r, relay.const(3)), relay.RefRead(r))),
-        relay.Let(y, p.id(relay.Let(z, relay.const(4), z)), y)
-    ])
+    x = relay.Var("x")
+    r = relay.Var("r")
+    y = relay.Var("y")
+    z = relay.Var("z")
+    expr = relay.Tuple(
+        [
+            relay.Let(x, relay.Tuple([relay.const(1), relay.const(2)]), relay.TupleGetItem(x, 1)),
+            relay.Let(
+                r,
+                relay.RefCreate(relay.const(1)),
+                seq(relay.RefWrite(r, relay.const(3)), relay.RefRead(r)),
+            ),
+            relay.Let(y, p.id(relay.Let(z, relay.const(4), z)), y),
+        ]
+    )
 
     tup_val = run_as_python(expr, mod)
     assert_adt_len(tup_val, 3)
@@ -447,34 +495,39 @@ def test_arbitrary_let_nesting():
 
 def test_ref_execution_order():
     # we want to have effects execute from left to right
-    x = relay.Var('x')
-    y = relay.Var('y')
-    f = relay.Var('f')
-    r = relay.Var('r')
-
-    expr = relay.Let(f, relay.Function([x, y], x),
-                     # r = 1
-                     relay.Let(r, relay.RefCreate(relay.const(1)),
-                               relay.Tuple([
-                                   # should be 1
-                                   relay.RefRead(r),
-                                   # set r to 2 and read back
-                                   seq(relay.RefWrite(r, relay.const(2)),
-                                       relay.RefRead(r)),
-                                   # set r to 3 and read back
-                                   seq(relay.RefWrite(r, relay.const(3)),
-                                       relay.RefRead(r)),
-                                   # set r to 4 and read as first arg to f
-                                   # set r to 5 and read as second arg to f
-                                   # f should evaluate to 4
-                                   f(
-                                       seq(relay.RefWrite(r, relay.const(4)),
-                                           relay.RefRead(r)),
-                                       seq(relay.RefWrite(r, relay.const(5)),
-                                           relay.RefRead(r))),
-                                   # read back 5
-                                   relay.RefRead(r)
-                  ])))
+    x = relay.Var("x")
+    y = relay.Var("y")
+    f = relay.Var("f")
+    r = relay.Var("r")
+
+    expr = relay.Let(
+        f,
+        relay.Function([x, y], x),
+        # r = 1
+        relay.Let(
+            r,
+            relay.RefCreate(relay.const(1)),
+            relay.Tuple(
+                [
+                    # should be 1
+                    relay.RefRead(r),
+                    # set r to 2 and read back
+                    seq(relay.RefWrite(r, relay.const(2)), relay.RefRead(r)),
+                    # set r to 3 and read back
+                    seq(relay.RefWrite(r, relay.const(3)), relay.RefRead(r)),
+                    # set r to 4 and read as first arg to f
+                    # set r to 5 and read as second arg to f
+                    # f should evaluate to 4
+                    f(
+                        seq(relay.RefWrite(r, relay.const(4)), relay.RefRead(r)),
+                        seq(relay.RefWrite(r, relay.const(5)), relay.RefRead(r)),
+                    ),
+                    # read back 5
+                    relay.RefRead(r),
+                ]
+            ),
+        ),
+    )
 
     tup_val = run_as_python(expr)
     assert_adt_len(tup_val, 5)
@@ -495,7 +548,7 @@ def test_op_add():
 # adapted from test_stack in test_op_level3
 def test_op_stack():
     def verify_stack(dshapes, axis):
-        x_data = [np.random.normal(size=shape).astype('int32') for shape in dshapes]
+        x_data = [np.random.normal(size=shape).astype("int32") for shape in dshapes]
         ref_res = np.stack(x_data, axis=axis)
 
         args = []
@@ -516,7 +569,7 @@ def test_op_stack():
 # adapted from test_split_infer_type in test_op_level3
 def test_split():
     def verify_split(shape, indices_or_sections, axis=0):
-        x = np.random.normal(size=shape).astype('float32')
+        x = np.random.normal(size=shape).astype("float32")
         ref_res = np.split(x, indices_or_sections, axis=axis)
         call = relay.split(relay.const(x), indices_or_sections, axis=axis)
         call_val = run_as_python(call)
@@ -534,13 +587,14 @@ def test_split():
 # ensure we can generate code for batch_norm, since it requires simplify_inference
 def test_batch_norm():
     def verify_batch_norm(shapes):
-        data = [np.absolute(np.random.normal(size=shape).astype('float32'))
-                for shape in shapes]
+        data = [np.absolute(np.random.normal(size=shape).astype("float32")) for shape in shapes]
         relay_args = [relay.const(arg) for arg in data]
 
         eps = 1e-5
+
         def reference(x, gamma, beta, moving_mean, moving_var):
             return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta
+
         ref_res = reference(*data)
 
         call = relay.nn.batch_norm(*relay_args, epsilon=eps)[0]
index e29038c..26fc356 100644 (file)
@@ -26,16 +26,18 @@ from tvm.ir import IRModule
 from tvm import relay
 from tvm.relay.data_dep_optimization import simplify_fc_transpose
 
+
 def run_func(func, params, x):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, new_params = relay.build(func, "llvm", params=params)
 
     from tvm.contrib import graph_runtime
+
     ctx = tvm.cpu(0)
-    dtype = 'float32'
+    dtype = "float32"
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
-    m.set_input('data', tvm.nd.array(x.astype(dtype)))
+    m.set_input("data", tvm.nd.array(x.astype(dtype)))
     m.set_input(**new_params)
     # execute
     m.run()
@@ -43,6 +45,7 @@ def run_func(func, params, x):
     tvm_output = m.get_output(0)
     return tvm_output.asnumpy()
 
+
 def test_simplify_fc_transpose():
     data = relay.var("data", shape=(1, 32), dtype="float32")
     x = relay.nn.relu(data)
@@ -54,7 +57,7 @@ def test_simplify_fc_transpose():
     func = relay.Function(relay.analysis.free_vars(zz), zz)
     params = {
         "w1": tvm.nd.array(np.random.uniform(-1, 1, (32, 64)).astype("float32")),
-        "w2": tvm.nd.array(np.random.uniform(-1, 1, (64, 16)).astype("float32"))
+        "w2": tvm.nd.array(np.random.uniform(-1, 1, (64, 16)).astype("float32")),
     }
     x_np = np.random.randn(1, 32).astype("float32")
     old_result = run_func(func, params, x_np)
@@ -63,5 +66,6 @@ def test_simplify_fc_transpose():
     new_result = run_func(new_func, new_params, x_np)
     np.testing.assert_allclose(old_result, new_result, atol=1e-5, rtol=1e-5)
 
+
 if __name__ == "__main__":
     test_simplify_fc_transpose()
index e0204ae..e3644e9 100644 (file)
@@ -34,27 +34,31 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"):
     num_blocks = int(nnz / (BS_R * BS_C)) + 1
     candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
     assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
-    chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
+    chosen_blocks = candidate_blocks[
+        np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
+    ]
     for i in range(len(chosen_blocks)):
         r, c = chosen_blocks[i]
-        Y[r:r+BS_R,c:c+BS_C] = np.random.randn(BS_R, BS_C)
+        Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C)
     s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
     assert s.data.shape == (num_blocks, BS_R, BS_C)
     assert s.data.size >= nnz
-    assert s.indices.shape == (num_blocks, )
-    assert s.indptr.shape == (M // BS_R + 1, )
+    assert s.indices.shape == (num_blocks,)
+    assert s.indptr.shape == (M // BS_R + 1,)
     return s
 
+
 def run_func(func, params, x):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, new_params = relay.build(func, "llvm", params=params)
 
     from tvm.contrib import graph_runtime
+
     ctx = tvm.cpu(0)
-    dtype = 'float32'
+    dtype = "float32"
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
-    m.set_input('data', tvm.nd.array(x.astype(dtype)))
+    m.set_input("data", tvm.nd.array(x.astype(dtype)))
     m.set_input(**new_params)
     # execute
     m.run()
@@ -62,6 +66,7 @@ def run_func(func, params, x):
     tvm_output = m.get_output(0)
     return tvm_output.asnumpy()
 
+
 def test_bsr_sparse_dense():
     data = relay.var("data", shape=(1, 128), dtype="float32")
     x = relay.nn.relu(data)
@@ -70,9 +75,7 @@ def test_bsr_sparse_dense():
     z = relay.nn.relu(y)
     func = relay.Function(relay.analysis.free_vars(z), z)
 
-    params = {
-        "weight": tvm.nd.array(random_bsr_matrix(768, 128, 32, 1, 0.1).todense())
-    }
+    params = {"weight": tvm.nd.array(random_bsr_matrix(768, 128, 32, 1, 0.1).todense())}
 
     x_np = np.random.randn(1, 128).astype("float32")
     # dense output
@@ -82,5 +85,6 @@ def test_bsr_sparse_dense():
     sparse_output = run_func(sparse_func, params, x_np)
     np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5)
 
+
 if __name__ == "__main__":
     test_bsr_sparse_dense()
index b90a688..8370b2a 100644 (file)
@@ -18,10 +18,20 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
-from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
-                 TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
+from tvm.relay.ty import (
+    TypeVar,
+    IncompleteType,
+    TensorType,
+    FuncType,
+    TupleType,
+    TypeRelation,
+    RefType,
+    GlobalTypeVar,
+    TypeCall,
+)
 from tvm.relay.adt import TypeData
 
+
 def check_visit(typ):
     try:
         ef = TypeFunctor()
@@ -33,12 +43,11 @@ def check_visit(typ):
     ev = TypeVisitor()
     ev.visit(typ)
 
-    tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ,
-                                   map_free_vars=True)
+    tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ, map_free_vars=True)
 
 
 def test_type_var():
-    tv = TypeVar('a')
+    tv = TypeVar("a")
     check_visit(tv)
 
 
@@ -53,8 +62,8 @@ def test_tensor_type():
 
 
 def test_func_type():
-    tv = TypeVar('tv')
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
+    tv = TypeVar("tv")
+    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
     ft = FuncType([tt], tt, type_params=[tv])
     check_visit(ft)
 
@@ -65,11 +74,11 @@ def test_tuple_type():
 
 
 def test_type_relation():
-    func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast')
-    attrs = tvm.ir.make_node('attrs.TestAttrs', name='attr', padding=(3,4))
-    tp = TypeVar('tp')
+    func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
+    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
+    tp = TypeVar("tp")
     tf = FuncType([], TupleType([]), [], [])
-    tt = TensorType([1, 2, 3], 'float32')
+    tt = TensorType([1, 2, 3], "float32")
     tr = TypeRelation(func, [tp, tf, tt], 2, attrs)
 
     check_visit(tr)
@@ -81,17 +90,17 @@ def test_ref_type():
 
 
 def test_global_type_var():
-    gtv = GlobalTypeVar('gtv')
+    gtv = GlobalTypeVar("gtv")
     check_visit(gtv)
 
 
 def test_type_call():
-    tc = TypeCall(GlobalTypeVar('tf'), [TupleType([])])
+    tc = TypeCall(GlobalTypeVar("tf"), [TupleType([])])
     check_visit(tc)
 
 
 def test_type_data():
-    td = TypeData(GlobalTypeVar('td'), [TypeVar('tv')], [])
+    td = TypeData(GlobalTypeVar("td"), [TypeVar("tv")], [])
     check_visit(td)
 
 
index 70e0c3f..455c8ce 100644 (file)
@@ -24,6 +24,7 @@ from tvm import relay
 from tvm.relay import op, transform, analysis
 from tvm.relay import Any
 
+
 def run_infer_type(expr, mod=None):
     if not mod:
         mod = tvm.IRModule.from_expr(expr)
@@ -50,15 +51,14 @@ def assert_has_type(expr, typ, mod=tvm.IRModule({})):
     checked_expr = run_infer_type(expr, mod)
     checked_type = checked_expr.checked_type
     if checked_type != typ:
-        raise RuntimeError("Type mismatch %s vs %s" % (
-            checked_type, typ))
+        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
 
 
 # initializes simple ADT for tests
 def initialize_box_adt(mod):
-    box = relay.GlobalTypeVar('box')
-    tv = relay.TypeVar('tv')
-    constructor = relay.Constructor('constructor', [tv], box)
+    box = relay.GlobalTypeVar("box")
+    tv = relay.TypeVar("tv")
+    constructor = relay.Constructor("constructor", [tv], box)
     data = relay.TypeData(box, [tv], [constructor])
     mod[box] = data
     return (box, constructor)
@@ -67,17 +67,17 @@ def initialize_box_adt(mod):
 def test_monomorphic_let():
     "Program: let %x = 1; %x"
     sb = relay.ScopeBuilder()
-    x = sb.let('x', relay.const(1.0, "float64"))
+    x = sb.let("x", relay.const(1.0, "float64"))
     sb.ret(x)
     xchecked = run_infer_type(sb.get())
-    assert xchecked.checked_type == relay.scalar_type("float64" )
+    assert xchecked.checked_type == relay.scalar_type("float64")
 
 
 def test_single_op():
     "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
-    x = relay.var('x', shape=[])
+    x = relay.var("x", shape=[])
     func = relay.Function([x], op.log(x))
-    ttype = relay.TensorType([], dtype='float32')
+    ttype = relay.TensorType([], dtype="float32")
     assert_has_type(func, relay.FuncType([ttype], ttype))
 
 
@@ -89,24 +89,24 @@ def test_add_broadcast_op():
             %x + %y
         }
     """
-    x = relay.var('x', shape=(10, 4))
-    y = relay.var('y', shape=(5, 10, 1))
+    x = relay.var("x", shape=(10, 4))
+    y = relay.var("y", shape=(5, 10, 1))
     z = x + y
     func = relay.Function([x, y], z)
-    t1 = relay.TensorType((10, 4), 'float32')
-    t2 = relay.TensorType((5, 10, 1), 'float32')
-    t3 = relay.TensorType((5, 10, 4), 'float32')
+    t1 = relay.TensorType((10, 4), "float32")
+    t2 = relay.TensorType((5, 10, 1), "float32")
+    t3 = relay.TensorType((5, 10, 4), "float32")
     expected_ty = relay.FuncType([t1, t2], t3)
     assert_has_type(func, expected_ty)
 
 
 def test_dual_op():
     """Program:
-       fn (%x : Tensor[(10, 10), float32]) {
-         let %t1 = log(x);
-         let %t2 = add(%t1, %x);
-         %t1
-       }
+    fn (%x : Tensor[(10, 10), float32]) {
+      let %t1 = log(x);
+      let %t2 = add(%t1, %x);
+      %t1
+    }
     """
     tp = relay.TensorType((10, 10), "float32")
     x = relay.var("x", tp)
@@ -121,9 +121,9 @@ def test_dual_op():
 
 def test_decl():
     """Program:
-       def @f(%x : Tensor[(10, 10), float32]) {
-           log(%x)
-       }
+    def @f(%x : Tensor[(10, 10), float32]) {
+        log(%x)
+    }
     """
     tp = relay.TensorType((10, 10))
     x = relay.var("x", tp)
@@ -161,9 +161,9 @@ def test_recursion():
 
 
 def test_incomplete_call():
-    tt = relay.scalar_type('int32')
-    x = relay.var('x', tt)
-    f = relay.var('f')
+    tt = relay.scalar_type("int32")
+    x = relay.var("x", tt)
+    f = relay.var("f")
     func = relay.Function([x, f], relay.Call(f, [x]), tt)
 
     ft = run_infer_type(func)
@@ -172,30 +172,30 @@ def test_incomplete_call():
 
 
 def test_higher_order_argument():
-    a = relay.TypeVar('a')
-    x = relay.Var('x', a)
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
     id_func = relay.Function([x], x, a, [a])
 
-    b = relay.TypeVar('b')
-    f = relay.Var('f', relay.FuncType([b], b))
-    y = relay.Var('y', b)
+    b = relay.TypeVar("b")
+    f = relay.Var("f", relay.FuncType([b], b))
+    y = relay.Var("y", b)
     ho_func = relay.Function([f, y], f(y), b, [b])
 
     # id func should be an acceptable argument to the higher-order
     # function even though id_func takes a type parameter
-    ho_call = ho_func(id_func, relay.const(0, 'int32'))
+    ho_call = ho_func(id_func, relay.const(0, "int32"))
 
     hc = run_infer_type(ho_call)
-    expected = relay.scalar_type('int32')
+    expected = relay.scalar_type("int32")
     assert hc.checked_type == expected
 
 
 def test_higher_order_return():
-    a = relay.TypeVar('a')
-    x = relay.Var('x', a)
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
     id_func = relay.Function([x], x, a, [a])
 
-    b = relay.TypeVar('b')
+    b = relay.TypeVar("b")
     nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
 
     ft = run_infer_type(nested_id)
@@ -203,20 +203,18 @@ def test_higher_order_return():
 
 
 def test_higher_order_nested():
-    a = relay.TypeVar('a')
-    x = relay.Var('x', a)
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
     id_func = relay.Function([x], x, a, [a])
 
-    choice_t = relay.FuncType([], relay.scalar_type('bool'))
-    f = relay.Var('f', choice_t)
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
 
-    b = relay.TypeVar('b')
-    z = relay.Var('z')
+    b = relay.TypeVar("b")
+    z = relay.Var("z")
     top = relay.Function(
-        [f],
-        relay.If(f(), id_func, relay.Function([z], z)),
-        relay.FuncType([b], b),
-        [b])
+        [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
+    )
 
     expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
     ft = run_infer_type(top)
@@ -227,7 +225,7 @@ def test_tuple():
     tp = relay.TensorType((10,))
     x = relay.var("x", tp)
     res = relay.Tuple([x, x])
-    assert (run_infer_type(res).checked_type == relay.TupleType([tp, tp]))
+    assert run_infer_type(res).checked_type == relay.TupleType([tp, tp])
 
 
 def test_ref():
@@ -271,8 +269,8 @@ def test_type_args():
 def test_global_var_recursion():
     mod = tvm.IRModule({})
     gv = relay.GlobalVar("main")
-    x = relay.var('x', shape=[])
-    tt = relay.scalar_type('float32')
+    x = relay.var("x", shape=[])
+    tt = relay.scalar_type("float32")
 
     func = relay.Function([x], relay.Call(gv, [x]), tt)
     mod[gv] = func
@@ -282,20 +280,22 @@ def test_global_var_recursion():
 
 
 def test_equal():
-    i = relay.var('i', shape=[], dtype='int32')
-    eq = op.equal(i, relay.const(0, dtype='int32'))
+    i = relay.var("i", shape=[], dtype="int32")
+    eq = op.equal(i, relay.const(0, dtype="int32"))
     func = relay.Function([i], eq)
     ft = run_infer_type(func)
 
-    assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
+    assert ft.checked_type == relay.FuncType(
+        [relay.scalar_type("int32")], relay.scalar_type("bool")
+    )
 
 
 def test_constructor_type():
     mod = tvm.IRModule()
     box, constructor = initialize_box_adt(mod)
 
-    a = relay.TypeVar('a')
-    x = relay.Var('x', a)
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
     ct = run_infer_type(relay.Function([x], constructor(x), box(a), [a]), mod)
     expected = relay.FuncType([a], box(a), [a])
     assert ct.checked_type == expected
@@ -306,27 +306,29 @@ def test_constructor_call():
     box, constructor = initialize_box_adt(mod)
 
     box_unit = constructor(relay.Tuple([]))
-    box_constant = constructor(relay.const(0, 'float32'))
+    box_constant = constructor(relay.const(0, "float32"))
 
     ut = run_infer_type(box_unit, mod)
     ct = run_infer_type(box_constant, mod)
     assert ut.checked_type == box(relay.TupleType([]))
-    assert ct.checked_type == box(relay.TensorType((), 'float32'))
+    assert ct.checked_type == box(relay.TensorType((), "float32"))
 
 
 def test_adt_match():
     mod = tvm.IRModule()
     box, constructor = initialize_box_adt(mod)
 
-    v = relay.Var('v', relay.TensorType((), 'float32'))
-    match = relay.Match(constructor(relay.const(0, 'float32')),
-                        [relay.Clause(
-                            relay.PatternConstructor(constructor,
-                                                     [relay.PatternVar(v)]),
-                            relay.Tuple([])),
-                         # redundant but shouldn't matter to typechecking
-                         relay.Clause(relay.PatternWildcard(),
-                                      relay.Tuple([]))])
+    v = relay.Var("v", relay.TensorType((), "float32"))
+    match = relay.Match(
+        constructor(relay.const(0, "float32")),
+        [
+            relay.Clause(
+                relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
+            ),
+            # redundant but shouldn't matter to typechecking
+            relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
+        ],
+    )
 
     mt = run_infer_type(match, mod)
     assert mt.checked_type == relay.TupleType([])
@@ -338,14 +340,17 @@ def test_adt_match_type_annotations():
 
     # the only type annotation is inside the match pattern var
     # but that should be enough info
-    tt = relay.TensorType((2, 2), 'float32')
-    x = relay.Var('x')
-    mv = relay.Var('mv', tt)
-    match = relay.Match(constructor(x),
-                        [relay.Clause(
-                            relay.PatternConstructor(constructor,
-                                                     [relay.PatternVar(mv)]),
-                                                     relay.Tuple([]))])
+    tt = relay.TensorType((2, 2), "float32")
+    x = relay.Var("x")
+    mv = relay.Var("mv", tt)
+    match = relay.Match(
+        constructor(x),
+        [
+            relay.Clause(
+                relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
+            )
+        ],
+    )
 
     func = relay.Function([x], match)
     ft = run_infer_type(func, mod)
@@ -364,16 +369,17 @@ def test_let_polymorphism():
 
 
 def test_if():
-    choice_t = relay.FuncType([], relay.scalar_type('bool'))
-    f = relay.Var('f', choice_t)
-    true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
-    false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
+    true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32"))
+    false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32"))
     top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
     ft = run_infer_type(top)
-    tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))
+    tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
+
 
 def test_type_arg_infer():
-  code = """
+    code = """
 #[version = "0.0.5"]
 def @id[A](%x: A) -> A {
   %x
@@ -382,9 +388,10 @@ def @main(%f: float32) -> float32 {
   @id(%f)
 }
 """
-  mod = tvm.parser.fromtext(code)
-  mod = transform.InferType()(mod)
-  tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')])
+    mod = tvm.parser.fromtext(code)
+    mod = transform.InferType()(mod)
+    tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
+
 
 if __name__ == "__main__":
     pytest.main([__file__])
index 6d72ad3..88bdd16 100644 (file)
@@ -25,6 +25,7 @@ def make_rel(name, args, num_inputs=None, attrs=None):
         num_inputs = len(args) - 1
     return relay.ty.TypeRelation(func, args, num_inputs, attrs)
 
+
 def make_solver():
     solver = relay.analysis._ffi_api._test_type_solver()
     solver.Solve = solver("Solve")
@@ -81,14 +82,14 @@ def test_unify_tuple():
 def test_unify_global_type_var():
     # should only be able to unify if they're the same
     solver = make_solver()
-    gtv = relay.GlobalTypeVar('gtv')
+    gtv = relay.GlobalTypeVar("gtv")
     unified = solver.Unify(gtv, gtv)
     assert unified == gtv
 
 
 def test_unify_typecall():
     solver = make_solver()
-    gtv = relay.GlobalTypeVar('gtv')
+    gtv = relay.GlobalTypeVar("gtv")
 
     # yeah, typecalls are shaped like tuples so the same
     # tests work out
@@ -153,7 +154,7 @@ def test_unify_vars_under_tuples():
     tup3 = relay.ty.TupleType([t1, t2])
     tup4 = relay.ty.TupleType([t2, t1])
     unified = solver.Unify(tup3, tup4)
-    assert (unified == tup1 or unified == tup2)
+    assert unified == tup1 or unified == tup2
 
 
 def test_binding_over_typevars():
@@ -162,15 +163,15 @@ def test_binding_over_typevars():
     t1 = relay.ty.IncompleteType()
     t2 = relay.ty.IncompleteType()
 
-    a = relay.ty.TypeVar('a')
-    b = relay.ty.TypeVar('b')
-    c = relay.ty.TypeVar('c')
-    d = relay.ty.TypeVar('d')
+    a = relay.ty.TypeVar("a")
+    b = relay.ty.TypeVar("b")
+    c = relay.ty.TypeVar("c")
+    d = relay.ty.TypeVar("d")
 
     ft1 = relay.ty.FuncType([t1], t2, [c, d])
     ft2 = relay.ty.FuncType([a], b, [a, b])
     unified = solver.Unify(ft1, ft2)
-    assert (unified == solver.Resolve(ft1))
+    assert unified == solver.Resolve(ft1)
 
 
 def test_recursive_backward_solving():
@@ -226,7 +227,7 @@ def test_backward_solving_after_child_update():
 
 def test_unify_quantified_funcs():
     solver = make_solver()
-    a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
+    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
     ft1 = relay.FuncType([a, b], c, [a, b, c])
     ft2 = relay.FuncType([a, a], a, [a])
     unified = solver.Unify(ft1, ft2)
@@ -240,7 +241,7 @@ def test_unify_quantified_funcs():
 
 def test_unify_quantified_func_and_concrete():
     solver = make_solver()
-    a, b = relay.TypeVar('a'), relay.TypeVar('b')
+    a, b = relay.TypeVar("a"), relay.TypeVar("b")
     ft1 = relay.FuncType([a], b, [a, b])
     ft2 = relay.FuncType([b], relay.TupleType([]), [b])
     unified = solver.Unify(ft1, ft2)
@@ -249,7 +250,7 @@ def test_unify_quantified_func_and_concrete():
 
 def test_unify_quantified_funcs_nesting():
     solver = make_solver()
-    a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
+    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
 
     ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
     ft2 = relay.FuncType([a, relay.TupleType([a, a])], relay.TupleType([a, a, a]), [a])
@@ -259,7 +260,7 @@ def test_unify_quantified_funcs_nesting():
 
 def test_unify_quantified_funcs_var_order():
     solver = make_solver()
-    a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
+    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
 
     ft1 = relay.FuncType([a, relay.TupleType([b, c])], relay.TupleType([a, b, c]), [a, b, c])
     ft2 = relay.FuncType([a, relay.TupleType([a, c])], relay.TupleType([a, a, c]), [a, c])
@@ -292,16 +293,16 @@ def test_bad_recursive_unification():
 @pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
 def test_unify_invalid_global_typevars():
     solver = make_solver()
-    gtv1 = relay.GlobalTypeVar('gtv1')
-    gtv2 = relay.GlobalTypeVar('gtv2')
+    gtv1 = relay.GlobalTypeVar("gtv1")
+    gtv2 = relay.GlobalTypeVar("gtv2")
     solver.Unify(gtv1, gtv2)
 
 
 @pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
 def test_incompatible_typecall_var_unification():
     solver = make_solver()
-    gtv1 = relay.GlobalTypeVar('gtv1')
-    gtv2 = relay.GlobalTypeVar('gtv2')
+    gtv1 = relay.GlobalTypeVar("gtv1")
+    gtv2 = relay.GlobalTypeVar("gtv2")
 
     t1 = relay.IncompleteType()
     t2 = relay.IncompleteType()
@@ -314,7 +315,7 @@ def test_incompatible_typecall_var_unification():
 @pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
 def test_incompatible_typecall_args_unification():
     solver = make_solver()
-    gtv = relay.GlobalTypeVar('gtv1')
+    gtv = relay.GlobalTypeVar("gtv1")
     t1 = relay.IncompleteType()
     t2 = relay.IncompleteType()
 
@@ -330,7 +331,7 @@ def test_incompatible_typecall_args_unification():
 @pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
 def test_incompatible_quantified_func_unification():
     solver = make_solver()
-    a, b, c = relay.TypeVar('a'), relay.TypeVar('b'), relay.TypeVar('c')
+    a, b, c = relay.TypeVar("a"), relay.TypeVar("b"), relay.TypeVar("c")
 
     ft1 = relay.FuncType([a, b], c, [a, b, c])
     ft2 = relay.FuncType([b, c], relay.TupleType([a]), [a, b, c])
index 491047d..1cfa661 100644 (file)
@@ -19,6 +19,7 @@ from tvm import te
 from tvm import relay
 from tvm.relay import transform
 
+
 def test_dup_type():
     a = relay.TypeVar("a")
     av = relay.Var("av", a)
index 1e8069b..0ee6acc 100644 (file)
@@ -26,6 +26,7 @@ from tvm.relay.loops import while_loop
 from tvm.relay import testing
 import tvm.testing
 
+
 def check_result(args, expected_result, mod=None):
     """
     Check that evaluating `expr` applied to the arguments produces
@@ -40,10 +41,11 @@ def check_result(args, expected_result, mod=None):
         The expected result of running the expression.
     """
     for target, ctx in tvm.testing.enabled_targets():
-        vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
+        vm = relay.create_executor("vm", ctx=ctx, target=target, mod=mod)
         rts_result = vm.evaluate()(*args)
         tvm.testing.assert_allclose(expected_result, rts_result.asnumpy())
 
+
 def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
     if isinstance(f, relay.Expr):
         mod = tvm.IRModule()
@@ -55,6 +57,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
     vm = runtime.vm.VirtualMachine(exe, ctx)
     return vm.invoke("main", *args)
 
+
 def vmobj_to_list(o):
     if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy().tolist()]
@@ -66,61 +69,71 @@ def vmobj_to_list(o):
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
+
 @tvm.testing.uses_gpu
 def test_split():
-    x = relay.var('x', shape=(12,))
+    x = relay.var("x", shape=(12,))
     y = relay.split(x, 3, axis=0).astuple()
     f = relay.Function([x], y)
 
-    x_data = np.random.rand(12,).astype('float32')
+    x_data = np.random.rand(
+        12,
+    ).astype("float32")
     ref_res = np.split(x_data, 3, axis=0)
     for tgt, ctx in tvm.testing.enabled_targets():
         res = veval(f, x_data, ctx=ctx, target=tgt)
         for i in range(3):
             tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
 
+
 @tvm.testing.uses_gpu
 def test_split_no_fuse():
-    x = relay.var('x', shape=(12,))
+    x = relay.var("x", shape=(12,))
     y = relay.split(x, 3, axis=0).astuple()
     z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
     z = relay.annotation.stop_fusion(z)
     f = relay.Function([x], z)
-    x_data = np.random.rand(12,).astype('float32')
+    x_data = np.random.rand(
+        12,
+    ).astype("float32")
     for tgt, ctx in tvm.testing.enabled_targets():
         res = veval(f, x_data, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
 
+
 @tvm.testing.uses_gpu
 def test_id():
-    x = relay.var('x', shape=(10, 10), dtype='float64')
+    x = relay.var("x", shape=(10, 10), dtype="float64")
     f = relay.Function([x], x)
-    x_data = np.random.rand(10, 10).astype('float64')
+    x_data = np.random.rand(10, 10).astype("float64")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([x_data], x_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_op():
-    x = relay.var('x', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
     f = relay.Function([x], x + x)
-    x_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([x_data], 2 * x_data, mod=mod)
 
+
 def any(x):
     x = relay.op.nn.batch_flatten(x)
     return relay.op.min(x, axis=[0, 1])
 
+
 @tvm.testing.uses_gpu
 def test_cond():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(10, 10))
     # f = relay.Function([x, y], relay.op.equal(x, y))
     f = relay.Function([x, y], any(relay.op.equal(x, y)))
-    x_data = np.random.rand(10, 10).astype('float32')
-    y_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
+    y_data = np.random.rand(10, 10).astype("float32")
 
     mod = tvm.IRModule()
     mod["main"] = f
@@ -130,14 +143,14 @@ def test_cond():
     # diff
     check_result([x_data, y_data], False, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_simple_if():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(10, 10))
-    f = relay.Function([x, y],
-        relay.If(any(relay.op.equal(x, y)), x, y))
-    x_data = np.random.rand(10, 10).astype('float32')
-    y_data = np.random.rand(10, 10).astype('float32')
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(10, 10))
+    f = relay.Function([x, y], relay.If(any(relay.op.equal(x, y)), x, y))
+    x_data = np.random.rand(10, 10).astype("float32")
+    y_data = np.random.rand(10, 10).astype("float32")
 
     mod = tvm.IRModule()
     mod["main"] = f
@@ -147,107 +160,114 @@ def test_simple_if():
     # diff
     check_result([x_data, y_data], y_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_multiple_ifs():
     mod = tvm.IRModule({})
-    b = relay.var('b')
-    v0 = relay.var('v0')
-    v1 = relay.var('v1')
-    v2 = relay.var('v2')
-    v3 = relay.var('v3')
+    b = relay.var("b")
+    v0 = relay.var("v0")
+    v1 = relay.var("v1")
+    v2 = relay.var("v2")
+    v3 = relay.var("v3")
     out = relay.Tuple([v2, v3])
     out = relay.Let(v3, relay.If(b, v1, v0), out)
     out = relay.Let(v2, relay.If(b, v0, v1), out)
     out = relay.Let(v1, relay.Tuple([relay.const(1)]), out)
     out = relay.Let(v0, relay.Tuple([relay.const(0)]), out)
     fn = relay.Function([b], out)
-    mod['main'] = fn
-    ctx = tvm.runtime.ndarray.context('llvm', 0)
-    vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm')
+    mod["main"] = fn
+    ctx = tvm.runtime.ndarray.context("llvm", 0)
+    vm = relay.create_executor(ctx=ctx, mod=mod, kind="vm")
     res = vmobj_to_list(vm.evaluate()(False))
-    assert(res == [1, 0])
+    assert res == [1, 0]
+
 
 @tvm.testing.uses_gpu
 def test_simple_call():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
     sb = ScopeBuilder()
     sb.ret(i)
-    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
     mod[sum_up] = func
-    i_data = np.array(0, dtype='int32')
-    iarg = relay.var('iarg', shape=[], dtype='int32')
+    i_data = np.array(0, dtype="int32")
+    iarg = relay.var("iarg", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     check_result([i_data], i_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_count_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
     sb = ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
         sb.ret(i)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+        one_less = relay.subtract(i, relay.const(1, dtype="int32"))
         rec_call = relay.Call(sum_up, [one_less])
         sb.ret(relay.add(rec_call, i))
-    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32"))
     mod[sum_up] = func
-    i_data = np.array(0, dtype='int32')
-    iarg = relay.var('i', shape=[], dtype='int32')
+    i_data = np.array(0, dtype="int32")
+    iarg = relay.var("i", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     for tgt, ctx in tvm.testing.enabled_targets():
         result = veval(mod, i_data, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(result.asnumpy(), i_data)
     check_result([i_data], i_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_sum_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
-    accum = relay.var('accum', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
+    accum = relay.var("accum", shape=[], dtype="int32")
     sb = ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, "int32"))):
         sb.ret(accum)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, 'int32'))
+        one_less = relay.subtract(i, relay.const(1, "int32"))
         new_accum = relay.add(accum, i)
         sb.ret(relay.Call(sum_up, [one_less, new_accum]))
     func = relay.Function([i, accum], sb.get())
     mod[sum_up] = func
     loop_bound = 0
-    i_data = np.array(loop_bound, dtype='int32')
-    accum_data = np.array(0, dtype='int32')
-    iarg = relay.var('i', shape=[], dtype='int32')
-    aarg = relay.var('accum', shape=[], dtype='int32')
+    i_data = np.array(loop_bound, dtype="int32")
+    accum_data = np.array(0, dtype="int32")
+    iarg = relay.var("i", shape=[], dtype="int32")
+    aarg = relay.var("accum", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
     check_result([i_data, accum_data], sum(range(1, loop_bound + 1)), mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_tuple_fst():
     ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
-    tup = relay.var('tup', type_annotation=ttype)
+    tup = relay.var("tup", type_annotation=ttype)
     f = relay.Function([tup], relay.TupleGetItem(tup, 0))
-    i_data = np.random.rand(41).astype('float32')
-    j_data = np.random.rand(10).astype('float32')
+    i_data = np.random.rand(41).astype("float32")
+    j_data = np.random.rand(10).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([(i_data, j_data)], i_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_tuple_second():
     ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
-    tup = relay.var('tup', type_annotation=ttype)
+    tup = relay.var("tup", type_annotation=ttype)
     f = relay.Function([tup], relay.TupleGetItem(tup, 1))
-    i_data = np.random.rand(41).astype('float32')
-    j_data = np.random.rand(10).astype('float32')
+    i_data = np.random.rand(41).astype("float32")
+    j_data = np.random.rand(10).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([(i_data, j_data)], j_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_list_constructor():
     mod = tvm.IRModule()
@@ -270,44 +290,47 @@ def test_list_constructor():
         assert len(result[1]) == 2
 
         obj = vmobj_to_list(result)
-        tvm.testing.assert_allclose(obj, np.array([3,2,1]))
+        tvm.testing.assert_allclose(obj, np.array([3, 2, 1]))
+
 
 @tvm.testing.uses_gpu
 def test_let_tensor():
     sb = relay.ScopeBuilder()
     shape = (1,)
-    x = relay.var('x', shape=shape, dtype='float32')
-    x1 = relay.var('x1', shape=shape, dtype='float32')
+    x = relay.var("x", shape=shape, dtype="float32")
+    x1 = relay.var("x1", shape=shape, dtype="float32")
 
     x1 = sb.let(x1, x)
-    xplusone = x1 + relay.const(42.0, 'float32')
+    xplusone = x1 + relay.const(42.0, "float32")
     sb.ret(xplusone)
     body = sb.get()
 
     f = relay.Function([x], body)
 
-    x_data = np.random.rand(*shape).astype('float32')
+    x_data = np.random.rand(*shape).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([x_data], x_data + 42.0, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_let_scalar():
     sb = relay.ScopeBuilder()
 
-    x = relay.var('x', 'float32')
-    x1 = sb.let('x1', x)
-    xplusone = x1 + relay.const(42.0, 'float32')
+    x = relay.var("x", "float32")
+    x1 = sb.let("x1", x)
+    xplusone = x1 + relay.const(42.0, "float32")
     sb.ret(xplusone)
     body = sb.get()
 
     f = relay.Function([x], body)
 
-    x_data = np.array(np.random.rand()).astype('float32')
+    x_data = np.array(np.random.rand()).astype("float32")
     mod = tvm.IRModule()
     mod["main"] = f
     check_result([x_data], x_data + 42.0, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_compose():
     mod = tvm.IRModule()
@@ -317,9 +340,9 @@ def test_compose():
 
     # add_one = fun x -> x + 1
     sb = relay.ScopeBuilder()
-    x = relay.var('x', 'float32')
-    x1 = sb.let('x1', x)
-    xplusone = x1 + relay.const(1.0, 'float32')
+    x = relay.var("x", "float32")
+    x1 = sb.let("x1", x)
+    xplusone = x1 + relay.const(1.0, "float32")
     sb.ret(xplusone)
     body = sb.get()
     add_one = relay.GlobalVar("add_one")
@@ -327,8 +350,8 @@ def test_compose():
 
     # add_two = compose(add_one, add_one)
     sb = relay.ScopeBuilder()
-    y = relay.var('y', 'float32')
-    add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
+    y = relay.var("y", "float32")
+    add_two_func = sb.let("add_two", compose(add_one_func, add_one_func))
     add_two_res = add_two_func(y)
     sb.ret(add_two_res)
     add_two_body = sb.get()
@@ -338,11 +361,12 @@ def test_compose():
     f = relay.Function([y], add_two_body)
     mod["main"] = f
 
-    x_data = np.array(np.random.rand()).astype('float32')
+    x_data = np.array(np.random.rand()).astype("float32")
     for tgt, ctx in tvm.testing.enabled_targets():
         result = veval(mod, [x_data], ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
 
+
 @tvm.testing.uses_gpu
 def test_list_hd():
     mod = tvm.IRModule()
@@ -365,6 +389,7 @@ def test_list_hd():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(result.asnumpy(), 3)
 
+
 @pytest.mark.xfail
 def test_list_tl_empty_list():
     mod = tvm.IRModule()
@@ -381,6 +406,7 @@ def test_list_tl_empty_list():
     for tgt, ctx in tvm.testing.enabled_targets():
         result = veval(mod, ctx=ctx, target=tgt)
 
+
 @tvm.testing.uses_gpu
 def test_list_tl():
     mod = tvm.IRModule()
@@ -401,7 +427,8 @@ def test_list_tl():
 
     for tgt, ctx in tvm.testing.enabled_targets():
         result = veval(mod, ctx=ctx, target=tgt)
-        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2, 1]))
+
 
 @tvm.testing.uses_gpu
 def test_list_nth():
@@ -424,6 +451,7 @@ def test_list_nth():
             result = veval(mod, ctx=ctx, target=tgt)
             tvm.testing.assert_allclose(result.asnumpy(), expected[i])
 
+
 @tvm.testing.uses_gpu
 def test_list_update():
     expected = list(range(10))
@@ -450,6 +478,7 @@ def test_list_update():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
 
+
 @tvm.testing.uses_gpu
 def test_list_length():
     expected = list(range(10))
@@ -474,12 +503,13 @@ def test_list_length():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(result.asnumpy(), 10)
 
+
 @tvm.testing.uses_gpu
 def test_list_map():
     mod = tvm.IRModule()
     p = Prelude(mod)
 
-    x = relay.var('x', 'int32')
+    x = relay.var("x", "int32")
     add_one_func = relay.Function([x], relay.const(1) + x)
 
     nil = p.nil
@@ -494,6 +524,7 @@ def test_list_map():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
 
+
 @tvm.testing.uses_gpu
 def test_list_foldl():
     mod = tvm.IRModule()
@@ -514,6 +545,7 @@ def test_list_foldl():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
 
+
 @tvm.testing.uses_gpu
 def test_list_foldr():
     mod = tvm.IRModule()
@@ -534,6 +566,7 @@ def test_list_foldr():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
 
+
 @tvm.testing.uses_gpu
 def test_list_sum():
     mod = tvm.IRModule()
@@ -550,6 +583,7 @@ def test_list_sum():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(result.asnumpy(), 6)
 
+
 @tvm.testing.uses_gpu
 def test_list_filter():
     mod = tvm.IRModule()
@@ -559,23 +593,25 @@ def test_list_filter():
     cons = p.cons
     filter = p.filter
 
-    x = relay.var("x", 'int32')
+    x = relay.var("x", "int32")
     greater_than_one = relay.Function([x], x > relay.const(1))
-    l = cons(relay.const(1),
-            cons(relay.const(3),
-                cons(relay.const(1),
-                    cons(relay.const(5),
-                        cons(relay.const(1), nil())))))
+    l = cons(
+        relay.const(1),
+        cons(
+            relay.const(3), cons(relay.const(1), cons(relay.const(5), cons(relay.const(1), nil())))
+        ),
+    )
     f = relay.Function([], filter(greater_than_one, l))
     mod["main"] = f
     for tgt, ctx in tvm.testing.enabled_targets():
         result = veval(mod, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
 
+
 @tvm.testing.uses_gpu
 def test_closure():
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     f = relay.Function([x], x + y)
     ff = relay.Function([y], f)
     clo = ff(relay.const(1.0))
@@ -584,6 +620,7 @@ def test_closure():
         res = veval(main, ctx=ctx, target=tgt)
         tvm.testing.assert_allclose(res.asnumpy(), 3.0)
 
+
 @tvm.testing.uses_gpu
 def test_add_op_scalar():
     """
@@ -593,14 +630,15 @@ def test_add_op_scalar():
         }
     """
     mod = tvm.IRModule()
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     func = relay.Function([x, y], relay.op.add(x, y))
-    x_data = np.array(10.0, dtype='float32')
-    y_data = np.array(1.0, dtype='float32')
+    x_data = np.array(10.0, dtype="float32")
+    y_data = np.array(1.0, dtype="float32")
     mod["main"] = func
     check_result([x_data, y_data], x_data + y_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_add_op_tensor():
     """
@@ -610,14 +648,15 @@ def test_add_op_tensor():
         }
     """
     mod = tvm.IRModule()
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(10, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(10, 5))
     func = relay.Function([x, y], relay.op.add(x, y))
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(10, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(10, 5).astype("float32")
     mod["main"] = func
     check_result([x_data, y_data], x_data + y_data, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_add_op_broadcast():
     """
@@ -627,37 +666,40 @@ def test_add_op_broadcast():
         }
     """
     mod = tvm.IRModule()
-    x = relay.var('x', shape=(10, 5))
-    y = relay.var('y', shape=(1, 5))
+    x = relay.var("x", shape=(10, 5))
+    y = relay.var("y", shape=(1, 5))
     func = relay.Function([x, y], relay.op.add(x, y))
-    x_data = np.random.rand(10, 5).astype('float32')
-    y_data = np.random.rand(1, 5).astype('float32')
+    x_data = np.random.rand(10, 5).astype("float32")
+    y_data = np.random.rand(1, 5).astype("float32")
     mod["main"] = func
     check_result([x_data, y_data], x_data + y_data, mod=mod)
 
+
 def test_vm_optimize_dynamic():
-    dtype = 'float32'
-    x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype=dtype)
-    y = relay.var('y', shape=(relay.Any(), relay.Any()), dtype=dtype)
+    dtype = "float32"
+    x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype=dtype)
+    y = relay.var("y", shape=(relay.Any(), relay.Any()), dtype=dtype)
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([x, y], relay.add(x, y))
+    mod["main"] = relay.Function([x, y], relay.add(x, y))
     comp = relay.vm.VMCompiler()
     opt_mod, _ = comp.optimize(mod, target="llvm")
     assert "shape_func" in opt_mod.astext(False)
 
+
 def test_vm_optimize():
     mod, params = testing.synthetic.get_workload()
     comp = relay.vm.VMCompiler()
     opt_mod, _ = comp.optimize(mod, target="llvm", params=params)
 
+
 @tvm.testing.uses_gpu
 def test_loop_free_var():
-    x = relay.var('x', shape=(), dtype='int32')
-    i = relay.var('i', shape=(), dtype='int32')
-    s = relay.var('s', shape=(), dtype='int32')
+    x = relay.var("x", shape=(), dtype="int32")
+    i = relay.var("i", shape=(), dtype="int32")
+    s = relay.var("s", shape=(), dtype="int32")
 
     def cond(i, _):
-        return i < relay.const(10, dtype='int32')
+        return i < relay.const(10, dtype="int32")
 
     def body_no_free_var(i, acc):
         incr = relay.const(1, "int32")
@@ -667,16 +709,15 @@ def test_loop_free_var():
         incr = relay.const(1, "int32")
         return i + incr, acc + x
 
-    for args, body, expected in zip([[], [1]],
-                                    [body_no_free_var, body_with_free_var],
-                                    [45, 10]):
+    for args, body, expected in zip([[], [1]], [body_no_free_var, body_with_free_var], [45, 10]):
         loop = while_loop(cond, [i, s], body)
-        tup = loop(relay.const(0, dtype='int32'), relay.zeros(shape=(), dtype='int32'))
+        tup = loop(relay.const(0, dtype="int32"), relay.zeros(shape=(), dtype="int32"))
         ret = relay.TupleGetItem(tup, 1)
         mod = tvm.IRModule()
         mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
         check_result(args, expected, mod=mod)
 
+
 @tvm.testing.uses_gpu
 def test_vm_reshape_tensor():
     x_np = np.random.uniform(size=(8, 16)).astype("float32")
@@ -700,7 +741,7 @@ def test_vm_reshape_tensor():
     check_result([x_np], x_np.reshape([4, 4, 8]), mod)
 
     # reshape with symbolic/any shape
-    for n in [tvm.tir.Any(), tvm.te.size_var('n')]:
+    for n in [tvm.tir.Any(), tvm.te.size_var("n")]:
         x = relay.var("x", shape=(n, 16), dtype="float32")
         y = relay.reshape(x, [-1, 4])
         y = relay.reshape(y, [0, 2, -1])
@@ -725,5 +766,6 @@ def test_vm_reshape_tensor():
     y_np = np.array([8, 2, 8]).astype("int32")
     check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod)
 
+
 if __name__ == "__main__":
     pytest.main([__file__])
index b304c43..4be2fe9 100644 (file)
@@ -29,6 +29,7 @@ from tvm.relay.prelude import Prelude
 from tvm.contrib import util
 from tvm.relay import testing
 
+
 def create_exec(f, target="llvm", params=None):
     if isinstance(f, relay.Expr):
         mod = tvm.IRModule()
@@ -41,8 +42,7 @@ def create_exec(f, target="llvm", params=None):
         return executable
 
 
-def get_serialized_output(mod, *data, params=None, target="llvm",
-                          ctx=tvm.cpu()):
+def get_serialized_output(mod, *data, params=None, target="llvm", ctx=tvm.cpu()):
     exe = create_exec(mod, target, params=params)
     code, lib = exe.save()
     des_exec = _vm.Executable.load_exec(code, lib)
@@ -50,11 +50,10 @@ def get_serialized_output(mod, *data, params=None, target="llvm",
     result = des_vm.run(*data)
     return result
 
-def run_network(mod,
-                params,
-                dtype='float32'):
-    def get_vm_output(mod, data, params, target, ctx, dtype='float32'):
-        ex = relay.create_executor('vm', mod=mod, ctx=ctx)
+
+def run_network(mod, params, dtype="float32"):
+    def get_vm_output(mod, data, params, target, ctx, dtype="float32"):
+        ex = relay.create_executor("vm", mod=mod, ctx=ctx)
         result = ex.evaluate()(data, **params)
         return result.asnumpy().astype(dtype)
 
@@ -64,30 +63,29 @@ def run_network(mod,
     target = "llvm"
     ctx = tvm.cpu(0)
 
-    tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
-                            target, ctx, dtype)
-    vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)),
-                                   params=params, target=target, ctx=ctx)
-    tvm.testing.assert_allclose(vm_out.asnumpy().astype(dtype), tvm_out,
-                                rtol=1e-5, atol=1e-5)
+    tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, target, ctx, dtype)
+    vm_out = get_serialized_output(
+        mod, tvm.nd.array(data.astype(dtype)), params=params, target=target, ctx=ctx
+    )
+    tvm.testing.assert_allclose(vm_out.asnumpy().astype(dtype), tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def test_serializer():
     mod = tvm.IRModule({})
     a = relay.const(1.0, "float32")
-    x = relay.var('x', shape=(10, 10), dtype='float32')
+    x = relay.var("x", shape=(10, 10), dtype="float32")
     f1 = relay.Function([x], x + a)
     glb_f1 = relay.GlobalVar("f1")
     mod[glb_f1] = f1
 
     b = relay.const(2.0, "float32")
-    y = relay.var('y', shape=(10, 10), dtype='float32')
+    y = relay.var("y", shape=(10, 10), dtype="float32")
     f2 = relay.Function([y], y - b)
     glb_f2 = relay.GlobalVar("f2")
     mod[glb_f2] = f2
 
-    x1 = relay.var('x1', shape=(10, 10), dtype='float32')
-    y1 = relay.var('y1', shape=(10, 10), dtype='float32')
+    x1 = relay.var("x1", shape=(10, 10), dtype="float32")
+    y1 = relay.var("y1", shape=(10, 10), dtype="float32")
     main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
     mod["main"] = main
 
@@ -100,9 +98,9 @@ def test_serializer():
     assert "main" in glbs
 
     prim_ops = exe.primitive_ops
-    assert any(item.startswith('fused_add') for item in prim_ops)
-    assert any(item.startswith('fused_subtract') for item in prim_ops)
-    assert any(item.startswith('fused_multiply') for item in prim_ops)
+    assert any(item.startswith("fused_add") for item in prim_ops)
+    assert any(item.startswith("fused_subtract") for item in prim_ops)
+    assert any(item.startswith("fused_multiply") for item in prim_ops)
 
     code = exe.bytecode
     assert "main(x1, y1)" in code
@@ -115,9 +113,9 @@ def test_serializer():
 
 
 def test_save_load():
-    x = relay.var('x', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
     f = relay.Function([x], x + x)
-    x_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
 
     # serialize.
     vm = create_exec(f)
@@ -144,22 +142,21 @@ def test_save_load():
 
 def test_const():
     c = relay.const(1.0, "float32")
-    x = relay.var('x', shape=(10, 10), dtype='float32')
+    x = relay.var("x", shape=(10, 10), dtype="float32")
     f = relay.Function([x], x + c)
-    x_data = np.random.rand(10, 10).astype('float32')
+    x_data = np.random.rand(10, 10).astype("float32")
     res = get_serialized_output(f, x_data)
     tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
 
 
 def test_if():
-    x = relay.var('x', shape=(10, 10))
-    y = relay.var('y', shape=(10, 10))
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(10, 10))
     equal = relay.op.equal(x, y)
     equal = relay.op.nn.batch_flatten(equal)
-    f = relay.Function([x, y], relay.If(relay.op.min(equal, axis=[0, 1]), x,
-                                        y))
-    x_data = np.random.rand(10, 10).astype('float32')
-    y_data = np.random.rand(10, 10).astype('float32')
+    f = relay.Function([x, y], relay.If(relay.op.min(equal, axis=[0, 1]), x, y))
+    x_data = np.random.rand(10, 10).astype("float32")
+    y_data = np.random.rand(10, 10).astype("float32")
 
     # same
     res = get_serialized_output(f, x_data, x_data)
@@ -172,23 +169,23 @@ def test_if():
 
 def test_loop():
     mod = tvm.IRModule({})
-    sum_up = relay.GlobalVar('sum_up')
-    i = relay.var('i', shape=[], dtype='int32')
-    accum = relay.var('accum', shape=[], dtype='int32')
+    sum_up = relay.GlobalVar("sum_up")
+    i = relay.var("i", shape=[], dtype="int32")
+    accum = relay.var("accum", shape=[], dtype="int32")
     sb = ScopeBuilder()
-    with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
+    with sb.if_scope(relay.equal(i, relay.const(0, "int32"))):
         sb.ret(accum)
     with sb.else_scope():
-        one_less = relay.subtract(i, relay.const(1, 'int32'))
+        one_less = relay.subtract(i, relay.const(1, "int32"))
         new_accum = relay.add(accum, i)
         sb.ret(relay.Call(sum_up, [one_less, new_accum]))
     func = relay.Function([i, accum], sb.get())
     mod[sum_up] = func
     loop_bound = 0
-    i_data = np.array(loop_bound, dtype='int32')
-    accum_data = np.array(0, dtype='int32')
-    iarg = relay.var('i', shape=[], dtype='int32')
-    aarg = relay.var('accum', shape=[], dtype='int32')
+    i_data = np.array(loop_bound, dtype="int32")
+    accum_data = np.array(0, dtype="int32")
+    iarg = relay.var("i", shape=[], dtype="int32")
+    aarg = relay.var("accum", shape=[], dtype="int32")
     mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
 
     result = get_serialized_output(mod, i_data, accum_data)
@@ -197,10 +194,10 @@ def test_loop():
 
 def test_tuple():
     ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
-    tup = relay.var('tup', type_annotation=ttype)
+    tup = relay.var("tup", type_annotation=ttype)
     f = relay.Function([tup], relay.TupleGetItem(tup, 1))
-    i_data = np.random.rand(41).astype('float32')
-    j_data = np.random.rand(10).astype('float32')
+    i_data = np.random.rand(41).astype("float32")
+    j_data = np.random.rand(10).astype("float32")
 
     result = get_serialized_output(f, (i_data, j_data))
     tvm.testing.assert_allclose(result.asnumpy(), j_data)
@@ -236,9 +233,9 @@ def test_adt_compose():
 
     # add_one = fun x -> x + 1
     sb = relay.ScopeBuilder()
-    x = relay.var('x', 'float32')
-    x1 = sb.let('x1', x)
-    xplusone = x1 + relay.const(1.0, 'float32')
+    x = relay.var("x", "float32")
+    x1 = sb.let("x1", x)
+    xplusone = x1 + relay.const(1.0, "float32")
     sb.ret(xplusone)
     body = sb.get()
     add_one = relay.GlobalVar("add_one")
@@ -246,8 +243,8 @@ def test_adt_compose():
 
     # add_two = compose(add_one, add_one)
     sb = relay.ScopeBuilder()
-    y = relay.var('y', 'float32')
-    add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
+    y = relay.var("y", "float32")
+    add_two_func = sb.let("add_two", compose(add_one_func, add_one_func))
     add_two_res = add_two_func(y)
     sb.ret(add_two_res)
     add_two_body = sb.get()
@@ -257,14 +254,14 @@ def test_adt_compose():
     f = relay.Function([y], add_two_body)
     mod["main"] = f
 
-    x_data = np.array(np.random.rand()).astype('float32')
+    x_data = np.array(np.random.rand()).astype("float32")
     result = get_serialized_output(mod, x_data)
     tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
 
 
 def test_closure():
-    x = relay.var('x', shape=())
-    y = relay.var('y', shape=())
+    x = relay.var("x", shape=())
+    y = relay.var("y", shape=())
     f = relay.Function([x], x + y)
     ff = relay.Function([y], f)
     clo = ff(relay.const(1.0))
@@ -285,13 +282,13 @@ def test_mobilenet():
 
 
 def test_vm_shape_of():
-    x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32")
+    x = relay.var("x", shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32")
     relu_x = relay.nn.relu(x)
-    data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+    data = np.random.uniform(size=(2, 3, 4)).astype("float32")
     args = [data]
 
-    newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
-    args.append(np.array((1, -1), dtype='int64'))
+    newshape_var = relay.var("newshape", shape=(2,), dtype="int64")
+    args.append(np.array((1, -1), dtype="int64"))
     main = relay.Function([x, newshape_var], relay.reshape(relu_x, newshape=newshape_var))
 
     res = get_serialized_output(main, *args).asnumpy()
@@ -299,17 +296,16 @@ def test_vm_shape_of():
 
 
 def test_dynamic_bcast():
-    dtype = 'float32'
-    x = relay.var('x', shape=(relay.Any(), 2), dtype=dtype)
-    y = relay.var('y', shape=(3, 2), dtype=dtype)
+    dtype = "float32"
+    x = relay.var("x", shape=(relay.Any(), 2), dtype=dtype)
+    y = relay.var("y", shape=(3, 2), dtype=dtype)
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([x, y], relay.add(x, y))
+    mod["main"] = relay.Function([x, y], relay.add(x, y))
     x_data = np.random.uniform(size=(1, 2)).astype(dtype)
     y_data = np.random.uniform(size=(3, 2)).astype(dtype)
     res_np = np.add(x_data, y_data)
     for target, ctx in testing.enabled_targets():
-        res = get_serialized_output(mod, *(x_data, y_data), target=target,
-                                    ctx=ctx)
+        res = get_serialized_output(mod, *(x_data, y_data), target=target, ctx=ctx)
         tvm.testing.assert_allclose(res.asnumpy(), res_np)
 
 
index d63251e..c0c4b1e 100644 (file)
@@ -19,6 +19,7 @@
 from tvm import autotvm
 from tvm.autotvm.task.space import FallbackConfigEntity
 
+
 class Int8Fallback(autotvm.FallbackContext):
     def _query_inside(self, target, workload):
         key = (target, workload)
index 0df4822..458fabf 100644 (file)
@@ -25,12 +25,12 @@ import numpy as np
 from tvm.contrib.pickle_memoize import memoize
 
 
-def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
-    buffer = te.placeholder(buffer_shape, name='buffer', dtype=dtype)
-    data = te.placeholder(data_shape, name='data', dtype=dtype)
+def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype="float32"):
+    buffer = te.placeholder(buffer_shape, name="buffer", dtype=dtype)
+    data = te.placeholder(data_shape, name="data", dtype=dtype)
 
     # Use memoize, pickle the test data for next time use
-    @memoize('topi.tests.test_fifo_buffer')
+    @memoize("topi.tests.test_fifo_buffer")
     def get_ref_data():
         buffer_np = np.random.uniform(size=buffer_shape).astype(dtype)
         data_np = np.random.uniform(size=data_shape).astype(dtype)
@@ -47,7 +47,7 @@ def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
     buffer_np, data_np, out_np = get_ref_data()
 
     def check_device(device, ctx):
-        print('  Running on target: {}'.format(device))
+        print("  Running on target: {}".format(device))
 
         with tvm.target.Target(device):
             out = topi.nn.fifo_buffer(data, buffer, axis=axis)
@@ -56,13 +56,14 @@ def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
         buffer_tvm = tvm.nd.array(buffer_np, ctx=ctx)
         data_tvm = tvm.nd.array(data_np, ctx=ctx)
         out_tvm = tvm.nd.empty(shape=buffer_shape, ctx=ctx, dtype=dtype)
-        f = tvm.build(s, [data, buffer, out], device, name='fifo')
+        f = tvm.build(s, [data, buffer, out], device, name="fifo")
         f(data_tvm, buffer_tvm, out_tvm)
         tvm.testing.assert_allclose(out_tvm.asnumpy(), out_np)
 
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_conv1d_integration():
     batch_size = 1
     num_channel = 1
@@ -93,21 +94,22 @@ def verify_conv1d_integration():
     # Rule: Convolution of Tensor[context_shape] and Tensor[kernel_shape]
     #       produces Tensor[inc_input_shape]
 
-    dtype = 'float32'
+    dtype = "float32"
 
-    inc_input = te.placeholder(inc_input_shape, name='inc_input', dtype=dtype)
-    input_window = te.placeholder(input_window_shape, name='input_window', dtype=dtype)
-    context = te.placeholder(context_shape, name='context', dtype=dtype)
-    kernel = te.placeholder(kernel_shape, name='kernel', dtype=dtype)
-    inc_output = te.placeholder(inc_input_shape, name='inc_output', dtype=dtype)
-    output_window = te.placeholder(output_window_shape, name='output_window', dtype=dtype)
+    inc_input = te.placeholder(inc_input_shape, name="inc_input", dtype=dtype)
+    input_window = te.placeholder(input_window_shape, name="input_window", dtype=dtype)
+    context = te.placeholder(context_shape, name="context", dtype=dtype)
+    kernel = te.placeholder(kernel_shape, name="kernel", dtype=dtype)
+    inc_output = te.placeholder(inc_input_shape, name="inc_output", dtype=dtype)
+    output_window = te.placeholder(output_window_shape, name="output_window", dtype=dtype)
 
     # Use memoize, pickle the test data for next time use
-    @memoize('topi.tests.test_fifo_buffer_conv1d_integration')
+    @memoize("topi.tests.test_fifo_buffer_conv1d_integration")
     def get_data():
         # Generate [num_iteration] slices of input
-        inc_input_np = np.random.uniform(size=tuple([num_iteration] + list(inc_input_shape)))\
-                       .astype(dtype)
+        inc_input_np = np.random.uniform(
+            size=tuple([num_iteration] + list(inc_input_shape))
+        ).astype(dtype)
         input_window_np = np.zeros(input_window_shape, dtype=dtype)
         kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
         context_np = np.zeros(context_shape, dtype=dtype)
@@ -119,32 +121,34 @@ def verify_conv1d_integration():
     inc_input_np, input_window_np, kernel_np, context_np, output_window_np = get_data()
 
     def check_device(device, ctx):
-        print('  Running on target: {}'.format(device))
+        print("  Running on target: {}".format(device))
 
         conv2d_nchw, schedule_conv2d_nchw = tvm.topi.testing.get_conv2d_nchw_implement(device)
 
         with tvm.target.Target(device):
             out = topi.nn.fifo_buffer(inc_input, context, axis=buffer_axis)
             s = tvm.topi.testing.get_injective_schedule(device)([out])
-            update_context = tvm.build(s, [inc_input, context, out], device, name='update_context')
+            update_context = tvm.build(s, [inc_input, context, out], device, name="update_context")
 
             out = conv2d_nchw(context, kernel, stride, padding, dilate, dtype)
             s = schedule_conv2d_nchw([out])
-            conv2d_inc = tvm.build(s, [context, kernel, out], device, name='conv2d_inc')
+            conv2d_inc = tvm.build(s, [context, kernel, out], device, name="conv2d_inc")
 
             out = topi.nn.fifo_buffer(inc_output, output_window, axis=buffer_axis)
             s = tvm.topi.testing.get_injective_schedule(device)([out])
-            update_output_window = tvm.build(s, [inc_output, output_window, out], device,
-                 name='update_output_window')
+            update_output_window = tvm.build(
+                s, [inc_output, output_window, out], device, name="update_output_window"
+            )
 
             out = topi.nn.fifo_buffer(inc_input, input_window, axis=buffer_axis)
             s = tvm.topi.testing.get_injective_schedule(device)([out])
-            update_input_window = tvm.build(s, [inc_input, input_window, out], device,
-                                            name='update_input_window')
+            update_input_window = tvm.build(
+                s, [inc_input, input_window, out], device, name="update_input_window"
+            )
 
             out = conv2d_nchw(input_window, kernel, stride, padding, dilate, dtype)
             s = schedule_conv2d_nchw([out])
-            conv2d = tvm.build(s, [input_window, kernel, out], device, name='conv2d')
+            conv2d = tvm.build(s, [input_window, kernel, out], device, name="conv2d")
 
         input_window_tvm = tvm.nd.array(input_window_np, ctx=ctx)
         new_input_window_tvm = tvm.nd.empty(shape=input_window_shape, ctx=ctx, dtype=dtype)
@@ -173,27 +177,34 @@ def verify_conv1d_integration():
             conv2d(input_window_tvm, kernel_tvm, output_window_ref_tvm)
             # Incrementally updating the output window should be equivalent to computing it from
             # scratch using the input window
-            tvm.testing.assert_allclose(output_window_tvm.asnumpy(),
-                                        output_window_ref_tvm.asnumpy())
+            tvm.testing.assert_allclose(
+                output_window_tvm.asnumpy(), output_window_ref_tvm.asnumpy()
+            )
 
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_fifo_buffer():
     for ndim in [1, 2, 3, 4, 5, 6]:
         for axis in range(ndim):
             buffer_shape = tuple(7 for _ in range(ndim))
             data_shape = tuple((2 if i == axis else 7) for i in range(ndim))
-            print('Testing FIFO buffer op: buffer_shape = {}, data_shape = {}, axis = {}'
-                  .format(buffer_shape, data_shape, axis))
+            print(
+                "Testing FIFO buffer op: buffer_shape = {}, data_shape = {}, axis = {}".format(
+                    buffer_shape, data_shape, axis
+                )
+            )
             verify_fifo_buffer(buffer_shape, data_shape, axis)
 
+
 @tvm.testing.uses_gpu
 def test_conv1d_integration():
-    print('Testing FIFO buffer with 1D convolution')
+    print("Testing FIFO buffer with 1D convolution")
     verify_conv1d_integration()
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_fifo_buffer()
     test_conv1d_integration()
index e1e5cf8..0743198 100644 (file)
@@ -27,9 +27,9 @@ def test_util():
 
 
 def test_ewise():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
 
     def test_apply(func, name):
         B = func(A)
index c23716c..769822e 100644 (file)
@@ -31,9 +31,10 @@ _batch_matmul_implement = {
     "gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul),
 }
 
+
 def verify_batch_matmul(batch, M, N, K):
-    x = te.placeholder((batch, M, K), name='x')
-    y = te.placeholder((batch, N, K), name='y')
+    x = te.placeholder((batch, M, K), name="x")
+    y = te.placeholder((batch, N, K), name="y")
     dtype = x.dtype
 
     # use memoize to pickle the test data for next time use
@@ -43,6 +44,7 @@ def verify_batch_matmul(batch, M, N, K):
         b_np = np.random.uniform(size=(batch, N, K)).astype(dtype)
         c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
         return (a_np, b_np, c_np)
+
     # get the test data
     a_np, b_np, c_np = get_ref_data()
 
@@ -62,6 +64,7 @@ def verify_batch_matmul(batch, M, N, K):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_batch_matmul():
     verify_batch_matmul(1, 16, 16, 32)
index 91d6da2..b276aad 100644 (file)
@@ -22,22 +22,35 @@ import tvm.topi.testing
 from tvm.topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
 
+
 def generate_quantized_np(shape, bits, out_dtype):
     min_val = 0
     max_val = 1 << bits
     return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
 
-def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
-                                 activation_bits, weight_bits, unipolar):
+
+def verify_bitserial_conv2d_nchw(
+    batch,
+    in_size,
+    in_channel,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    unipolar,
+):
     in_height = in_width = in_size
-    input_dtype = 'uint32'
-    out_dtype = 'int32'
-
-    with tvm.target.Target('llvm'):
-        A = te.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
-        W = te.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
-        B = topi.x86.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
-                                           input_dtype, out_dtype, unipolar)
+    input_dtype = "uint32"
+    out_dtype = "int32"
+
+    with tvm.target.Target("llvm"):
+        A = te.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name="A")
+        W = te.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name="W")
+        B = topi.x86.bitserial_conv2d_nchw(
+            A, W, stride, padding, activation_bits, weight_bits, input_dtype, out_dtype, unipolar
+        )
         s = topi.x86.schedule_bitserial_conv2d_nchw([B])
 
     a_shape = get_const_tuple(A.shape)
@@ -49,12 +62,13 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
         w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
         if unipolar:
             w_ = np.copy(w_np).astype(out_dtype)
-            for x in np.nditer(w_, op_flags=['readwrite']):
+            for x in np.nditer(w_, op_flags=["readwrite"]):
                 x[...] = 1 if x == 1 else -1
             b_np = tvm.topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding)
         else:
             b_np = tvm.topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
         return a_np, w_np, b_np
+
     a_np, w_np, b_np = get_ref_data()
 
     ctx = tvm.cpu(0)
@@ -65,17 +79,29 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
     func(a, w, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
-                                 activation_bits, weight_bits, unipolar):
+
+def verify_bitserial_conv2d_nhwc(
+    batch,
+    in_size,
+    in_channel,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    unipolar,
+):
     in_height = in_width = in_size
-    input_dtype='uint32'
-    out_dtype='int32'
-
-    with tvm.target.Target('llvm'):
-        A = te.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
-        W = te.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
-        B = topi.x86.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
-                                           input_dtype, out_dtype, unipolar)
+    input_dtype = "uint32"
+    out_dtype = "int32"
+
+    with tvm.target.Target("llvm"):
+        A = te.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name="A")
+        W = te.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name="W")
+        B = topi.x86.bitserial_conv2d_nhwc(
+            A, W, stride, padding, activation_bits, weight_bits, input_dtype, out_dtype, unipolar
+        )
         s = topi.x86.schedule_bitserial_conv2d_nhwc([B])
 
     a_shape = get_const_tuple(A.shape)
@@ -87,23 +113,27 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
         w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
         if unipolar:
             w_ = np.copy(w_np).astype(out_dtype)
-            for x in np.nditer(w_, op_flags=['readwrite']):
+            for x in np.nditer(w_, op_flags=["readwrite"]):
                 x[...] = 1 if x == 1 else -1
             b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
         else:
-            b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
+            b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(
+                out_dtype
+            )
         return a_np, w_np, b_np
+
     a_np, w_np, b_np = get_ref_data()
 
     ctx = tvm.cpu(0)
     a = tvm.nd.array(a_np, ctx)
     w = tvm.nd.array(w_np, ctx)
     b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
-    func = tvm.build(s, [A, W, B], 'llvm')
+    func = tvm.build(s, [A, W, B], "llvm")
 
     func(a, w, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
+
 def test_bitserial_conv2d():
     in_size = 56
     ic, oc = 64, 64
@@ -122,5 +152,6 @@ def test_bitserial_conv2d():
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
 
+
 if __name__ == "__main__":
     test_bitserial_conv2d()
index 659ee21..76f3422 100644 (file)
@@ -23,40 +23,53 @@ from tvm import topi
 import tvm.topi.testing
 from tvm.topi.util import get_const_tuple
 
+
 def generate_quantized_np(shape, bits, out_dtype):
     np.random.seed(0)
     min_val = 0
     max_val = 1 << bits
     return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
 
+
 # Verify that certain special instructions from the tensorize pass exist
-def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
-                                 activation_bits, weight_bits, unipolar):
+def verify_bitserial_conv2d_nhwc(
+    batch,
+    in_size,
+    in_channel,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    activation_bits,
+    weight_bits,
+    unipolar,
+):
     in_height = in_width = in_size
-    input_type = 'uint32'
-    out_dtype = 'int16'
+    input_type = "uint32"
+    out_dtype = "int16"
 
-    device = 'llvm -device=arm_cpu -model=bcm2837 -mtriple=armv7l-linux-gnueabihf -mattr=+neon'
+    device = "llvm -device=arm_cpu -model=bcm2837 -mtriple=armv7l-linux-gnueabihf -mattr=+neon"
     with tvm.target.Target(device):
-        A = te.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
-        W = te.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
-        B = topi.arm_cpu.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
-                                               'uint8', out_dtype, unipolar)
+        A = te.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name="A")
+        W = te.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name="W")
+        B = topi.arm_cpu.bitserial_conv2d_nhwc(
+            A, W, stride, padding, activation_bits, weight_bits, "uint8", out_dtype, unipolar
+        )
         s = topi.arm_cpu.schedule_bitserial_conv2d_nhwc([B])
 
     func = tvm.build(s, [A, W, B], device)
 
-    assembly = func.get_source('asm')
+    assembly = func.get_source("asm")
     matches = re.findall("vpadal", assembly)
-    assert (len(matches) > 0)
+    assert len(matches) > 0
     matches = re.findall("vcnt", assembly)
-    assert (len(matches) > 0)
+    assert len(matches) > 0
     matches = re.findall("vpadd", assembly)
-    assert (len(matches) > 0)
+    assert len(matches) > 0
 
     ctx = tvm.context(device, 0)
-    if 'arm' not in os.uname()[4]:
-        print ("Skipped running code, not an arm device")
+    if "arm" not in os.uname()[4]:
+        print("Skipped running code, not an arm device")
         return
 
     print("Running on target: %s" % device)
@@ -66,12 +79,15 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
         w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
         if unipolar:
             w_ = np.copy(w_np).astype(out_dtype)
-            for x in np.nditer(w_, op_flags=['readwrite']):
+            for x in np.nditer(w_, op_flags=["readwrite"]):
                 x[...] = 1 if x == 1 else -1
             b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
         else:
-            b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
+            b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(
+                out_dtype
+            )
         return a_np, w_np, b_np
+
     a_np, w_np, b_np = get_ref_data()
     a = tvm.nd.array(a_np, ctx)
     w = tvm.nd.array(w_np, ctx)
@@ -81,6 +97,7 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
     func(a, w, b)
     np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
+
 def test_bitserial_conv2d():
     in_size = 56
     ic, oc = 64, 64
@@ -94,6 +111,6 @@ def test_bitserial_conv2d():
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
 
+
 if __name__ == "__main__":
     test_bitserial_conv2d()
-
index 19a4d94..db6ad5b 100644 (file)
@@ -30,20 +30,22 @@ _bitserial_dense_implement = {
     "arm_cpu": (topi.arm_cpu.bitserial_dense, topi.arm_cpu.schedule_bitserial_dense),
 }
 
+
 def generate_quantized_np(shape, bits, out_dtype):
     min_val = 0
     max_val = 1 << bits
     return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
 
+
 def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
-    out_dtype = 'int16'
+    out_dtype = "int16"
 
     def get_ref_data(a_shape, b_shape, input_dtype):
         a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
         b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
         if unipolar:
             b_ = np.copy(b_np).astype(out_dtype)
-            for x in np.nditer(b_, op_flags=['readwrite']):
+            for x in np.nditer(b_, op_flags=["readwrite"]):
                 x[...] = 1 if x == 1 else -1
             c_np = np.dot(a_np, b_.T)
         else:
@@ -51,15 +53,14 @@ def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits,
         return a_np, b_np, c_np
 
     for target in ["llvm", "llvm -device=arm_cpu"]:
-        if "arm_cpu" in target and 'arm' not in os.uname()[4]:
-            print ("Skipped running code, not an arm device")
+        if "arm_cpu" in target and "arm" not in os.uname()[4]:
+            print("Skipped running code, not an arm device")
             continue
-        input_dtype = 'uint8' if "arm_cpu" in target else "uint32"
-        A = te.placeholder((batch, in_dim), dtype=input_dtype, name='A')
-        B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
+        input_dtype = "uint8" if "arm_cpu" in target else "uint32"
+        A = te.placeholder((batch, in_dim), dtype=input_dtype, name="A")
+        B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name="B")
         fcompute, fschedule = tvm.topi.testing.dispatch(target, _bitserial_dense_implement)
-        C = fcompute(A, B, activation_bits, weight_bits,
-                     input_dtype, out_dtype, unipolar)
+        C = fcompute(A, B, activation_bits, weight_bits, input_dtype, out_dtype, unipolar)
         s = fschedule([C])
 
         a_shape = get_const_tuple(A.shape)
@@ -74,6 +75,7 @@ def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits,
         func(a, b, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
+
 def test_bitserial_dense():
     verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
     verify_bitserial_dense(1, 1024, 1000, 2, 1, True)
@@ -81,5 +83,6 @@ def test_bitserial_dense():
     verify_bitserial_dense(1, 1024, 1000, 1, 1, False)
     verify_bitserial_dense(1, 1024, 1000, 2, 1, False)
 
+
 if __name__ == "__main__":
     test_bitserial_dense()
index 95d3e04..753fb17 100644 (file)
@@ -24,8 +24,8 @@ from tvm.contrib.pickle_memoize import memoize
 
 
 def verify_binary_dense(batch, in_dim, out_dim):
-    A = te.placeholder((batch, in_dim), name='A')
-    B = te.placeholder((out_dim, in_dim), name='B')
+    A = te.placeholder((batch, in_dim), name="A")
+    B = te.placeholder((out_dim, in_dim), name="B")
     bnn_A = topi.nn.binarize_pack(A)
     bnn_B = topi.nn.binarize_pack(B)
     # binary dense
@@ -33,12 +33,13 @@ def verify_binary_dense(batch, in_dim, out_dim):
     bnn_B1 = te.placeholder(bnn_B.shape, dtype=bnn_B.dtype)
     bnn_C = topi.nn.binary_dense(bnn_A1, bnn_B1)
     # schedule
-    with tvm.target.Target('llvm'):
+    with tvm.target.Target("llvm"):
         s1 = topi.x86.schedule_binarize_pack(bnn_A)
         s2 = topi.x86.schedule_binarize_pack(bnn_B)
         s3 = topi.x86.schedule_binary_dense(bnn_C)
 
     dtype = A.dtype
+
     @memoize("topi.tests.test_topi_binary_dense")
     def get_ref_data():
         # generate random matrix of +1 or -1 value
@@ -55,14 +56,15 @@ def verify_binary_dense(batch, in_dim, out_dim):
     bnn_a = tvm.nd.array(np.zeros(get_const_tuple(bnn_A.shape), dtype=bnn_A.dtype), ctx)
     bnn_b = tvm.nd.array(np.zeros(get_const_tuple(bnn_B.shape), dtype=bnn_B.dtype), ctx)
     bnn_c = tvm.nd.array(np.zeros(get_const_tuple(bnn_C.shape), dtype=bnn_C.dtype), ctx)
-    f1 = tvm.build(s1, [A, bnn_A], 'llvm')
-    f2 = tvm.build(s2, [B, bnn_B], 'llvm')
-    f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], 'llvm')
+    f1 = tvm.build(s1, [A, bnn_A], "llvm")
+    f2 = tvm.build(s2, [B, bnn_B], "llvm")
+    f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], "llvm")
     f1(a, bnn_a)
     f2(b, bnn_b)
     f3(bnn_a, bnn_b, bnn_c)
     tvm.testing.assert_allclose(bnn_c.asnumpy(), c_np, rtol=1e-5)
 
+
 def test_binary_dense():
     verify_binary_dense(1, 4096, 1024)
     verify_binary_dense(1, 1024, 1000)
index b41f7f7..6f3e91b 100644 (file)
@@ -48,30 +48,41 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
     check_device("sdaccel")
 
 
-def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
-                                ftopi, fnumpy,
-                                lhs_min=-100, lhs_max=100,
-                                rhs_min=-100, rhs_max=100,
-                                dtype="float32"):
+def verify_broadcast_binary_ele(
+    lhs_shape,
+    rhs_shape,
+    ftopi,
+    fnumpy,
+    lhs_min=-100,
+    lhs_max=100,
+    rhs_min=-100,
+    rhs_max=100,
+    dtype="float32",
+):
     # Build the logic and compile the function
-    A = (te.var("A", dtype=dtype) if lhs_shape is None
-         else te.placeholder(shape=lhs_shape, name="A", dtype=dtype))
-    B = (te.var("B", dtype=dtype) if rhs_shape is None
-         else te.placeholder(shape=rhs_shape, name="B", dtype=dtype))
+    A = (
+        te.var("A", dtype=dtype)
+        if lhs_shape is None
+        else te.placeholder(shape=lhs_shape, name="A", dtype=dtype)
+    )
+    B = (
+        te.var("B", dtype=dtype)
+        if rhs_shape is None
+        else te.placeholder(shape=rhs_shape, name="B", dtype=dtype)
+    )
     C = ftopi(A, B)
     if isinstance(A, tvm.tir.PrimExpr) and isinstance(B, tvm.tir.PrimExpr):
-        assert(isinstance(C, tvm.tir.PrimExpr))
+        assert isinstance(C, tvm.tir.PrimExpr)
         return
 
     def gen_operand(shape, low, high, ctx):
         if shape is None:
             npy = float(np.random.uniform(low=low, high=high))
-            if dtype.startswith('int'):
+            if dtype.startswith("int"):
                 npy = int(npy)
             nd = npy
         else:
-            npy = np.random.uniform(low=low, high=high,
-                                    size=shape).astype(dtype)
+            npy = np.random.uniform(low=low, high=high, size=shape).astype(dtype)
             nd = tvm.nd.array(npy, ctx)
         return npy, nd
 
@@ -91,7 +102,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
 
         out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
         foo(lhs_nd, rhs_nd, out_nd)
-        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
+        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1e-4, atol=1e-4)
 
     for target, ctx in tvm.testing.enabled_targets():
         check_device(target)
@@ -108,77 +119,88 @@ def test_broadcast_to():
 
 @tvm.testing.uses_gpu
 def test_add():
-    verify_broadcast_binary_ele(
-        (), (), topi.add, np.add)
-    verify_broadcast_binary_ele(
-        (5, 2, 3), (2, 1), topi.add, np.add)
+    verify_broadcast_binary_ele((), (), topi.add, np.add)
+    verify_broadcast_binary_ele((5, 2, 3), (2, 1), topi.add, np.add)
 
 
 @tvm.testing.uses_gpu
 def test_subtract():
-    verify_broadcast_binary_ele(
-        (5, 2, 3), (), topi.subtract, np.subtract)
-    verify_broadcast_binary_ele(
-        (5, 2, 3), None, topi.subtract, np.subtract)
-    verify_broadcast_binary_ele(
-        None, None, topi.subtract, np.subtract)
-    verify_broadcast_binary_ele(
-        (1, 32), (64, 32), topi.subtract, np.subtract)
+    verify_broadcast_binary_ele((5, 2, 3), (), topi.subtract, np.subtract)
+    verify_broadcast_binary_ele((5, 2, 3), None, topi.subtract, np.subtract)
+    verify_broadcast_binary_ele(None, None, topi.subtract, np.subtract)
+    verify_broadcast_binary_ele((1, 32), (64, 32), topi.subtract, np.subtract)
 
 
 @tvm.testing.uses_gpu
 def test_multiply():
-    verify_broadcast_binary_ele(
-        (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
+    verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
 
 
 @tvm.testing.uses_gpu
 def test_divide():
-    verify_broadcast_binary_ele(
-        None, (10,), topi.divide, np.divide, rhs_min=0.0001)
-    verify_broadcast_binary_ele(
-        (), None, topi.divide, np.divide, rhs_min=0.0001)
-    verify_broadcast_binary_ele(
-        (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
+    verify_broadcast_binary_ele(None, (10,), topi.divide, np.divide, rhs_min=0.0001)
+    verify_broadcast_binary_ele((), None, topi.divide, np.divide, rhs_min=0.0001)
+    verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
+
 
 @tvm.testing.uses_gpu
 def test_floor_divide():
-    def _canonical_floor_div(a,b):
+    def _canonical_floor_div(a, b):
         return np.floor(a / b)
+
     verify_broadcast_binary_ele(
-        None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
-    verify_broadcast_binary_ele(
-        (), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
+        None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001
+    )
+    verify_broadcast_binary_ele((), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
     verify_broadcast_binary_ele(
-        (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
+        (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001
+    )
+
 
 @tvm.testing.uses_gpu
 def test_maximum_minmum():
-    verify_broadcast_binary_ele(
-        (32,), (64, 32), topi.maximum, np.maximum)
-    verify_broadcast_binary_ele(
-        (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
+    verify_broadcast_binary_ele((32,), (64, 32), topi.maximum, np.maximum)
+    verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
 
 
 @tvm.testing.uses_gpu
 def test_power():
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
+        (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2
+    )
 
 
 @tvm.testing.uses_gpu
 def test_mod():
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
+        (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32"
+    )
+
 
 @tvm.testing.uses_gpu
 def test_floor_mod():
-    def _canonical_floor_mod(a,b):
+    def _canonical_floor_mod(a, b):
         return a - np.floor(a / b) * b
+
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
+        (1, 2, 2),
+        (2,),
+        topi.floor_mod,
+        _canonical_floor_mod,
+        lhs_min=0.001,
+        rhs_min=1,
+        dtype="int32",
+    )
     verify_broadcast_binary_ele(
-        (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")
+        (3, 4, 5),
+        (3, 4, 5),
+        topi.floor_mod,
+        _canonical_floor_mod,
+        lhs_min=0.001,
+        rhs_min=1,
+        dtype="float32",
+    )
+
 
 @tvm.testing.uses_gpu
 def test_cmp():
@@ -200,54 +222,85 @@ def test_cmp():
 
     def less_equal(x, y):
         return topi.less_equal(x, y).astype("int8")
-    verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), greater, np.greater)
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 3, 1), less, np.less)
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 3, 1), equal, np.equal,
-        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
-        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
-    verify_broadcast_binary_ele(
-        (7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
-        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
-    verify_broadcast_binary_ele(
-        (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
-        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
+
+    verify_broadcast_binary_ele((1, 2, 2), (2,), greater, np.greater)
+    verify_broadcast_binary_ele((2, 1, 2), (2, 3, 1), less, np.less)
+    verify_broadcast_binary_ele(
+        (2, 1, 2),
+        (2, 3, 1),
+        equal,
+        np.equal,
+        lhs_min=-2,
+        lhs_max=2,
+        rhs_min=-2,
+        rhs_max=2,
+        dtype="int32",
+    )
+    verify_broadcast_binary_ele(
+        (2, 1, 2),
+        (2, 3, 1),
+        not_equal,
+        np.not_equal,
+        lhs_min=-2,
+        lhs_max=2,
+        rhs_min=-2,
+        rhs_max=2,
+        dtype="int32",
+    )
+    verify_broadcast_binary_ele(
+        (7, 1, 5),
+        (7, 3, 1),
+        greater_equal,
+        np.greater_equal,
+        lhs_min=-3,
+        lhs_max=3,
+        rhs_min=-3,
+        rhs_max=3,
+        dtype="int32",
+    )
+    verify_broadcast_binary_ele(
+        (7, 1, 5),
+        (7, 3, 1),
+        less_equal,
+        np.less_equal,
+        lhs_min=-3,
+        lhs_max=3,
+        rhs_min=-3,
+        rhs_max=3,
+        dtype="int32",
+    )
 
 
 @tvm.testing.uses_gpu
 def test_shift():
     # explicit specify the output type
     verify_broadcast_binary_ele(
-        (2, 1, 2), None, topi.right_shift, np.right_shift,
-        dtype="int32", rhs_min=0, rhs_max=32)
+        (2, 1, 2), None, topi.right_shift, np.right_shift, dtype="int32", rhs_min=0, rhs_max=32
+    )
 
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
-        dtype="int32", rhs_min=0, rhs_max=32)
+        (1, 2, 2), (2,), topi.left_shift, np.left_shift, dtype="int32", rhs_min=0, rhs_max=32
+    )
 
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
-        dtype="int8", rhs_min=0, rhs_max=32)
+        (1, 2, 2), (2,), topi.left_shift, np.left_shift, dtype="int8", rhs_min=0, rhs_max=32
+    )
 
 
 @tvm.testing.uses_gpu
 def test_logical_single_ele():
     def test_apply(
-            func,
-            name,
-            f_numpy,
-            indata,
-            dtype="bool",
+        func,
+        name,
+        f_numpy,
+        indata,
+        dtype="bool",
     ):
         # Build the logic and compile the function
         A = te.placeholder(shape=indata.shape, name="A", dtype=dtype)
         B = func(A)
         if isinstance(A, tvm.tir.PrimExpr):
-            assert (isinstance(B, tvm.tir.PrimExpr))
+            assert isinstance(B, tvm.tir.PrimExpr)
             return
 
         def check_device(device, ctx):
@@ -274,18 +327,18 @@ def test_logical_single_ele():
 @tvm.testing.uses_gpu
 def test_bitwise_not():
     def test_apply(
-            func,
-            name,
-            f_numpy,
-            shape,
-            dtype="int32",
+        func,
+        name,
+        f_numpy,
+        shape,
+        dtype="int32",
     ):
         # Build the logic and compile the function
         A = te.placeholder(shape=shape, name="A", dtype=dtype)
         B = func(A)
 
         if isinstance(A, tvm.tir.PrimExpr):
-            assert (isinstance(B, tvm.tir.PrimExpr))
+            assert isinstance(B, tvm.tir.PrimExpr)
             return
 
         def check_device(device, ctx):
@@ -312,19 +365,19 @@ def test_bitwise_not():
 @tvm.testing.uses_gpu
 def test_logical_binary_ele():
     def test_apply(
-            func,
-            name,
-            f_numpy,
-            lhs,
-            rhs,
-            dtype="bool",
+        func,
+        name,
+        f_numpy,
+        lhs,
+        rhs,
+        dtype="bool",
     ):
         # Build the logic and compile the function
-        A = (te.var("A", dtype=dtype))
-        B = (te.var("B", dtype=dtype))
+        A = te.var("A", dtype=dtype)
+        B = te.var("B", dtype=dtype)
         C = func(A, B)
         if isinstance(A, tvm.tir.PrimExpr) and isinstance(B, tvm.tir.PrimExpr):
-            assert (isinstance(C, tvm.tir.PrimExpr))
+            assert isinstance(C, tvm.tir.PrimExpr)
             return
 
         def check_device(device, ctx):
@@ -339,7 +392,7 @@ def test_logical_binary_ele():
             out_npy = f_numpy(lhs, rhs)
             out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
             foo(lhs_nd, rhs_nd, out_nd)
-            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
+            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1e-4, atol=1e-4)
 
         for device, ctx in tvm.testing.enabled_targets():
             check_device(device, ctx)
@@ -354,32 +407,24 @@ def test_logical_binary_ele():
 
 @tvm.testing.uses_gpu
 def test_bitwise_and():
+    verify_broadcast_binary_ele(None, None, topi.bitwise_and, np.bitwise_and, dtype="int32")
     verify_broadcast_binary_ele(
-        None, None, topi.bitwise_and, np.bitwise_and,
-        dtype="int32")
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 1, 2), topi.bitwise_and, np.bitwise_and,
-        dtype="int32")
+        (2, 1, 2), (2, 1, 2), topi.bitwise_and, np.bitwise_and, dtype="int32"
+    )
 
 
 @tvm.testing.uses_gpu
 def test_bitwise_or():
-    verify_broadcast_binary_ele(
-        None, None, topi.bitwise_or, np.bitwise_or,
-        dtype="int32")
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 1, 2), topi.bitwise_or, np.bitwise_or,
-        dtype="int32")
+    verify_broadcast_binary_ele(None, None, topi.bitwise_or, np.bitwise_or, dtype="int32")
+    verify_broadcast_binary_ele((2, 1, 2), (2, 1, 2), topi.bitwise_or, np.bitwise_or, dtype="int32")
 
 
 @tvm.testing.uses_gpu
 def test_bitwise_xor():
+    verify_broadcast_binary_ele(None, None, topi.bitwise_xor, np.bitwise_xor, dtype="int32")
     verify_broadcast_binary_ele(
-        None, None, topi.bitwise_xor, np.bitwise_xor,
-        dtype="int32")
-    verify_broadcast_binary_ele(
-        (2, 1, 2), (2, 1, 2), topi.bitwise_xor, np.bitwise_xor,
-        dtype="int32")
+        (2, 1, 2), (2, 1, 2), topi.bitwise_xor, np.bitwise_xor, dtype="int32"
+    )
 
 
 if __name__ == "__main__":
index 8f018b5..bee31a5 100644 (file)
@@ -26,16 +26,17 @@ from tvm.contrib.pickle_memoize import memoize
 
 
 def verify_clip(N, a_min, a_max, dtype):
-    A = te.placeholder((N, N), dtype=dtype, name='A')
+    A = te.placeholder((N, N), dtype=dtype, name="A")
     B = topi.clip(A, a_min, a_max)
     s = te.create_schedule([B.op])
 
     # use memoize to pickle the test data for next time use
     @memoize("topi.tests.test_topi_clip")
     def get_ref_data():
-        a_np = np.random.uniform(a_min*2, a_max*2, size=(N, N)).astype(dtype)
+        a_np = np.random.uniform(a_min * 2, a_max * 2, size=(N, N)).astype(dtype)
         b_np = np.clip(a_np, a_min, a_max)
         return a_np, b_np
+
     a_np, b_np = get_ref_data()
 
     def check_device(device, ctx):
@@ -52,11 +53,12 @@ def verify_clip(N, a_min, a_max, dtype):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_clip():
-    verify_clip(1024, -127, 127, 'float32')
-    verify_clip(1024, -127, 127, 'int16')
-    verify_clip(1024, -127, 127, 'int8')
+    verify_clip(1024, -127, 127, "float32")
+    verify_clip(1024, -127, 127, "int16")
+    verify_clip(1024, -127, 127, "int8")
 
 
 if __name__ == "__main__":
index 77a37ff..3d98120 100644 (file)
@@ -28,46 +28,49 @@ from tvm.topi.util import get_const_tuple
 _conv1d_ncw_implement = {
     "generic": (topi.nn.conv1d_ncw, topi.generic.schedule_conv1d_ncw),
     "cpu": (topi.nn.conv1d_ncw, topi.x86.schedule_conv1d_ncw),
-    "gpu": (topi.cuda.conv1d_ncw, topi.cuda.schedule_conv1d_ncw)
+    "gpu": (topi.cuda.conv1d_ncw, topi.cuda.schedule_conv1d_ncw),
 }
 
 _conv1d_nwc_implement = {
     "generic": (topi.nn.conv1d_nwc, topi.generic.schedule_conv1d_nwc),
     "cpu": (topi.nn.conv1d_nwc, topi.x86.schedule_conv1d_nwc),
-    "gpu": (topi.cuda.conv1d_nwc, topi.cuda.schedule_conv1d_nwc)
+    "gpu": (topi.cuda.conv1d_nwc, topi.cuda.schedule_conv1d_nwc),
 }
 
-def verify_conv1d(batch,
-                  in_channels,
-                  in_width,
-                  filters,
-                  kernel_size=3,
-                  stride=1,
-                  dilation=1,
-                  padding='VALID',
-                  layout='NCW'):
-    if layout == 'NCW':
+
+def verify_conv1d(
+    batch,
+    in_channels,
+    in_width,
+    filters,
+    kernel_size=3,
+    stride=1,
+    dilation=1,
+    padding="VALID",
+    layout="NCW",
+):
+    if layout == "NCW":
         in_shape = [batch, in_channels, in_width]
         kernel_shape = [filters, in_channels, kernel_size]
     else:
         in_shape = [batch, in_width, in_channels]
         kernel_shape = [kernel_size, in_channels, filters]
 
-    dtype = 'float32'
-    A = te.placeholder(in_shape, name='A', dtype=dtype)
-    W = te.placeholder(kernel_shape, name='W', dtype=dtype)
+    dtype = "float32"
+    A = te.placeholder(in_shape, name="A", dtype=dtype)
+    W = te.placeholder(kernel_shape, name="W", dtype=dtype)
 
     def get_ref_data(layout):
         a_np = np.random.uniform(size=in_shape).astype(dtype)
         w_np = np.random.uniform(size=kernel_shape).astype(dtype)
-        if layout == 'NWC':
+        if layout == "NWC":
             np_in = np.transpose(a_np, [0, 2, 1])
             np_w = np.transpose(w_np, [2, 1, 0])
         else:
             np_in = a_np
             np_w = w_np
         b_np = tvm.topi.testing.conv1d_ncw_python(np_in, np_w, stride, padding, dilation)
-        if layout == 'NWC':
+        if layout == "NWC":
             b_np = np.transpose(b_np, [0, 2, 1])
         return a_np, w_np, b_np
 
@@ -79,7 +82,7 @@ def verify_conv1d(batch,
         else:
             fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv1d_nwc_implement)
         with tvm.target.Target(device):
-            B = fcompute(A, W, stride, padding, dilation, 'float32')
+            B = fcompute(A, W, stride, padding, dilation, "float32")
             s = fschedule([B])
 
         a = tvm.nd.array(a_np, ctx)
@@ -98,25 +101,24 @@ def verify_conv1d(batch,
 def test_conv1d():
     for layout in ["NCW", "NWC"]:
         # Most basic test case
-        verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'VALID', layout)
+        verify_conv1d(1, 1, 8, 1, 3, 1, 1, "VALID", layout)
         # With padding
-        verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'SAME', layout)
+        verify_conv1d(1, 1, 8, 1, 3, 1, 1, "SAME", layout)
         # Realistic dimensions
-        verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout)
+        verify_conv1d(1, 16, 32, 16, 3, 1, 1, "SAME", layout)
         # With stride
-        verify_conv1d(1, 16, 32, 16, 3, 2, 1, 'SAME', layout)
+        verify_conv1d(1, 16, 32, 16, 3, 2, 1, "SAME", layout)
         # With dilation
-        verify_conv1d(1, 16, 32, 16, 3, 1, 2, 'SAME', layout)
+        verify_conv1d(1, 16, 32, 16, 3, 1, 2, "SAME", layout)
         # Large batch size
-        verify_conv1d(8, 16, 32, 16, 3, 1, 1, 'SAME', layout)
+        verify_conv1d(8, 16, 32, 16, 3, 1, 1, "SAME", layout)
         # Other kernel sizes
-        verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout)
-        verify_conv1d(1, 16, 32, 16, 2, 1, 1, 'SAME', layout)
-        verify_conv1d(1, 16, 32, 16, 1, 1, 1, 'SAME', layout)
+        verify_conv1d(1, 16, 32, 16, 3, 1, 1, "SAME", layout)
+        verify_conv1d(1, 16, 32, 16, 2, 1, 1, "SAME", layout)
+        verify_conv1d(1, 16, 32, 16, 1, 1, 1, "SAME", layout)
         # Non-power-of-two shape
-        verify_conv1d(1, 17, 12, 21, 3, 1, 1, 'SAME', layout)
-        verify_conv1d(1, 5, 27, 18, 3, 1, 1, 'VALID', layout)
-
+        verify_conv1d(1, 17, 12, 21, 3, 1, 1, "SAME", layout)
+        verify_conv1d(1, 5, 27, 18, 3, 1, 1, "VALID", layout)
 
 
 if __name__ == "__main__":
index d575579..5b02d58 100644 (file)
@@ -27,13 +27,16 @@ import tvm.testing
 
 _conv1d_transpose_ncw_implement = {
     "generic": (topi.nn.conv1d_transpose_ncw, topi.generic.schedule_conv1d_transpose_ncw),
-    "gpu": (topi.cuda.conv1d_transpose_ncw, topi.cuda.schedule_conv1d_transpose_ncw)
+    "gpu": (topi.cuda.conv1d_transpose_ncw, topi.cuda.schedule_conv1d_transpose_ncw),
 }
 
-def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
+
+def verify_conv1d_transpose_ncw(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding
+):
     in_width = in_size
-    A = te.placeholder((batch, in_channel, in_width), name='A')
-    W = te.placeholder((in_channel, num_filter, kernel), name='W')
+    A = te.placeholder((batch, in_channel, in_width), name="A")
+    W = te.placeholder((in_channel, num_filter, kernel), name="W")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -43,7 +46,9 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel,
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = tvm.topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding)
+        b_np = tvm.topi.testing.conv1d_transpose_ncw_python(
+            a_np, w_np, stride, padding, output_padding
+        )
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -86,9 +91,10 @@ def test_conv1d_transpose_ncw():
     verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
     verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
     verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0,3), (0,))
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1,3), (0,))
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2,3), (0,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,))
+
 
 if __name__ == "__main__":
     test_conv1d_transpose_ncw()
index b1df358..a21790d 100644 (file)
@@ -27,34 +27,51 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.nn.util import get_pad_tuple
 from tvm.topi.util import get_const_tuple
 
+
 def _transform_data(data, bn):
     # NCHW -> NCHW[x]c
     batch_size, channel, height, width = data.shape
-    data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
+    data = np.reshape(data, (batch_size, channel // bn, bn, height, width))
     data = np.transpose(data, (0, 1, 3, 4, 2))
     return data
 
+
 def _transform_kernel(kernel, ic_bn, oc_bn):
     # OIHW -> OIHW[x]i[x]o
     out_channel, in_channel, kh, kw = kernel.shape
-    kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn, kh, kw))
+    kernel = np.reshape(kernel, (out_channel // oc_bn, oc_bn, in_channel // ic_bn, ic_bn, kh, kw))
     kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1))
     return kernel
 
+
 def _transform_bias(bias, bn):
     # [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn]
     num_filter, h, w = bias.shape
-    bias = np.reshape(bias, (num_filter//bn, bn, h, w))
+    bias = np.reshape(bias, (num_filter // bn, bn, h, w))
     bias = np.transpose(bias, (0, 2, 3, 1))
     return bias
 
-def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
-                        padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
+
+def verify_conv2d_NCHWc(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    dtype="float32",
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
     in_height = in_width = in_size
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
-          (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum)
+    )
 
     # for testing functionality,
     # we choose arbitrary block size that can divide the channel,
@@ -71,9 +88,12 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
             ic_block = bn
             break
 
-    A = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A')
-    W = te.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W')
-    bias = te.placeholder((num_filter//oc_block, 1, 1, oc_block), name='bias')
+    A = te.placeholder((batch, in_channel // ic_block, in_height, in_width, ic_block), name="A")
+    W = te.placeholder(
+        (num_filter // oc_block, in_channel // ic_block, kernel, kernel, ic_block, oc_block),
+        name="W",
+    )
+    bias = te.placeholder((num_filter // oc_block, 1, 1, oc_block), name="bias")
 
     @memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc")
     def get_ref_data():
@@ -86,8 +106,12 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
             c_np += b_np
         if add_relu:
             c_np = np.maximum(c_np, 0)
-        return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
-               _transform_bias(b_np, oc_block), _transform_data(c_np, oc_block)
+        return (
+            _transform_data(a_np, ic_block),
+            _transform_kernel(w_np, ic_block, oc_block),
+            _transform_bias(b_np, oc_block),
+            _transform_data(c_np, oc_block),
+        )
 
     a_np, w_np, b_np, c_np = get_ref_data()
 
@@ -98,11 +122,16 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
             return
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.x86.conv2d_NCHWc(A, W, (stride, stride), padding,
-                                      (dilation, dilation),
-                                      'NCHW%dc'%ic_block,
-                                      "NCHW%dc"%oc_block,
-                                      dtype)
+            C = topi.x86.conv2d_NCHWc(
+                A,
+                W,
+                (stride, stride),
+                padding,
+                (dilation, dilation),
+                "NCHW%dc" % ic_block,
+                "NCHW%dc" % oc_block,
+                dtype,
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -114,14 +143,22 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device,
-                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
-                                  (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device,
-                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
-                                  (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
 
@@ -133,18 +170,18 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
 
 def test_conv2d_NCHWc():
     # ResNet18 workloads
-    verify_conv2d_NCHWc(1,   3, 224,  64, 7, 2, 3)
-    verify_conv2d_NCHWc(1,  64,  56,  64, 3, 1, 1)
-    verify_conv2d_NCHWc(1,  64,  56,  64, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  64,  56, 128, 3, 2, 1)
-    verify_conv2d_NCHWc(1,  64,  56, 128, 1, 2, 0)
-    verify_conv2d_NCHWc(1, 128,  28, 128, 3, 1, 1)
-    verify_conv2d_NCHWc(1, 128,  28, 256, 3, 2, 1)
-    verify_conv2d_NCHWc(1, 128,  28, 256, 1, 2, 0)
-    verify_conv2d_NCHWc(1, 256,  14, 256, 3, 1, 1)
-    verify_conv2d_NCHWc(1, 256,  14, 512, 3, 2, 1)
-    verify_conv2d_NCHWc(1, 256,  14, 512, 1, 2, 0)
-    verify_conv2d_NCHWc(1, 512,   7, 512, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3)
+    verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 64, 56, 128, 3, 2, 1)
+    verify_conv2d_NCHWc(1, 64, 56, 128, 1, 2, 0)
+    verify_conv2d_NCHWc(1, 128, 28, 128, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 128, 28, 256, 3, 2, 1)
+    verify_conv2d_NCHWc(1, 128, 28, 256, 1, 2, 0)
+    verify_conv2d_NCHWc(1, 256, 14, 256, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 256, 14, 512, 3, 2, 1)
+    verify_conv2d_NCHWc(1, 256, 14, 512, 1, 2, 0)
+    verify_conv2d_NCHWc(1, 512, 7, 512, 3, 1, 1)
 
     # bias, relu
     verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_relu=True)
@@ -171,70 +208,69 @@ def test_conv2d_NCHWc():
     # verify_conv2d_NCHWc(2, 13, 71, 59, 3, 1, 1)
 
     # inception v3 workloads
-    verify_conv2d_NCHWc(1,    3, 299,  32, 3, 2, 0)
-    verify_conv2d_NCHWc(1,   32, 149,  32, 3, 1, 0)
-    verify_conv2d_NCHWc(1,   32, 147,  64, 3, 1, 1)
-    verify_conv2d_NCHWc(1,   64,  73,  80, 1, 1, 0)
-    verify_conv2d_NCHWc(1,   80,  73, 192, 3, 1, 0)
-    verify_conv2d_NCHWc(1,  192,  35,  64, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  192,  35,  48, 1, 1, 0)
-    verify_conv2d_NCHWc(1,   48,  35,  64, 5, 1, 2)
-    verify_conv2d_NCHWc(1,   64,  35,  96, 3, 1, 1)
-    verify_conv2d_NCHWc(1,   96,  35,  96, 3, 1, 1)
-    verify_conv2d_NCHWc(1,  192,  35,  32, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  256,  35,  64, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  256,  35,  48, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  288,  35,  64, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  288,  35,  48, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  288,  35, 384, 3, 2, 0)
-    verify_conv2d_NCHWc(1,   96,  35,  96, 3, 2, 0)
-    verify_conv2d_NCHWc(1,  768,  17, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  768,  17, 128, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  128,  17, 128, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  128,  17, 192, 7, 1, 3)
-    verify_conv2d_NCHWc(1,  128,  17, 128, 7, 1, 3)
-    verify_conv2d_NCHWc(1,  128,  17, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  768,  17, 160, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  160,  17, 160, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  160,  17, 192, 7, 1, 3)
-    verify_conv2d_NCHWc(1,  160,  17, 160, 7, 1, 3)
-    verify_conv2d_NCHWc(1,  160,  17, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  192,  17, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  192,  17, 192, 7, 1, 3)
-    verify_conv2d_NCHWc(1,  192,  17, 320, 3, 2, 0)
-    verify_conv2d_NCHWc(1,  192,  17, 192, 3, 2, 0)
-    verify_conv2d_NCHWc(1, 1280,   8, 320, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 1280,   8, 384, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  384,   8, 384, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  384,   8, 384, 3, 1, 1)
-    verify_conv2d_NCHWc(1, 1280,   8, 448, 1, 1, 0)
-    verify_conv2d_NCHWc(1,  448,   8, 384, 3, 1, 1)
-    verify_conv2d_NCHWc(1, 1280,   8, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 2048,   8, 320, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 2048,   8, 384, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 2048,   8, 448, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 2048,   8, 192, 1, 1, 0)
-    verify_conv2d_NCHWc(1, 1024,  19,  84, 3, 1, 1)
-    verify_conv2d_NCHWc(1, 2048,  10, 126, 3, 1, 1)
-    verify_conv2d_NCHWc(1,  512,   5, 126, 3, 1, 1)
-    verify_conv2d_NCHWc(1,  256,   3, 126, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 3, 299, 32, 3, 2, 0)
+    verify_conv2d_NCHWc(1, 32, 149, 32, 3, 1, 0)
+    verify_conv2d_NCHWc(1, 32, 147, 64, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 64, 73, 80, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 80, 73, 192, 3, 1, 0)
+    verify_conv2d_NCHWc(1, 192, 35, 64, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 192, 35, 48, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 48, 35, 64, 5, 1, 2)
+    verify_conv2d_NCHWc(1, 64, 35, 96, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 96, 35, 96, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 192, 35, 32, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 256, 35, 64, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 256, 35, 48, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 288, 35, 64, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 288, 35, 48, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 288, 35, 384, 3, 2, 0)
+    verify_conv2d_NCHWc(1, 96, 35, 96, 3, 2, 0)
+    verify_conv2d_NCHWc(1, 768, 17, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 768, 17, 128, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 128, 17, 128, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 128, 17, 192, 7, 1, 3)
+    verify_conv2d_NCHWc(1, 128, 17, 128, 7, 1, 3)
+    verify_conv2d_NCHWc(1, 128, 17, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 768, 17, 160, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 160, 17, 160, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 160, 17, 192, 7, 1, 3)
+    verify_conv2d_NCHWc(1, 160, 17, 160, 7, 1, 3)
+    verify_conv2d_NCHWc(1, 160, 17, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 192, 17, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 192, 17, 192, 7, 1, 3)
+    verify_conv2d_NCHWc(1, 192, 17, 320, 3, 2, 0)
+    verify_conv2d_NCHWc(1, 192, 17, 192, 3, 2, 0)
+    verify_conv2d_NCHWc(1, 1280, 8, 320, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 1280, 8, 384, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 384, 8, 384, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 384, 8, 384, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 1280, 8, 448, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 448, 8, 384, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 1280, 8, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 2048, 8, 320, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 2048, 8, 384, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 2048, 8, 448, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 2048, 8, 192, 1, 1, 0)
+    verify_conv2d_NCHWc(1, 1024, 19, 84, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
+    verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
 
     # Asymmetric padding
-    verify_conv2d_NCHWc(1,  32,   17,  64,  7, 2, (0, 0, 1, 1))
-    verify_conv2d_NCHWc(1,  32,   35, 128,  3, 1, (3, 3, 2, 2))
-    verify_conv2d_NCHWc(1,  32,   35,  32,  1, 1, (1, 2, 2, 1))
-    verify_conv2d_NCHWc(1,  32,   17, 192,  1, 1, (1, 2))
-    verify_conv2d_NCHWc(1,  32,    8,  32,  3, 1, (3, 1))
-    verify_conv2d_NCHWc(1, 128,    8, 384,  3, 1, (0, 2))
-    verify_conv2d_NCHWc(1,  32,    8,  32,  1, 1, "VALID")
-    verify_conv2d_NCHWc(1, 388,    8,  32,  3, 1, "VALID")
-    verify_conv2d_NCHWc(1, 512,   19,  32,  1, 1, "SAME")
-    verify_conv2d_NCHWc(1,  32,   10,  32,  2, 1, "SAME")
-    verify_conv2d_NCHWc(1,  32,    8,  32,  3, 1, (1, 2, 2, 1), add_relu=True)
-    verify_conv2d_NCHWc(1,  32,    8,  32,  5, 2, (1, 3), add_bias=True)
-    verify_conv2d_NCHWc(1,  32,    8,  32,  3, 1, "VALID", add_bias=True, add_relu=True)
-    verify_conv2d_NCHWc(1,  32,    8,  32, 24, 1, "SAME", add_bias=True, add_relu=True)
-
+    verify_conv2d_NCHWc(1, 32, 17, 64, 7, 2, (0, 0, 1, 1))
+    verify_conv2d_NCHWc(1, 32, 35, 128, 3, 1, (3, 3, 2, 2))
+    verify_conv2d_NCHWc(1, 32, 35, 32, 1, 1, (1, 2, 2, 1))
+    verify_conv2d_NCHWc(1, 32, 17, 192, 1, 1, (1, 2))
+    verify_conv2d_NCHWc(1, 32, 8, 32, 3, 1, (3, 1))
+    verify_conv2d_NCHWc(1, 128, 8, 384, 3, 1, (0, 2))
+    verify_conv2d_NCHWc(1, 32, 8, 32, 1, 1, "VALID")
+    verify_conv2d_NCHWc(1, 388, 8, 32, 3, 1, "VALID")
+    verify_conv2d_NCHWc(1, 512, 19, 32, 1, 1, "SAME")
+    verify_conv2d_NCHWc(1, 32, 10, 32, 2, 1, "SAME")
+    verify_conv2d_NCHWc(1, 32, 8, 32, 3, 1, (1, 2, 2, 1), add_relu=True)
+    verify_conv2d_NCHWc(1, 32, 8, 32, 5, 2, (1, 3), add_bias=True)
+    verify_conv2d_NCHWc(1, 32, 8, 32, 3, 1, "VALID", add_bias=True, add_relu=True)
+    verify_conv2d_NCHWc(1, 32, 8, 32, 24, 1, "SAME", add_bias=True, add_relu=True)
 
 
 if __name__ == "__main__":
index 71a83fc..a16499a 100644 (file)
@@ -32,12 +32,13 @@ _conv2d_hwcn_implement = {
     "opencl": (topi.cuda.conv2d_hwcn, topi.cuda.schedule_conv2d_hwcn),
 }
 
+
 def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
-    A = te.placeholder((in_height, in_width, in_channel, batch), name='A')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    B = te.placeholder((1, num_filter, 1), name='bias')
+    A = te.placeholder((in_height, in_width, in_channel, batch), name="A")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+    B = te.placeholder((1, num_filter, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -75,12 +76,9 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(b_np, ctx)
 
-        conv_out = tvm.nd.array(
-            np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
-        bias_out = tvm.nd.array(
-            np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
-        relu_out = tvm.nd.array(
-            np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
+        conv_out = tvm.nd.array(np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
+        bias_out = tvm.nd.array(np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
+        relu_out = tvm.nd.array(np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
         func1 = tvm.build(s1, [A, W, t_conv], device)
         func2 = tvm.build(s2, [A, W, B, t_bias], device)
         func3 = tvm.build(s3, [A, W, B, t_relu], device)
@@ -91,7 +89,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
         tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
         tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+    for device in ["cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]:
         check_device(device)
 
 
@@ -108,5 +106,6 @@ def test_conv2d_hwcn():
     # dilation = 2
     verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
 
+
 if __name__ == "__main__":
     test_conv2d_hwcn()
index f0eb2d2..81563ba 100644 (file)
@@ -31,39 +31,52 @@ _conv2d_hwnc_tensorcore_implement = {
     "cuda": (topi.cuda.conv2d_hwnc_tensorcore, topi.cuda.schedule_conv2d_hwnc_tensorcore)
 }
 
-def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride,
-                       padding, dilation=1, dtype='int4'):
+
+def verify_conv2d_hwnc(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, dtype="int4"
+):
     """Test the conv2d with tensorcore for hwnc layout"""
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
-        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
-    # choose dtype from int4, int8 
-    assert dtype in ['int4', 'int8']
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
+    # choose dtype from int4, int8
+    assert dtype in ["int4", "int8"]
 
     in_height = in_width = in_size
 
-    A = te.placeholder((in_height, in_width, batch, in_channel), name='A', dtype=dtype)
-    W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype)
+    A = te.placeholder((in_height, in_width, batch, in_channel), name="A", dtype=dtype)
+    W = te.placeholder((kernel, kernel, num_filter, in_channel), name="W", dtype=dtype)
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
+
     @memoize("topi.tests.test_topi_conv2d_hwnc.verify_conv2d_hwnc")
     def get_ref_data():
-        if  dtype == 'int4':
+        if dtype == "int4":
             a_np = np.random.randint(low=-8, high=7, size=a_shape).transpose((2, 0, 1, 3))
             w_np = np.random.randint(low=-8, high=7, size=w_shape)
-            dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation))
-        elif dtype == 'int8':
-            a_np = np.random.randint(low=-128, high=127, size=a_shape).transpose((2, 0, 1, 3)).astype(dtype)
+            dw_np = topi.testing.dilate_python(
+                w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)
+            )
+        elif dtype == "int8":
+            a_np = (
+                np.random.randint(low=-128, high=127, size=a_shape)
+                .transpose((2, 0, 1, 3))
+                .astype(dtype)
+            )
             w_np = np.random.randint(low=-128, high=127, size=w_shape).astype(dtype)
-            dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation))
+            dw_np = topi.testing.dilate_python(
+                w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)
+            )
 
         c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
         return a_np, w_np, c_np
-        
+
     def convert_int32_into_int4(a_int32):
-        """ convert int32 values into int4
+        """convert int32 values into int4
         Parameters
         ----------
         a_int32 : int
@@ -78,12 +91,14 @@ def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride,
             for j in range(J):
                 for k in range(K):
                     for l in range(L // 8):
-                        for m in range(min(8, L-l*8)):
-                            a_int4[i, j, k, l] = a_int4[i, j, k, l] | ((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4))
+                        for m in range(min(8, L - l * 8)):
+                            a_int4[i, j, k, l] = a_int4[i, j, k, l] | (
+                                (a_int32[i, j, k, l * 8 + m] & 0xF) << ((7 - m) * 4)
+                            )
         return a_int4
 
     a_np, w_np, c_np = get_ref_data()
-    if dtype == 'int4':
+    if dtype == "int4":
         a_np = convert_int32_into_int4(a_np)
         w_np = convert_int32_into_int4(w_np)
 
@@ -98,28 +113,33 @@ def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride,
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
             fcompute, fschedule = topi.testing.dispatch(device, _conv2d_hwnc_tensorcore_implement)
-            C = fcompute(A, W, stride, padding, dilation, dtype, 'int32')
+            C = fcompute(A, W, stride, padding, dilation, dtype, "int32")
             s = fschedule([C])
 
         a = tvm.nd.array(a_np.transpose((1, 2, 0, 3)), ctx)
         w = tvm.nd.array(w_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
 
-        func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-            batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+        func = tvm.build(
+            s,
+            [A, W, C],
+            device,
+            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+        )
         func(a, w, c)
 
         rtol = 1e-3
         tvm.testing.assert_allclose(c.asnumpy().transpose((2, 0, 1, 3)), c_np, rtol=rtol)
 
-    check_device('cuda')
+    check_device("cuda")
 
 
 @tvm.testing.requires_tensorcore
 def test_conv2d_hwnc_tensorcore():
     """Test the conv2d with tensorcore for hwnc layout"""
-    verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype='int8')
-    verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0, dtype='int4')
+    verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype="int8")
+    verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0, dtype="int4")
     verify_conv2d_hwnc(8, 64, 56, 128, 3, 2, 1)
     verify_conv2d_hwnc(8, 64, 56, 64, 1, 2, 0)
     verify_conv2d_hwnc(8, 128, 28, 128, 3, 1, 1)
@@ -130,5 +150,6 @@ def test_conv2d_hwnc_tensorcore():
     verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0)
     verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1)
 
+
 if __name__ == "__main__":
     test_conv2d_hwnc_tensorcore()
index 8082044..238517e 100644 (file)
@@ -31,18 +31,31 @@ from tvm.topi.arm_cpu.conv2d_gemm import is_aarch64_arm
 from common import Int8Fallback
 import tvm.testing
 
-def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
-                                 dilation=1, add_bias=False, add_relu=False):
+
+def compile_conv2d_NHWC_gemm_int8_arm(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
-                                                          kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
-    bias = te.placeholder((num_filter,), name='bias', dtype='int8')
-    dtype = 'int32'
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
+    bias = te.placeholder((num_filter,), name="bias", dtype="int8")
+    dtype = "int32"
     device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"
 
     ctx = tvm.context(device, 0)
@@ -53,8 +66,9 @@ def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, ke
     with tvm.target.Target(device):
         assert is_aarch64_arm(), "AArch64 target not recognized"
 
-        C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
-                                                       (dilation, dilation), dtype)
+        C = topi.arm_cpu.compute_conv2d_NHWC_quantized(
+            A, W, (stride, stride), padding, (dilation, dilation), dtype
+        )
         if add_bias:
             C = topi.add(C, bias)
         if add_relu:
@@ -62,47 +76,54 @@ def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, ke
         s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])
 
     if add_bias:
-        tvm.build(s, [A, W, bias, C], device,
-                  name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                         in_channel,
-                                                         in_size,
-                                                         num_filter,
-                                                         kernel,
-                                                         stride,
-                                                         padding_sum,
-                                                         dilation))
-        func = tvm.build(s, [A, W, bias, C], device,
-                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                                in_channel,
-                                                                in_size,
-                                                                num_filter,
-                                                                kernel,
-                                                                stride,
-                                                                padding_sum,
-                                                                dilation))
+        tvm.build(
+            s,
+            [A, W, bias, C],
+            device,
+            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+        )
+        func = tvm.build(
+            s,
+            [A, W, bias, C],
+            device,
+            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+        )
     else:
-        func = tvm.build(s, [A, W, C], device,
-                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                                in_channel,
-                                                                in_size,
-                                                                num_filter,
-                                                                kernel,
-                                                                stride,
-                                                                padding_sum,
-                                                                dilation))
-
-def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding,
-                                 dilation=1, add_bias=False, add_relu=False):
+        func = tvm.build(
+            s,
+            [A, W, C],
+            device,
+            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+        )
+
+
+def verify_conv2d_NHWC_gemm_int8(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
-                                                          kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
-    bias = te.placeholder((num_filter,), name='bias', dtype='int8')
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
+    bias = te.placeholder((num_filter,), name="bias", dtype="int8")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -134,8 +155,9 @@ def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel,
             return
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
-                                                           (dilation, dilation), dtype)
+            C = topi.arm_cpu.compute_conv2d_NHWC_quantized(
+                A, W, (stride, stride), padding, (dilation, dilation), dtype
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -147,52 +169,64 @@ def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            tvm.build(s, [A, W, bias, C], device,
-                      name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                             in_channel,
-                                                             in_size,
-                                                             num_filter,
-                                                             kernel,
-                                                             stride,
-                                                             padding_sum,
-                                                             dilation))
-            func = tvm.build(s, [A, W, bias, C], device,
-                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                                    in_channel,
-                                                                    in_size,
-                                                                    num_filter,
-                                                                    kernel,
-                                                                    stride,
-                                                                    padding_sum,
-                                                                    dilation))
+            tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device,
-                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
-                                                                    in_channel,
-                                                                    in_size,
-                                                                    num_filter,
-                                                                    kernel,
-                                                                    stride,
-                                                                    padding_sum,
-                                                                    dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
     check_device("llvm")
 
+
 oc_block_factor = 4
-def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
+
+
+def verify_conv2d_NCHWc_int8(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
-    bias = te.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
-                            dtype='int8')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype="int8")
+    bias = te.placeholder(
+        (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype="int8"
+    )
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -209,8 +243,9 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
 
         # convert to NCHWc
         _, _, out_height, out_width = c_np.shape
-        c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
-                out_height, out_width)).transpose(0, 1, 3, 4, 2)
+        c_np = c_np.reshape(
+            (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+        ).transpose(0, 1, 3, 4, 2)
 
         if add_bias:
             b_np = np.random.uniform(size=bias_shape).astype(dtype)
@@ -233,8 +268,9 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
 
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.cuda.conv2d_NCHWc_int8(A, W, (stride, stride), padding, (dilation, dilation),
-                                            'NCHW', dtype)
+            C = topi.cuda.conv2d_NCHWc_int8(
+                A, W, (stride, stride), padding, (dilation, dilation), "NCHW", dtype
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -246,11 +282,29 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
@@ -258,16 +312,30 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
         check_device(device)
 
 
-def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
+def verify_conv2d_nchw_int8(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
-    bias = te.placeholder((num_filter, 1, 1), name='bias', dtype='int8')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W", dtype="int8")
+    bias = te.placeholder((num_filter, 1, 1), name="bias", dtype="int8")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -303,8 +371,9 @@ def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stri
 
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.cuda.conv2d_nchw_int8(A, W, (stride, stride), padding, (dilation, dilation),
-                                           dtype)
+            C = topi.cuda.conv2d_nchw_int8(
+                A, W, (stride, stride), padding, (dilation, dilation), dtype
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -316,11 +385,29 @@ def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stri
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
@@ -332,17 +419,17 @@ def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stri
 def test_conv2d_nchw():
     with Int8Fallback():
         # ResNet18 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(1,  64,  56,  64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1,  64,  56,  64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  64,  56, 128, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(1,  64,  56, 128, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(1, 128,  28, 128, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1, 128,  28, 256, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(1, 128,  28, 256, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(1, 256,  14, 256, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1, 256,  14, 512, 3, 2, 1)
-        verify_conv2d_NCHWc_int8(1, 256,  14, 512, 1, 2, 0)
-        verify_conv2d_NCHWc_int8(1, 512,   7, 512, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 2, 1)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 128, 1, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 128, 28, 128, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 128, 28, 256, 3, 2, 1)
+        verify_conv2d_NCHWc_int8(1, 128, 28, 256, 1, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 256, 14, 256, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 256, 14, 512, 3, 2, 1)
+        verify_conv2d_NCHWc_int8(1, 256, 14, 512, 1, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 512, 7, 512, 3, 1, 1)
 
         # bias, relu
         verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
@@ -360,120 +447,121 @@ def test_conv2d_nchw():
         verify_conv2d_NCHWc_int8(4, 4, 4, 4, 4, 4, 4)
 
         # inception v3 workloads where channels in / out are multiple of oc_block_factor
-        verify_conv2d_NCHWc_int8(1,   32, 149,  32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(1,   32, 147,  64, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1,   64,  73,  80, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,   80,  73, 192, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  192,  35,  64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  192,  35,  48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,   48,  35,  64, 5, 1, 2)
-        verify_conv2d_NCHWc_int8(1,   64,  35,  96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1,   96,  35,  96, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1,  192,  35,  32, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  256,  35,  64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  256,  35,  48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  288,  35,  64, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  288,  35,  48, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  288,  35, 384, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(1,   96,  35,  96, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(1,  768,  17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  768,  17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  128,  17, 128, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  128,  17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(1,  128,  17, 128, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(1,  128,  17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  768,  17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  160,  17, 160, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  160,  17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(1,  160,  17, 160, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(1,  160,  17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  192,  17, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  192,  17, 192, 7, 1, 3)
-        verify_conv2d_NCHWc_int8(1,  192,  17, 320, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(1,  192,  17, 192, 3, 2, 0)
-        verify_conv2d_NCHWc_int8(1, 1280,   8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 1280,   8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  384,   8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  384,   8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1, 1280,   8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1,  448,   8, 384, 3, 1, 1)
-        verify_conv2d_NCHWc_int8(1, 1280,   8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 2048,   8, 320, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 2048,   8, 384, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 2048,   8, 448, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 2048,   8, 192, 1, 1, 0)
-        verify_conv2d_NCHWc_int8(1, 1024,  19,  84, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 32, 149, 32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 32, 147, 64, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 64, 73, 80, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 80, 73, 192, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 192, 35, 64, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 192, 35, 48, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 48, 35, 64, 5, 1, 2)
+        verify_conv2d_NCHWc_int8(1, 64, 35, 96, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 192, 35, 32, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 256, 35, 64, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 256, 35, 48, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 288, 35, 64, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 288, 35, 48, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 288, 35, 384, 3, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 768, 17, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 768, 17, 128, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 128, 17, 128, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 128, 17, 192, 7, 1, 3)
+        verify_conv2d_NCHWc_int8(1, 128, 17, 128, 7, 1, 3)
+        verify_conv2d_NCHWc_int8(1, 128, 17, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 768, 17, 160, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 160, 17, 160, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 160, 17, 192, 7, 1, 3)
+        verify_conv2d_NCHWc_int8(1, 160, 17, 160, 7, 1, 3)
+        verify_conv2d_NCHWc_int8(1, 160, 17, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 192, 17, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 192, 17, 192, 7, 1, 3)
+        verify_conv2d_NCHWc_int8(1, 192, 17, 320, 3, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 192, 17, 192, 3, 2, 0)
+        verify_conv2d_NCHWc_int8(1, 1280, 8, 320, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 1280, 8, 384, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 384, 8, 384, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 384, 8, 384, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 1280, 8, 448, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 448, 8, 384, 3, 1, 1)
+        verify_conv2d_NCHWc_int8(1, 1280, 8, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 2048, 8, 320, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 2048, 8, 384, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 2048, 8, 448, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0)
+        verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1)
 
         # batch > 1
-        verify_conv2d_NCHWc_int8(7,   32, 149,  32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(8,   32, 149,  32, 3, 1, 0)
-        verify_conv2d_NCHWc_int8(32,  32, 149,  32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0)
 
         # Asymmetric padding
-        verify_conv2d_NCHWc_int8(1,  32,   35,  64,  7, 2, (0, 0, 1, 1))
-        verify_conv2d_NCHWc_int8(1,  64,    8, 128,  3, 1, (3, 3, 2, 2))
-        verify_conv2d_NCHWc_int8(1,  64,    8,  64,  1, 1, (1, 2, 2, 1))
-        verify_conv2d_NCHWc_int8(1,  64,   17, 192,  1, 1, (1, 2))
-        verify_conv2d_NCHWc_int8(1,  64,    8,  64,  3, 1, (3, 1))
-        verify_conv2d_NCHWc_int8(1, 128,    8, 384,  3, 1, (0, 2))
-        verify_conv2d_NCHWc_int8(1,  64,    8,  64,  1, 1, "VALID")
-        verify_conv2d_NCHWc_int8(1, 388,    8,  64,  3, 1, "VALID")
-        verify_conv2d_NCHWc_int8(1, 512,   19,  64,  1, 1, "SAME")
-        verify_conv2d_NCHWc_int8(1,  64,   16,  32,  2, 1, "SAME")
-        verify_conv2d_NCHWc_int8(1,  64,    8,  64,  3, 1, (1, 2, 2, 1), add_relu=True)
-        verify_conv2d_NCHWc_int8(1,  64,    8,  64,  5, 2, (1, 3), add_bias=True)
-        verify_conv2d_NCHWc_int8(1,  64,   56,  64,  3, 1, "VALID", add_bias=True, add_relu=True)
-        verify_conv2d_NCHWc_int8(1,  64,   56,  64, 24, 1, "SAME", add_bias=True, add_relu=True)
+        verify_conv2d_NCHWc_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
+        verify_conv2d_NCHWc_int8(1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
+        verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
+        verify_conv2d_NCHWc_int8(1, 64, 17, 192, 1, 1, (1, 2))
+        verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (3, 1))
+        verify_conv2d_NCHWc_int8(1, 128, 8, 384, 3, 1, (0, 2))
+        verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, "VALID")
+        verify_conv2d_NCHWc_int8(1, 388, 8, 64, 3, 1, "VALID")
+        verify_conv2d_NCHWc_int8(1, 512, 19, 64, 1, 1, "SAME")
+        verify_conv2d_NCHWc_int8(1, 64, 16, 32, 2, 1, "SAME")
+        verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
+        verify_conv2d_NCHWc_int8(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
+        verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
 
         # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
         # performing basic testing - one test for all different scenarios - batch, dilation etc..
-        verify_conv2d_nchw_int8(1,  64,  56,  64, 3, 1, 1)
+        verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1)
         verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
         verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2)
         verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1)
         verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4)
-        verify_conv2d_nchw_int8(1,   32, 149,  32, 3, 1, 0)
-        verify_conv2d_nchw_int8(7,   32, 149,  32, 3, 1, 0)
-        verify_conv2d_nchw_int8(1,  32,   35,  64,  7, 2, (0, 0, 1, 1))
+        verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0)
+        verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0)
+        verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
+
 
 def test_conv2d_nhwc():
     with Int8Fallback():
         # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
-        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, 'SAME', dilation=2)
-        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, 'VALID')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, 'SAME', dilation=2)
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, 'VALID')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, 'VALID')
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, 'SAME')
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True)
-        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True)
+        verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
+        verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
+        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
+        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
+        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
+        verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
 
         # Let's also verify that it compiles fine on AArch64 targets
-        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, 'SAME')
+        compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")
 
 
 if __name__ == "__main__":
index 033869f..fef46e3 100644 (file)
@@ -28,18 +28,33 @@ from tvm.topi.util import get_const_tuple
 
 import tvm.testing
 
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,\
-        use_cudnn=False):
+
+def verify_conv2d_nchw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    use_cudnn=False,
+):
 
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -75,7 +90,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
 
         with tvm.target.Target(device):
             if "cudnn" in device:
-                C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype)
+                C = fcompute(
+                    A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype
+                )
             else:
                 C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
@@ -90,10 +107,22 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
 
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
 
@@ -108,18 +137,18 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
 @tvm.testing.uses_gpu
 def test_conv2d_nchw():
     # ResNet18 workloads
-    verify_conv2d_nchw(1,   3, 224,  64, 7, 2, 3)
-    verify_conv2d_nchw(1,  64,  56,  64, 3, 1, 1)
-    verify_conv2d_nchw(1,  64,  56,  64, 1, 1, 0)
-    verify_conv2d_nchw(1,  64,  56, 128, 3, 2, 1)
-    verify_conv2d_nchw(1,  64,  56, 128, 1, 2, 0)
-    verify_conv2d_nchw(1, 128,  28, 128, 3, 1, 1)
-    verify_conv2d_nchw(1, 128,  28, 256, 3, 2, 1)
-    verify_conv2d_nchw(1, 128,  28, 256, 1, 2, 0)
-    verify_conv2d_nchw(1, 256,  14, 256, 3, 1, 1)
-    verify_conv2d_nchw(1, 256,  14, 512, 3, 2, 1)
-    verify_conv2d_nchw(1, 256,  14, 512, 1, 2, 0)
-    verify_conv2d_nchw(1, 512,   7, 512, 3, 1, 1)
+    verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
+    verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
+    verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
+    verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
+    verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
+    verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
+    verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
+    verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
+    verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
+    verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
+    verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
 
     # bias, relu
     verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True)
@@ -146,73 +175,73 @@ def test_conv2d_nchw():
     # verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
 
     # inception v3 workloads
-    verify_conv2d_nchw(1,    3, 299,  32, 3, 2, 0)
-    verify_conv2d_nchw(1,   32, 149,  32, 3, 1, 0)
-    verify_conv2d_nchw(1,   32, 147,  64, 3, 1, 1)
-    verify_conv2d_nchw(1,   64,  73,  80, 1, 1, 0)
-    verify_conv2d_nchw(1,   80,  73, 192, 3, 1, 0)
-    verify_conv2d_nchw(1,  192,  35,  64, 1, 1, 0)
-    verify_conv2d_nchw(1,  192,  35,  48, 1, 1, 0)
-    verify_conv2d_nchw(1,   48,  35,  64, 5, 1, 2)
-    verify_conv2d_nchw(1,   64,  35,  96, 3, 1, 1)
-    verify_conv2d_nchw(1,   96,  35,  96, 3, 1, 1)
-    verify_conv2d_nchw(1,  192,  35,  32, 1, 1, 0)
-    verify_conv2d_nchw(1,  256,  35,  64, 1, 1, 0)
-    verify_conv2d_nchw(1,  256,  35,  48, 1, 1, 0)
-    verify_conv2d_nchw(1,  288,  35,  64, 1, 1, 0)
-    verify_conv2d_nchw(1,  288,  35,  48, 1, 1, 0)
-    verify_conv2d_nchw(1,  288,  35, 384, 3, 2, 0)
-    verify_conv2d_nchw(1,   96,  35,  96, 3, 2, 0)
-    verify_conv2d_nchw(1,  768,  17, 192, 1, 1, 0)
-    verify_conv2d_nchw(1,  768,  17, 128, 1, 1, 0)
-    verify_conv2d_nchw(1,  128,  17, 128, 1, 1, 0)
-    verify_conv2d_nchw(1,  128,  17, 192, 7, 1, 3)
-    verify_conv2d_nchw(1,  128,  17, 128, 7, 1, 3)
-    verify_conv2d_nchw(1,  128,  17, 192, 1, 1, 0)
-    verify_conv2d_nchw(1,  768,  17, 160, 1, 1, 0)
+    verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0)
+    verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0)
+    verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1)
+    verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0)
+    verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0)
+    verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0)
+    verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0)
+    verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2)
+    verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1)
+    verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1)
+    verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0)
+    verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0)
+    verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0)
+    verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0)
+    verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0)
+    verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0)
+    verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0)
+    verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0)
+    verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0)
+    verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3)
+    verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3)
+    verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0)
     # disable these tests due to some bugs of llvm with nvptx
     # verify_conv2d_nchw(1,  160,  17, 160, 1, 1, 0)
-    verify_conv2d_nchw(1,  160,  17, 192, 7, 1, 3)
-    verify_conv2d_nchw(1,  160,  17, 160, 7, 1, 3)
-    verify_conv2d_nchw(1,  160,  17, 192, 1, 1, 0)
-    verify_conv2d_nchw(1,  192,  17, 192, 1, 1, 0)
-    verify_conv2d_nchw(1,  192,  17, 192, 7, 1, 3)
-    verify_conv2d_nchw(1,  192,  17, 320, 3, 2, 0)
-    verify_conv2d_nchw(1,  192,  17, 192, 3, 2, 0)
-    verify_conv2d_nchw(1, 1280,   8, 320, 1, 1, 0)
-    verify_conv2d_nchw(1, 1280,   8, 384, 1, 1, 0)
-    verify_conv2d_nchw(1,  384,   8, 384, 1, 1, 0)
-    verify_conv2d_nchw(1,  384,   8, 384, 3, 1, 1)
-    verify_conv2d_nchw(1, 1280,   8, 448, 1, 1, 0)
-    verify_conv2d_nchw(1,  448,   8, 384, 3, 1, 1)
-    verify_conv2d_nchw(1, 1280,   8, 192, 1, 1, 0)
-    verify_conv2d_nchw(1, 2048,   8, 320, 1, 1, 0)
-    verify_conv2d_nchw(1, 2048,   8, 384, 1, 1, 0)
-    verify_conv2d_nchw(1, 2048,   8, 448, 1, 1, 0)
-    verify_conv2d_nchw(1, 2048,   8, 192, 1, 1, 0)
-    verify_conv2d_nchw(1, 1024,  19,  84, 3, 1, 1)
-    verify_conv2d_nchw(1, 2048,  10, 126, 3, 1, 1)
-    verify_conv2d_nchw(1,  512,   5, 126, 3, 1, 1)
-    verify_conv2d_nchw(1,  256,   3, 126, 3, 1, 1)
+    verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3)
+    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3)
+    verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3)
+    verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0)
+    verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0)
+    verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0)
+    verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1)
+    verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0)
+    verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1)
+    verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1)
+    verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1)
+    verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1)
+    verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1)
 
     # Asymmetric padding
-    verify_conv2d_nchw(1,   3,   35,  64,  7, 2, (0, 0, 1, 1))
-    verify_conv2d_nchw(1,  64,    8, 128,  3, 1, (3, 3, 2, 2))
-    verify_conv2d_nchw(1,  64,    8,  64,  1, 1, (1, 2, 2, 1))
-    verify_conv2d_nchw(1,  64,   17, 192,  1, 1, (1, 2))
-    verify_conv2d_nchw(1,  64,    8,  64,  3, 1, (3, 1))
-    verify_conv2d_nchw(1, 128,    8, 384,  3, 1, (0, 2))
-    verify_conv2d_nchw(1,  64,   35,  64,  3, 1, (1, 2), use_cudnn=True)
-    verify_conv2d_nchw(1,  64,    8,  64,  1, 1, "VALID")
-    verify_conv2d_nchw(1, 388,    8,  64,  3, 1, "VALID")
-    verify_conv2d_nchw(1,  64,   10,  48,  3, 1, "VALID", use_cudnn=True)
-    verify_conv2d_nchw(1, 512,   19,  64,  1, 1, "SAME")
-    verify_conv2d_nchw(1,  64,    5,  32,  2, 1, "SAME")
-    verify_conv2d_nchw(1,  64,    8,  64,  3, 1, "SAME", use_cudnn=True)
-    verify_conv2d_nchw(1,  64,    8,  64,  3, 1, (1, 2, 2, 1), add_relu=True)
-    verify_conv2d_nchw(1,  64,    8,  64,  5, 2, (1, 3), add_bias=True)
-    verify_conv2d_nchw(1,  64,    8,  64,  3, 1, "VALID", add_bias=True, add_relu=True)
-    verify_conv2d_nchw(1,  64,    8,  64, 24, 1, "SAME", add_bias=True, add_relu=True)
+    verify_conv2d_nchw(1, 3, 35, 64, 7, 2, (0, 0, 1, 1))
+    verify_conv2d_nchw(1, 64, 8, 128, 3, 1, (3, 3, 2, 2))
+    verify_conv2d_nchw(1, 64, 8, 64, 1, 1, (1, 2, 2, 1))
+    verify_conv2d_nchw(1, 64, 17, 192, 1, 1, (1, 2))
+    verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (3, 1))
+    verify_conv2d_nchw(1, 128, 8, 384, 3, 1, (0, 2))
+    verify_conv2d_nchw(1, 64, 35, 64, 3, 1, (1, 2), use_cudnn=True)
+    verify_conv2d_nchw(1, 64, 8, 64, 1, 1, "VALID")
+    verify_conv2d_nchw(1, 388, 8, 64, 3, 1, "VALID")
+    verify_conv2d_nchw(1, 64, 10, 48, 3, 1, "VALID", use_cudnn=True)
+    verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME")
+    verify_conv2d_nchw(1, 64, 5, 32, 2, 1, "SAME")
+    verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True)
+    verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
+    verify_conv2d_nchw(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True)
+    verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
+    verify_conv2d_nchw(1, 64, 8, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
 
 
 if __name__ == "__main__":
index 7482d64..747bd4f 100644 (file)
@@ -30,17 +30,19 @@ _conv2d_nhwc_implement = {
     "llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
     "cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
     "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
-    "arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack,
-                topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
-    "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc)
+    "arm_cpu": (
+        topi.arm_cpu.conv2d_nhwc_spatial_pack,
+        topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
+    ),
+    "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
 }
 
 
 def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -53,6 +55,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
         dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
         b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
         return a_np, w_np, b_np
+
     a_np, w_np, b_np = get_ref_data()
 
     def check_device(device):
@@ -72,7 +75,7 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
         func(a, w, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm', 'cuda']:
+    for device in ["llvm", "cuda"]:
         check_device(device)
 
 
index dc9599c..f661737 100644 (file)
@@ -28,11 +28,13 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.util import get_const_tuple
 
 
-def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
+def verify_conv2d_1x1_nhwc_pack_int8(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1
+):
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="uint8")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -57,7 +59,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker
         print("Running on target: %s" % device)
 
         with tvm.target.Target(device):
-            B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32")
+            B = topi.nn.conv2d(A, W, stride, padding, dilation, layout="NHWC", out_dtype="int32")
             s = topi.x86.schedule_conv2d_nhwc_pack_int8([B])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
@@ -67,7 +69,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
     # for device in ['llvm -mcpu=skylake-avx512']:
-    for device in ['llvm']:
+    for device in ["llvm"]:
         check_device(device)
 
 
index 1223b0e..8d881b0 100644 (file)
@@ -34,19 +34,32 @@ _conv2d_nhwc_tensorcore_implement = {
 }
 
 
-def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
-                       padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'):
+def verify_conv2d_nhwc(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    devices="cuda",
+):
     """Test the conv2d with tensorcore for nhwc layout"""
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
-        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    bias = te.placeholder((1, 1, 1, num_filter), name='bias')
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+    bias = te.placeholder((1, 1, 1, num_filter), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -79,8 +92,10 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
             return
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement)
-            C = fcompute(A, W, stride, padding, dilation, 'float32')
+            fcompute, fschedule = tvm.topi.testing.dispatch(
+                device, _conv2d_nhwc_tensorcore_implement
+            )
+            C = fcompute(A, W, stride, padding, dilation, "float32")
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -92,12 +107,22 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
 
         rtol = 1e-3
index 0bb0e69..3ffa4ac 100644 (file)
@@ -30,30 +30,44 @@ import tvm.testing
 
 
 _conv2d_nhwc_winograd_tensorcore = {
-    "cuda": (topi.cuda.conv2d_nhwc_winograd_tensorcore,
-             topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore)
+    "cuda": (
+        topi.cuda.conv2d_nhwc_winograd_tensorcore,
+        topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore,
+    )
 }
 
 _conv2d_nhwc_winograd_direct = {
-    "cuda": (topi.cuda.conv2d_nhwc_winograd_direct,
-             topi.cuda.schedule_conv2d_nhwc_winograd_direct)
+    "cuda": (topi.cuda.conv2d_nhwc_winograd_direct, topi.cuda.schedule_conv2d_nhwc_winograd_direct)
 }
 
 
-def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
-                       padding, dilation=1, add_bias=False, add_relu=False,
-                       devices='cuda', bgemm="direct"):
+def verify_conv2d_nhwc(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    devices="cuda",
+    bgemm="direct",
+):
     """Test the conv2d with winograd for nhwc layout"""
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
-        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
-    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    bias = te.placeholder((1, 1, 1, num_filter), name='bias')
+    A = te.placeholder((batch, in_height, in_width, in_channel), name="A")
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+    bias = te.placeholder((1, 1, 1, num_filter), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -81,12 +95,14 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
             if bgemm == "direct":
-                fcompute, fschedule = tvm.topi.testing.dispatch(device,
-                                                            _conv2d_nhwc_winograd_direct)
+                fcompute, fschedule = tvm.topi.testing.dispatch(
+                    device, _conv2d_nhwc_winograd_direct
+                )
             elif bgemm == "tensorcore":
-                fcompute, fschedule = tvm.topi.testing.dispatch(device,
-                                                            _conv2d_nhwc_winograd_tensorcore)
-            C = fcompute(A, W, stride, padding, dilation, 'float32')
+                fcompute, fschedule = tvm.topi.testing.dispatch(
+                    device, _conv2d_nhwc_winograd_tensorcore
+                )
+            C = fcompute(A, W, stride, padding, dilation, "float32")
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -98,12 +114,22 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
 
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3)
@@ -117,34 +143,34 @@ def test_conv2d_nhwc_winograd_direct():
     """Test the conv2d with winograd for nhwc layout"""
     # resnet 18 workloads
     print("test_winograd_direct...")
-    verify_conv2d_nhwc(1,  64, 56,  64, 3, 1, 1, bgemm="direct")
+    verify_conv2d_nhwc(1, 64, 56, 64, 3, 1, 1, bgemm="direct")
     verify_conv2d_nhwc(1, 128, 28, 128, 3, 1, 1)
     verify_conv2d_nhwc(1, 256, 14, 256, 3, 1, 1)
-    verify_conv2d_nhwc(1, 512,  7, 512, 3, 1, 1)
-    verify_conv2d_nhwc(1,  48, 35,  64, 5, 1, 2)
+    verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1)
+    verify_conv2d_nhwc(1, 48, 35, 64, 5, 1, 2)
 
     # weird workloads
-    verify_conv2d_nhwc(1,  1,  1,  1, 3, 1, 1)
-    verify_conv2d_nhwc(3,  3,  3,  3, 3, 1, 1)
+    verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1)
+    verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1)
     verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1)
 
     # Asymmetric padding
-    verify_conv2d_nhwc(1, 512,  7, 512, 3, 1, "SAME")
-    verify_conv2d_nhwc(2,  48, 56,  48, 3, 1, (1, 1), add_relu=True)
-    verify_conv2d_nhwc(2,  48, 56,  48, 3, 1, "SAME", add_relu=True, add_bias=True)
-    verify_conv2d_nhwc(1, 48, 35,  48, 5, 1, "VALID")
+    verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, "SAME")
+    verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, (1, 1), add_relu=True)
+    verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True)
+    verify_conv2d_nhwc(1, 48, 35, 48, 5, 1, "VALID")
 
 
 @tvm.testing.requires_cuda
 @tvm.testing.requires_tensorcore
 def test_conv2d_nhwc_winograd_tensorcore():
     """Test the conv2d with winograd for nhwc layout"""
-    verify_conv2d_nhwc(8,  64, 56,  64, 3, 1, 1, bgemm="tensorcore")
+    verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore")
     verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore")
     verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore")
 
-    verify_conv2d_nhwc(2,  64, 56,  64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore")
-    verify_conv2d_nhwc(2,  64, 56,  64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore")
+    verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore")
+    verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore")
 
 
 if __name__ == "__main__":
index 89928f2..742892d 100644 (file)
@@ -34,14 +34,17 @@ _conv2d_transpose_nchw_implement = {
     "hls": (topi.nn.conv2d_transpose_nchw, topi.hls.schedule_conv2d_transpose_nchw),
 }
 
-def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
+
+def verify_conv2d_transpose_nchw(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding
+):
     in_height, in_width = in_size
     kernel_height, kernel_width = kernel
     stride_height, stride_width = stride
     pad_top, pad_left, pad_bottom, pad_right = padding
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = te.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+    W = te.placeholder((in_channel, num_filter, kernel_height, kernel_width), name="W")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -51,7 +54,9 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = tvm.topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding)
+        b_np = tvm.topi.testing.conv2d_transpose_nchw_python(
+            a_np, w_np, stride, padding, output_padding
+        )
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -60,11 +65,17 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_transpose_nchw_implement)
-            B = fcompute(A, W,
-                         [stride_height, stride_width],
-                         [pad_top, pad_left, pad_bottom, pad_right],
-                         A.dtype, output_padding)
+            fcompute, fschedule = tvm.topi.testing.dispatch(
+                device, _conv2d_transpose_nchw_implement
+            )
+            B = fcompute(
+                A,
+                W,
+                [stride_height, stride_width],
+                [pad_top, pad_left, pad_bottom, pad_right],
+                A.dtype,
+                output_padding,
+            )
             C = topi.nn.relu(B)
             s1 = fschedule([B])
             s2 = fschedule([C])
@@ -79,20 +90,21 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
         func2(a, w, c)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
 
 @tvm.testing.uses_gpu
 def test_conv2d_transpose_nchw():
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 1, 1), (1, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 0, 0), (0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 0, 0), (1, 1))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (1, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0), (1, 1))
     verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0))
     verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1), (0, 0))
     verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0))
index bb9fdee..69ef4f7 100644 (file)
@@ -36,17 +36,31 @@ _conv2d_nchw_winograd_implement = {
 }
 
 
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
-        devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
+def verify_conv2d_nchw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    devices=["cuda", "llvm -device=arm_cpu", "opencl -device=mali"],
+):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
     padding_sum = pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -89,16 +103,27 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
 
         rtol = 1e-3
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
 
-
     for device in devices:
         check_device(device)
 
@@ -106,16 +131,16 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
 @tvm.testing.uses_gpu
 def test_conv2d_nchw():
     # inception v3 workloads
-    verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda'])
-    verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=['cuda'])
-    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=['cuda'])
+    verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=["cuda"])
+    verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=["cuda"])
+    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=["cuda"])
 
     # resnet 18 workloads
     verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
     verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
     verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
     verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
-    verify_conv2d_nchw(1, 48,  35, 64, 5, 1, 2, devices=['cuda'])
+    verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2, devices=["cuda"])
 
     # batch size = 2
     verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
@@ -131,18 +156,18 @@ def test_conv2d_nchw():
     verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
 
     # Asymmetric padding
-    verify_conv2d_nchw(1,  48, 56,  48, 3, 1, (1, 1, 1, 1))
-    verify_conv2d_nchw(1,  64, 28,  64, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(1, 48, 56, 48, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(1, 64, 28, 64, 3, 1, (1, 1, 1, 1))
     verify_conv2d_nchw(1, 128, 14, 128, 3, 1, (1, 1))
-    verify_conv2d_nchw(1, 512,  7, 512, 3, 1, "SAME")
-    verify_conv2d_nchw(2, 13,  71,  59, 3, 1, (1, 1, 1, 1))
-    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1, 1, 1), add_bias=True)
-    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1), add_relu=True)
-    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, "SAME", add_relu=True, add_bias=True)
-    verify_conv2d_nchw(1,  64, 17, 192, 7, 1, (3, 1), devices=['cuda'])
-    verify_conv2d_nchw(1,  64, 17,  64, 7, 1, (3, 3, 2, 2), devices=['cuda'])
-    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=['cuda'])
-    verify_conv2d_nchw(1,  48, 35,  48, 5, 1, "VALID", devices=['cuda'])
+    verify_conv2d_nchw(1, 512, 7, 512, 3, 1, "SAME")
+    verify_conv2d_nchw(2, 13, 71, 59, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(2, 48, 56, 48, 3, 1, (1, 1, 1, 1), add_bias=True)
+    verify_conv2d_nchw(2, 48, 56, 48, 3, 1, (1, 1), add_relu=True)
+    verify_conv2d_nchw(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True)
+    verify_conv2d_nchw(1, 64, 17, 192, 7, 1, (3, 1), devices=["cuda"])
+    verify_conv2d_nchw(1, 64, 17, 64, 7, 1, (3, 3, 2, 2), devices=["cuda"])
+    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=["cuda"])
+    verify_conv2d_nchw(1, 48, 35, 48, 5, 1, "VALID", devices=["cuda"])
 
 
 if __name__ == "__main__":
index 73de19c..58f30fb 100644 (file)
@@ -33,17 +33,33 @@ _conv3d_ncdhw_implement = {
     "gpu": (topi.cuda.conv3d_ncdhw, topi.cuda.schedule_conv3d_ncdhw),
 }
 
-def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
-    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(padding, (kernel, kernel, kernel))
+
+def verify_conv3d_ncdhw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
+        padding, (kernel, kernel, kernel)
+    )
     padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride,
-          padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_depth = in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel, kernel, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -69,8 +85,9 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
         print("Running on target: %s" % device)
         fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ncdhw_implement)
         with tvm.target.Target(device):
-            C = fcompute(A, W, (stride, stride, stride), padding,
-                         (dilation, dilation, dilation), dtype)
+            C = fcompute(
+                A, W, (stride, stride, stride), padding, (dilation, dilation, dilation), dtype
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -82,10 +99,22 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
 
@@ -93,9 +122,10 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
         with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
             check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_conv3d_ncdhw():
-    #3DCNN  workloads
+    # 3DCNN  workloads
     verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0)
     verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, 0)
     verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 1)
@@ -125,5 +155,6 @@ def test_conv3d_ncdhw():
     verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
     verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")
 
+
 if __name__ == "__main__":
     test_conv3d_ncdhw()
index 82216c8..a9d54ea 100644 (file)
@@ -31,7 +31,10 @@ _conv3d_ndhwc_implement = {
     "gpu": (topi.cuda.conv3d_ndhwc, topi.cuda.schedule_conv3d_ndhwc),
 }
 
-def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
+
+def verify_conv3d_ndhwc(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1
+):
     if isinstance(in_size, tuple):
         in_depth, in_height, in_width = in_size
     else:
@@ -41,8 +44,10 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
     else:
         kernel_depth = kernel_height = kernel_width = kernel
 
-    A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
-    W = te.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W')
+    A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name="A")
+    W = te.placeholder(
+        (kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name="W"
+    )
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -55,6 +60,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, dilation, 1, 1))
         b_np = tvm.topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding)
         return a_np, w_np, b_np
+
     a_np, w_np, b_np = get_ref_data()
 
     def check_device(device, ctx):
@@ -88,10 +94,8 @@ def test_conv3d_ndhwc():
     verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2)
 
     verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32, (1, 3, 3), (1, 2, 2), "SAME")
-    verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32,
-                        (1, 6, 6), (1, 2, 2), (0, 2, 2))
-    verify_conv3d_ndhwc(1, 4, (20, 256, 256), 8,
-                        (1, 5, 5), (1, 2, 2), (0, 2, 2))
+    verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32, (1, 6, 6), (1, 2, 2), (0, 2, 2))
+    verify_conv3d_ndhwc(1, 4, (20, 256, 256), 8, (1, 5, 5), (1, 2, 2), (0, 2, 2))
 
 
 if __name__ == "__main__":
index 3a6d244..9a7d99a 100644 (file)
@@ -34,20 +34,34 @@ _conv3d_ndhwc_tensorcore_implement = {
 }
 
 
-def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
-                        padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'):
+def verify_conv3d_ndhwc(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    devices="cuda",
+):
     """Test the conv3d with tensorcore for ndhwc layout"""
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
-        padding, (kernel, kernel, kernel))
+        padding, (kernel, kernel, kernel)
+    )
     padding_sum = pad_front + pad_top + pad_left + pad_back + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
-        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
+    )
 
     in_depth = in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
-    W = te.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W')
-    bias = te.placeholder((1, 1, 1, 1, num_filter), name='bias')
+    A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name="A")
+    W = te.placeholder((kernel, kernel, kernel, in_channel, num_filter), name="W")
+    bias = te.placeholder((1, 1, 1, 1, num_filter), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -74,8 +88,10 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         ctx = tvm.context(device, 0)
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ndhwc_tensorcore_implement)
-            C = fcompute(A, W, stride, padding, dilation, 'float32')
+            fcompute, fschedule = tvm.topi.testing.dispatch(
+                device, _conv3d_ndhwc_tensorcore_implement
+            )
+            C = fcompute(A, W, stride, padding, dilation, "float32")
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -87,12 +103,22 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
-                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
+            )
             func(a, w, c)
 
         rtol = 1e-3
index 7a0121d..480ec19 100644 (file)
@@ -31,14 +31,19 @@ _conv3d_transpose_ncdhw_implement = {
     "gpu": (topi.cuda.conv3d_transpose_ncdhw, topi.cuda.schedule_conv3d_transpose_ncdhw),
 }
 
-def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
+
+def verify_conv3d_transpose_ncdhw(
+    batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding
+):
     in_depth, in_height, in_width = in_size
     kernel_depth, kernel_height, kernel_width = kernel
     stride_depth, stride_height, stride_width = stride
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = padding
 
-    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
-    W = te.placeholder((in_channel, num_filter, kernel_depth, kernel_height, kernel_width), name='W')
+    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
+    W = te.placeholder(
+        (in_channel, num_filter, kernel_depth, kernel_height, kernel_width), name="W"
+    )
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -48,7 +53,9 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding)
+        b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(
+            a_np, w_np, stride, padding, output_padding
+        )
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -57,11 +64,17 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_transpose_ncdhw_implement)
-            B = fcompute(A, W,
-                         [stride_depth, stride_height, stride_width],
-                         [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right],
-                         A.dtype, output_padding)
+            fcompute, fschedule = tvm.topi.testing.dispatch(
+                device, _conv3d_transpose_ncdhw_implement
+            )
+            B = fcompute(
+                A,
+                W,
+                [stride_depth, stride_height, stride_width],
+                [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right],
+                A.dtype,
+                output_padding,
+            )
             C = topi.nn.relu(B)
             s1 = fschedule([B])
             s2 = fschedule([C])
@@ -76,24 +89,50 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
         func2(a, w, c)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, atol=1e-4, rtol=1e-4)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, atol=1e-4, rtol=1e-4)
+
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
 
 @tvm.testing.uses_gpu
 def test_conv3d_transpose_ncdhw():
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1,  (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (2, 2, 2))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (1, 0, 2))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
-    verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1))
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (2, 2, 2)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (1, 0, 2)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
+    )
+    verify_conv3d_transpose_ncdhw(
+        1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1)
+    )
+
 
 if __name__ == "__main__":
     test_conv3d_transpose_ncdhw()
index e049aec..fbb2995 100644 (file)
@@ -33,28 +33,33 @@ _conv3d_ncdhw_implement = {
 }
 
 
-def verify_conv3d_ncdhw(batch,
-                        in_channel,
-                        in_size,
-                        num_filter,
-                        depth_kernel,
-                        space_kernel,
-                        stride,
-                        padding,
-                        dilation=1,
-                        add_bias=False,
-                        add_relu=False):
+def verify_conv3d_ncdhw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    depth_kernel,
+    space_kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+):
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
-        padding, (depth_kernel, space_kernel, space_kernel))
+        padding, (depth_kernel, space_kernel, space_kernel)
+    )
     padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
-          (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation)
+    )
 
     in_depth = in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -84,8 +89,9 @@ def verify_conv3d_ncdhw(batch,
         print("Running on target: %s" % device)
         fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ncdhw_implement)
         with tvm.target.Target(device):
-            C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation),
-                         dtype)
+            C = fcompute(
+                A, W, (stride, stride, stride), padding, (dilation, dilation, dilation), dtype
+            )
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
@@ -98,17 +104,39 @@ def verify_conv3d_ncdhw(batch,
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
             func = tvm.build(
-                s, [A, W, bias, C],
+                s,
+                [A, W, bias, C],
                 device,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
-                (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    space_kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
             func(a, w, b, c)
         else:
             func = tvm.build(
-                s, [A, W, C],
+                s,
+                [A, W, C],
                 device,
-                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
-                (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    space_kernel,
+                    stride,
+                    padding_sum,
+                    dilation,
+                ),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
 
@@ -120,7 +148,7 @@ def verify_conv3d_ncdhw(batch,
 @tvm.testing.requires_gpu
 def test_conv3d_ncdhw():
     # Try without depth transformation
-    #3DCNN  workloads
+    # 3DCNN  workloads
     verify_conv3d_ncdhw(1, 61, 20, 120, 3, 3, 1, 0)
     verify_conv3d_ncdhw(1, 61, 20, 120, 1, 3, 1, 0)
     verify_conv3d_ncdhw(1, 61, 20, 120, 5, 3, 1, 0)
index b792830..acff50d 100644 (file)
@@ -30,31 +30,47 @@ _correlation_implement = {
 }
 
 
-def verify_correlation_nchw(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size,
-                            is_multiply):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (data_shape[0], data_shape[1], data_shape[2], data_shape[3],
-                                                                  kernel_size, max_displacement, stride1, stride2, pad_size,
-                                                                  is_multiply))
+def verify_correlation_nchw(
+    data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (
+            data_shape[0],
+            data_shape[1],
+            data_shape[2],
+            data_shape[3],
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2,
+            pad_size,
+            is_multiply,
+        )
+    )
 
-    A = te.placeholder(data_shape, name='data1')
-    B = te.placeholder(data_shape, name='data2')
+    A = te.placeholder(data_shape, name="data1")
+    B = te.placeholder(data_shape, name="data2")
     dtype = A.dtype
 
     @memoize("topi.tests.test_topi_correlation_nchw.verify_correlation_nchw")
     def get_ref_data():
         a_np = np.random.uniform(size=data_shape).astype(dtype)
         b_np = np.random.uniform(size=data_shape).astype(dtype)
-        c_np = tvm.topi.testing.correlation_nchw_python(a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)
+        c_np = tvm.topi.testing.correlation_nchw_python(
+            a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply
+        )
         return a_np, b_np, c_np
 
     a_np, b_np, c_np = get_ref_data()
 
     def check_device(device, ctx):
         print("Running on target: %s" % device)
-        fcompute, fschedule = tvm.topi.testing.dispatch(
-            device, _correlation_implement)
+        fcompute, fschedule = tvm.topi.testing.dispatch(device, _correlation_implement)
         with tvm.target.Target(device):
-            C = fcompute(A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)
+            C = fcompute(
+                A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply
+            )
             s = fschedule([C])
 
             a = tvm.nd.array(a_np, ctx)
@@ -71,16 +87,51 @@ def verify_correlation_nchw(data_shape, kernel_size, max_displacement, stride1,
 
 @tvm.testing.uses_gpu
 def test_correlation_nchw():
-    verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=4,
-                        stride1=1, stride2=1, pad_size=4, is_multiply=True)
-    verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=5,
-                            stride1=1, stride2=1, pad_size=5, is_multiply=True)
-    verify_correlation_nchw((5, 1, 4, 4), kernel_size=3, max_displacement=1,
-                            stride1=2, stride2=1, pad_size=2, is_multiply=True)
-    verify_correlation_nchw((5, 1, 6, 4), kernel_size=3, max_displacement=1,
-                            stride1=2, stride2=2, pad_size=2, is_multiply=False)
-    verify_correlation_nchw((5, 1, 11, 11), kernel_size=5, max_displacement=1,
-                            stride1=1, stride2=1, pad_size=2, is_multiply=False)
+    verify_correlation_nchw(
+        (1, 3, 10, 10),
+        kernel_size=1,
+        max_displacement=4,
+        stride1=1,
+        stride2=1,
+        pad_size=4,
+        is_multiply=True,
+    )
+    verify_correlation_nchw(
+        (1, 3, 10, 10),
+        kernel_size=1,
+        max_displacement=5,
+        stride1=1,
+        stride2=1,
+        pad_size=5,
+        is_multiply=True,
+    )
+    verify_correlation_nchw(
+        (5, 1, 4, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=1,
+        pad_size=2,
+        is_multiply=True,
+    )
+    verify_correlation_nchw(
+        (5, 1, 6, 4),
+        kernel_size=3,
+        max_displacement=1,
+        stride1=2,
+        stride2=2,
+        pad_size=2,
+        is_multiply=False,
+    )
+    verify_correlation_nchw(
+        (5, 1, 11, 11),
+        kernel_size=5,
+        max_displacement=1,
+        stride1=1,
+        stride2=1,
+        pad_size=2,
+        is_multiply=False,
+    )
 
 
 if __name__ == "__main__":
index 3f3eca6..f57f421 100644 (file)
@@ -31,15 +31,42 @@ _deformable_conv2d_implement = {
     "cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw),
 }
 
-def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, deformable_groups=1, groups=1):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size,
-            num_filter, kernel, stride, padding, dilation, deformable_groups, groups))
 
-    A = te.placeholder((batch, in_channel, in_size, in_size), name='A')
+def verify_deformable_conv2d_nchw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    deformable_groups=1,
+    groups=1,
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            deformable_groups,
+            groups,
+        )
+    )
+
+    A = te.placeholder((batch, in_channel, in_size, in_size), name="A")
     out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1
-    Offset = te.placeholder((batch, deformable_groups * kernel * kernel * 2, out_size, out_size), name='offset')
-    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1), name='bias')
+    Offset = te.placeholder(
+        (batch, deformable_groups * kernel * kernel * 2, out_size, out_size), name="offset"
+    )
+    W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     offset_shape = get_const_tuple(Offset.shape)
@@ -53,8 +80,9 @@ def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel
         offset_np = np.random.randn(*offset_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
         b_np = np.random.uniform(size=bias_shape).astype(dtype)
-        c_np = tvm.topi.testing.deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding,
-                                                          dilation, deformable_groups, groups)
+        c_np = tvm.topi.testing.deformable_conv2d_nchw_python(
+            a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+        )
 
         return a_np, offset_np, w_np, c_np
 
@@ -68,8 +96,7 @@ def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel
         print("Running on target: %s" % device)
         fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_implement)
         with tvm.target.Target(device):
-            C = fcompute(A, Offset, W, stride, padding, dilation,
-                         deformable_groups, groups, dtype)
+            C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
             s = fschedule([C])
 
             a = tvm.nd.array(a_np, ctx)
@@ -81,7 +108,7 @@ def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel
             func(a, offset, w, c)
             tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
-    for device in ['llvm', 'cuda']:
+    for device in ["llvm", "cuda"]:
         check_device(device)
 
 
index 94e3670..f46a271 100644 (file)
@@ -28,20 +28,25 @@ import tvm.testing
 
 _dense_implement = {
     "generic": [(topi.nn.dense, topi.generic.schedule_dense)],
-    "cpu": [(topi.x86.dense_nopack, topi.x86.schedule_dense_nopack),
-            (topi.x86.dense_pack, topi.x86.schedule_dense_pack)],
-    "gpu": [(topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch),
-            (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch)],
+    "cpu": [
+        (topi.x86.dense_nopack, topi.x86.schedule_dense_nopack),
+        (topi.x86.dense_pack, topi.x86.schedule_dense_pack),
+    ],
+    "gpu": [
+        (topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch),
+        (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch),
+    ],
     "mali": [(topi.mali.dense, topi.mali.schedule_dense)],
     "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
     "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)],
     "hls": [(topi.nn.dense, topi.hls.schedule_dense)],
 }
 
+
 def verify_dense(batch, in_dim, out_dim, use_bias=True):
-    A = te.placeholder((batch, in_dim), name='A')
-    B = te.placeholder((out_dim, in_dim), name='B')
-    C = te.placeholder((out_dim,), name='C')
+    A = te.placeholder((batch, in_dim), name="A")
+    B = te.placeholder((out_dim, in_dim), name="B")
+    C = te.placeholder((out_dim,), name="C")
     dtype = A.dtype
 
     # use memoize to pickle the test data for next time use
@@ -55,6 +60,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
         else:
             d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
         return (a_np, b_np, c_np, d_np)
+
     # get the test data
     a_np, b_np, c_np, d_np = get_ref_data()
 
@@ -78,11 +84,11 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
 
 
 def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
-    dtype = 'int8'
-    out_dtype = 'int32'
-    A = te.placeholder((batch, in_dim), name='A', dtype=dtype)
-    B = te.placeholder((out_dim, in_dim), name='B', dtype=dtype)
-    C = te.placeholder((out_dim,), name='C', dtype=out_dtype)
+    dtype = "int8"
+    out_dtype = "int32"
+    A = te.placeholder((batch, in_dim), name="A", dtype=dtype)
+    B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype)
+    C = te.placeholder((out_dim,), name="C", dtype=out_dtype)
 
     # use memoize to pickle the test data for next time use
     @memoize("topi.tests.test_topi_dense_int8")
@@ -118,7 +124,7 @@ def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
         f(a, b, c, d)
         tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
 
-    for device in ['cuda']:
+    for device in ["cuda"]:
         check_device(device)
 
 
index 11dc407..07dab35 100644 (file)
@@ -26,15 +26,14 @@ from tvm.contrib.pickle_memoize import memoize
 import tvm.testing
 
 
-_dense_implement = {
-    "gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]
-}
+_dense_implement = {"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]}
+
 
 def verify_dense(batch, in_dim, out_dim, use_bias=True):
     """Dense tensorcore verify function"""
-    A = te.placeholder((batch, in_dim), name='A')
-    B = te.placeholder((out_dim, in_dim), name='B')
-    C = te.placeholder((out_dim,), name='C')
+    A = te.placeholder((batch, in_dim), name="A")
+    B = te.placeholder((out_dim, in_dim), name="B")
+    C = te.placeholder((out_dim,), name="C")
     dtype = A.dtype
 
     # use memoize to pickle the test data for next time use
@@ -48,6 +47,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
         else:
             d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
         return (a_np, b_np, c_np, d_np)
+
     # get the test data
     a_np, b_np, c_np, d_np = get_ref_data()
 
@@ -67,8 +67,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
             f(a, b, c, d)
             tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-3)
 
-
-    check_device('cuda')
+    check_device("cuda")
 
 
 @tvm.testing.requires_tensorcore
index 182f099..cb16f9b 100644 (file)
@@ -23,29 +23,31 @@ import tvm.testing
 import tvm.topi.testing
 
 
-def verify_depth_to_space(block_size, batch, in_channel, in_height, in_width, layout='NCHW', mode='DCR'):
+def verify_depth_to_space(
+    block_size, batch, in_channel, in_height, in_width, layout="NCHW", mode="DCR"
+):
     out_channel = int(in_channel / (block_size * block_size))
     out_height = int(in_height * block_size)
     out_width = int(in_width * block_size)
 
-    if layout == 'NCHW':
+    if layout == "NCHW":
         in_shape = [batch, in_channel, in_height, in_width]
         out_shape = [batch, out_channel, out_height, out_width]
-    elif layout == 'NHWC':
+    elif layout == "NHWC":
         in_shape = [batch, in_height, in_width, in_channel]
         out_shape = [batch, out_height, out_width, out_channel]
     else:
-        raise NotImplementedError('Layout not supported {}'.format(layout))
+        raise NotImplementedError("Layout not supported {}".format(layout))
 
-    A = te.placeholder(in_shape, name='A', dtype='float32')
+    A = te.placeholder(in_shape, name="A", dtype="float32")
     dtype = A.dtype
     a_np = np.random.uniform(size=in_shape).astype(dtype)
 
     B = topi.nn.depth_to_space(A, block_size=block_size, layout=layout, mode=mode)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         a_np = np.transpose(a_np, axes=[0, 3, 1, 2])
     b_np = tvm.topi.testing.depth_to_space_python(a_np, block_size, mode=mode)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         a_np = np.transpose(a_np, axes=[0, 2, 3, 1])
         b_np = np.transpose(b_np, axes=[0, 2, 3, 1])
 
@@ -65,8 +67,8 @@ def verify_depth_to_space(block_size, batch, in_channel, in_height, in_width, la
 
 @tvm.testing.uses_gpu
 def test_depth_to_space():
-    for layout in ['NCHW', 'NHWC']:
-        for mode in ['DCR', 'CDR']:
+    for layout in ["NCHW", "NHWC"]:
+        for mode in ["DCR", "CDR"]:
             # Simplest possible case
             verify_depth_to_space(2, 1, 4, 1, 1, layout=layout, mode=mode)
             # Average input size
index f9c0a1c..07ddeab 100644 (file)
@@ -28,24 +28,37 @@ import tvm.testing
 
 _depthwise_conv2d_nchw_implement = {
     "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)],
-    "arm_cpu": [(topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw),
-                (topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack,
-                 topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack)],
+    "arm_cpu": [
+        (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw),
+        (
+            topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack,
+            topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack,
+        ),
+    ],
     "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)],
     "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)],
     "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)],
-    "intel_graphics": [(topi.intel_graphics.depthwise_conv2d_nchw,
-                        topi.intel_graphics.schedule_depthwise_conv2d_nchw)],
+    "intel_graphics": [
+        (
+            topi.intel_graphics.depthwise_conv2d_nchw,
+            topi.intel_graphics.schedule_depthwise_conv2d_nchw,
+        )
+    ],
 }
 
 _depthwise_conv2d_nhwc_implement = {
     "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc),
-    "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
+    "arm_cpu": (
+        topi.arm_cpu.compute_depthwise_conv2d_nhwc,
+        topi.arm_cpu.schedule_depthwise_conv2d_nhwc,
+    ),
     "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc),
 }
 
 
-def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
+def depthwise_conv2d_with_workload_nchw(
+    batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1
+):
     in_width = in_height
     filter_channel = in_channel
     filter_width = filter_height
@@ -60,25 +73,30 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
         padding_args = padding
 
     # placeholder
-    Input = te.placeholder((batch, in_channel, in_height, in_width), name='Input')
-    Filter = te.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
-    Scale = te.placeholder((in_channel * channel_multiplier,), name='Scale')
-    Shift = te.placeholder((in_channel * channel_multiplier,), name='Shift')
+    Input = te.placeholder((batch, in_channel, in_height, in_width), name="Input")
+    Filter = te.placeholder(
+        (filter_channel, channel_multiplier, filter_height, filter_width), name="Filter"
+    )
+    Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale")
+    Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift")
 
-    dtype = 'float32'
+    dtype = "float32"
 
     def check_device(device, ctx):
         print("Running on target: %s" % device)
 
         impl_list = tvm.topi.testing.dispatch(device, _depthwise_conv2d_nchw_implement)[:]
         if device == "llvm" and channel_multiplier == 1 and dilation == 1:
-            impl_list.append((topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw))
+            impl_list.append(
+                (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw)
+            )
 
         for fcompute, fschedule in impl_list:
             with tvm.target.Target(device):
                 # declare
-                DepthwiseConv2d = fcompute(Input, Filter, (stride_h, stride_w),
-                                           padding_args, dilation, dtype)
+                DepthwiseConv2d = fcompute(
+                    Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype
+                )
                 ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
                 Relu = topi.nn.relu(ScaleShift)
                 # schedule
@@ -102,30 +120,56 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
             def get_ref_data():
                 input_np = np.random.uniform(size=input_shape).astype(dtype)
                 filter_np = np.random.uniform(size=filter_shape).astype(dtype)
-                dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
+                dilated_filter_np = tvm.topi.testing.dilate_python(
+                    filter_np, (1, 1, dilation, dilation)
+                )
                 scale_np = np.random.uniform(size=scale_shape).astype(dtype)
                 shift_np = np.random.uniform(size=shift_shape).astype(dtype)
                 # correctness with scipy
                 depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw(
-                    input_np, dilated_filter_np, stride, padding)
+                    input_np, dilated_filter_np, stride, padding
+                )
                 scale_shift_scipy = np.zeros(shape=scale_shift_shape)
                 for c in range(in_channel * channel_multiplier):
-                    scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
+                    scale_shift_scipy[:, c, :, :] = (
+                        depthwise_conv2d_scipy[:, c, :, :] * scale_np[c] + shift_np[c]
+                    )
                     relu_scipy = np.maximum(scale_shift_scipy, 0)
-                return (input_np, filter_np, scale_np, shift_np,
-                        depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
+                return (
+                    input_np,
+                    filter_np,
+                    scale_np,
+                    shift_np,
+                    depthwise_conv2d_scipy,
+                    scale_shift_scipy,
+                    relu_scipy,
+                )
 
             # Get the test data
-            (input_np, filter_np, scale_np, shift_np,
-             depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
+            (
+                input_np,
+                filter_np,
+                scale_np,
+                shift_np,
+                depthwise_conv2d_scipy,
+                scale_shift_scipy,
+                relu_scipy,
+            ) = get_ref_data()
 
             input_tvm = tvm.nd.array(input_np, ctx)
             filter_tvm = tvm.nd.array(filter_np, ctx)
             scale_tvm = tvm.nd.array(scale_np, ctx)
             shift_tvm = tvm.nd.array(shift_np, ctx)
-            depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
-            scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
-            relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
+            depthwise_conv2d_tvm = tvm.nd.array(
+                np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype),
+                ctx,
+            )
+            scale_shift_tvm = tvm.nd.array(
+                np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx
+            )
+            relu_tvm = tvm.nd.array(
+                np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx
+            )
             # launch kernel 1 (depthwise_conv2d)
             timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
             tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
@@ -135,7 +179,9 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
             # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
             timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
             tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
-            tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+            tvm.testing.assert_allclose(
+                depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5
+            )
             tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
             tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
@@ -144,7 +190,9 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
             check_device(device, ctx)
 
 
-def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
+def depthwise_conv2d_with_workload_nhwc(
+    batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1
+):
     in_width = in_height
     filter_channel = in_channel
     filter_width = filter_height
@@ -159,12 +207,14 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         padding_args = padding
 
     # placeholder
-    Input = te.placeholder((batch, in_height, in_width, in_channel), name='Input')
-    Filter = te.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
-    Scale = te.placeholder((in_channel * channel_multiplier,), name='Scale')
-    Shift = te.placeholder((in_channel * channel_multiplier,), name='Shift')
+    Input = te.placeholder((batch, in_height, in_width, in_channel), name="Input")
+    Filter = te.placeholder(
+        (filter_height, filter_width, filter_channel, channel_multiplier), name="Filter"
+    )
+    Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale")
+    Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift")
 
-    dtype = 'float32'
+    dtype = "float32"
 
     def check_device(device, ctx):
         print("Running on target: %s" % device)
@@ -172,8 +222,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         fcompute, fschedule = tvm.topi.testing.dispatch(device, _depthwise_conv2d_nhwc_implement)
         with tvm.target.Target(device):
             # declare
-            DepthwiseConv2d = fcompute(Input, Filter,
-                (stride_h, stride_w), padding_args, dilation, dtype)
+            DepthwiseConv2d = fcompute(
+                Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype
+            )
             ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
             Relu = topi.nn.relu(ScaleShift)
             # schedule
@@ -197,29 +248,53 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         def get_ref_data():
             input_np = np.random.uniform(size=input_shape).astype(dtype)
             filter_np = np.random.uniform(size=filter_shape).astype(dtype)
-            dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, (dilation, dilation, 1, 1))
+            dilated_filter_np = tvm.topi.testing.dilate_python(
+                filter_np, (dilation, dilation, 1, 1)
+            )
             scale_np = np.random.uniform(size=scale_shape).astype(dtype)
             shift_np = np.random.uniform(size=shift_shape).astype(dtype)
             # correctness with scipy
             depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc(
-                input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding)
+                input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding
+            )
             scale_shift_scipy = np.zeros(shape=scale_shift_shape)
             for c in range(in_channel * channel_multiplier):
-                scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
+                scale_shift_scipy[:, :, :, c] = (
+                    depthwise_conv2d_scipy[:, :, :, c] * scale_np[c] + shift_np[c]
+                )
                 relu_scipy = np.maximum(scale_shift_scipy, 0)
-            return (input_np, filter_np, scale_np, shift_np,
-                    depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
+            return (
+                input_np,
+                filter_np,
+                scale_np,
+                shift_np,
+                depthwise_conv2d_scipy,
+                scale_shift_scipy,
+                relu_scipy,
+            )
+
         # Get the test data
-        (input_np, filter_np, scale_np, shift_np,
-         depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
+        (
+            input_np,
+            filter_np,
+            scale_np,
+            shift_np,
+            depthwise_conv2d_scipy,
+            scale_shift_scipy,
+            relu_scipy,
+        ) = get_ref_data()
 
         # prepare data
         input_tvm = tvm.nd.array(input_np, ctx)
         filter_tvm = tvm.nd.array(filter_np, ctx)
         scale_tvm = tvm.nd.array(scale_np, ctx)
         shift_tvm = tvm.nd.array(shift_np, ctx)
-        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
-        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
+        depthwise_conv2d_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx
+        )
+        scale_shift_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx
+        )
         relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
         # launch kernel 1 (depthwise_conv2d)
         timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
@@ -231,7 +306,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
         tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
         relu_scipy = np.maximum(scale_shift_scipy, 0)
-        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+        tvm.testing.assert_allclose(
+            depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5
+        )
         tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
         tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
@@ -239,29 +316,36 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
             check_device(device, ctx)
 
+
 def _transform_data(data, bn):
     # NCHW -> NCHW[x]c
     batch_size, channel, height, width = data.shape
-    data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
+    data = np.reshape(data, (batch_size, channel // bn, bn, height, width))
     data = np.transpose(data, (0, 1, 3, 4, 2))
     return data
 
+
 def _transform_kernel(kernel, bn):
     # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block
     channel, channel_multiplier, kh, kw = kernel.shape
     out_channel = channel * channel_multiplier
-    kernel = np.reshape(kernel, (out_channel//bn, bn, kh, kw))
+    kernel = np.reshape(kernel, (out_channel // bn, bn, kh, kw))
     kernel = np.transpose(kernel, (0, 2, 3, 1))
     out_channel_chunk, kh, kw, out_channel_block = kernel.shape
     return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block)
 
-def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
+
+def depthwise_conv2d_with_workload_NCHWc(
+    batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1
+):
     in_width = in_height
     filter_channel = in_channel
     filter_width = filter_height
     stride_h = stride_w = stride
 
-    assert channel_multiplier == 1, "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1."
+    assert (
+        channel_multiplier == 1
+    ), "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1."
     pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
     padding_args = (pad_h, pad_w)
 
@@ -282,11 +366,15 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
             break
 
     # placeholder
-    Input = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input')
-    Filter = te.placeholder((out_channel//oc_block, 1, filter_height, filter_width, 1, oc_block), name='Filter')
+    Input = te.placeholder(
+        (batch, in_channel // ic_block, in_height, in_width, ic_block), name="Input"
+    )
+    Filter = te.placeholder(
+        (out_channel // oc_block, 1, filter_height, filter_width, 1, oc_block), name="Filter"
+    )
     in_layout = "NCHW%dc" % ic_block
     out_layout = "NCHW%dc" % oc_block
-    dtype = 'float32'
+    dtype = "float32"
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -296,12 +384,16 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
             # declare
-            DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc(Input, Filter,
-                                                              (stride_h, stride_w),
-                                                              padding,
-                                                              (dilation, dilation),
-                                                              in_layout,
-                                                              out_layout, dtype)
+            DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc(
+                Input,
+                Filter,
+                (stride_h, stride_w),
+                padding,
+                (dilation, dilation),
+                in_layout,
+                out_layout,
+                dtype,
+            )
             # TODO: add scale_shift implement for NCHWc and add test here
             Relu = topi.nn.relu(DepthwiseConv2d)
             # schedule
@@ -321,14 +413,19 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
             input_np = np.random.uniform(size=input_shape).astype(dtype)
             filter_np = np.random.uniform(size=filter_shape).astype(dtype)
             # correctness with scipy
-            dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype(dtype)
+            dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype(
+                dtype
+            )
             depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw(
-                input_np, dw_np, stride, padding)
+                input_np, dw_np, stride, padding
+            )
             relu_scipy = np.maximum(depthwise_conv2d_scipy, 0)
-            return (_transform_data(input_np, ic_block),
-                    _transform_kernel(filter_np, oc_block),
-                    _transform_data(depthwise_conv2d_scipy, oc_block),
-                    _transform_data(relu_scipy, oc_block))
+            return (
+                _transform_data(input_np, ic_block),
+                _transform_kernel(filter_np, oc_block),
+                _transform_data(depthwise_conv2d_scipy, oc_block),
+                _transform_data(relu_scipy, oc_block),
+            )
 
         # Get the test data
         (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data()
@@ -336,14 +433,17 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
         input_tvm = tvm.nd.array(input_np, ctx)
         filter_tvm = tvm.nd.array(filter_np, ctx)
 
-        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),
-                                                     dtype=DepthwiseConv2d.dtype), ctx)
+        depthwise_conv2d_tvm = tvm.nd.array(
+            np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx
+        )
         relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
         # launch kernel 1 (depthwise_conv2d)
         f1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
         # launch kernel 2 (depthwise_conv2d + relu)
         f2(input_tvm, filter_tvm, relu_tvm)
-        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+        tvm.testing.assert_allclose(
+            depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5
+        )
         tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
     # test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend.
index 25ef6f1..8b4575f 100644 (file)
@@ -27,26 +27,34 @@ from tvm.topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_in
 import tvm.testing
 
 
-def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
+def verify_depthwise_conv2d_back_input(
+    batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h
+):
     in_w = in_h
     filter_channel = in_channel
     filter_w = filter_h
     stride_w = stride_h
     padding_w = padding_h
 
-    out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
-    out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
+    out_h = np.int((in_h + 2 * padding_h - filter_h) / stride_h + 1)
+    out_w = np.int((in_w + 2 * padding_w - filter_w) / stride_w + 1)
     out_channel = in_channel * channel_multiplier
 
     ishape = [batch, in_h, in_w, in_channel]
     oshape = [batch, out_h, out_w, out_channel]
 
     # placeholder
-    Out_grad = te.placeholder(oshape, name='Out_grad')
+    Out_grad = te.placeholder(oshape, name="Out_grad")
     Filter = te.placeholder((filter_h, filter_w, filter_channel, channel_multiplier))
     # declare
-    In_grad = topi.nn.depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape,
-        stride=[stride_h, stride_w], padding=[padding_h, padding_w])
+    In_grad = topi.nn.depthwise_conv2d_backward_input_nhwc(
+        Filter,
+        Out_grad,
+        oshape,
+        ishape,
+        stride=[stride_h, stride_w],
+        padding=[padding_h, padding_w],
+    )
     # schedule
     schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
 
@@ -68,26 +76,43 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
         def get_ref_data():
             out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
             filter_np = np.random.uniform(size=filter_shape).astype(dtype)
-            dilated_out_grad_np = tvm.topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
+            dilated_out_grad_np = tvm.topi.testing.dilate_python(
+                out_grad_np, [1, stride_h, stride_w, 1]
+            )
             # padding params in forward propagation
-            fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
+            fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
+                [padding_h, padding_w], (filter_h, filter_w)
+            )
             # padding params in backward propagation
             bpad_top = filter_h - 1 - fpad_top
             bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
             bpad_left = filter_w - 1 - fpad_left
             bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
 
-            padded_out_grad = np.zeros((batch, dilated_out_grad_np.shape[1]+bpad_top+bpad_bottom,
-                dilated_out_grad_np.shape[2]+bpad_left+bpad_right, out_channel))
-            padded_out_grad[:, bpad_top:dilated_out_grad_np.shape[1]+bpad_top,
-                bpad_left:dilated_out_grad_np.shape[2]+bpad_left, :] = dilated_out_grad_np
+            padded_out_grad = np.zeros(
+                (
+                    batch,
+                    dilated_out_grad_np.shape[1] + bpad_top + bpad_bottom,
+                    dilated_out_grad_np.shape[2] + bpad_left + bpad_right,
+                    out_channel,
+                )
+            )
+            padded_out_grad[
+                :,
+                bpad_top : dilated_out_grad_np.shape[1] + bpad_top,
+                bpad_left : dilated_out_grad_np.shape[2] + bpad_left,
+                :,
+            ] = dilated_out_grad_np
 
             in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
             for b in range(batch):
                 for c in range(in_channel):
                     for m in range(channel_multiplier):
-                        in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
-                                filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
+                        in_grad_np[b, :, :, c] += signal.convolve2d(
+                            padded_out_grad[b, :, :, c * channel_multiplier + m],
+                            filter_np[:, :, c, m],
+                            mode="valid",
+                        )[0:in_h, 0:in_w]
             return (out_grad_np, filter_np, in_grad_np)
 
         (out_grad_np, filter_np, in_grad_np) = get_ref_data()
@@ -107,6 +132,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
     check_device("vulkan")
     check_device("nvptx")
 
+
 @tvm.testing.requires_gpu
 def test_topi_depthwise_conv2d_backward_input_nhwc():
     verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
index 5ebc56d..3826f6f 100644 (file)
@@ -27,26 +27,29 @@ from tvm.topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_we
 import tvm.testing
 
 
-def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h):
+def verify_depthwise_conv2d_back_weight(
+    batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h
+):
     in_w = in_h
     filter_channel = in_channel
     filter_w = filter_h
     stride_w = stride_h
     padding_w = padding_h
 
-    out_h = np.int((in_h+2*padding_h-filter_h)/stride_h+1)
-    out_w = np.int((in_w+2*padding_w-filter_w)/stride_w+1)
+    out_h = np.int((in_h + 2 * padding_h - filter_h) / stride_h + 1)
+    out_w = np.int((in_w + 2 * padding_w - filter_w) / stride_w + 1)
     out_channel = in_channel * channel_multiplier
 
     oshape = [batch, out_h, out_w, out_channel]
     fshape = [filter_h, filter_w, in_channel, channel_multiplier]
 
     # placeholder
-    Out_grad = te.placeholder(oshape, name='Out_grad')
-    Input = te.placeholder((batch, in_h, in_w, in_channel), name='In_grad')
+    Out_grad = te.placeholder(oshape, name="Out_grad")
+    Input = te.placeholder((batch, in_h, in_w, in_channel), name="In_grad")
     # declare
-    Weight_grad = topi.nn.depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape,
-        stride=[stride_h, stride_w], padding=[padding_h, padding_w])
+    Weight_grad = topi.nn.depthwise_conv2d_backward_weight_nhwc(
+        Input, Out_grad, oshape, fshape, stride=[stride_h, stride_w], padding=[padding_h, padding_w]
+    )
     # schedule
     schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
 
@@ -68,19 +71,32 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
         def get_ref_data():
             out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
             input_np = np.random.uniform(size=in_shape).astype(dtype)
-            dilated_out_grad_np = tvm.topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
+            dilated_out_grad_np = tvm.topi.testing.dilate_python(
+                out_grad_np, [1, stride_h, stride_w, 1]
+            )
 
-            pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
-            padded_input_np = np.zeros((batch, in_h+pad_top+pad_bottom, in_w+pad_left+pad_right, in_channel))
-            padded_input_np[:, pad_top:in_h+pad_top, pad_left:in_w+pad_left, :] = input_np
+            pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
+                [padding_h, padding_w], (filter_h, filter_w)
+            )
+            padded_input_np = np.zeros(
+                (batch, in_h + pad_top + pad_bottom, in_w + pad_left + pad_right, in_channel)
+            )
+            padded_input_np[:, pad_top : in_h + pad_top, pad_left : in_w + pad_left, :] = input_np
 
             weight_grad_np = np.zeros((filter_h, filter_w, in_channel, channel_multiplier))
             for c in range(in_channel):
                 for m in range(channel_multiplier):
                     for b in range(batch):
-                        weight_grad_np[:, :, c, m] += signal.convolve2d(padded_input_np[b, :, :, c], \
-                            np.rot90(dilated_out_grad_np[b, :, :, c*channel_multiplier+m%channel_multiplier], 2), \
-                            mode='valid')[0:filter_h, 0:filter_w]
+                        weight_grad_np[:, :, c, m] += signal.convolve2d(
+                            padded_input_np[b, :, :, c],
+                            np.rot90(
+                                dilated_out_grad_np[
+                                    b, :, :, c * channel_multiplier + m % channel_multiplier
+                                ],
+                                2,
+                            ),
+                            mode="valid",
+                        )[0:filter_h, 0:filter_w]
             return (out_grad_np, input_np, weight_grad_np)
 
         (out_grad_np, input_np, weight_grad_np) = get_ref_data()
@@ -100,6 +116,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
     check_device("vulkan")
     check_device("nvptx")
 
+
 @tvm.testing.requires_gpu
 def test_topi_depthwise_conv2d_backward_weight_nhwc():
     verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
@@ -120,5 +137,6 @@ def test_topi_depthwise_conv2d_backward_weight_nhwc():
     verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 5, 2, 0)
     verify_depthwise_conv2d_back_weight(15, 256, 56, 2, 5, 2, 0)
 
+
 if __name__ == "__main__":
     test_topi_depthwise_conv2d_backward_weight_nhwc()
index 60f2083..872ee05 100644 (file)
@@ -22,7 +22,7 @@ import numpy as np
 
 
 def test_dilate():
-    target = 'llvm'
+    target = "llvm"
     ctx = tvm.cpu(0)
 
     def _test_dilate(input_size, strides):
@@ -39,13 +39,13 @@ def test_dilate():
         tvm.testing.assert_allclose(output_tvm.asnumpy(), output_np, rtol=1e-5)
 
     _test_dilate((32,), (2,))
-    _test_dilate((32,32), (2,2))
-    _test_dilate((1,3,32,32), (1,1,1,1))
-    _test_dilate((1,3,32,32), (2,2,2,2))
-    _test_dilate((1,32,32,3,3), (1,1,1,1,1))
-    _test_dilate((1,32,32,3,3), (2,2,2,2,2))
-    _test_dilate((1,32,32,32,3,3), (1,1,1,2,2,2))
-    _test_dilate((1,32,32,32,3,3), (2,2,2,1,1,1))
+    _test_dilate((32, 32), (2, 2))
+    _test_dilate((1, 3, 32, 32), (1, 1, 1, 1))
+    _test_dilate((1, 3, 32, 32), (2, 2, 2, 2))
+    _test_dilate((1, 32, 32, 3, 3), (1, 1, 1, 1, 1))
+    _test_dilate((1, 32, 32, 3, 3), (2, 2, 2, 2, 2))
+    _test_dilate((1, 32, 32, 32, 3, 3), (1, 1, 1, 2, 2, 2))
+    _test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1))
 
 
 if __name__ == "__main__":
index 959b15c..d97716b 100644 (file)
@@ -36,16 +36,29 @@ _group_conv2d_nchw_implement = {
 }
 
 
-def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
-        (batch, in_channel, in_size, num_filter,
-         kernel, stride, padding, dilation, groups))
+def verify_group_conv2d_nchw(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation,
+    groups,
+    add_bias=False,
+    add_relu=False,
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W')
-    bias = te.placeholder((num_filter, 1, 1), name='bias')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+    W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name="W")
+    bias = te.placeholder((num_filter, 1, 1), name="bias")
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -58,7 +71,9 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
         w_np = np.random.uniform(size=w_shape).astype(dtype)
         b_np = np.random.uniform(size=bias_shape).astype(dtype)
         dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
+        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(
+            dtype
+        )
 
         if add_bias:
             b_np = np.random.uniform(size=bias_shape).astype(dtype)
@@ -91,12 +106,42 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
-                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                ),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
-            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                ),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
@@ -107,17 +152,31 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
 oc_block_factor = 4
 
 
-def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
-        (batch, in_channel, in_size, num_filter,
-         kernel, stride, padding, dilation, groups))
+def verify_group_conv2d_NCHWc_int8(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation,
+    groups,
+    add_bias=False,
+    add_relu=False,
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)
+    )
 
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
-    W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W', dtype='int8')
-    bias = te.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
-                            dtype='int8')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8")
+    W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name="W", dtype="int8")
+    bias = te.placeholder(
+        (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype="int8"
+    )
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -130,12 +189,15 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kerne
         w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
         b_np = np.random.uniform(size=bias_shape).astype(dtype)
         dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
+        c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(
+            dtype
+        )
 
         # convert to NCHWc
         _, _, out_height, out_width = c_np.shape
-        c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
-                out_height, out_width)).transpose(0, 1, 3, 4, 2)
+        c_np = c_np.reshape(
+            (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width)
+        ).transpose(0, 1, 3, 4, 2)
 
         if add_bias:
             b_np = np.random.uniform(size=bias_shape).astype(dtype)
@@ -170,12 +232,42 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kerne
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
-                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func = tvm.build(
+                s,
+                [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                ),
+            )
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
-            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func = tvm.build(
+                s,
+                [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d"
+                % (
+                    batch,
+                    in_channel,
+                    in_size,
+                    num_filter,
+                    kernel,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                ),
+            )
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
@@ -197,8 +289,7 @@ def test_group_conv2d_nchw():
     # bias, relu
     verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
     verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
-    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
-                             add_bias=True)
+    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True)
 
     # dilation
     verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 2, 32)
@@ -208,7 +299,6 @@ def test_group_conv2d_nchw():
     verify_group_conv2d_nchw(9, 128, 56, 128, 3, 1, 1, 1, 32)
 
 
-
 @tvm.testing.requires_cuda
 def test_group_conv2d_NCHWc_int8():
     with Int8Fallback():
@@ -224,8 +314,9 @@ def test_group_conv2d_NCHWc_int8():
         # bias, relu
         verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
         verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
-        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
-                                       add_bias=True)
+        verify_group_conv2d_NCHWc_int8(
+            1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True
+        )
         # dilation
         verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)
 
index 20c4490..36b7f29 100644 (file)
@@ -27,25 +27,44 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.util import get_const_tuple
 import pytest
 
+
 def _transform_data(data, bn):
     # NCHW -> NCHW[x]c
     batch_size, channel, height, width = data.shape
-    data = np.reshape(data, (batch_size, channel//bn, bn, height, width))
+    data = np.reshape(data, (batch_size, channel // bn, bn, height, width))
     data = np.transpose(data, (0, 1, 3, 4, 2))
     return data
 
+
 def _transform_kernel(kernel, ic_bn, oc_bn):
     # OIHW -> OIHW[x]i[x]o
     out_channel, in_channel, kh, kw = kernel.shape
-    kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4))
+    kernel = np.reshape(
+        kernel, (out_channel // oc_bn, oc_bn, in_channel // ic_bn, ic_bn // 4, kh, kw, 4)
+    )
     kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6))
     return kernel
 
-def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride,
-                        padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"):
+
+def verify_group_conv2d_NCHWc_int8(
+    batch,
+    in_channel,
+    groups,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    add_bias=False,
+    add_relu=False,
+    dtype="int32",
+):
     assert dilation == 1, "conv2d_NCHWc does not support dilation for now."
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
-          (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding))
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
+        % (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding)
+    )
 
     in_height = in_width = in_size
 
@@ -60,16 +79,35 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
 
     ic_block = 8
     autotvm.GLOBAL_SCOPE.silent = True
-    A = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8')
-    W = te.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8')
+    A = te.placeholder(
+        (batch, in_channel // ic_block, in_height, in_width, ic_block), name="A", dtype="uint8"
+    )
+    W = te.placeholder(
+        (
+            num_filter // oc_block,
+            in_channel // ic_block // groups,
+            kernel,
+            kernel,
+            ic_block // 4,
+            oc_block,
+            4,
+        ),
+        name="W",
+        dtype="int8",
+    )
 
     @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8")
     def get_ref_data():
         a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8")
-        w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8")
+        w_np = np.random.uniform(size=(num_filter, in_channel // groups, kernel, kernel)).astype(
+            "int8"
+        )
         c_np = tvm.topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups)
-        return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
-               _transform_data(c_np, oc_block)
+        return (
+            _transform_data(a_np, ic_block),
+            _transform_kernel(w_np, ic_block, oc_block),
+            _transform_data(c_np, oc_block),
+        )
 
     a_np, w_np, c_np = get_ref_data()
 
@@ -80,19 +118,28 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
             return
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
-            C = topi.x86.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
-                                      (dilation, dilation),
-                                      'NCHW%dc'%ic_block,
-                                      "NCHW%dc"%oc_block,
-                                      dtype)
+            C = topi.x86.conv2d_NCHWc(
+                A,
+                W,
+                (stride, stride),
+                (padding, padding),
+                (dilation, dilation),
+                "NCHW%dc" % ic_block,
+                "NCHW%dc" % oc_block,
+                dtype,
+            )
             s = topi.x86.schedule_conv2d_NCHWc([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        func = tvm.build(s, [A, W, C], device,
-                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
-                              (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+        func = tvm.build(
+            s,
+            [A, W, C],
+            device,
+            name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
+            % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation),
+        )
         # print(tvm.lower(s, [A, W, C], simple_mode=True))
         func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
@@ -103,12 +150,14 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
             check_device(device)
     autotvm.GLOBAL_SCOPE.silent = False
 
+
 @tvm.testing.uses_gpu
 @pytest.mark.skip
 def test_conv2d_NCHWc():
     # ResNet50 workloads
     verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3)
 
+
 if __name__ == "__main__":
     # The test requires Skylake and newer Intel machines to generate the correct
     # instruction. This test directly calls the topi operator, requiring correct
index 207dfe7..79e0930 100644 (file)
@@ -23,24 +23,40 @@ import tvm.topi.testing
 from tvm.contrib.pickle_memoize import memoize
 
 
-def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
-                  layout='NCHW', coord_trans="align_corners", method="bilinear"):
-    if layout == 'NCHW':
-        A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
+def verify_resize(
+    batch,
+    in_channel,
+    in_height,
+    in_width,
+    out_height,
+    out_width,
+    layout="NCHW",
+    coord_trans="align_corners",
+    method="bilinear",
+):
+    if layout == "NCHW":
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="float32")
         dtype = A.dtype
         out_shape = (batch, in_channel, out_height, out_width)
         a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
-    elif layout == 'NHWC':
-        A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='float32')
+    elif layout == "NHWC":
+        A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="float32")
         dtype = A.dtype
         out_shape = (batch, out_height, out_width, in_channel)
         a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
     else:
-        raise NotImplementedError(
-            'Layout not supported {} '.format(layout))
-    B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method)
+        raise NotImplementedError("Layout not supported {} ".format(layout))
+    B = topi.image.resize(
+        A,
+        (out_height, out_width),
+        layout=layout,
+        coordinate_transformation_mode=coord_trans,
+        method=method,
+    )
     if method == "bilinear":
-        b_np = tvm.topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, coord_trans)
+        b_np = tvm.topi.testing.bilinear_resize_python(
+            a_np, (out_height, out_width), layout, coord_trans
+        )
     else:
         scale_h = out_height / in_height
         scale_w = out_width / in_width
@@ -64,46 +80,70 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
 @tvm.testing.uses_gpu
 def test_resize():
     # Scale NCHW
-    verify_resize(4, 16, 32, 32, 50, 50, 'NCHW')
+    verify_resize(4, 16, 32, 32, 50, 50, "NCHW")
     # Scale NCHW + Align Corners
-    verify_resize(6, 32, 64, 64, 20, 20, 'NCHW')
+    verify_resize(6, 32, 64, 64, 20, 20, "NCHW")
     # Scale NHWC
     verify_resize(4, 16, 32, 32, 50, 50, "NHWC")
     # Scale NHWC + Align Corners
     verify_resize(6, 32, 64, 64, 20, 20, "NHWC")
     # Nearest + Fractional
-    verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="nearest_neighbor")
-    verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="nearest_neighbor")
+    verify_resize(4, 16, 32, 32, 50, 50, "NCHW", "asymmetric", method="nearest_neighbor")
+    verify_resize(4, 16, 32, 32, 50, 50, "NHWC", "asymmetric", method="nearest_neighbor")
     # half_pixel
-    verify_resize(4, 16, 16, 16, 32, 32, 'NCHW', "half_pixel", method="bilinear")
-    verify_resize(4, 16, 16, 16, 32, 32, 'NHWC', "half_pixel", method="bilinear")
+    verify_resize(4, 16, 16, 16, 32, 32, "NCHW", "half_pixel", method="bilinear")
+    verify_resize(4, 16, 16, 16, 32, 32, "NHWC", "half_pixel", method="bilinear")
     # Bilinear + Fractional
-    verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="bilinear")
-    verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="bilinear")
-
-
-def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width,
-                    layout='NCDHW', coordinate_transformation_mode="half_pixel", method="trilinear"):
-    if layout == 'NCDHW':
-        A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A', dtype='float32')
+    verify_resize(4, 16, 32, 32, 50, 50, "NCHW", "asymmetric", method="bilinear")
+    verify_resize(4, 16, 32, 32, 50, 50, "NHWC", "asymmetric", method="bilinear")
+
+
+def verify_resize3d(
+    batch,
+    in_channel,
+    in_depth,
+    in_height,
+    in_width,
+    out_depth,
+    out_height,
+    out_width,
+    layout="NCDHW",
+    coordinate_transformation_mode="half_pixel",
+    method="trilinear",
+):
+    if layout == "NCDHW":
+        A = te.placeholder(
+            (batch, in_channel, in_depth, in_height, in_width), name="A", dtype="float32"
+        )
         dtype = A.dtype
         out_shape = (batch, in_channel, out_depth, out_height, out_width)
-        a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype)
-    elif layout == 'NDHWC':
-        A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A', dtype='float32')
+        a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(
+            dtype
+        )
+    elif layout == "NDHWC":
+        A = te.placeholder(
+            (batch, in_depth, in_height, in_width, in_channel), name="A", dtype="float32"
+        )
         dtype = A.dtype
         out_shape = (batch, out_depth, out_height, out_width, in_channel)
-        a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype)
+        a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(
+            dtype
+        )
     else:
-        raise NotImplementedError(
-            'Layout not supported {} '.format(layout))
+        raise NotImplementedError("Layout not supported {} ".format(layout))
 
-    B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout,
-                            coordinate_transformation_mode=coordinate_transformation_mode, method=method)
+    B = topi.image.resize3d(
+        A,
+        (out_depth, out_height, out_width),
+        layout=layout,
+        coordinate_transformation_mode=coordinate_transformation_mode,
+        method=method,
+    )
 
     if method == "trilinear":
-        b_np = tvm.topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout,
-                                                      coordinate_transformation_mode)
+        b_np = tvm.topi.testing.trilinear_resize3d_python(
+            a_np, (out_depth, out_height, out_width), layout, coordinate_transformation_mode
+        )
     else:
         scale_d = out_depth / in_depth
         scale_h = out_height / in_height
@@ -128,46 +168,60 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth,
 @tvm.testing.uses_gpu
 def test_resize3d():
     # Trilinear
-    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW')
+    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW")
     verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC")
-    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "align_corners")
-    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "align_corners")
-    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "asymmetric")
-    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "asymmetric")
+    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "align_corners")
+    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "align_corners")
+    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "asymmetric")
+    verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "asymmetric")
 
     # Nearest neighbor
-    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW', method="nearest_neighbor")
-    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor")
+    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW", method="nearest_neighbor")
+    verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NDHWC", method="nearest_neighbor")
 
 
 @tvm.testing.uses_gpu
 def test_crop_and_resize():
-    def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, layout='NHWC',
-                               method="bilinear", extrapolation_value=0.0):
-
-        images = te.placeholder(image_shape, name='images', dtype='float32')
+    def verify_crop_and_resize(
+        image_shape,
+        np_boxes,
+        np_box_indices,
+        np_crop_size,
+        layout="NHWC",
+        method="bilinear",
+        extrapolation_value=0.0,
+    ):
+
+        images = te.placeholder(image_shape, name="images", dtype="float32")
         np_images = np.random.uniform(size=image_shape).astype("float32")
         boxes = te.placeholder(np_boxes.shape, name="boxes", dtype="float32")
         box_ind = te.placeholder(np_box_indices.shape, name="box_ind", dtype="int32")
 
         batch = len(np_box_indices)
         target_height, target_width = np_crop_size[0], np_crop_size[1]
-        if layout == 'NHWC':
+        if layout == "NHWC":
             channel = image_shape[3]
             out_shape = (batch, target_height, target_width, channel)
-        elif layout == 'NCHW':
+        elif layout == "NCHW":
             channel = image_shape[1]
             out_shape = (batch, channel, target_height, target_width)
         else:
-            raise NotImplementedError(
-                'Layout {} is not supported.'.format(layout))
-
-        out = topi.image.crop_and_resize(images, boxes, box_ind, np_crop_size, layout=layout,
-                                         method=method, extrapolation_value=extrapolation_value)
+            raise NotImplementedError("Layout {} is not supported.".format(layout))
+
+        out = topi.image.crop_and_resize(
+            images,
+            boxes,
+            box_ind,
+            np_crop_size,
+            layout=layout,
+            method=method,
+            extrapolation_value=extrapolation_value,
+        )
+
+        baseline_np = tvm.topi.testing.crop_and_resize_python(
+            np_images, np_boxes, np_box_indices, np_crop_size, layout, method, extrapolation_value
+        )
 
-        baseline_np = tvm.topi.testing.crop_and_resize_python(np_images, np_boxes, np_box_indices,
-                                                          np_crop_size, layout, method,
-                                                          extrapolation_value)
         def check_device(device, ctx):
             print("Running on target: %s" % device)
             with tvm.target.Target(device):
@@ -184,18 +238,18 @@ def test_crop_and_resize():
         for device, ctx in tvm.testing.enabled_targets():
             check_device(device, ctx)
 
-    boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32")
-    boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32")
+    boxes_1 = np.array([[0.2, 0.3, 0.7, 0.9]], dtype="float32")
+    boxes_2 = np.array([[0.2, 0.3, 0.7, 0.9], [0, 0.1, 0.8, 1]], dtype="float32")
     indices_1 = np.array([0], dtype="int32")
     indices_2 = np.array([1, 0], dtype="int32")
     size_1 = (7, 11)
     size_2 = (90, 60)
 
     verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC")
-    verify_crop_and_resize((10, 224, 224, 5), boxes_2, indices_2,
-                           size_2, extrapolation_value=0.3, layout="NHWC")
-    verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1,
-                           size_1, method='nearest_neighbor')
+    verify_crop_and_resize(
+        (10, 224, 224, 5), boxes_2, indices_2, size_2, extrapolation_value=0.3, layout="NHWC"
+    )
+    verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1, size_1, method="nearest_neighbor")
     verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW")
 
 
@@ -224,8 +278,7 @@ def test_affine_grid():
             f = tvm.build(s, [data, out], device)
             f(tvm_data, tvm_out)
 
-            tvm.testing.assert_allclose(
-                tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
 
         for device, ctx in tvm.testing.enabled_targets():
             check_device(device, ctx)
@@ -240,14 +293,14 @@ def test_grid_sample():
         dtype = "float32"
         data = te.placeholder(data_shape, dtype=dtype)
         grid = te.placeholder(grid_shape, dtype=dtype)
-        out = topi.image.grid_sample(data, grid, 'bilinear')
+        out = topi.image.grid_sample(data, grid, "bilinear")
 
         @memoize("topi.tests.test_grid_sample.verify_grid_sample")
         def get_ref_data():
             data_np = np.random.uniform(size=data_shape).astype(dtype)
             # allow grid values to be out-of-bound
             grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
-            out_np = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, 'bilinear')
+            out_np = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, "bilinear")
             return data_np, grid_np, out_np
 
         data_np, grid_np, out_np = get_ref_data()
@@ -262,8 +315,7 @@ def test_grid_sample():
             f = tvm.build(s, [data, grid, out], device)
             f(tvm_data, tvm_grid, tvm_out)
 
-            tvm.testing.assert_allclose(
-                tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
 
         for device, ctx in tvm.testing.enabled_targets():
             check_device(device, ctx)
index b753ca1..7e3300c 100644 (file)
@@ -33,8 +33,9 @@ _lrn_schedule = {
     "nvptx": topi.cuda.schedule_lrn,
 }
 
+
 def verify_lrn(shape, size, axis, bias, alpha, beta):
-    A = te.placeholder(shape, name='A')
+    A = te.placeholder(shape, name="A")
     B = topi.nn.lrn(A, size, axis, alpha, beta, bias)
     dtype = A.dtype
 
@@ -56,14 +57,16 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
         f(a, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+    for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]:
         check_device(device)
 
+
 @tvm.testing.uses_gpu
 def test_lrn():
     verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5)
     verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
     verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
 
+
 if __name__ == "__main__":
     test_lrn()
index f068c97..48b6d9b 100644 (file)
@@ -107,7 +107,7 @@ def test_ewise():
             check_device(target, ctx)
 
     def test_infiniteness_ops(topi_op, ref_op, name):
-        for dtype in ['float32', 'float64', 'int32', 'int16']:
+        for dtype in ["float32", "float64", "int32", "int16"]:
             m = te.var("m")
             l = te.var("l")
             A = te.placeholder((m, l), dtype=dtype, name="A")
@@ -115,9 +115,13 @@ def test_ewise():
             assert tuple(B.shape) == tuple(A.shape)
 
             a_np = np.random.uniform(size=(8, 8)).astype(A.dtype) * 10
-            if dtype.startswith('float'):
-                a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.infty
-                a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
+            if dtype.startswith("float"):
+                a_np.ravel()[
+                    np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)
+                ] = np.infty
+                a_np.ravel()[
+                    np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)
+                ] = np.nan
             b_np = ref_op(a_np)
 
             def check_device(device, ctx):
@@ -144,15 +148,17 @@ def test_ewise():
     test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1)
     test_apply(topi.log, "log", np.log, 0, 100)
     test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
-    test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
-    test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
-    test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
-    test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
-    test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
-    test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
+    test_apply(
+        topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True
+    )
+    test_apply(topi.cos, "cos", np.cos, -2.0 * np.pi, 2.0 * np.pi)
+    test_apply(topi.tan, "tan", np.tan, -2.0 * np.pi, 2.0 * np.pi, dtype="float32")
+    test_apply(topi.tan, "tan", np.tan, -2.0 * np.pi, 2.0 * np.pi, dtype="float64")
+    test_apply(topi.sin, "sin", np.sin, -2.0 * np.pi, 2.0 * np.pi)
+    test_apply(topi.erf, "erf", scipy.special.erf, -0.1, 0.1, dtype="float32")
     test_isnan(-100, 100)
-    test_infiniteness_ops(topi.isfinite, np.isfinite, 'isifinite')
-    test_infiniteness_ops(topi.isinf, np.isinf, 'isinf')
+    test_infiniteness_ops(topi.isfinite, np.isfinite, "isifinite")
+    test_infiniteness_ops(topi.isinf, np.isinf, "isinf")
 
 
 @tvm.testing.uses_gpu
@@ -191,15 +197,7 @@ def test_cast():
 
 
 def test_fastmath():
-    def test_apply(
-        func,
-        name,
-        f_numpy,
-        low,
-        high,
-        step,
-        dtype="float32"
-    ):
+    def test_apply(func, name, f_numpy, low, high, step, dtype="float32"):
         a_np = np.arange(low, high, step).astype(dtype)
         b_np = f_numpy(a_np)
         A = te.placeholder(a_np.shape, dtype=dtype, name="A")
@@ -219,16 +217,13 @@ def test_fastmath():
             func(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        check_device('llvm')
-        check_device('llvm -device=arm-cpu')
+        check_device("llvm")
+        check_device("llvm -device=arm-cpu")
 
+    test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
+    test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
+    test_apply(topi.fast_tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
 
-    test_apply(topi.fast_exp, "fast_exp", np.exp,
-               low=-88, high=88, step=0.01)
-    test_apply(topi.fast_erf, "fast_erf", scipy.special.erf,
-               low=-10, high=10, step=0.01)
-    test_apply(topi.fast_tanh, "fast_tanh", np.tanh,
-               low=-10, high=10, step=0.01)
 
 if __name__ == "__main__":
     test_util()
index 4ffa29e..f3b3396 100644 (file)
@@ -20,41 +20,44 @@ from tvm import te
 from tvm import topi
 from tvm.topi.util import get_const_tuple
 
+
 def with_tvm(lam, *args):
-    """ Take numpy arrays as args, convert them to TVM tensors and call `lam`.
+    """Take numpy arrays as args, convert them to TVM tensors and call `lam`.
     Result of lambda is converted back to numpy array and returned.
     """
     ctx = tvm.cpu(0)
-    pls = []     # placeholders
-    vals_nd = [] # initial values
-    for i,arg in enumerate(args):
-        pls.append(te.placeholder(arg.shape, name='pl'+str(i)))
+    pls = []  # placeholders
+    vals_nd = []  # initial values
+    for i, arg in enumerate(args):
+        pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
         vals_nd.append(tvm.nd.array(arg, ctx))
 
     out = lam(*pls)
     out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), ctx)
     s = te.create_schedule([out.op])
     m = tvm.build(s, pls + [out], "llvm")
-    m(*(vals_nd+[out_nd]))
+    m(*(vals_nd + [out_nd]))
     return out_nd.asnumpy()
 
+
 def verify_matmul(sa, sb, transp_a, transp_b):
     a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
     b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
-    c1 = np.matmul(np.transpose(a) if transp_a else a,
-                   np.transpose(b) if transp_b else b)
-    c2 = with_tvm(lambda A,B: topi.matmul(A,B,transp_a,transp_b), a,b)
+    c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if transp_b else b)
+    c2 = with_tvm(lambda A, B: topi.matmul(A, B, transp_a, transp_b), a, b)
     tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
 
+
 def test_matmul():
-    verify_matmul((1,1),(1,1),False,False)
-    verify_matmul((1,1),(1,1),True,True)
-    verify_matmul((2,2),(2,2),False,False)
-    verify_matmul((2,2),(2,2),True,True)
-    verify_matmul((2,3),(3,5),False,False)
-    verify_matmul((5,3),(3,2),False,False)
-    verify_matmul((3,5),(3,2),True,False)
-    verify_matmul((3,5),(2,3),True,True)
+    verify_matmul((1, 1), (1, 1), False, False)
+    verify_matmul((1, 1), (1, 1), True, True)
+    verify_matmul((2, 2), (2, 2), False, False)
+    verify_matmul((2, 2), (2, 2), True, True)
+    verify_matmul((2, 3), (3, 5), False, False)
+    verify_matmul((5, 3), (3, 2), False, False)
+    verify_matmul((3, 5), (3, 2), True, False)
+    verify_matmul((3, 5), (2, 3), True, True)
+
 
 def verify_tensordot(sa, sb, axes):
     a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
@@ -63,6 +66,7 @@ def verify_tensordot(sa, sb, axes):
     c2 = with_tvm(lambda A, B: topi.tensordot(A, B, axes), a, b)
     tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
 
+
 def test_tensordot():
     verify_tensordot((3), (3), 0)
     verify_tensordot((2, 3), (3, 5), 1)
@@ -72,7 +76,7 @@ def test_tensordot():
     verify_tensordot((3, 2, 2), (2, 3, 5), ((1, 0), (0, 1)))
     verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1)))
 
+
 if __name__ == "__main__":
     test_matmul()
     test_tensordot()
-
index c64624f..2df26dd 100644 (file)
@@ -75,7 +75,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
 
         foo = tvm.build(s, [A, B], device, name=type)
         # Test
-        if dtype == 'bool':
+        if dtype == "bool":
             in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
         else:
             in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
@@ -83,7 +83,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
 
         if type == "sum":
             out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
-        elif type == "all" and dtype == 'bool':
+        elif type == "all" and dtype == "bool":
             out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
         elif type == "any" and dtype == "bool":
             out_npy = in_npy_map.any(axis=axis, keepdims=keepdims)
@@ -108,15 +108,16 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
             if axis is None:
                 out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
             else:
-                other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
+                other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis + 1) :]))
                 sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
                 out_tvm_val = in_npy_map[sel_indices]
             if type == "argmax":
-                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
+                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3)
             elif type == "argmin":
-                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
+                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3)
         else:
-            tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
+            tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
+
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
@@ -124,77 +125,31 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
 @tvm.testing.uses_gpu
 def test_reduce_map():
 
-    verify_reduce_map_ele(in_shape=(32,),
-                          axis=0,
-                          keepdims=False,
-                          type="argmax")
-    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                          axis=(1, 2, 3),
-                          keepdims=True,
-                          type="sum")
-    verify_reduce_map_ele(in_shape=(2, 3),
-                          axis=None,
-                          keepdims=True,
-                          type="all",
-                          dtype='bool')
-    verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
-                          axis=(1,),
-                          keepdims=False,
-                          type="max")
-    verify_reduce_map_ele(in_shape=(32, 128, 24),
-                          axis=None,
-                          keepdims=True,
-                          type="sum")
-    verify_reduce_map_ele(in_shape=(32, 128, 24),
-                          axis=None,
-                          keepdims=True,
-                          dtype='bool',
-                          type="all")
-    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                          axis=(0, 2),
-                          keepdims=False,
-                          type="min")
-    verify_reduce_map_ele(in_shape=(32, 128),
-                          axis=1,
-                          keepdims=True,
-                          type="argmax")
-    verify_reduce_map_ele(in_shape=(32, 24, 32, 24),
-                          axis=2,
-                          keepdims=False,
-                          type="argmin")
-    verify_reduce_map_ele(in_shape=(31, 21, 15),
-                          axis=None,
-                          keepdims=True,
-                          type="argmax")
-    verify_reduce_map_ele(in_shape=(31, 21, 15),
-                          axis=None,
-                          keepdims=False,
-                          type="sum")
-    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                          axis=(1, 2, 3),
-                          keepdims=True,
-                          type="sum",
-                          dtype="float64")
-    verify_reduce_map_ele(in_shape=(2, 3),
-                          axis=None,
-                          keepdims=True,
-                          type="any",
-                          dtype="bool")
-    verify_reduce_map_ele(in_shape=(32, 128, 24),
-                          axis=None,
-                          keepdims=True,
-                          type="any",
-                          dtype="bool")
-    verify_reduce_map_ele(in_shape=(1, 4, 7),
-                          axis=1,
-                          keepdims=True,
-                          type="any",
-                          dtype="bool")
-    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                          axis=2,
-                          keepdims=False,
-                          type="any",
-                          dtype="bool")
+    verify_reduce_map_ele(in_shape=(32,), axis=0, keepdims=False, type="argmax")
+    verify_reduce_map_ele(in_shape=(128, 24, 128, 24), axis=(1, 2, 3), keepdims=True, type="sum")
+    verify_reduce_map_ele(in_shape=(2, 3), axis=None, keepdims=True, type="all", dtype="bool")
+    verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), axis=(1,), keepdims=False, type="max")
+    verify_reduce_map_ele(in_shape=(32, 128, 24), axis=None, keepdims=True, type="sum")
+    verify_reduce_map_ele(
+        in_shape=(32, 128, 24), axis=None, keepdims=True, dtype="bool", type="all"
+    )
+    verify_reduce_map_ele(in_shape=(128, 24, 128, 24), axis=(0, 2), keepdims=False, type="min")
+    verify_reduce_map_ele(in_shape=(32, 128), axis=1, keepdims=True, type="argmax")
+    verify_reduce_map_ele(in_shape=(32, 24, 32, 24), axis=2, keepdims=False, type="argmin")
+    verify_reduce_map_ele(in_shape=(31, 21, 15), axis=None, keepdims=True, type="argmax")
+    verify_reduce_map_ele(in_shape=(31, 21, 15), axis=None, keepdims=False, type="sum")
+    verify_reduce_map_ele(
+        in_shape=(128, 24, 128, 24), axis=(1, 2, 3), keepdims=True, type="sum", dtype="float64"
+    )
+    verify_reduce_map_ele(in_shape=(2, 3), axis=None, keepdims=True, type="any", dtype="bool")
+    verify_reduce_map_ele(
+        in_shape=(32, 128, 24), axis=None, keepdims=True, type="any", dtype="bool"
+    )
+    verify_reduce_map_ele(in_shape=(1, 4, 7), axis=1, keepdims=True, type="any", dtype="bool")
+    verify_reduce_map_ele(
+        in_shape=(128, 24, 128, 24), axis=2, keepdims=False, type="any", dtype="bool"
+    )
+
 
 if __name__ == "__main__":
     test_reduce_map()
index 21e06b5..aa68f23 100644 (file)
@@ -26,8 +26,9 @@ from tvm.contrib.nvcc import have_fp16
 
 import tvm.testing
 
+
 def verify_relu(m, n, dtype="float32"):
-    A = te.placeholder((m, n), name='A', dtype=dtype)
+    A = te.placeholder((m, n), name="A", dtype=dtype)
     B = topi.nn.relu(A)
 
     a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype)
@@ -52,7 +53,7 @@ def verify_relu(m, n, dtype="float32"):
 
 
 def verify_leaky_relu(m, alpha):
-    A = te.placeholder((m,), name='A')
+    A = te.placeholder((m,), name="A")
     B = topi.nn.leaky_relu(A, alpha)
     s = te.create_schedule([B.op])
 
@@ -67,13 +68,13 @@ def verify_leaky_relu(m, alpha):
 
 
 def verify_prelu(x, w, axis, weight_reshape):
-    X = te.placeholder((x), name='X')
-    W = te.placeholder((w), name='W')
+    X = te.placeholder((x), name="X")
+    W = te.placeholder((w), name="W")
     x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
     w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
 
     def _prelu_numpy(x, W):
-        return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
+        return (x < 0) * (x * W.reshape(weight_reshape)) + (x >= 0) * x
 
     B = topi.nn.prelu(X, W, axis)
     s = te.create_schedule([B.op])
@@ -88,22 +89,27 @@ def verify_prelu(x, w, axis, weight_reshape):
     out_np = _prelu_numpy(x_np, w_np)
     tvm.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_relu():
     verify_relu(10, 128, "float32")
     verify_relu(128, 64, "float16")
 
+
 @tvm.testing.uses_gpu
 def test_schedule_big_array():
-    verify_relu(1024 * 100 , 512)
+    verify_relu(1024 * 100, 512)
+
 
 def test_leaky_relu():
     verify_leaky_relu(100, 0.1)
 
+
 def test_prelu():
     verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
     verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
-    verify_prelu((1, 3), (3,), 1, (3, ))
+    verify_prelu((1, 3), (3,), 1, (3,))
+
 
 if __name__ == "__main__":
     test_schedule_big_array()
index 95e028d..37a0eb9 100644 (file)
@@ -28,11 +28,12 @@ _reorg_schedule = {
     "gpu": topi.cuda.schedule_reorg,
 }
 
+
 def verify_reorg(batch, in_size, in_channel, stride):
-    '''Verify reorg operator by comparing outputs from tvm and numpy implementation'''
+    """Verify reorg operator by comparing outputs from tvm and numpy implementation"""
     in_height = in_width = in_size
 
-    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
+    A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
     B = topi.vision.reorg(A, stride)
 
     a_shape = get_const_tuple(A.shape)
@@ -46,7 +47,7 @@ def verify_reorg(batch, in_size, in_channel, stride):
     a_np, b_np = get_ref_data_reorg()
 
     def check_device(device):
-        '''Cheching devices is enabled or not'''
+        """Cheching devices is enabled or not"""
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
             print("Skip because %s is not enabled" % device)
@@ -61,12 +62,14 @@ def verify_reorg(batch, in_size, in_channel, stride):
         func(a, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm', 'cuda']:
+    for device in ["llvm", "cuda"]:
         check_device(device)
 
+
 @tvm.testing.uses_gpu
 def test_reorg():
     verify_reorg(1, 20, 8, 2)
 
+
 if __name__ == "__main__":
     test_reorg()
index 5107b64..df4c4ea 100644 (file)
@@ -33,6 +33,7 @@ _softmax_schedule = {
     "hls": topi.hls.schedule_softmax,
 }
 
+
 def check_device(A, B, a_np, b_np, device, ctx, name):
     print("Running on target: %s" % device)
     with tvm.target.Target(device):
@@ -45,8 +46,9 @@ def check_device(A, B, a_np, b_np, device, ctx, name):
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
+
 def verify_softmax(m, n, dtype="float32"):
-    A = te.placeholder((m, n), dtype=dtype, name='A')
+    A = te.placeholder((m, n), dtype=dtype, name="A")
     B = topi.nn.softmax(A)
     # confirm lower works
     s = te.create_schedule([B.op])
@@ -58,18 +60,20 @@ def verify_softmax(m, n, dtype="float32"):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(A, B, a_np, b_np, device, ctx, "softmax")
 
+
 def verify_softmax_4d(shape, dtype="float32"):
-    A = te.placeholder(shape, dtype=dtype, name='A')
+    A = te.placeholder(shape, dtype=dtype, name="A")
     B = topi.nn.softmax(A, axis=1)
 
     _, c, h, w = shape
     a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
-    b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
+    b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
     b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
 
     for device, ctx in tvm.testing.enabled_targets():
         check_device(A, B, a_np, b_np, device, ctx, "softmax")
 
+
 @tvm.testing.uses_gpu
 def test_softmax():
     verify_softmax(32, 10)
@@ -77,8 +81,9 @@ def test_softmax():
     verify_softmax(32, 10, "float64")
     verify_softmax_4d((1, 16, 256, 256))
 
+
 def verify_log_softmax(m, n, dtype="float32"):
-    A = te.placeholder((m, n), dtype=dtype, name='A')
+    A = te.placeholder((m, n), dtype=dtype, name="A")
     B = topi.nn.log_softmax(A)
     # confirm lower works
     s = te.create_schedule([B.op])
@@ -96,6 +101,7 @@ def test_log_softmax():
     verify_log_softmax(3, 4)
     verify_log_softmax(32, 10, "float64")
 
+
 if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)
     test_softmax()
index e33531f..7e0c982 100644 (file)
@@ -33,6 +33,7 @@ _topk_implement = {
     "gpu": (topi.cuda.topk, topi.cuda.schedule_topk),
 }
 
+
 def verify_argsort(axis, is_ascend):
     dshape = (20, 100)
     data_dtype = "float32"
@@ -48,9 +49,9 @@ def verify_argsort(axis, is_ascend):
         np_indices = np.argsort(-np_data, axis=axis)
 
     if axis == 0:
-        np_indices = np_indices[:dshape[axis], :]
+        np_indices = np_indices[: dshape[axis], :]
     else:
-        np_indices = np_indices[:, :dshape[axis]]
+        np_indices = np_indices[:, : dshape[axis]]
 
     def check_device(device):
         if not tvm.testing.device_enabled(device):
@@ -69,7 +70,7 @@ def verify_argsort(axis, is_ascend):
         f(tvm_data, tvm_out)
         tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0)
 
-    for device in ['llvm', 'cuda', 'opencl']:
+    for device in ["llvm", "cuda", "opencl"]:
         check_device(device)
 
 
@@ -121,7 +122,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
         else:
             tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
 
-    for device in ['llvm', 'cuda', 'opencl']:
+    for device in ["llvm", "cuda", "opencl"]:
         check_device(device)
 
 
index 504c359..7cdaa68 100644 (file)
@@ -22,29 +22,29 @@ from tvm import topi
 import tvm.topi.testing
 
 
-def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, layout='NCHW'):
+def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, layout="NCHW"):
     out_channel = int(in_channel * (block_size * block_size))
     out_height = int(in_height / block_size)
     out_width = int(in_width / block_size)
 
-    if layout == 'NCHW':
+    if layout == "NCHW":
         in_shape = [batch, in_channel, in_height, in_width]
         out_shape = [batch, out_channel, out_height, out_width]
-    elif layout == 'NHWC':
+    elif layout == "NHWC":
         in_shape = [batch, in_height, in_width, in_channel]
         out_shape = [batch, out_height, out_width, out_channel]
     else:
-        raise NotImplementedError('Layout not supported {}'.format(layout))
+        raise NotImplementedError("Layout not supported {}".format(layout))
 
-    A = te.placeholder(in_shape, name='A', dtype='float32')
+    A = te.placeholder(in_shape, name="A", dtype="float32")
     dtype = A.dtype
     a_np = np.random.uniform(size=in_shape).astype(dtype)
 
     B = topi.nn.space_to_depth(A, block_size=block_size, layout=layout)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         a_np = np.transpose(a_np, axes=[0, 3, 1, 2])
     b_np = tvm.topi.testing.space_to_depth_python(a_np, block_size)
-    if layout == 'NHWC':
+    if layout == "NHWC":
         a_np = np.transpose(a_np, axes=[0, 2, 3, 1])
         b_np = np.transpose(b_np, axes=[0, 2, 3, 1])
 
@@ -64,7 +64,7 @@ def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, la
 
 @tvm.testing.uses_gpu
 def test_space_to_depth():
-    for layout in ['NCHW', 'NHWC']:
+    for layout in ["NCHW", "NHWC"]:
         # Simplest possible case
         verify_space_to_depth(2, 1, 1, 2, 2, layout=layout)
         # Average input size
index dbab292..b50110a 100644 (file)
@@ -30,29 +30,31 @@ import tvm.testing
 _sparse_dense_implement = {
     "generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense),
     "cuda": (topi.cuda.sparse_dense, topi.cuda.schedule_sparse_dense),
-    "x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense)
+    "x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense),
 }
 
+
 def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
     nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
-    dtype = 'float32'
-    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
-    B = te.placeholder((in_dim, 1), name='B')
-    C = te.placeholder((nr,), name='C')
+    dtype = "float32"
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A")
+    B = te.placeholder((in_dim, 1), name="B")
+    C = te.placeholder((nr,), name="C")
     D = topi.sparse.csrmv(A, B, C if use_bias else None)
     s = te.create_schedule(D.op)
     dtype = A.dtype
 
     # get the test data
     def get_ref_data():
-        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
-        b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype)-0.5
-        c_np = np.random.uniform(size=(batch, )).astype(dtype)
+        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0)
+        b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype) - 0.5
+        c_np = np.random.uniform(size=(batch,)).astype(dtype)
         if use_bias:
             d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
         else:
             d_np = np.dot(a_np, b_np)
         return (a_np, b_np, c_np, d_np)
+
     a_np, b_np, c_np, d_np = get_ref_data()
 
     def check_device(device):
@@ -63,7 +65,7 @@ def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
         print("Running on target: %s" % device)
         a = tvmsp.array(a_np, ctx)
         _nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
-        assert a.shape[0] == a.indptr.shape[0]-1
+        assert a.shape[0] == a.indptr.shape[0] - 1
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(c_np, ctx)
         d = tvm.nd.array(np.zeros((_nr, 1), dtype=dtype), ctx)
@@ -77,26 +79,28 @@ def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
     for device in ["llvm"]:
         check_device(device)
 
+
 def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
     nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
-    dtype = 'float32'
-    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
-    B = te.placeholder((in_dim, out_dim), name='B')
-    C = te.placeholder((nr,), name='C')
+    dtype = "float32"
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A")
+    B = te.placeholder((in_dim, out_dim), name="B")
+    C = te.placeholder((nr,), name="C")
     D = topi.sparse.csrmm(A, B, C if use_bias else None)
     s = te.create_schedule(D.op)
     dtype = A.dtype
 
     # get the test data
     def get_ref_data():
-        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
-        b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype)-0.5
-        c_np = np.random.uniform(size=(batch, )).astype(dtype)
+        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0)
+        b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype) - 0.5
+        c_np = np.random.uniform(size=(batch,)).astype(dtype)
         if use_bias:
             d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
         else:
             d_np = np.dot(a_np, b_np)
         return (a_np, b_np, c_np, d_np)
+
     a_np, b_np, c_np, d_np = get_ref_data()
 
     def check_device(device):
@@ -107,7 +111,7 @@ def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
         print("Running on target: %s" % device)
         a = tvmsp.array(a_np, ctx)
         _nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
-        assert a.shape[0] == a.indptr.shape[0]-1
+        assert a.shape[0] == a.indptr.shape[0] - 1
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(c_np, ctx)
         d = tvm.nd.array(np.zeros((_nr, out_dim), dtype=dtype), ctx)
@@ -119,25 +123,31 @@ def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
     for device in ["llvm"]:
         check_device(device)
 
-def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
-    nonzeros = te.var('nonzeros')
-    A = tvmsp.placeholder(shape=(batch, in_dim), nonzeros=nonzeros, dtype=dtype, name='A')
-    B = te.placeholder((out_dim, in_dim), dtype=dtype, name='B')
-    C = te.placeholder((out_dim,), dtype=dtype, name='C')
+
+def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype="float32"):
+    nonzeros = te.var("nonzeros")
+    A = tvmsp.placeholder(shape=(batch, in_dim), nonzeros=nonzeros, dtype=dtype, name="A")
+    B = te.placeholder((out_dim, in_dim), dtype=dtype, name="B")
+    C = te.placeholder((out_dim,), dtype=dtype, name="C")
     D = topi.sparse.dense(A, B, C if use_bias else None)
     s = te.create_schedule(D.op)
 
     # get the test data
     def get_ref_data():
-        mag = 10.
-        a_np = np.maximum(mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
-        b_np = (mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-.5)).astype(dtype)
-        c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
+        mag = 10.0
+        a_np = np.maximum(
+            mag * (np.random.uniform(size=(batch, in_dim)).astype("float32") - 0.5), 0.0
+        ).astype(dtype)
+        b_np = (mag * (np.random.uniform(size=(out_dim, in_dim)).astype("float32") - 0.5)).astype(
+            dtype
+        )
+        c_np = (mag * (np.random.uniform(size=(out_dim,)).astype("float32") - 0.5)).astype(dtype)
         if use_bias:
             d_np = np.dot(a_np, b_np.T) + c_np
         else:
             d_np = np.dot(a_np, b_np.T)
         return (a_np, b_np, c_np, d_np)
+
     a_np, b_np, c_np, d_np = get_ref_data()
 
     def check_device(device):
@@ -154,27 +164,33 @@ def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
         f(a.data, a.indices, a.indptr, b, c, d)
         tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-4, atol=1e-4)
 
-    check_device('llvm')
+    check_device("llvm")
 
-def verify_dense_sw(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
-    nonzeros = te.var('nonzeros')
-    A = te.placeholder((batch, in_dim), dtype=dtype, name='A')
-    B = tvmsp.placeholder(shape=(out_dim, in_dim), nonzeros=nonzeros, dtype=dtype, name='B')
-    C = te.placeholder((out_dim,), dtype=dtype, name='C')
+
+def verify_dense_sw(batch, in_dim, out_dim, use_bias=True, dtype="float32"):
+    nonzeros = te.var("nonzeros")
+    A = te.placeholder((batch, in_dim), dtype=dtype, name="A")
+    B = tvmsp.placeholder(shape=(out_dim, in_dim), nonzeros=nonzeros, dtype=dtype, name="B")
+    C = te.placeholder((out_dim,), dtype=dtype, name="C")
     D = topi.sparse.dense(A, B, C if use_bias else None)
     s = te.create_schedule(D.op)
 
     # get the test data
     def get_ref_data():
-        mag = 10.
-        a_np = (mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-.5)).astype(dtype)
-        b_np = np.maximum(mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
-        c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
+        mag = 10.0
+        a_np = (mag * (np.random.uniform(size=(batch, in_dim)).astype("float32") - 0.5)).astype(
+            dtype
+        )
+        b_np = np.maximum(
+            mag * (np.random.uniform(size=(out_dim, in_dim)).astype("float32") - 0.5), 0.0
+        ).astype(dtype)
+        c_np = (mag * (np.random.uniform(size=(out_dim,)).astype("float32") - 0.5)).astype(dtype)
         if use_bias:
             d_np = np.dot(a_np, b_np.T) + c_np
         else:
             d_np = np.dot(a_np, b_np.T)
         return (a_np, b_np, c_np, d_np)
+
     a_np, b_np, c_np, d_np = get_ref_data()
 
     def check_device(device):
@@ -191,34 +207,39 @@ def verify_dense_sw(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
         f(a, b.data, b.indices, b.indptr, c, d)
         tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-4, atol=1e-4)
 
-    check_device('llvm')
+    check_device("llvm")
+
 
 def test_csrmv():
     verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False)
     verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True)
 
+
 def test_csrmm():
     M, K, N = 5, 7, 2
     verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False)
     verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True)
 
+
 def test_dense_si():
     M, K, N = 3, 5, 2
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
-    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="float32")
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="float32")
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="int32")
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="int32")
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="int16")
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="int16")
+
 
 def test_dense_sw():
     M, K, N = 3, 5, 2
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
-    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="float32")
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="float32")
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="int32")
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="int32")
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype="int16")
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype="int16")
+
 
 def test_dense():
     test_dense_si()
@@ -228,7 +249,7 @@ def test_dense():
 def test_sparse_dense_csr():
     M, N, K, density = 1, 17, 47, 0.2
     X_np = np.random.randn(M, K).astype("float32")
-    W_sp_np = sp.random(N, K, density=density, format='csr', dtype="float32")
+    W_sp_np = sp.random(N, K, density=density, format="csr", dtype="float32")
     W_np = W_sp_np.todense()
     Y_np = X_np.dot(W_np.T)
 
@@ -240,13 +261,20 @@ def test_sparse_dense_csr():
     s = te.create_schedule(Y.op)
     func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
     Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-    func(tvm.nd.array(X_np), tvm.nd.array(W_sp_np.data), tvm.nd.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indptr), Y_tvm)
+    func(
+        tvm.nd.array(X_np),
+        tvm.nd.array(W_sp_np.data),
+        tvm.nd.array(W_sp_np.indices),
+        tvm.nd.array(W_sp_np.indptr),
+        Y_tvm,
+    )
     tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
 
+
 def test_sparse_transpose_csr():
     N, density = 1023, 0.3
 
-    X_sp = sp.random(N, N, density=density, format='csr', dtype='float32')
+    X_sp = sp.random(N, N, density=density, format="csr", dtype="float32")
 
     X_sp_T = X_sp.transpose()
     X_np_T = X_sp_T.todense()
@@ -259,19 +287,28 @@ def test_sparse_transpose_csr():
     s = te.create_schedule([X_T_data.op, X_T_indices.op, X_T_indptr.op])
     func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
 
-
     X_T_data_tvm = tvm.nd.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
     X_T_indices_tvm = tvm.nd.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
     X_T_indptr_tvm = tvm.nd.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
 
-    func(tvm.nd.array(X_sp.data), tvm.nd.array(X_sp.indices), tvm.nd.array(X_sp.indptr),
-        X_T_data_tvm,  X_T_indices_tvm, X_T_indptr_tvm)
-
-    X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
+    func(
+        tvm.nd.array(X_sp.data),
+        tvm.nd.array(X_sp.indices),
+        tvm.nd.array(X_sp.indptr),
+        X_T_data_tvm,
+        X_T_indices_tvm,
+        X_T_indptr_tvm,
+    )
+
+    X_T_out = sp.csr_matrix(
+        (X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N, N)
+    ).todense()
     tvm.testing.assert_allclose(X_np_T, X_T_out, atol=1e-4, rtol=1e-4)
 
+
 def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
     import itertools
+
     Y = np.zeros((M, N), dtype=dtype)
     assert M % BS_R == 0
     assert N % BS_C == 0
@@ -279,16 +316,19 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
     num_blocks = int(nnz / (BS_R * BS_C)) + 1
     candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
     assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
-    chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
+    chosen_blocks = candidate_blocks[
+        np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
+    ]
     for i in range(len(chosen_blocks)):
         r, c = chosen_blocks[i]
-        Y[r:r + BS_R, c:c + BS_C] = np.random.randn(BS_R, BS_C)
+        Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C)
     s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
     assert s.data.shape == (num_blocks, BS_R, BS_C)
-    assert s.indices.shape == (num_blocks, )
-    assert s.indptr.shape == (M // BS_R + 1, )
+    assert s.indices.shape == (num_blocks,)
+    assert s.indptr.shape == (M // BS_R + 1,)
     return s
 
+
 def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu):
     X_np = np.random.randn(M, K).astype("float32")
     W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
@@ -316,22 +356,26 @@ def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu):
             s = fschedule([Y])
             func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
             Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
-            func(tvm.nd.array(X_np, ctx=ctx),
-                 tvm.nd.array(W_sp_np.data, ctx=ctx),
-                 tvm.nd.array(W_sp_np.indices, ctx=ctx),
-                 tvm.nd.array(W_sp_np.indptr, ctx=ctx),
-                 Y_tvm)
+            func(
+                tvm.nd.array(X_np, ctx=ctx),
+                tvm.nd.array(W_sp_np.data, ctx=ctx),
+                tvm.nd.array(W_sp_np.indices, ctx=ctx),
+                tvm.nd.array(W_sp_np.indptr, ctx=ctx),
+                Y_tvm,
+            )
             tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
 
-    for device in ['llvm', 'cuda']:
+    for device in ["llvm", "cuda"]:
         check_device(device)
 
+
 @tvm.testing.uses_gpu
 def test_sparse_dense_bsr():
     M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
     verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True)
     verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False)
 
+
 @tvm.testing.uses_gpu
 def test_sparse_dense_bsr_randomized():
     for _ in range(20):
@@ -364,14 +408,16 @@ def test_sparse_dense_bsr_randomized():
                 s = fschedule([Y])
                 func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
                 Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
-                func(tvm.nd.array(X_np, ctx=ctx),
-                     tvm.nd.array(W_sp_np.data, ctx=ctx),
-                     tvm.nd.array(W_sp_np.indices, ctx=ctx),
-                     tvm.nd.array(W_sp_np.indptr, ctx=ctx),
-                     Y_tvm)
+                func(
+                    tvm.nd.array(X_np, ctx=ctx),
+                    tvm.nd.array(W_sp_np.data, ctx=ctx),
+                    tvm.nd.array(W_sp_np.indices, ctx=ctx),
+                    tvm.nd.array(W_sp_np.indptr, ctx=ctx),
+                    Y_tvm,
+                )
                 tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
 
-        for device in ['llvm', 'cuda']:
+        for device in ["llvm", "cuda"]:
             check_device(device)
 
 
index 7052a7e..d384767 100644 (file)
@@ -24,21 +24,21 @@ from tvm.contrib.pickle_memoize import memoize
 from tvm.contrib.nvcc import have_fp16
 import tvm.testing
 
+
 def verify_elemwise_sum(num_args, dtype):
-    shape = (3,5,4)
+    shape = (3, 5, 4)
 
     tvm_placeholders = []
     for i in range(num_args):
-        tvm_placeholders.append(
-            te.placeholder(shape, name="data"+str(i), dtype=dtype))
+        tvm_placeholders.append(te.placeholder(shape, name="data" + str(i), dtype=dtype))
     esum = topi.elemwise_sum(tvm_placeholders)
     s = te.create_schedule([esum.op])
 
     @memoize("topi.tests.test_topi_elemwise_sum")
     def get_ref_data():
-        np_nd = [np.random.uniform(0, 10, size=shape).astype(dtype)
-                 for i in range(num_args)]
+        np_nd = [np.random.uniform(0, 10, size=shape).astype(dtype) for i in range(num_args)]
         return np_nd
+
     np_nd = get_ref_data()
 
     def check_device(device):
@@ -68,6 +68,7 @@ def verify_full(shape, dtype, fill_value):
     @memoize("topi.tests.test_topi_full")
     def get_ref_data():
         return np.full(shape, fill_value, dtype)
+
     np_nd = get_ref_data()
 
     def check_device(device):
@@ -88,6 +89,7 @@ def verify_full(shape, dtype, fill_value):
     for device in ["llvm"]:
         check_device(device)
 
+
 def verify_vectorization(n, m, dtype):
     def check_device(device):
         if not tvm.testing.device_enabled(device):
@@ -98,14 +100,12 @@ def verify_vectorization(n, m, dtype):
             return
         with tvm.target.Target(device):
             ctx = tvm.context(device, 0)
-            A = te.placeholder((n, m), name='A', dtype=dtype)
-            B = te.compute((n, m), lambda i, j:
-                             A[i, j] + tvm.tir.const(1, A.dtype), name='B')
+            A = te.placeholder((n, m), name="A", dtype=dtype)
+            B = te.compute((n, m), lambda i, j: A[i, j] + tvm.tir.const(1, A.dtype), name="B")
             S = tvm.topi.testing.get_elemwise_schedule(device)(B)
 
             fun = tvm.build(S, [A, B], device)
-            np_A = tvm.nd.empty((n, m), A.dtype, ctx).copyfrom(
-                                np.random.uniform(size=(n, m)))
+            np_A = tvm.nd.empty((n, m), A.dtype, ctx).copyfrom(np.random.uniform(size=(n, m)))
             np_B = tvm.nd.empty((n, m), B.dtype, ctx)
             fun(np_A, np_B)
             tvm.testing.assert_allclose(np_B.asnumpy(), np_A.asnumpy() + 1, rtol=1e-5)
@@ -113,20 +113,24 @@ def verify_vectorization(n, m, dtype):
     for device in ["cuda"]:
         check_device(device)
 
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_vectorization():
     verify_vectorization(128, 64, "float16")
 
+
 def test_elemwise_sum():
     verify_elemwise_sum(1, "float32")
     verify_elemwise_sum(5, "float32")
     verify_elemwise_sum(4, "int32")
 
+
 def test_full():
-    verify_full((3,4,5), "float32", 3.14)
+    verify_full((3, 4, 5), "float32", 3.14)
     verify_full((10,), "int32", 7)
 
+
 if __name__ == "__main__":
     test_elemwise_sum()
     test_full()
index fc6f19f..a32d41a 100644 (file)
@@ -25,9 +25,11 @@ from tvm.contrib.nvcc import have_fp16
 
 import tvm.testing
 
+
 def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.expand_dims(A, axis, num_newaxis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -47,8 +49,9 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
 def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
     A = te.placeholder(shape=in_shape, name="A", dtype=in_dtype)
     B = topi.reinterpret(A, out_dtype)
+
     def check_device(device, ctx):
-        if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version):
+        if in_dtype == "float16" and device == "cuda" and not have_fp16(ctx.compute_version):
             print("Skip because %s does not have fp16 support" % device)
             return
         print("Running on target: %s" % device)
@@ -69,6 +72,7 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
 def verify_transpose(in_shape, axes):
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.transpose(A, axes)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -88,6 +92,7 @@ def verify_transpose(in_shape, axes):
 def verify_reshape(src_shape, dst_shape):
     A = te.placeholder(shape=src_shape, name="A")
     B = topi.reshape(A, dst_shape)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -107,6 +112,7 @@ def verify_reshape(src_shape, dst_shape):
 def verify_squeeze(src_shape, axis):
     A = te.placeholder(shape=src_shape, name="A")
     B = topi.squeeze(A, axis=axis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -124,8 +130,8 @@ def verify_squeeze(src_shape, axis):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
-def verify_concatenate(shapes, axis):
 
+def verify_concatenate(shapes, axis):
     def get_concat_schedule(target):
         schedule_map = {
             "cpu": topi.x86.schedule_concatenate,
@@ -142,6 +148,7 @@ def verify_concatenate(shapes, axis):
     for i, shape in enumerate(shapes):
         tensor_l.append(te.placeholder(shape, name="A" + str(i)))
     out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -158,11 +165,13 @@ def verify_concatenate(shapes, axis):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_stack(shapes, axis):
     tensor_l = []
     for i, shape in enumerate(shapes):
         tensor_l.append(te.placeholder(shape, name="A" + str(i)))
     out_tensor = topi.stack(tensor_l, axis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -183,6 +192,7 @@ def verify_stack(shapes, axis):
 def verify_split(src_shape, indices_or_sections, axis):
     A = te.placeholder(shape=src_shape, name="A")
     tensor_l = topi.split(A, indices_or_sections, axis=axis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -192,7 +202,9 @@ def verify_split(src_shape, indices_or_sections, axis):
         data_npy = np.random.normal(size=src_shape).astype(A.dtype)
         out_npys = np.split(data_npy, indices_or_sections, axis=axis)
         data_nd = tvm.nd.array(data_npy, ctx)
-        out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys]
+        out_nds = [
+            tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys
+        ]
         foo(*([data_nd] + out_nds))
         for out_nd, out_npy in zip(out_nds, out_npys):
             tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
@@ -221,7 +233,7 @@ def verify_expand_like(in_shape, out_shape, axis):
         for x in real_axis:
             input = np.expand_dims(input, x).astype(A.dtype)
         for x in real_axis:
-            input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype)
+            input = np.concatenate([input] * out_shape[x], axis=x).astype(A.dtype)
         assert input.shape == out_shape
 
         tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
@@ -232,9 +244,11 @@ def verify_expand_like(in_shape, out_shape, axis):
     for device in ["llvm"]:
         check_device(device)
 
+
 def verify_flip(in_shape, axis):
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.flip(A, axis) + 1
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
@@ -281,46 +295,53 @@ def test_reverse_sequence():
             check_device(device, ctx)
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 5, 10, 15],
-              [4, 1, 6, 11],
-              [8, 9, 2, 7],
-              [12, 13, 14, 3]]
+    result = [[0, 5, 10, 15], [4, 1, 6, 11], [8, 9, 2, 7], [12, 13, 14, 3]]
     verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result))
     verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result))
-    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32"))
+    verify_reverse_sequence(
+        indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32")
+    )
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 1, 2, 3],
-              [5, 4, 6, 7],
-              [10, 9, 8, 11],
-              [15, 14, 13, 12]]
+    result = [[0, 1, 2, 3], [5, 4, 6, 7], [10, 9, 8, 11], [15, 14, 13, 12]]
     verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result))
     verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result))
-    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32"))
+    verify_reverse_sequence(
+        indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32")
+    )
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
-    result = [[0, 1, 2, 3],
-              [4, 5, 6, 7],
-              [8, 9, 10, 11],
-              [15, 14, 13, 12]]
+    result = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [15, 14, 13, 12]]
     verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result))
 
     indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
-    result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
-               [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
-               [[0,  1,  2], [3,  4,  5], [6,  7,  8]]],
-              [[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
-               [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
-               [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]]
+    result = [
+        [
+            [[18, 19, 20], [21, 22, 23], [24, 25, 26]],
+            [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8]],
+        ],
+        [
+            [[45, 46, 47], [48, 49, 50], [51, 52, 53]],
+            [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
+            [[27, 28, 29], [30, 31, 32], [33, 34, 35]],
+        ],
+    ]
     verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result))
 
     indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
-    result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]],
-               [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
-               [[18, 19, 20], [3, 4, 5], [24, 25, 26]]],
-              [[[36, 37, 38], [48, 49, 50], [42, 43, 44]],
-               [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
-               [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]]
+    result = [
+        [
+            [[9, 10, 11], [21, 22, 23], [15, 16, 17]],
+            [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
+            [[18, 19, 20], [3, 4, 5], [24, 25, 26]],
+        ],
+        [
+            [[36, 37, 38], [48, 49, 50], [42, 43, 44]],
+            [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
+            [[45, 46, 47], [30, 31, 32], [51, 52, 53]],
+        ],
+    ]
     verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result))
 
     indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
@@ -328,8 +349,11 @@ def test_reverse_sequence():
     with pytest.raises(Exception) as execinfo:
         verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result))
 
-    assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \
-           " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+    assert (
+        "For reverse_sequnece seq_lengths size should match with dimension of batch axis,"
+        " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+    )
+
 
 def verify_take(src_shape, indices_src, axis=None, mode="clip"):
     src_dtype = "float32"
@@ -351,7 +375,7 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
         with tvm.target.Target(device):
             s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
 
-        foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
+        foo = tvm.build(s, [A] + [indices] + [out_tensor], device, name="take")
         shape_size = 1
         for i in range(len(src_shape)):
             shape_size = shape_size * src_shape[i]
@@ -372,9 +396,10 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
     for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
+
 def verify_strided_slice(in_shape, begin, end, strides=None):
     A = te.placeholder(shape=in_shape, name="A")
-    strides = [1,1,1] if strides is None else strides
+    strides = [1, 1, 1] if strides is None else strides
     B = topi.strided_slice(A, begin, end, strides) + 1
 
     def check_device(device):
@@ -388,8 +413,7 @@ def verify_strided_slice(in_shape, begin, end, strides=None):
 
         foo = tvm.build(s, [A, B], device, name="stride_slice")
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
-        out_npy = tvm.topi.testing.strided_slice_python(
-            x_np, begin, end, strides) + 1
+        out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1
         data_nd = tvm.nd.array(x_np, ctx)
         out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
         foo(data_nd, out_nd)
@@ -398,13 +422,14 @@ def verify_strided_slice(in_shape, begin, end, strides=None):
     for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
+
 def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
     A = te.placeholder(shape=in_shape, name="A")
     V = te.placeholder(shape=v_shape, name="V")
-    b = te.placeholder(shape=(len(begin),), name="b", dtype='int32')
-    e = te.placeholder(shape=(len(end),), name="e", dtype='int32')
+    b = te.placeholder(shape=(len(begin),), name="b", dtype="int32")
+    e = te.placeholder(shape=(len(end),), name="e", dtype="int32")
     if strides is not None:
-        st = te.placeholder(shape=(len(strides),), name="st", dtype='int32')
+        st = te.placeholder(shape=(len(strides),), name="st", dtype="int32")
         B = topi.strided_set(A, V, b, e, st) + 1
     else:
         B = topi.strided_set(A, V, b, e) + 1
@@ -420,16 +445,15 @@ def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
 
         if strides is not None:
             foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
-            s_np = np.asarray(strides).astype('int32')
+            s_np = np.asarray(strides).astype("int32")
             s_nd = tvm.nd.array(s_np, ctx)
         else:
             foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set")
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
         v_np = np.random.uniform(size=v_shape).astype(V.dtype)
-        b_np = np.asarray(begin).astype('int32')
-        e_np = np.asarray(end).astype('int32')
-        out_npy = tvm.topi.testing.strided_set_python(
-            x_np, v_np, begin, end, strides) + 1
+        b_np = np.asarray(begin).astype("int32")
+        e_np = np.asarray(end).astype("int32")
+        out_npy = tvm.topi.testing.strided_set_python(x_np, v_np, begin, end, strides) + 1
         data_nd = tvm.nd.array(x_np, ctx)
         v_nd = tvm.nd.array(v_np, ctx)
         b_nd = tvm.nd.array(b_np, ctx)
@@ -444,6 +468,7 @@ def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
     for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
+
 def verify_gather(data, axis, indices):
     data = np.asarray(data)
     indices = np.asarray(indices)
@@ -457,7 +482,7 @@ def verify_gather(data, axis, indices):
         with tvm.target.Target(device):
             s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
 
-        func = tvm.build(s, [var_data, var_indices, out_tensor] , device, name="gather")
+        func = tvm.build(s, [var_data, var_indices, out_tensor], device, name="gather")
         out_npys = tvm.topi.testing.gather_python(data, axis, indices)
 
         data_nd = tvm.nd.array(data, ctx)
@@ -469,6 +494,7 @@ def verify_gather(data, axis, indices):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_gather_nd(src_shape, indices_src, indices_dtype):
     src_dtype = "float32"
     indices_src = np.array(indices_src, dtype=indices_dtype)
@@ -481,7 +507,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
         with tvm.target.Target(device):
             s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
 
-        func = tvm.build(s, [A, indices, out_tensor] , device, name="take")
+        func = tvm.build(s, [A, indices, out_tensor], device, name="take")
         shape_size = 1
         for i in range(len(src_shape)):
             shape_size = shape_size * src_shape[i]
@@ -497,6 +523,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_arange(start, stop, step):
     if start is None and step is None:
         A = topi.arange(stop)
@@ -516,16 +543,18 @@ def verify_arange(start, stop, step):
         with tvm.target.Target(device):
             s = tvm.topi.testing.get_injective_schedule(device)(A)
         f = tvm.build(s, [A], device, name="arange")
-        a_nd = tvm.nd.empty(a_np.shape, dtype='float32', ctx=ctx)
+        a_nd = tvm.nd.empty(a_np.shape, dtype="float32", ctx=ctx)
         f(a_nd)
         tvm.testing.assert_allclose(a_nd.asnumpy(), a_np)
 
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_repeat(in_shape, repeats, axis):
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.repeat(A, repeats, axis)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -541,9 +570,11 @@ def verify_repeat(in_shape, repeats, axis):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_tile(in_shape, reps):
     A = te.placeholder(shape=in_shape, name="A")
     B = topi.tile(A, reps)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -559,12 +590,14 @@ def verify_tile(in_shape, reps):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_where(in_shape):
     Cond = te.placeholder(shape=in_shape, name="cond")
     dtype = Cond.dtype
     A = te.placeholder(shape=in_shape, name="A")
     B = te.placeholder(shape=in_shape, name="B")
     C = topi.where(Cond, A, B)
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -584,11 +617,15 @@ def verify_where(in_shape):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
     indices = te.placeholder(shape=indices_shape, name="indices", dtype="int32")
     on_value_const = tvm.tir.const(on_value, dtype)
     off_value_const = tvm.tir.const(off_value, dtype)
-    one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
+    one_hot_result = topi.transform.one_hot(
+        indices, on_value_const, off_value_const, depth, axis, dtype
+    )
+
     def check_device(device, ctx):
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
@@ -634,14 +671,19 @@ def verify_unravel_index(indices, shape, dtype):
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):
     sparse_indices_data = np.array(sparse_indices)
     sparse_values_data = np.array(sparse_values)
     output_shape_data = np.array(output_shape)
     default_value_data = np.array(default_value)
 
-    A = te.placeholder(shape=sparse_indices_data.shape, name="sparse_indices", dtype=str(sparse_indices_data.dtype))
-    B = te.placeholder(shape=sparse_values_data.shape, name="sparse_values", dtype=str(sparse_values_data.dtype))
+    A = te.placeholder(
+        shape=sparse_indices_data.shape, name="sparse_indices", dtype=str(sparse_indices_data.dtype)
+    )
+    B = te.placeholder(
+        shape=sparse_values_data.shape, name="sparse_values", dtype=str(sparse_values_data.dtype)
+    )
     if default_value is None:
         args = [A, B]
         D = topi.sparse_to_dense(A, output_shape, B)
@@ -672,6 +714,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 def verify_matrix_set_diag(input_shape, dtype):
     diagonal_shape = list(input_shape[:-2])
     diagonal_shape.append(min(input_shape[-2], input_shape[-1]))
@@ -698,6 +741,7 @@ def verify_matrix_set_diag(input_shape, dtype):
     for target, ctx in tvm.testing.enabled_targets():
         check_device(target, ctx)
 
+
 def verify_adv_index(data_shape, index_shapes):
     dtype = "float32"
     data = te.placeholder(shape=data_shape, name="data", dtype=dtype)
@@ -733,6 +777,7 @@ def verify_adv_index(data_shape, index_shapes):
     for target, ctx in tvm.testing.enabled_targets():
         check_device(target, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_strided_slice():
     verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
@@ -743,6 +788,7 @@ def test_strided_slice():
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
     verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
 
+
 @tvm.testing.uses_gpu
 def test_strided_set():
     verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
@@ -755,6 +801,7 @@ def test_strided_set():
     verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1, 0], [4, 4, 3])
     verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1], [4, 4, 3])
 
+
 @tvm.testing.uses_gpu
 def test_expand_dims():
     verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
@@ -763,16 +810,17 @@ def test_expand_dims():
 
 @tvm.testing.uses_gpu
 def test_reinterpret():
-    verify_reinterpret((1000,), "float32", "int32",
-                       lambda shape: np.random.randn(*shape) * 1000)
-    verify_reinterpret((1000,), "float16", "int16",
-                       lambda shape: np.random.randn(*shape) * 100)
-    verify_reinterpret((1000,), "int16", "uint16",
-                       lambda shape: np.random.randint(-1000, 1000, size=shape))
-    verify_reinterpret((1000,), "uint32", "int32",
-                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
-    verify_reinterpret((1000,), "uint32", "int32",
-                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
+    verify_reinterpret((1000,), "float32", "int32", lambda shape: np.random.randn(*shape) * 1000)
+    verify_reinterpret((1000,), "float16", "int16", lambda shape: np.random.randn(*shape) * 100)
+    verify_reinterpret(
+        (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape)
+    )
+    verify_reinterpret(
+        (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
+    )
+    verify_reinterpret(
+        (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
+    )
 
 
 @tvm.testing.uses_gpu
@@ -787,7 +835,7 @@ def test_reshape():
     verify_reshape((1, 2, 3, 4), (2, 3, 4))
     verify_reshape((4, 2, 3, 4), (2, 4, 12))
     verify_reshape((4, 2, 3, 4), (2, 48))
-    verify_reshape((16, ), (2, 2, 2, 2))
+    verify_reshape((16,), (2, 2, 2, 2))
     verify_reshape((4, 0), (2, 0, 2))
 
 
@@ -804,17 +852,17 @@ def test_squeeze():
     verify_squeeze((1, 1, 1, 1), None)
 
     # a special case to trigger inline let expression
-    A = te.placeholder((2,), 'float32', 'A')
+    A = te.placeholder((2,), "float32", "A")
     E = topi.squeeze(A)
-    C = te.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')])
-    for device in ['cuda', 'opencl']:
+    C = te.compute((1,), lambda i: E[(2 * A[0] - 1).astype("int32")])
+    for device in ["cuda", "opencl"]:
         ctx = tvm.context(device, 0)
         if tvm.testing.device_enabled(device):
             with tvm.target.Target(device):
                 s = tvm.topi.testing.get_injective_schedule(device)(C)
                 func = tvm.build(s, [A, C])
-            a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx)
-            c = tvm.nd.empty((1,), dtype='float32', ctx=ctx)
+            a = tvm.nd.array(np.array((1, 2)).astype("float32"), ctx=ctx)
+            c = tvm.nd.empty((1,), dtype="float32", ctx=ctx)
             func(a, c)
             assert c.asnumpy()[0] == 2
 
@@ -824,11 +872,7 @@ def test_concatenate():
     verify_concatenate([(2,), (2,), (2,)], -1)
     verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
     verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
-    verify_concatenate([(5, 6, 7, 3),
-                        (16, 6, 7, 3),
-                        (12, 6, 7, 3),
-                        (8, 6, 7, 3),
-                        (2, 6, 7, 3)], 0)
+    verify_concatenate([(5, 6, 7, 3), (16, 6, 7, 3), (12, 6, 7, 3), (8, 6, 7, 3), (2, 6, 7, 3)], 0)
     verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
 
 
@@ -847,6 +891,7 @@ def test_split():
     verify_split((2, 12, 3), [2, 4], 1)
     verify_split((10, 12, 24), [5, 7, 9], -1)
 
+
 @tvm.testing.uses_gpu
 def test_flip():
     verify_flip((3, 4, 3), 1)
@@ -856,6 +901,7 @@ def test_flip():
     verify_flip((3, 4, 3), -3)
     verify_flip((3, 4, 3), -2)
 
+
 @tvm.testing.requires_llvm
 def test_expand_like():
     verify_expand_like((3,), (2, 3), [0])
@@ -863,25 +909,27 @@ def test_expand_like():
     verify_expand_like((3, 4), (3, 5, 4), [1])
     verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])
 
+
 @tvm.testing.uses_gpu
 def test_take():
     verify_take((4,), [1])
-    verify_take((4,), [[0,1,2,3]])
-    verify_take((3,3,3), [[11,25]])
-    verify_take((4,), [[0,1],[2,3]])
+    verify_take((4,), [[0, 1, 2, 3]])
+    verify_take((3, 3, 3), [[11, 25]])
+    verify_take((4,), [[0, 1], [2, 3]])
     verify_take((4,), [1], 0)
-    verify_take((2,2), [[[1,0],[0,1]]], 0)
-    verify_take((2,2), [[[1,0],[0,1]]], 1)
-    verify_take((4,3,5,6), [[2,1,0,0]], -2)
-    verify_take((3,4), [-5, 20])
-    verify_take((3,4), [-5, 20], mode="wrap")
-    verify_take((3,4), [-1, 2], axis=0)
-    verify_take((3,4), [-1, 2], axis=0, mode="wrap")
-    verify_take((3,4), [-1, 2], axis=1)
-    verify_take((3,4), [-1, 2], axis=1, mode="wrap")
-    verify_take((3,3,3), [[11,25]], mode="fast")
-    verify_take((3,4), [0, 2], axis=0, mode="fast")
-    verify_take((3,4), [0, 2], axis=1, mode="fast")
+    verify_take((2, 2), [[[1, 0], [0, 1]]], 0)
+    verify_take((2, 2), [[[1, 0], [0, 1]]], 1)
+    verify_take((4, 3, 5, 6), [[2, 1, 0, 0]], -2)
+    verify_take((3, 4), [-5, 20])
+    verify_take((3, 4), [-5, 20], mode="wrap")
+    verify_take((3, 4), [-1, 2], axis=0)
+    verify_take((3, 4), [-1, 2], axis=0, mode="wrap")
+    verify_take((3, 4), [-1, 2], axis=1)
+    verify_take((3, 4), [-1, 2], axis=1, mode="wrap")
+    verify_take((3, 3, 3), [[11, 25]], mode="fast")
+    verify_take((3, 4), [0, 2], axis=0, mode="fast")
+    verify_take((3, 4), [0, 2], axis=1, mode="fast")
+
 
 @tvm.testing.uses_gpu
 def test_gather():
@@ -893,9 +941,10 @@ def test_gather():
     verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
     verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
 
+
 @tvm.testing.uses_gpu
 def test_gather_nd():
-    for indices_dtype in ['int32', 'float32']:
+    for indices_dtype in ["int32", "float32"]:
         verify_gather_nd((4,), [[1.8]], indices_dtype)
         verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
         verify_gather_nd((2, 3), [[1]], indices_dtype)
@@ -903,11 +952,12 @@ def test_gather_nd():
         verify_gather_nd((2, 3), [[1, 0], [0, 2]], indices_dtype)
         verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]], indices_dtype)
         verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]], indices_dtype)
-        verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
-                                     [[3, 1], [0, 2]]], indices_dtype)
+        verify_gather_nd(
+            (2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]], [[3, 1], [0, 2]]], indices_dtype
+        )
         verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]], indices_dtype)
-        verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]],
-                         indices_dtype)
+        verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]], indices_dtype)
+
 
 @tvm.testing.uses_gpu
 def test_arange():
@@ -921,6 +971,7 @@ def test_arange():
     verify_arange(20, 1, -1)
     verify_arange(20, 1, -1.5)
 
+
 @tvm.testing.uses_gpu
 def test_repeat():
     verify_repeat((2,), 1, 0)
@@ -928,13 +979,15 @@ def test_repeat():
     verify_repeat((3, 2, 4), 3, 1)
     verify_repeat((1, 3, 2, 4), 4, -1)
 
+
 @tvm.testing.uses_gpu
 def test_tile():
     verify_tile((3, 2), (2, 3))
     verify_tile((3, 2, 5), (2,))
-    verify_tile((3, ), (2, 3, 3))
+    verify_tile((3,), (2, 3, 3))
     verify_tile((4, 0), (5,))
 
+
 @tvm.testing.uses_gpu
 def test_layout_transform():
     in_shape = (1, 32, 8, 8)
@@ -1008,9 +1061,11 @@ def test_sequence_mask():
                     f = tvm.build(s, [A, B, C], device, name="SequenceMask")
                     f(tvm_A, tvm_B, tvm_C)
                     tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
+
                 for backend, ctx in tvm.testing.enabled_targets():
                     check_device(backend, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_ndarray_size():
     in_shape = (5, 11, 7)
@@ -1038,17 +1093,18 @@ def test_ndarray_size():
 @tvm.testing.uses_gpu
 def test_where_fusion():
     """integration test that where and zeros should be properly inlined"""
+
     def check_device(device, ctx):
         with tvm.target.Target(device):
             print("Running on target: %s" % device)
             conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
-            data = te.placeholder((2, 1, 2, 4), 'int8', 'data')
-            w = te.placeholder((3, 1, 2, 2), 'int8', 'w')
-            conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32')
-            zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32'))
+            data = te.placeholder((2, 1, 2, 4), "int8", "data")
+            w = te.placeholder((3, 1, 2, 2), "int8", "w")
+            conv1 = conv2d_compute(data, w, 1, 0, 1, "int32")
+            zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32"))
             gt = topi.greater_equal(conv1, zeros)
-            one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32'))
-            two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32'))
+            one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32"))
+            two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32"))
             where = topi.where(gt, one, two)
             add = topi.add(conv1, where)
             outs = [add]
@@ -1058,6 +1114,7 @@ def test_where_fusion():
     for backend, ctx in tvm.testing.enabled_targets():
         check_device(backend, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_one_hot():
     verify_one_hot((3,), 3, 1, 0, -1, "int32")
@@ -1076,42 +1133,50 @@ def test_unravel_index():
         verify_unravel_index(144, [5, 5, 5, 2], dtype)
         verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
 
+
 @tvm.testing.uses_gpu
 def test_sparse_to_dense():
-    verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) #scalar
-    verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) #vector
-    verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0],[0, 0, 2, 0],[0, 0, 0, 0]]) #nXd
+    verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0])  # scalar
+    verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3])  # vector
+    verify_sparse_to_dense(
+        [[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
+    )  # nXd
     verify_sparse_to_dense(
         [[0, 0, 0], [1, 2, 3]],
         [1, 2],
         4,
         [2, 3, 4],
-        [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]],  [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]]
-    ) #nXd
-    verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])  #floats
+        [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]],
+    )  # nXd
+    verify_sparse_to_dense(
+        [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]
+    )  # floats
     verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0])  # default value not specified
 
-    #negative test cases
-    #sparse indices should be ints
-    #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
-    #sparse_values should be 0d or 1d only
-    #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
-    #sparse_indices should not be > 2d tensor
-    #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # negative test cases
+    # sparse indices should be ints
+    # verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # sparse_values should be 0d or 1d only
+    # verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+    # sparse_indices should not be > 2d tensor
+    # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+
 
 @tvm.testing.uses_gpu
 def test_matrix_set_diag():
-    for dtype in ['float32', 'int32']:
+    for dtype in ["float32", "int32"]:
         verify_matrix_set_diag((2, 2), dtype)
         verify_matrix_set_diag((4, 3, 3), dtype)
         verify_matrix_set_diag((2, 3, 4), dtype)
 
+
 @tvm.testing.uses_gpu
 def test_adv_index():
-    verify_adv_index((3, 4, 5), [(2,), (2, ), (1,)])
+    verify_adv_index((3, 4, 5), [(2,), (2,), (1,)])
     verify_adv_index((10, 15, 5), [(1, 1), (2, 7)])
     verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
 
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
index 09ca58d..4bce660 100644 (file)
@@ -23,35 +23,62 @@ import tvm.topi.testing
 import math
 from tvm.topi.util import nchw_pack_layout
 
-def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
-                      layout='NCHW', method="nearest_neighbor",
-                      in_batch_block = 0, in_channel_block = 0):
-    if layout == 'NCHW':
-        A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
+
+def verify_upsampling(
+    batch,
+    in_channel,
+    in_height,
+    in_width,
+    scale_h,
+    scale_w,
+    layout="NCHW",
+    method="nearest_neighbor",
+    in_batch_block=0,
+    in_channel_block=0,
+):
+    if layout == "NCHW":
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
         dtype = A.dtype
-        out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)))
+        out_shape = (
+            batch,
+            in_channel,
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+        )
         a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
     elif nchw_pack_layout(layout):
-        A = te.placeholder((batch, in_channel, in_height, in_width, in_batch_block, in_channel_block),
-                             name='A')
+        A = te.placeholder(
+            (batch, in_channel, in_height, in_width, in_batch_block, in_channel_block), name="A"
+        )
         dtype = A.dtype
-        out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)),
-                     in_batch_block, in_channel_block)
-        a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width,
-                                 in_batch_block, in_channel_block)).astype(dtype)
-    elif layout == 'NHWC':
-        A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
+        out_shape = (
+            batch,
+            in_channel,
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+            in_batch_block,
+            in_channel_block,
+        )
+        a_np = np.random.uniform(
+            size=(batch, in_channel, in_height, in_width, in_batch_block, in_channel_block)
+        ).astype(dtype)
+    elif layout == "NHWC":
+        A = te.placeholder((batch, in_height, in_width, in_channel), name="A")
         dtype = A.dtype
-        out_shape = (batch, int(round(in_height*scale_h)), int(round(in_width*scale_w)), in_channel)
+        out_shape = (
+            batch,
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+            in_channel,
+        )
         a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
     else:
-        raise NotImplementedError(
-            'Layout not supported {} '.format(layout))
+        raise NotImplementedError("Layout not supported {} ".format(layout))
 
     B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False)
 
     if method == "bilinear":
-        out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w)))
+        out_size = (int(round(in_height * scale_h)), int(round(in_width * scale_w)))
         b_np = tvm.topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric")
     else:
         b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout)
@@ -70,6 +97,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_upsampling():
     # nearest_neighbor - NCHW
@@ -88,51 +116,114 @@ def test_upsampling():
     verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear")
 
     # nearest_neighbor - NCHWinic
-    verify_upsampling(2, 2, 32, 32, in_batch_block=4, in_channel_block=8,
-                      scale_h=2.0, scale_w=2.0)
-    verify_upsampling(2, 2, 64, 64, in_batch_block=1, in_channel_block=16,
-                      scale_h=3.0, scale_w=3.0)
-    verify_upsampling(1, 4, 22, 32, in_batch_block=1, in_channel_block=16,
-                      scale_h=1.954545497894287, scale_w=2.0)
+    verify_upsampling(2, 2, 32, 32, in_batch_block=4, in_channel_block=8, scale_h=2.0, scale_w=2.0)
+    verify_upsampling(2, 2, 64, 64, in_batch_block=1, in_channel_block=16, scale_h=3.0, scale_w=3.0)
+    verify_upsampling(
+        1, 4, 22, 32, in_batch_block=1, in_channel_block=16, scale_h=1.954545497894287, scale_w=2.0
+    )
 
     # bilinear - NCHWinic
-    verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
-                      scale_h=2.0, scale_w=2.0, method="bilinear")
-    verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
-                      scale_h=3.0, scale_w=3.0, method="bilinear")
-    verify_upsampling(2, 4, 22, 32, in_batch_block=1, in_channel_block=16,
-                      scale_h=1.954545497894287, scale_w=2.0, layout="NCHW1n16c", method="bilinear")
+    verify_upsampling(
+        2,
+        2,
+        32,
+        32,
+        in_batch_block=1,
+        in_channel_block=1,
+        scale_h=2.0,
+        scale_w=2.0,
+        method="bilinear",
+    )
+    verify_upsampling(
+        2,
+        2,
+        32,
+        32,
+        in_batch_block=1,
+        in_channel_block=1,
+        scale_h=3.0,
+        scale_w=3.0,
+        method="bilinear",
+    )
+    verify_upsampling(
+        2,
+        4,
+        22,
+        32,
+        in_batch_block=1,
+        in_channel_block=16,
+        scale_h=1.954545497894287,
+        scale_w=2.0,
+        layout="NCHW1n16c",
+        method="bilinear",
+    )
 
     # bilinear - NHWC
     verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear")
     verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
-    verify_upsampling(1, 64, 22, 32,  3.0, 3.0, layout="NHWC", method="bilinear")
+    verify_upsampling(1, 64, 22, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
+
 
-def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_d, scale_h, scale_w,
-                        layout='NCDHW', method="nearest_neighbor"):
-    if layout == 'NCDHW':
-        A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
+def verify_upsampling3d(
+    batch,
+    in_channel,
+    in_depth,
+    in_height,
+    in_width,
+    scale_d,
+    scale_h,
+    scale_w,
+    layout="NCDHW",
+    method="nearest_neighbor",
+):
+    if layout == "NCDHW":
+        A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
         dtype = A.dtype
-        out_shape = (batch, in_channel, int(round(in_depth*scale_d)), int(round(in_height*scale_h)),
-                     int(round(in_width*scale_w)))
-        a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype)
-    elif layout == 'NDHWC':
-        A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
+        out_shape = (
+            batch,
+            in_channel,
+            int(round(in_depth * scale_d)),
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+        )
+        a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(
+            dtype
+        )
+    elif layout == "NDHWC":
+        A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), name="A")
         dtype = A.dtype
-        out_shape = (batch, int(round(in_depth*scale_d)), int(round(in_height*scale_h)),
-                     int(round(in_width*scale_w)), in_channel)
-        a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype)
+        out_shape = (
+            batch,
+            int(round(in_depth * scale_d)),
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+            in_channel,
+        )
+        a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(
+            dtype
+        )
     else:
-        raise NotImplementedError(
-            'Layout not supported {} '.format(layout))
+        raise NotImplementedError("Layout not supported {} ".format(layout))
 
-    B = topi.nn.upsampling3d(A, scale_d, scale_h, scale_w, layout=layout, method=method,
-                             coordinate_transformation_mode="half_pixel")
+    B = topi.nn.upsampling3d(
+        A,
+        scale_d,
+        scale_h,
+        scale_w,
+        layout=layout,
+        method=method,
+        coordinate_transformation_mode="half_pixel",
+    )
 
     if method == "trilinear":
-        out_size = (int(round(in_depth*scale_d)), int(round(in_height*scale_h)), int(round(in_width*scale_w)))
-        b_np = tvm.topi.testing.trilinear_resize3d_python(a_np, out_size, layout,
-                                                      coordinate_transformation_mode="half_pixel")
+        out_size = (
+            int(round(in_depth * scale_d)),
+            int(round(in_height * scale_h)),
+            int(round(in_width * scale_w)),
+        )
+        b_np = tvm.topi.testing.trilinear_resize3d_python(
+            a_np, out_size, layout, coordinate_transformation_mode="half_pixel"
+        )
     else:
         b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout)
 
@@ -150,6 +241,7 @@ def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_
     for device, ctx in tvm.testing.enabled_targets():
         check_device(device, ctx)
 
+
 @tvm.testing.uses_gpu
 def test_upsampling3d():
     # nearest_neighbor - NCDHW
@@ -170,7 +262,10 @@ def test_upsampling3d():
     # trilinear - NDHWC
     verify_upsampling3d(2, 2, 16, 16, 16, 2.0, 2.0, 2.0, layout="NDHWC", method="trilinear")
     verify_upsampling3d(2, 2, 32, 32, 32, 3.0, 3.0, 3.0, layout="NDHWC", method="trilinear")
-    verify_upsampling3d(1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC", method="trilinear")
+    verify_upsampling3d(
+        1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC", method="trilinear"
+    )
+
 
 if __name__ == "__main__":
     test_upsampling()
index a6287b1..18182dc 100644 (file)
@@ -21,8 +21,10 @@ from tvm import topi
 
 def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape):
     dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout)
-    assert dst_shape == expect_shape, \
-        "Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape)
+    assert dst_shape == expect_shape, "Shape mismatch: expecting %s but got %s" % (
+        expect_shape,
+        dst_shape,
+    )
 
 
 def test_get_shape():
@@ -31,5 +33,6 @@ def test_get_shape():
     verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32))
     verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16))
 
+
 if __name__ == "__main__":
     test_get_shape()
index 22f4683..0d02dd8 100644 (file)
@@ -65,6 +65,7 @@ _proposal_implement = {
     "gpu": (topi.cuda.proposal, topi.cuda.schedule_proposal),
 }
 
+
 def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
     dtype = "float32"
     batch_size, num_anchor, elem_length = dshape
@@ -116,13 +117,15 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
             tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
             tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
 
-    for device in ['llvm', 'cuda', 'opencl']:
+    for device in ["llvm", "cuda", "opencl"]:
         check_device(device)
 
 
 @tvm.testing.uses_gpu
-@pytest.mark.skip("Skip this test as it is intermittent."
-                  "See https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094")
+@pytest.mark.skip(
+    "Skip this test as it is intermittent."
+    "See https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094"
+)
 def test_get_valid_counts():
     verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0)
     verify_get_valid_counts((1, 2500, 6), 0, 0, 1)
@@ -131,8 +134,20 @@ def test_get_valid_counts():
     verify_get_valid_counts((16, 500, 5), 0.95, -1, 1)
 
 
-def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, max_output_size,
-                               iou_threshold, force_suppress, top_k, coord_start, score_index, id_index):
+def verify_non_max_suppression(
+    np_data,
+    np_valid_count,
+    np_indices,
+    np_result,
+    np_indices_result,
+    max_output_size,
+    iou_threshold,
+    force_suppress,
+    top_k,
+    coord_start,
+    score_index,
+    id_index,
+):
     dshape = np_data.shape
     batch, num_anchors, _ = dshape
     indices_dshape = (batch, num_anchors)
@@ -148,12 +163,32 @@ def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, n
         print("Running on target: %s" % device)
         with tvm.target.Target(device):
             fcompute, fschedule = tvm.topi.testing.dispatch(device, _nms_implement)
-            out = fcompute(data, valid_count, indices, max_output_size, iou_threshold, force_suppress,
-                           top_k, coord_start=coord_start, score_index=score_index, id_index=id_index,
-                           return_indices=False)
-            indices_out = fcompute(data, valid_count, indices, max_output_size, iou_threshold, force_suppress,
-                                   top_k, coord_start=coord_start, score_index=score_index, id_index=id_index,
-                                   return_indices=True)
+            out = fcompute(
+                data,
+                valid_count,
+                indices,
+                max_output_size,
+                iou_threshold,
+                force_suppress,
+                top_k,
+                coord_start=coord_start,
+                score_index=score_index,
+                id_index=id_index,
+                return_indices=False,
+            )
+            indices_out = fcompute(
+                data,
+                valid_count,
+                indices,
+                max_output_size,
+                iou_threshold,
+                force_suppress,
+                top_k,
+                coord_start=coord_start,
+                score_index=score_index,
+                id_index=id_index,
+                return_indices=True,
+            )
             s = fschedule(out)
             indices_s = fschedule(indices_out)
 
@@ -167,7 +202,7 @@ def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, n
         tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
 
         tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
-        if device == 'llvm':
+        if device == "llvm":
             f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
             f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
         else:
@@ -175,40 +210,99 @@ def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, n
             f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
         tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
 
-    for device in ['llvm', 'cuda', 'opencl']:
+    for device in ["llvm", "cuda", "opencl"]:
         check_device(device)
 
+
 @tvm.testing.uses_gpu
 def test_non_max_suppression():
-    np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
-                         [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
-                         [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
+    np_data = np.array(
+        [
+            [
+                [0, 0.8, 1, 20, 25, 45],
+                [1, 0.7, 30, 60, 50, 80],
+                [0, 0.4, 4, 21, 19, 40],
+                [2, 0.9, 35, 61, 52, 79],
+                [1, 0.5, 100, 60, 70, 110],
+            ]
+        ]
+    ).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
     np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
     max_output_size = -1
-    np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
-                           [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
-                           [-1, -1, -1, -1, -1, -1]]])
+    np_result = np.array(
+        [
+            [
+                [2, 0.9, 35, 61, 52, 79],
+                [0, 0.8, 1, 20, 25, 45],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1, -1],
+            ]
+        ]
+    )
     np_indices_result = np.array([[3, 0, -1, -1, -1]])
 
-    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
-                               max_output_size, 0.7, True, 2, 2, 1, 0)
-
-    np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80],
-                         [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79],
-                         [0.5, 100, 60, 70, 110]]]).astype("float32")
+    verify_non_max_suppression(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_result,
+        np_indices_result,
+        max_output_size,
+        0.7,
+        True,
+        2,
+        2,
+        1,
+        0,
+    )
+
+    np_data = np.array(
+        [
+            [
+                [0.8, 1, 20, 25, 45],
+                [0.7, 30, 60, 50, 80],
+                [0.4, 4, 21, 19, 40],
+                [0.9, 35, 61, 52, 79],
+                [0.5, 100, 60, 70, 110],
+            ]
+        ]
+    ).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
     np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
     max_output_size = 2
-    np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45],
-                           [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1],
-                           [-1, -1, -1, -1, -1]]])
+    np_result = np.array(
+        [
+            [
+                [0.9, 35, 61, 52, 79],
+                [0.8, 1, 20, 25, 45],
+                [-1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1],
+                [-1, -1, -1, -1, -1],
+            ]
+        ]
+    )
     np_indices_result = np.array([[3, 0, -1, -1, -1]])
-    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
-                               max_output_size, 0.7, False, 2, 1, 0, -1)
-
-
-def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
+    verify_non_max_suppression(
+        np_data,
+        np_valid_count,
+        np_indices,
+        np_result,
+        np_indices_result,
+        max_output_size,
+        0.7,
+        False,
+        2,
+        1,
+        0,
+        -1,
+    )
+
+
+def verify_multibox_prior(
+    dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False
+):
     data = te.placeholder(dshape, name="data")
 
     dtype = data.dtype
@@ -232,11 +326,25 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
         for j in range(in_width):
             center_w = (j + offset_w) * steps_w
             for k in range(num_sizes + num_ratios - 1):
-                w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \
-                    size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0
-                h = size_ratio_concat[k] / 2.0 if k < num_sizes else \
-                    size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
-                count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k
+                w = (
+                    size_ratio_concat[k] * in_height / in_width / 2.0
+                    if k < num_sizes
+                    else size_ratio_concat[0]
+                    * in_height
+                    / in_width
+                    * math.sqrt(size_ratio_concat[k + 1])
+                    / 2.0
+                )
+                h = (
+                    size_ratio_concat[k] / 2.0
+                    if k < num_sizes
+                    else size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
+                )
+                count = (
+                    i * in_width * (num_sizes + num_ratios - 1)
+                    + j * (num_sizes + num_ratios - 1)
+                    + k
+                )
                 np_out[0][count][0] = center_w - w
                 np_out[0][count][1] = center_h - h
                 np_out[0][count][2] = center_w + w
@@ -262,7 +370,7 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
         f(tvm_input_data, tvm_out)
         tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3)
 
-    for device in ['llvm', 'opencl', 'cuda']:
+    for device in ["llvm", "opencl", "cuda"]:
         check_device(device)
 
 
@@ -270,7 +378,9 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
 def test_multibox_prior():
     verify_multibox_prior((1, 3, 50, 50))
     verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5))
-    verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True)
+    verify_multibox_prior(
+        (1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True
+    )
 
 
 @tvm.testing.uses_gpu
@@ -287,9 +397,15 @@ def test_multibox_detection():
     np_loc_preds = np.array([[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]])
     np_anchors = np.array([[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]])
 
-    expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
-                                 [0, 0.44999999, 1, 1, 1, 1],
-                                 [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
+    expected_np_out = np.array(
+        [
+            [
+                [1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
+                [0, 0.44999999, 1, 1, 1, 1],
+                [0, 0.30000001, 0, 0, 0.22903419, 0.20435292],
+            ]
+        ]
+    )
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -311,7 +427,7 @@ def test_multibox_detection():
         f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
         tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
 
-    for device in ['llvm', 'opencl', 'cuda']:
+    for device in ["llvm", "opencl", "cuda"]:
         check_device(device)
 
 
@@ -324,12 +440,16 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
 
     @memoize("topi.tests.test_topi_vision.verify_roi_align")
     def get_ref_data():
-        a_np = np.random.uniform(size=a_shape).astype('float32')
-        rois_np = np.random.uniform(size=rois_shape).astype('float32') * in_size
-        rois_np[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
-        b_np = tvm.topi.testing.roi_align_nchw_python(a_np, rois_np, pooled_size=pooled_size,
-                                                  spatial_scale=spatial_scale,
-                                                  sample_ratio=sample_ratio)
+        a_np = np.random.uniform(size=a_shape).astype("float32")
+        rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size
+        rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi)
+        b_np = tvm.topi.testing.roi_align_nchw_python(
+            a_np,
+            rois_np,
+            pooled_size=pooled_size,
+            spatial_scale=spatial_scale,
+            sample_ratio=sample_ratio,
+        )
 
         return a_np, rois_np, b_np
 
@@ -344,9 +464,13 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
 
         with tvm.target.Target(device):
             fcompute, fschedule = tvm.topi.testing.dispatch(device, _roi_align_implement)
-            b = fcompute(a, rois, pooled_size=pooled_size,
-                         spatial_scale=spatial_scale,
-                         sample_ratio=sample_ratio)
+            b = fcompute(
+                a,
+                rois,
+                pooled_size=pooled_size,
+                spatial_scale=spatial_scale,
+                sample_ratio=sample_ratio,
+            )
             s = fschedule(b)
 
         tvm_a = tvm.nd.array(a_np, ctx)
@@ -356,7 +480,7 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
         f(tvm_a, tvm_rois, tvm_b)
         tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3)
 
-    for device in ['llvm', 'cuda', 'opencl']:
+    for device in ["llvm", "cuda", "opencl"]:
         check_device(device)
 
 
@@ -377,12 +501,13 @@ def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_sc
 
     @memoize("topi.tests.test_topi_vision.verify_roi_pool")
     def get_ref_data():
-        a_np = np.random.uniform(size=a_shape).astype('float32')
-        rois_np = np.random.uniform(size=rois_shape).astype('float32') * in_size
-        rois_np[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32')
+        a_np = np.random.uniform(size=a_shape).astype("float32")
+        rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size
+        rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi).astype("float32")
 
-        b_np = tvm.topi.testing.roi_pool_nchw_python(a_np, rois_np, pooled_size=pooled_size,
-                                                 spatial_scale=spatial_scale)
+        b_np = tvm.topi.testing.roi_pool_nchw_python(
+            a_np, rois_np, pooled_size=pooled_size, spatial_scale=spatial_scale
+        )
         return a_np, rois_np, b_np
 
     a_np, rois_np, b_np = get_ref_data()
@@ -395,8 +520,9 @@ def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_sc
         print("Running on target: %s" % device)
 
         with tvm.target.Target(device):
-            b = topi.vision.rcnn.roi_pool_nchw(a, rois, pooled_size=pooled_size,
-                                                spatial_scale=spatial_scale)
+            b = topi.vision.rcnn.roi_pool_nchw(
+                a, rois, pooled_size=pooled_size, spatial_scale=spatial_scale
+            )
             s_func = tvm.topi.testing.dispatch(device, _roi_pool_schedule)
             s = s_func(b)
 
@@ -407,7 +533,7 @@ def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_sc
         f(tvm_a, tvm_rois, tvm_b)
         tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-4)
 
-    for device in ['cuda', 'llvm']:
+    for device in ["cuda", "llvm"]:
         check_device(device)
 
 
@@ -440,48 +566,66 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
             f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out)
             tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
 
-    for device in ['llvm', 'cuda']:
+    for device in ["llvm", "cuda"]:
         check_device(device)
 
 
 @tvm.testing.uses_gpu
 def test_proposal():
-    attrs = {'scales': (0.5,),'ratios': (0.5,),
-        'feature_stride': 16,
-        'iou_loss': False,
-        'rpn_min_size': 16,
-        'threshold': 0.7,
-        'rpn_pre_nms_top_n': 200,
-        'rpn_post_nms_top_n': 4,
+    attrs = {
+        "scales": (0.5,),
+        "ratios": (0.5,),
+        "feature_stride": 16,
+        "iou_loss": False,
+        "rpn_min_size": 16,
+        "threshold": 0.7,
+        "rpn_pre_nms_top_n": 200,
+        "rpn_post_nms_top_n": 4,
     }
-    np_cls_prob = np.array([[
-        [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
-        [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
-    ]], dtype='float32')
-    np_bbox_pred = np.array([[
-        [[0.5, 1.0, 0.6], [0.8,  1.2, 2.0], [0.9, 1.0, 0.8]],
-        [[0.5, 1.0, 0.7], [0.8,  1.2, 1.6], [2.1, 1.5, 0.7]],
-        [[1.0, 0.5, 0.7], [1.5,  0.9, 1.6], [1.4, 1.5, 0.8]],
-        [[1.0, 0.5, 0.6], [1.5,  0.9, 2.0], [1.8, 1.0, 0.9]],
-    ]], dtype='float32')
-    np_im_info = np.array([[48., 48., 1.]], dtype='float32')
-    np_out = np.array([
-        [0., 0., 2.8451548,28.38012, 18.154846],
-        [0., 0., 15.354933, 41.96971, 41.245064],
-        [0., 18.019852, 1.0538368, 51.98015, 25.946163],
-        [0., 27.320923, -1.266357, 55., 24.666357]
-    ], dtype='float32')
+    np_cls_prob = np.array(
+        [
+            [
+                [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
+                [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]],
+            ]
+        ],
+        dtype="float32",
+    )
+    np_bbox_pred = np.array(
+        [
+            [
+                [[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
+                [[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
+                [[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
+                [[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
+            ]
+        ],
+        dtype="float32",
+    )
+    np_im_info = np.array([[48.0, 48.0, 1.0]], dtype="float32")
+    np_out = np.array(
+        [
+            [0.0, 0.0, 2.8451548, 28.38012, 18.154846],
+            [0.0, 0.0, 15.354933, 41.96971, 41.245064],
+            [0.0, 18.019852, 1.0538368, 51.98015, 25.946163],
+            [0.0, 27.320923, -1.266357, 55.0, 24.666357],
+        ],
+        dtype="float32",
+    )
 
     verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
 
-    np_out = np.array([
-        [ 0., -5.25, -2.5, 21.75, 19.],
-        [ 0., 11.25, -2., 37.25, 18.5],
-        [ 0., 26.849998, -2.3000002, 53.45, 18.6],
-        [ 0., -4.95, 13.799999, 22.25, 35.5]
-    ], dtype='float32')
-
-    attrs['iou_loss'] = True
+    np_out = np.array(
+        [
+            [0.0, -5.25, -2.5, 21.75, 19.0],
+            [0.0, 11.25, -2.0, 37.25, 18.5],
+            [0.0, 26.849998, -2.3000002, 53.45, 18.6],
+            [0.0, -4.95, 13.799999, 22.25, 35.5],
+        ],
+        dtype="float32",
+    )
+
+    attrs["iou_loss"] = True
     verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
 
 
index 39d5d61..65c8ec3 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 class CanonicalChecker:
     def __init__(self):
         self.analyzer = tvm.arith.Analyzer()
@@ -24,16 +25,16 @@ class CanonicalChecker:
     def verify(self, data, expected):
         res = self.analyzer.canonical_simplify(data)
         expected = tvm.runtime.convert(expected)
-        assert tvm.ir.structural_equal(
-            res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected)
+        assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(
+            data, res, expected
+        )
 
 
 def test_mul_sum_simplify():
     ck = CanonicalChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
 
-    ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
-              x * 13 + z * 4 + y * 4 +6)
+    ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6)
     ck.verify(x * 3 - 4 * x + 1, 1 - x)
     ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
     tdiv = tvm.tir.truncdiv
@@ -60,7 +61,7 @@ def test_split_index_simplify():
     tmod = tvm.tir.truncmod
 
     # split div const
-    ck.verify(tdiv(x, 3) *3 + tmod(x, 3), x)
+    ck.verify(tdiv(x, 3) * 3 + tmod(x, 3), x)
     ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)
     ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))
     ck.verify(tdiv(tmod(x, 2), 8), 0)
@@ -85,7 +86,7 @@ def test_split_index_simplify():
     # floordiv
     fld = tvm.te.floordiv
     flm = tvm.te.floormod
-    ck.verify(fld(x*5, 2), fld(x*5, 2))
+    ck.verify(fld(x * 5, 2), fld(x * 5, 2))
     ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
     ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
     ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))
@@ -94,7 +95,7 @@ def test_split_index_simplify():
     ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
 
     # cannot simplify mixed case, unless we canonicalize into one mode.
-    ck.verify(tdiv(x,6) * 2 + tmod(fld(x,3), 2), tdiv(x,6) * 2 + tmod(fld(x,3), 2))
+    ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2))
 
 
 def test_div_simplify():
@@ -103,7 +104,7 @@ def test_div_simplify():
     tdiv = tvm.tir.truncdiv
 
     # truc div
-    ck.verify(tdiv(16+48*x,16), x*3 + 1)
+    ck.verify(tdiv(16 + 48 * x, 16), x * 3 + 1)
     # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
     # (17+48*x)/16 != 1+3*x
     ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))
@@ -116,18 +117,17 @@ def test_div_simplify():
     # floordiv
     fld = tvm.te.floordiv
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
-    ck.verify(fld(16+48*x, 16), x*3 + 1)
-    ck.verify(fld(17+48*x, 16), x * 3 + 1)
-    ck.verify(fld(17+47*x, 16), fld(x * 47 + 17, 16))
+    ck.verify(fld(16 + 48 * x, 16), x * 3 + 1)
+    ck.verify(fld(17 + 48 * x, 16), x * 3 + 1)
+    ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16))
 
 
 def test_floormod_simplify():
     ck = CanonicalChecker()
     flm = tvm.te.floormod
     x, y = te.var("x"), te.var("y")
-    ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512,  16),
-              flm((x*4) + y + 12, 16))
-    ck.verify(flm(flm((x*4), 16), 8), flm(x, 2) * 4)
+    ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16))
+    ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)
 
 
 def test_canonical_mixed():
@@ -136,48 +136,51 @@ def test_canonical_mixed():
     z = tvm.tir.const(3, "int32")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
-    ck.verify(tdiv(x, (z*z)) - tdiv(x, (z*z)), 0)
-    ck.verify(tdiv(x, (z+z)) - tdiv(x, (z+z)), 0)
+    ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0)
+    ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0)
     ck.verify(x - 2 < 3, x < 5)
     ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0)
     ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0)
     ck.verify(x * x - x * x, 0)
 
     fld = tvm.te.floordiv
-    ck.verify(fld(x, (z*z)) - fld(x, (z*z)), 0)
-    ck.verify(fld(x, (z+z)) - fld(x, (z+z)), 0)
+    ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0)
+    ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)
 
 
 def test_reduce_combiner_simplify():
     ck = CanonicalChecker()
-    dummy = te.var('dummy')
+    dummy = te.var("dummy")
     comm_reducer = te.comm_reducer
-    prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.tir.const(1, t0))
+    prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0))
 
     sum_or_prod = comm_reducer(
-        lambda x, y: tvm.tir.Select(dummy < 0,
-                                     x + y, x*y),
-        lambda t0: tvm.tir.Select(dummy < 0,
-                                   tvm.tir.const(0, t0), tvm.tir.const(1, t0)))
+        lambda x, y: tvm.tir.Select(dummy < 0, x + y, x * y),
+        lambda t0: tvm.tir.Select(dummy < 0, tvm.tir.const(0, t0), tvm.tir.const(1, t0)),
+    )
     sum_and_prod = comm_reducer(
-        lambda x, y: (x[0] + y[0],
-                      x[1]*y[1]),
-        lambda t0, t1: (tvm.tir.const(0, t0),
-                        tvm.tir.const(5, t0) - tvm.tir.const(4, t0)))
+        lambda x, y: (x[0] + y[0], x[1] * y[1]),
+        lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t0) - tvm.tir.const(4, t0)),
+    )
     some_reducer1 = comm_reducer(
-        lambda x, y: (x[0] + y[0],
-                      x[0] + y[0] + x[1] + y[1],
-                      x[0]*y[2] + y[0]*x[2],
-                      x[1] + y[2],
-                    4.0),
-        lambda t0, t1, t2, t3, t4: (tvm.tir.const(0, t0),
-                                    tvm.tir.const(1, t1),
-                                    tvm.tir.const(2, t2),
-                                    tvm.tir.const(3, t3),
-                                    tvm.tir.const(4, t4)))
+        lambda x, y: (
+            x[0] + y[0],
+            x[0] + y[0] + x[1] + y[1],
+            x[0] * y[2] + y[0] * x[2],
+            x[1] + y[2],
+            4.0,
+        ),
+        lambda t0, t1, t2, t3, t4: (
+            tvm.tir.const(0, t0),
+            tvm.tir.const(1, t1),
+            tvm.tir.const(2, t2),
+            tvm.tir.const(3, t3),
+            tvm.tir.const(4, t4),
+        ),
+    )
 
     k = te.reduce_axis((0, 10), name="k")
-    A = te.placeholder((10,), name='A')
+    A = te.placeholder((10,), name="A")
     # Test that SimplifyCombiner makes use of vranges
     ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
     ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k))
@@ -186,19 +189,22 @@ def test_reduce_combiner_simplify():
     ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
     ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1))
     ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
-    ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], te.sum(A[k], k))
-    ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k))
-
-    reference_simplified_sources = [[A[0]],
-                                    [A[0], A[1]],
-                                    [A[0], A[2]],
-                                    [A[0], A[1], A[2], A[3]],
-                                    [A[4]]]
+    ck.verify(sum_and_prod((A[k], A[10 - k]), k)[0], te.sum(A[k], k))
+    ck.verify(sum_and_prod((A[k], A[10 - k]), k)[1], prod(A[10 - k], k))
+
+    reference_simplified_sources = [
+        [A[0]],
+        [A[0], A[1]],
+        [A[0], A[2]],
+        [A[0], A[1], A[2], A[3]],
+        [A[4]],
+    ]
     for j in range(5):
         # Here we use the j-th component of the result, so only it and the components it
         # depends on are left.
         simplified = ck.analyzer.canonical_simplify(
-            some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
+            some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]
+        )
 
         # Check that the remaining components are the expected ones.
         for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
@@ -207,21 +213,21 @@ def test_reduce_combiner_simplify():
     # Test that components with side effects are not removed
     dummy = tvm.ir.GlobalVar("dummy")
     side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs)
-    ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
-             sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
-    ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
-              te.sum(side_effect(A[k]), k))
+    ck.verify(
+        sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],
+        sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],
+    )
+    ck.verify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0], te.sum(side_effect(A[k]), k))
 
 
 def test_reduce_simplify():
     ck = CanonicalChecker()
     k = te.reduce_axis((0, 10), name="k")
     j = te.reduce_axis((-5, 3), name="j")
-    A = te.placeholder((10,), name='A')
-    ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
-              te.sum(k + j, [k, j]))
+    A = te.placeholder((10,), name="A")
+    ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j]))
     ck.verify(te.sum(A[3], []), A[3])
-    ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype='float32'))
+    ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype="float32"))
     # The rule below is not typical, removed for now
     ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
 
@@ -233,20 +239,32 @@ def test_simplify_if_then_else():
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
     # simplification that takes condition into account.
-    res = tvm.tir.if_then_else((x * 4 + y) >= 466036,
-                           tvm.tir.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
-                                            tmod(tmod(((x*4) + y)  - 466036, 24528) -24512, 16),
-                                            x), y)
-
-    res2 = tvm.tir.if_then_else((x * 4) >= 466036 - y,
-                           tvm.tir.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
-                                            tmod(tmod(((x*4) + y)  - 466036, 24528) -24512, 16),
-                                            x), y)
+    res = tvm.tir.if_then_else(
+        (x * 4 + y) >= 466036,
+        tvm.tir.if_then_else(
+            24512 <= tmod(((x * 4) + y) - 466036, 24528),
+            tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),
+            x,
+        ),
+        y,
+    )
+
+    res2 = tvm.tir.if_then_else(
+        (x * 4) >= 466036 - y,
+        tvm.tir.if_then_else(
+            24512 <= tmod(((x * 4) + y) - 466036, 24528),
+            tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),
+            x,
+        ),
+        y,
+    )
     expected = tvm.tir.if_then_else(
         tvm.tir.LE(466036, (x * 4 + y)),
-        tvm.tir.if_then_else(tvm.tir.LE(24512, tmod(((x*4) + y) - 4, 24528)),
-                                     tmod(((x*4) + y)  - 4, 16),
-                         x), y)
+        tvm.tir.if_then_else(
+            tvm.tir.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x
+        ),
+        y,
+    )
     ck.verify(res, expected)
     ck.verify(res2, expected)
     # can only simplify if condition
@@ -254,13 +272,11 @@ def test_simplify_if_then_else():
     expected = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
-    res = tvm.tir.Select(x >= 10,
-                          tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
+    res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
     expected = tvm.tir.Select(x >= 10, x, 0)
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
-    res = tvm.tir.Select(x >= 10,
-                          tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
+    res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
     ck.verify(res, 0)
 
 
@@ -270,18 +286,28 @@ def test_complex_cases():
     y = te.var("y")
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
-    res2 = (tdiv(tdiv(tmod(x*128 + y, 1296),36)*2 + 1,2)*36 +
-            tdiv(tmod((x*128) + y, 36)*2 + 1,2)
-            - tmod((x*128) + y, 1296) + 1)
+    res2 = (
+        tdiv(tdiv(tmod(x * 128 + y, 1296), 36) * 2 + 1, 2) * 36
+        + tdiv(tmod((x * 128) + y, 36) * 2 + 1, 2)
+        - tmod((x * 128) + y, 1296)
+        + 1
+    )
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
     ck.verify(res2, 1)
 
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
-    res3 = (tdiv(x*1024 + y,65536) + tdiv(tmod(x*1024 + y, 65536),256)
-            + tdiv(tmod(x*1024 + y, 256),16) + tmod(x*1024 + y, 16) - tdiv(y,256) -
-            tdiv(tmod(y, 256),16) - tmod(y, 16) - (x*4))
-    ck.verify(res3, tdiv((x*1024) + y, 256) - tdiv(y,256) - (x*4))
+    res3 = (
+        tdiv(x * 1024 + y, 65536)
+        + tdiv(tmod(x * 1024 + y, 65536), 256)
+        + tdiv(tmod(x * 1024 + y, 256), 16)
+        + tmod(x * 1024 + y, 16)
+        - tdiv(y, 256)
+        - tdiv(tmod(y, 256), 16)
+        - tmod(y, 16)
+        - (x * 4)
+    )
+    ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))
 
 
 if __name__ == "__main__":
index 9ead0d4..badbcbc 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_dtype_bound():
     analyzer = tvm.arith.Analyzer()
 
@@ -44,8 +45,7 @@ def test_cast_bound():
     assert bd.min_value == 0
     assert bd.max_value == 2
 
-    bd = analyzer.const_int_bound(
-        tmod(x, 3).astype("float32").astype("int32"))
+    bd = analyzer.const_int_bound(tmod(x, 3).astype("float32").astype("int32"))
     assert bd.min_value == -2
     assert bd.max_value == 2
 
@@ -240,8 +240,7 @@ def test_select_bound():
     analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
 
-    bd = analyzer.const_int_bound(
-        tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1))
+    bd = analyzer.const_int_bound(tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1))
     assert bd.min_value == 0
     assert bd.max_value == 11
 
index 372f0e9..d72a0e0 100644 (file)
@@ -19,10 +19,10 @@ from tvm import te
 
 
 def test_deduce():
-    a = te.var('a')
-    b = te.var('b')
-    c = te.var('c')
-    d = te.var('d')
+    a = te.var("a")
+    b = te.var("b")
+    c = te.var("c")
+    d = te.var("d")
 
     b_s = tvm.arith.IntervalSet(2, 3)
     c_s = tvm.arith.IntervalSet(10, 15)
@@ -31,51 +31,48 @@ def test_deduce():
 
     fdiv = tvm.te.floordiv
 
-    e0 = (-b)*a+c-d
-    res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = fdiv(d - c, b*-1)
+    e0 = (-b) * a + c - d
+    res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
+    ans0 = fdiv(d - c, b * -1)
     tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
     res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
     tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
 
-    e0 = d*a+c-d
-    res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = fdiv(d-c, d)
+    e0 = d * a + c - d
+    res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {})
+    ans0 = fdiv(d - c, d)
     tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
     res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
     tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
 
-
-    e1 = (a*4+b < c)
+    e1 = a * 4 + b < c
     res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    ans1 = fdiv(c-1-b, 4)
+    ans1 = fdiv(c - 1 - b, 4)
     tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
 
-
     # expression containing variable a is on rhs
-    e1 = (c > a*4+b)
+    e1 = c > a * 4 + b
     res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
     tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
 
-
-    e2 = (tvm.te.max(5, a * 4) < 0)
+    e2 = tvm.te.max(5, a * 4) < 0
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf: handle"
     assert str(res2.min_value) == "pos_inf: handle"
 
     # expression containing variable a is on rhs
-    e2 = (zero < tvm.te.max(5, a * 4))
+    e2 = zero < tvm.te.max(5, a * 4)
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf: handle"
     assert str(res2.min_value) == "pos_inf: handle"
 
-    e3 = (-b)+a*c-d
-    res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
-    ans3 = fdiv(2,c)+1
+    e3 = (-b) + a * c - d
+    res3 = tvm.arith.deduce_bound(a, e3 >= 0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
+    ans3 = fdiv(2, c) + 1
     tvm.testing.assert_prim_expr_equal(res3.min_value, ans3)
 
     res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
@@ -97,8 +94,8 @@ def test_deduce():
     tvm.testing.assert_prim_expr_equal(res6.min_value, 10)
 
     # Add, Sub in `EQ`
-    e4 = ((a - c) == (b + d))
-    ans4 = (b + d + c)
+    e4 = (a - c) == (b + d)
+    ans4 = b + d + c
     res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
     tvm.testing.assert_prim_expr_equal(res7.max_value, ans4)
     tvm.testing.assert_prim_expr_equal(res7.min_value, ans4)
@@ -109,47 +106,50 @@ def test_deduce():
     tvm.testing.assert_prim_expr_equal(res8.min_value, -2)
 
     # Unsatisfiable Mul in `EQ`
-    e5 = (4 * a == b)
+    e5 = 4 * a == b
     res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
     assert str(res9.max_value) == "neg_inf: handle"
     assert str(res9.min_value) == "pos_inf: handle"
 
     # Unsatisfiable Mul in `EQ`
-    res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})    # simplifier is not able to prove that (b % b == 0)
+    res10 = tvm.arith.deduce_bound(
+        a, (b * a == b), {b: b_s}, {}
+    )  # simplifier is not able to prove that (b % b == 0)
     assert str(res10.max_value) == "neg_inf: handle"
     assert str(res10.min_value) == "pos_inf: handle"
 
 
 def test_check():
-    a = te.var('a')
-    b = te.var('b')
-    c = te.var('c')
-    d = te.var('d')
+    a = te.var("a")
+    b = te.var("b")
+    c = te.var("c")
+    d = te.var("d")
 
     b_s = tvm.arith.IntervalSet(2, 3)
     c_s = tvm.arith.IntervalSet(5, 7)
     d_s = tvm.arith.IntervalSet(-3, -1)
 
     # no compare operator
-    res1 = tvm.arith.deduce_bound(a, a+b, {b: b_s}, {})
+    res1 = tvm.arith.deduce_bound(a, a + b, {b: b_s}, {})
     assert res1.is_nothing()
 
     # multiple compare operators
-    res2 = tvm.arith.deduce_bound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
+    res2 = tvm.arith.deduce_bound(a, (a + b > 3).astype(c.dtype) > c, {b: b_s, c: c_s}, {})
     assert res2.is_nothing()
 
     # multiple target variable
-    res2 = tvm.arith.deduce_bound(a, a*2-a>b, {b: b_s}, {})
+    res2 = tvm.arith.deduce_bound(a, a * 2 - a > b, {b: b_s}, {})
     assert res2.is_nothing()
 
+
 def test_deduce_basic():
     def test_basic(a1, a2, coff):
-        a = te.var('a')
-        b = te.var('b')
+        a = te.var("a")
+        b = te.var("b")
         b_s = tvm.arith.IntervalSet(a1, a2)
-        e0 = b + a*coff + 3
+        e0 = b + a * coff + 3
 
-        res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0 < 17, {b: b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
         tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True)
 
@@ -159,12 +159,12 @@ def test_deduce_basic():
         tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True)
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") >= e0, {b: b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
 
         tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True)
 
-        res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0 >= 17, {b: b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
         tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True)
 
@@ -175,30 +175,31 @@ def test_deduce_basic():
     test_basic(1, 5, -4)
     test_basic(2, 6, -4)
 
+
 def test_deduce_complex():
     def test_complex(a1, a2, coff):
-        a = te.var('a')
-        b = te.var('b')
+        a = te.var("a")
+        b = te.var("b")
         b_s = tvm.arith.IntervalSet(a1, a2)
-        e0 = (b*3 + a* coff) * 4
+        e0 = (b * 3 + a * coff) * 4
 
-        res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0 < 63, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
-        tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) < 63, True)
+        tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) < 63, True)
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") >= e0, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
-        tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) <= 63, True)
+        tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) <= 63, True)
 
-        res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0 > 63, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
-        tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) > 63, True)
+        tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) > 63, True)
 
         # expression containing variable a is on rhs
         res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
-        tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) >= 63, True)
+        tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) >= 63, True)
 
     test_complex(0, 4, 4)
     test_complex(0, 4, -4)
index 129237a..4d887d6 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_basic():
     a = te.var("a")
     b = te.var("b")
     c = te.var("c")
-    m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a])
+    m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a])
     tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1)
     assert m[0].value == 2
-    m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a, b])
+    m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a, b])
     assert len(m) == 0
-    m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20,
-                                          b - 1 > 0), [a, b])
+    m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20, b - 1 > 0), [a, b])
     tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c)
     tvm.testing.assert_prim_expr_equal(m[2], 2)
 
index 82153ab..16b75f7 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_basic():
     a = te.var("a")
     b = te.var("b")
@@ -24,10 +25,10 @@ def test_basic():
     assert m[0].value == 4
     tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7)
 
-    m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * 4 * (a + 1) + b * 6 + 7, [a])
     assert len(m) == 0
 
-    m = tvm.arith.detect_linear_equation(a * 4  + (a+1) + b * 6 + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * 4 + (a + 1) + b * 6 + 7, [a])
     assert m[0].value == 5
     tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7 + 1)
 
@@ -41,6 +42,7 @@ def test_basic():
     assert len(m) == 1
     tvm.testing.assert_prim_expr_equal(m[0], b * 7)
 
+
 def test_multivariate():
     v = [te.var("v%d" % i) for i in range(4)]
     b = te.var("b")
@@ -48,28 +50,29 @@ def test_multivariate():
 
     tvm.testing.assert_prim_expr_equal(m[0], b + 5)
 
-    assert(m[1].value == 8)
+    assert m[1].value == 8
 
     m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
-    assert(len(m) == 0)
+    assert len(m) == 0
 
     m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
-    assert(len(m) == 0)
+    assert len(m) == 0
 
     m = tvm.arith.detect_linear_equation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
-    assert(m[1].value == 16)
-    assert(m[2].value == 2)
-    assert(m[len(m)-1].value == 2)
+    assert m[1].value == 16
+    assert m[2].value == 2
+    assert m[len(m) - 1].value == 2
 
     m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]])
-    assert(m[0].value == 0)
+    assert m[0].value == 0
 
     tvm.testing.assert_prim_expr_equal(m[1], v[0] - v[1])
 
     m = tvm.arith.detect_linear_equation((v[0] - v[1]), [])
-    assert(len(m) == 1)
+    assert len(m) == 1
     tvm.testing.assert_prim_expr_equal(m[0], v[0] - v[1])
 
+
 if __name__ == "__main__":
     test_basic()
     test_multivariate()
index 1033721..ca5df4a 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_domain_touched():
-    i = te.var('i')
-    j = te.var('j')
+    i = te.var("i")
+    j = te.var("j")
     n = tvm.runtime.convert(100)
-    m = te.var('m')
-
-    a = tvm.tir.decl_buffer((n, m), name='a')
-    b = tvm.tir.decl_buffer((n, m), name='b')
+    m = te.var("m")
 
+    a = tvm.tir.decl_buffer((n, m), name="a")
+    b = tvm.tir.decl_buffer((n, m), name="b")
 
     ir = tvm.tir.For(
-            i, 0, n, 0, 0,
-            tvm.tir.For(j, 0, m, 0, 0,
-                tvm.tir.BufferStore(
-                    a,
-                    tvm.tir.BufferLoad(b, [i - 1, j + 1]) +
-                    tvm.tir.BufferLoad(a, [i - 1, j - 1]),
-                    [i, j]
-                )
-            )
+        i,
+        0,
+        n,
+        0,
+        0,
+        tvm.tir.For(
+            j,
+            0,
+            m,
+            0,
+            0,
+            tvm.tir.BufferStore(
+                a,
+                tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]),
+                [i, j],
+            ),
+        ),
     )
 
     a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
@@ -44,20 +52,20 @@ def test_domain_touched():
     assert a_domain_r[0].min.value == -1
     assert a_domain_r[0].extent.value == 100
     assert a_domain_r[1].min.value == -1
-    assert a_domain_r[1].extent.name == 'm'
+    assert a_domain_r[1].extent.name == "m"
 
     a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True)
     assert a_domain_w[0].min.value == 0
     assert a_domain_w[0].extent.value == 100
     assert a_domain_w[1].min.value == 0
-    assert a_domain_w[1].extent.name == 'm'
+    assert a_domain_w[1].extent.name == "m"
 
-    a_domain_rw= tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
+    a_domain_rw = tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
     assert a_domain_rw[0].min.value == -1
     assert a_domain_rw[0].extent.value == 101
     assert a_domain_rw[1].min.value == -1
     assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
-    assert a_domain_rw[1].extent.a.name == 'm'
+    assert a_domain_rw[1].extent.a.name == "m"
     assert a_domain_rw[1].extent.b.value == 1
 
     b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False)
@@ -65,11 +73,12 @@ def test_domain_touched():
     assert b_domain_r[0].min.value == -1
     assert b_domain_r[0].extent.value == 100
     assert b_domain_r[1].min.value == 1
-    assert b_domain_r[1].extent.name == 'm'
+    assert b_domain_r[1].extent.name == "m"
 
     b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
     assert isinstance(b_domain_w, tvm.container.Array)
     assert len(b_domain_w) == 0
 
+
 if __name__ == "__main__":
     test_domain_touched()
index 5e8c947..5c4cc94 100644 (file)
@@ -24,14 +24,18 @@ class IntSetChecker:
 
     def verify(self, data, dmap, expected):
         res = self.analyzer.int_set(data, dmap)
+
         def err_msg():
             return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
+
         def equal(x, y):
             res = self.analyzer.canonical_simplify(x - y)
             return tvm.tir.analysis.expr_deep_equal(res, 0)
+
         assert equal(res.min_value, expected[0]), err_msg()
         assert equal(res.max_value, expected[1]), err_msg()
 
+
 def test_basic():
     s = tvm.arith.IntervalSet(2, 3)
     assert s.min_value.value == 2
@@ -54,13 +58,10 @@ def test_vector():
 def test_add_sub():
     ck = IntSetChecker()
     x, y = te.var("x"), te.var("y")
-    ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
-    ck.verify(x + y,
-              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
-              (1, 21))
-    ck.verify(x - y,
-              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
-              (-11, 9))
+    ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
+    ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21))
+    ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9))
+
 
 def test_mul_div():
     ck = IntSetChecker()
@@ -68,16 +69,16 @@ def test_mul_div():
 
     tdiv = tvm.tir.truncdiv
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
-    ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
-    ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
-    ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
+    ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
+    ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20))
+    ck.verify(x * -2, {x: tvm.arith.IntervalSet(1, 10)}, (-20, -2))
 
-    ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
-    ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
+    ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
+    ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5))
 
     fld = tvm.te.floordiv
-    ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
-    ck.verify(fld(x, 2), {x : tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
+    ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
+    ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
 
 
 def test_mod():
@@ -85,32 +86,33 @@ def test_mod():
     x, y = te.var("x"), te.var("y")
     tmod = tvm.tir.truncmod
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
-    ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
-    ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
+    ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
+    ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9))
 
     flm = tvm.te.floormod
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
-    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 15)}, (0, 9))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9))
+    ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9))
 
     floordiv = tvm.te.floordiv
     z = te.var("z")
     ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
-    ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
-              (0, 7))
+    ck.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (0, 7))
     ck1 = IntSetChecker()
     ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
-    ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
+    ck1.verify(
+        flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (x * 4, x * 4 + 3)
+    )
 
 
 def test_max_min():
     ck = IntSetChecker()
     x, y = te.var("x"), te.var("y")
-    ck.verify(tvm.te.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
-    ck.verify(tvm.te.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
+    ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11))
+    ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9))
     ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
     ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
 
@@ -118,8 +120,7 @@ def test_max_min():
 def test_select():
     ck = IntSetChecker()
     x, y = te.var("x"), te.var("y")
-    ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
-              {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
+    ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11))
 
 
 if __name__ == "__main__":
index 7d9f739..4a4cd6a 100644 (file)
@@ -24,8 +24,7 @@ def test_cast():
     m = analyzer.modular_set((x * 3).astype("uint32"))
     assert m.coeff == 3
     assert m.base == 0
-    m = analyzer.modular_set(
-        (x * 3 + 1).astype("float32").astype("int32"))
+    m = analyzer.modular_set((x * 3 + 1).astype("float32").astype("int32"))
     assert m.coeff == 3
     assert m.base == 1
 
@@ -111,7 +110,7 @@ def test_mix_index():
     assert m.coeff == 2
     assert m.base == 0
 
-    m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7  + 2))
+    m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2))
     assert m.coeff == 3
     assert m.base == 2
 
@@ -142,6 +141,7 @@ def test_constraint_scope():
     assert m.coeff == 1
     assert m.base == 0
 
+
 def test_intersect():
     a = te.var("a")
     analyzer = tvm.arith.Analyzer()
@@ -159,6 +159,7 @@ def test_intersect():
                 assert m.coeff == 105
                 assert m.base == 23
 
+
 def test_let():
     analyzer = tvm.arith.Analyzer()
     x = te.var("x")
index ae7b432..0571ede 100644 (file)
 import tvm
 from tvm import te
 
+
 class RewriteChecker:
     def __init__(self):
         self.analyzer = tvm.arith.Analyzer()
 
     def verify(self, data, expected):
         res = self.analyzer.rewrite_simplify(data)
-        assert tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
+        assert tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format(
+            data, res, expected
+        )
 
 
 def test_vector_simplify():
     ck = RewriteChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     # Add rules
-    ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4),
-              tvm.tir.Ramp(x + y, 3, 4))
-    ck.verify(tvm.tir.Ramp(x, 1, 2) + y,
-              tvm.tir.Ramp(x + y, 1, 2))
-    ck.verify(y + tvm.tir.Ramp(x, 1, 2) ,
-              tvm.tir.Ramp(y + x, 1, 2))
-    ck.verify(y.astype("int32x2") + x.astype("int32x2"),
-              (y + x).astype("int32x2"))
-    ck.verify(tvm.tir.Broadcast(0, 4) + y,
-              tvm.tir.Broadcast(y, 4))
-    ck.verify(tvm.tir.Ramp(x, 1, 4).astype('float32x4') + tvm.tir.Broadcast(0.0, 4),
-              tvm.tir.Ramp(x, 1, 4).astype('float32x4'))
+    ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4))
+    ck.verify(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2))
+    ck.verify(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2))
+    ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2"))
+    ck.verify(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4))
+    ck.verify(
+        tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4),
+        tvm.tir.Ramp(x, 1, 4).astype("float32x4"),
+    )
     # Sub rules
-    ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4),
-              tvm.tir.Ramp(x - y, 2, 4))
-    ck.verify(tvm.tir.Ramp(x, 1, 2) - y,
-              tvm.tir.Ramp(x - y, 1, 2))
-    ck.verify(y - tvm.tir.Ramp(x, 1, 2) ,
-              tvm.tir.Ramp(y - x, -1, 2))
-    ck.verify(y.astype("int32x2") - x.astype("int32x2"),
-              (y - x).astype("int32x2"))
+    ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4))
+    ck.verify(tvm.tir.Ramp(x, 1, 2) - y, tvm.tir.Ramp(x - y, 1, 2))
+    ck.verify(y - tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y - x, -1, 2))
+    ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2"))
 
     # Mul rules
-    ck.verify(y.astype("int32x2") * x.astype("int32x2"),
-              (y * x).astype("int32x2"))
-    ck.verify(tvm.tir.Ramp(x, 4, 4) * 2,
-              tvm.tir.Ramp(x * 2, 8, 4))
-    ck.verify(2 * tvm.tir.Ramp(x, 4, 4),
-              tvm.tir.Ramp(x * 2, 8, 4))
-    ck.verify(tvm.tir.Broadcast(0, 4) * x,
-              tvm.tir.Broadcast(0, 4))
-    ck.verify(tvm.tir.Broadcast(0.0, 4) * x,
-              tvm.tir.Broadcast(0.0, 4))
+    ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2"))
+    ck.verify(tvm.tir.Ramp(x, 4, 4) * 2, tvm.tir.Ramp(x * 2, 8, 4))
+    ck.verify(2 * tvm.tir.Ramp(x, 4, 4), tvm.tir.Ramp(x * 2, 8, 4))
+    ck.verify(tvm.tir.Broadcast(0, 4) * x, tvm.tir.Broadcast(0, 4))
+    ck.verify(tvm.tir.Broadcast(0.0, 4) * x, tvm.tir.Broadcast(0.0, 4))
 
     ## DivMod rules
     tdiv = tvm.tir.truncdiv
     tmod = tvm.tir.truncmod
     # truc div
-    ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")),
-              tdiv(y, x).astype("int32x2"))
-    ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2),
-              tvm.tir.Ramp(tdiv(x, 2), 2, 4))
+    ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2"))
+    ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
-    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
-              (x).astype("int32x4"))
-    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
-              tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
+    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4"))
+    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
     # truc mod
-    ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
-              tmod(y, x).astype("int32x2"))
-    ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2),
-              tvm.tir.Broadcast(tmod(x, 2), 4))
-    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
-              tvm.tir.Ramp(1, 1, 4))
-    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
-              tmod(tvm.tir.Ramp(1, 15, 4), 8))
+    ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2"))
+    ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4))
+    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4))
+    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8))
 
     # floor div
     fld = tvm.te.floordiv
     flm = tvm.te.floormod
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True)
-    ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")),
-              fld(y, x).astype("int32x2"))
-    ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2),
-              tvm.tir.Ramp(fld(x, 2), 2, 4))
-    ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
-              (x).astype("int32x4"))
-    ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
-              fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
-    ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)),
-              tvm.tir.Ramp(fld(x, 4), 2, 5))
-    ck.verify(fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
-              fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
-    ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
-              tvm.tir.Broadcast(x * 2, 4))
-    ck.verify(fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
-              fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)))
-    ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
-              fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)))
-    ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
-              tvm.tir.Broadcast(fld(x, 16), 4))
-    ck.verify(fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
-              tvm.tir.Broadcast(fld(x, 8), 4))
-    ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
-              fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)))
-    ck.verify(fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
-              fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)))
-    ck.verify(fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
-              fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))
+    ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2"))
+    ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4))
+    ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4"))
+    ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
+    ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5))
+    ck.verify(
+        fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
+        fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
+    )
+    ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4))
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
+        fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
+        fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 16), 4)
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 8), 4)
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
+        fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
+        fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
+    )
+    ck.verify(
+        fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
+        fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
+    )
     # floor mod
-    ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")),
-              flm(y, x).astype("int32x2"))
-    ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2),
-              tvm.tir.Broadcast(flm(x, 2), 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
-              tvm.tir.Ramp(1, 1, 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
-              flm(tvm.tir.Ramp(1, 15, 4), 8))
-    ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)),
-              tvm.tir.Broadcast(flm(x, 4), 4))
-    ck.verify(flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
-              flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
-    ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
-              tvm.tir.Ramp(0, 1, 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)),
-              flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)))
-    ck.verify(flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)),
-              flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)))
-    ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
-              tvm.tir.Ramp(flm(x * 4, 64), 1, 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
-              tvm.tir.Ramp(flm(x * 8, 64), 2, 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
-              tvm.tir.Ramp(flm(x * 4, 64), 1, 5))
-    ck.verify(flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
-              tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4))
-    ck.verify(flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
-              flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))
+    ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2"))
+    ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4))
+    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4))
+    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8))
+    ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4))
+    ck.verify(
+        flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
+        flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
+    )
+    ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Ramp(0, 1, 4))
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)),
+        flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)),
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)),
+        flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)),
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4)
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4)
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5)
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
+        tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4),
+    )
+    ck.verify(
+        flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
+        flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
+    )
 
     # Min/Max rules
     vx = te.var("vx", dtype="int32x2")
     vc = te.var("vc", dtype="uint1")
-    ck.verify(tvm.te.min(y.astype("int32x2"), x.astype("int32x2")),
-              tvm.te.min(y, x).astype("int32x2"))
-    ck.verify(tvm.te.min(tvm.te.min(vx, y.astype("int32x2")), x.astype("int32x2")),
-              tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")))
-    ck.verify(tvm.te.max(y.astype("int32x2"), x.astype("int32x2")),
-              tvm.te.max(y, x).astype("int32x2"))
-    ck.verify(tvm.te.max(tvm.te.max(vx, y.astype("int32x2")), x.astype("int32x2")),
-              tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")))
+    ck.verify(
+        tvm.te.min(y.astype("int32x2"), x.astype("int32x2")), tvm.te.min(y, x).astype("int32x2")
+    )
+    ck.verify(
+        tvm.te.min(tvm.te.min(vx, y.astype("int32x2")), x.astype("int32x2")),
+        tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")),
+    )
+    ck.verify(
+        tvm.te.max(y.astype("int32x2"), x.astype("int32x2")), tvm.te.max(y, x).astype("int32x2")
+    )
+    ck.verify(
+        tvm.te.max(tvm.te.max(vx, y.astype("int32x2")), x.astype("int32x2")),
+        tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
+    )
 
     ## Logical rules
-    ck.verify(y.astype("int32x2").equal(x.astype("int32x2")),
-              (y.equal(x)).astype("uint1x2"))
-    ck.verify(tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
-              (tvm.tir.NE(y, x)).astype("uint1x2"))
-    ck.verify(y.astype("int32x2") > x.astype("int32x2"),
-              (x < y).astype("uint1x2"))
-    ck.verify(y.astype("int32x2") >= x.astype("int32x2"),
-              (x <= y).astype("uint1x2"))
-    ck.verify(y.astype("int32x2") < x.astype("int32x2"),
-              (y < x).astype("uint1x2"))
-    ck.verify(y.astype("int32x2") <= x.astype("int32x2"),
-              (y <= x).astype("uint1x2"))
-    ck.verify(tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-              (tvm.tir.And(y <= x, vc)).astype("uint1x2"))
-    ck.verify(tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-              (tvm.tir.Or(y <= x, vc)).astype("uint1x2"))
+    ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2"))
+    ck.verify(
+        tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.tir.NE(y, x)).astype("uint1x2")
+    )
+    ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2"))
+    ck.verify(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2"))
+    ck.verify(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2"))
+    ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2"))
+    ck.verify(
+        tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+        (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
+    )
+    ck.verify(
+        tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+        (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
+    )
 
 
 def test_select_simplify():
     ck = RewriteChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     # Add rules
-    ck.verify(tvm.tir.Select(x < 0, y, 0) + tvm.tir.Select(x < 0, 1, z),
-              tvm.tir.Select(x < 0, y + 1, z))
-    ck.verify(tvm.tir.Select(x < 0, y, 1) - tvm.tir.Select(x < 0, 1, z),
-              tvm.tir.Select(x < 0, y + (-1), 1 - z))
-    ck.verify(tvm.tir.Select(x < 0, y, z) - y,
-              tvm.tir.Select(x < 0, 0, z - y))
-    ck.verify(tvm.tir.Select(x < 0, y, z) - z,
-              tvm.tir.Select(x < 0, y - z, 0))
-    ck.verify(tvm.te.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
-              tvm.tir.Select(x < 0, tvm.te.min(y, 1), tvm.te.min(0, z)))
-    ck.verify(tvm.te.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
-              tvm.tir.Select(x < 0, tvm.te.max(y, 1), tvm.te.max(0, z)))
+    ck.verify(
+        tvm.tir.Select(x < 0, y, 0) + tvm.tir.Select(x < 0, 1, z), tvm.tir.Select(x < 0, y + 1, z)
+    )
+    ck.verify(
+        tvm.tir.Select(x < 0, y, 1) - tvm.tir.Select(x < 0, 1, z),
+        tvm.tir.Select(x < 0, y + (-1), 1 - z),
+    )
+    ck.verify(tvm.tir.Select(x < 0, y, z) - y, tvm.tir.Select(x < 0, 0, z - y))
+    ck.verify(tvm.tir.Select(x < 0, y, z) - z, tvm.tir.Select(x < 0, y - z, 0))
+    ck.verify(
+        tvm.te.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
+        tvm.tir.Select(x < 0, tvm.te.min(y, 1), tvm.te.min(0, z)),
+    )
+    ck.verify(
+        tvm.te.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
+        tvm.tir.Select(x < 0, tvm.te.max(y, 1), tvm.te.max(0, z)),
+    )
 
     ck.verify(tvm.tir.Select(x * 3 + 1 != 0, y, z), y)
     ck.verify(tvm.tir.Select(x * 3 + 1 == 0, y, z), z)
@@ -216,12 +225,12 @@ def test_add_index_simplify():
     ck.verify(tvm.te.max(x, y - 10) + 10, tvm.te.max(x + 10, y))
     ck.verify(tvm.te.max(x - 11, y) + 11, tvm.te.max(x, y + 11))
 
-    ck.verify(tvm.te.max(x, y * 2) + tvm.te.min(x, y * 2), x + y * 2);
-    ck.verify(tvm.te.min(x, y * 2) + tvm.te.max(x, y * 2), x + y * 2);
+    ck.verify(tvm.te.max(x, y * 2) + tvm.te.min(x, y * 2), x + y * 2)
+    ck.verify(tvm.te.min(x, y * 2) + tvm.te.max(x, y * 2), x + y * 2)
 
-    ck.verify(tvm.te.max(x, y + 2) + (-2), tvm.te.max(x + (-2), y));
-    ck.verify(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y));
-    ck.verify(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1));
+    ck.verify(tvm.te.max(x, y + 2) + (-2), tvm.te.max(x + (-2), y))
+    ck.verify(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y))
+    ck.verify(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1))
 
     ck.verify(tvm.te.max(0, 1 - x * 4) + x * 4, tvm.te.max(x * 4, 1))
     ck.verify(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2))
@@ -244,10 +253,9 @@ def test_add_index_simplify():
     ck.verify(x + 3 + y, x + y + 3)
     ck.verify((3 - y) + x, x - y + 3)
 
-
     # canonicalization
-    ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9);
-    ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
+    ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9)
+    ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9)
 
     # DivMod rules
     tdiv = tvm.tir.truncdiv
@@ -264,7 +272,6 @@ def test_add_index_simplify():
     ck.verify(fld(x, 8) * 8 + flm(x, 8), x)
 
 
-
 def test_sub_index_simplify():
     ck = RewriteChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
@@ -493,12 +500,12 @@ def test_mod_index_simplify():
     ck.verify(tmod(x * 10 + y, 2), tmod(y, 2))
     ck.verify(tmod(x + 10, 2), tmod(x, 2))
     ck.verify(tmod(x + y * 10, 2), tmod(x, 2))
-    ck.verify(tmod(x* 10 + 1 + y * 2 + 2, 2), 1)
+    ck.verify(tmod(x * 10 + 1 + y * 2 + 2, 2), 1)
     ck.verify(tmod(x * 10, -2), 0)
     ck.verify(tmod(x * 10 + y, -2), tmod(y, 2))
     ck.verify(tmod(x + 10, -2), tmod(x, 2))
     ck.verify(tmod(x + y * 10, -2), tmod(x, 2))
-    ck.verify(tmod(x* 10 + 1 + y * 2 + 2, -2), 1)
+    ck.verify(tmod(x * 10 + 1 + y * 2 + 2, -2), 1)
 
     ck.verify(tmod(x * (-10), 2), 0)
     ck.verify(tmod(x * (-10) + y, 2), tmod(x * (-10) + y, 2))
@@ -527,7 +534,7 @@ def test_floormod_index_simplify():
     ck.verify(flm(x * 10 + y, 2), flm(y, 2))
     ck.verify(flm(x + 10, 2), flm(x, 2))
     ck.verify(flm(x + y * 10, 2), flm(x, 2))
-    ck.verify(flm(x* 10 + 1 + y * 2 + 2, 2), 1)
+    ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
     ck.verify(flm(x * (-10), 2), 0)
     ck.verify(flm(x * (-10) + y, 2), flm(y, 2))
     ck.verify(flm(x + (-10), 2), flm(x, 2))
@@ -565,12 +572,15 @@ def test_min_index_simplify():
     ck.verify(tvm.te.min(x, tvm.te.min(x, y)), tvm.te.min(x, y))
     ck.verify(tvm.te.min(y, tvm.te.min(x, y)), tvm.te.min(x, y))
 
-    ck.verify(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), y),
-              tvm.te.min(tvm.te.min(x, y), z))
-    ck.verify(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), y),
-              tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2))
-    ck.verify(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z * 2), y),
-              tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z * 2))
+    ck.verify(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), y), tvm.te.min(tvm.te.min(x, y), z))
+    ck.verify(
+        tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), y),
+        tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2),
+    )
+    ck.verify(
+        tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z * 2), y),
+        tvm.te.min(tvm.te.min(tvm.te.min(tvm.te.min(x, y), z), x * 2), z * 2),
+    )
 
     ck.verify(tvm.te.min(tvm.te.max(x, y), tvm.te.max(x, z)), tvm.te.max(tvm.te.min(y, z), x))
     ck.verify(tvm.te.min(tvm.te.max(x, y), tvm.te.max(z, x)), tvm.te.max(tvm.te.min(y, z), x))
@@ -591,7 +601,7 @@ def test_min_index_simplify():
     ck.verify(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3)
     ck.verify(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2)
     ck.verify(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2)
-    ck.verify(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x,  1))
+    ck.verify(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x, 1))
     ck.verify(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2)
     ck.verify(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2)
     ck.verify(tvm.te.min(x * (0), 4), 0)
@@ -606,8 +616,7 @@ def test_min_index_simplify():
     ck.verify(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), tvm.te.max(x, 4))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
     ck.verify(tvm.te.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.min(x, y), 10))
-    ck.verify(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))),
-              tdiv(tvm.te.max(x, y), (-10)))
+    ck.verify(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.te.max(x, y), (-10)))
 
     # floor div
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
@@ -651,12 +660,15 @@ def test_max_index_simplify():
     ck.verify(tvm.te.max(x, tvm.te.max(x, y)), tvm.te.max(x, y))
     ck.verify(tvm.te.max(y, tvm.te.max(x, y)), tvm.te.max(x, y))
 
-    ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y),
-              tvm.te.max(tvm.te.max(x, y), z))
-    ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y),
-              tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2))
-    ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2), y),
-              tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2))
+    ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y), tvm.te.max(tvm.te.max(x, y), z))
+    ck.verify(
+        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y),
+        tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2),
+    )
+    ck.verify(
+        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2), y),
+        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2),
+    )
 
     ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(x, z)), tvm.te.min(tvm.te.max(y, z), x))
     ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(z, x)), tvm.te.min(tvm.te.max(y, z), x))
@@ -675,7 +687,7 @@ def test_max_index_simplify():
     ck.verify(tvm.te.max(tvm.te.max(x, 11), 10), tvm.te.max(x, 11))
 
     ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3)
-    ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x,  2))
+    ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2))
     ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2)
     ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2)
     ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2)
@@ -839,10 +851,10 @@ def test_cmp_simplify():
     ck.verify(x + y >= -10, tvm.tir.const(1, "bool"))
     ck.verify(z - 5 <= y + 10, tvm.tir.const(1, "bool"))
     ck.verify(tvm.tir.all(x > -1, z <= x + 5), tvm.tir.const(1, "bool"))
-    ck.verify(x*y <= 0, tvm.tir.const(1, "bool"))
-    ck.verify((x + 1)*(y - 1) < 0, tvm.tir.const(1, "bool"))
-    ck.verify(y*y >= 0, tvm.tir.const(1, "bool"))
-    ck.verify(x*6 <= -3, tvm.tir.const(0, "bool"))
+    ck.verify(x * y <= 0, tvm.tir.const(1, "bool"))
+    ck.verify((x + 1) * (y - 1) < 0, tvm.tir.const(1, "bool"))
+    ck.verify(y * y >= 0, tvm.tir.const(1, "bool"))
+    ck.verify(x * 6 <= -3, tvm.tir.const(0, "bool"))
     ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0)
 
 
@@ -850,10 +862,8 @@ def test_logical_simplify():
     ck = RewriteChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
 
-    ck.verify(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
-              tvm.tir.const(False, "bool"))
-    ck.verify(tvm.tir.And(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)),
-              tvm.tir.const(False, "bool"))
+    ck.verify(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(False, "bool"))
+    ck.verify(tvm.tir.And(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(False, "bool"))
     ck.verify(tvm.tir.And(x > 1, tvm.tir.Not(x > 1)), tvm.tir.const(False, "bool"))
     ck.verify(tvm.tir.And(x <= y, y < x), tvm.tir.const(False, "bool"))
     ck.verify(tvm.tir.And(y < x, x <= y), tvm.tir.const(False, "bool"))
@@ -867,11 +877,8 @@ def test_logical_simplify():
     ck.verify(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool"))
     ck.verify(tvm.tir.And(x == 1, x != 2), x == 1)
 
-
-    ck.verify(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
-              tvm.tir.const(True, "bool"))
-    ck.verify(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)),
-              tvm.tir.const(True, "bool"))
+    ck.verify(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool"))
+    ck.verify(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool"))
     ck.verify(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool"))
 
     ck.verify(tvm.tir.Or(x <= y, y < x), tvm.tir.const(True, "bool"))
@@ -888,12 +895,14 @@ def test_logical_simplify():
     ck.verify(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool"))
     ck.verify(tvm.tir.Or(x != 1, x == 2), x != 1)
 
+
 def test_let_simplify():
     ck = RewriteChecker()
     x, y = te.var("x"), te.var("y")
     z = tvm.tir.Let(x, 1, x + 1)
     ck.verify(z + z, 4)
 
+
 def test_cast_simplify():
     ck = RewriteChecker()
     x = te.var("x")
@@ -906,6 +915,7 @@ def test_cast_simplify():
             for i in [0, 1, 2, 3]:
                 ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1))
 
+
 if __name__ == "__main__":
     test_floordiv_index_simplify()
     test_floormod_index_simplify()
index 968e40b..87aea26 100644 (file)
@@ -23,8 +23,10 @@ from tvm import te, arith, ir, tir, testing
 
 def test_solution_consistency():
     seed = random.randrange(sys.maxsize)
-    print("\nThis test is intentionally non-deterministic, "
-          "if it fails please report it in github issue together with this seed {}\n".format(seed))
+    print(
+        "\nThis test is intentionally non-deterministic, "
+        "if it fails please report it in github issue together with this seed {}\n".format(seed)
+    )
     random.seed(seed)
 
     def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
@@ -32,9 +34,9 @@ def test_solution_consistency():
 
         relations = []
         for i in range(num_formulas):
-            s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+            s1 = sum([v * random.randint(coef[0], coef[1]) for v in variables])
             s1 += random.randint(coef[0], coef[1])
-            s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
+            s2 = sum([v * random.randint(coef[0], coef[1]) for v in variables])
             s2 += random.randint(coef[0], coef[1])
             if random.random() < 0.7:
                 op = tvm.tir.EQ
@@ -99,10 +101,13 @@ def test_empty_var_to_solve():
 def test_unique_solution():
     x, y = te.var("x"), te.var("y")
 
-    solution = arith.solve_linear_equations([
-        tvm.tir.EQ(x + y, 20),
-        tvm.tir.EQ(x - y, 10),
-    ], [x, y])
+    solution = arith.solve_linear_equations(
+        [
+            tvm.tir.EQ(x + y, 20),
+            tvm.tir.EQ(x - y, 10),
+        ],
+        [x, y],
+    )
     assert list(solution.dst.variables) == []
     assert ir.structural_equal(solution.src_to_dst[x], 15)
     assert ir.structural_equal(solution.src_to_dst[y], 5)
@@ -112,10 +117,14 @@ def test_low_rank():
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     ranges = {}
 
-    solution = arith.solve_linear_equations([
-        tvm.tir.EQ(x + y + z, 15),
-        tvm.tir.EQ(x + y, 10),
-    ], [x, y, z], ranges)
+    solution = arith.solve_linear_equations(
+        [
+            tvm.tir.EQ(x + y + z, 15),
+            tvm.tir.EQ(x + y, 10),
+        ],
+        [x, y, z],
+        ranges,
+    )
     [n0] = solution.dst.variables
     assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
     assert ir.structural_equal(solution.src_to_dst[y], -n0)
@@ -129,9 +138,13 @@ def test_infer_range():
         y: tvm.ir.Range.from_min_extent(0, 10),
     }
 
-    solution = arith.solve_linear_equations([
-        tvm.tir.EQ(x + y, 0),
-    ], [x, y], ranges)
+    solution = arith.solve_linear_equations(
+        [
+            tvm.tir.EQ(x + y, 0),
+        ],
+        [x, y],
+        ranges,
+    )
     [n0] = solution.dst.variables
     assert ir.structural_equal(solution.src_to_dst[x], n0)
     assert ir.structural_equal(solution.src_to_dst[y], -n0)
@@ -148,11 +161,15 @@ def test_infer_range():
 def test_ill_formed():
     x, y = te.var("x"), te.var("y")
 
-    solution = arith.solve_linear_equations([
-        tvm.tir.EQ(x + y, 0),
-        tvm.tir.EQ(x - y, 0),
-        tvm.tir.EQ(x, 5),
-    ], [x, y], {})
+    solution = arith.solve_linear_equations(
+        [
+            tvm.tir.EQ(x + y, 0),
+            tvm.tir.EQ(x - y, 0),
+            tvm.tir.EQ(x, 5),
+        ],
+        [x, y],
+        {},
+    )
     assert list(solution.dst.variables) == []
     [rel] = solution.dst.relations
     assert ir.structural_equal(rel, False)
index 80618dd..b2a6c9c 100644 (file)
@@ -23,8 +23,10 @@ from tvm import te, arith, ir, tir, testing
 
 def test_solution_consistency():
     seed = random.randrange(sys.maxsize)
-    print("\nThis test is intentionally non-deterministic, "
-          "if it fails please report it in github issue together with this seed {}\n".format(seed))
+    print(
+        "\nThis test is intentionally non-deterministic, "
+        "if it fails please report it in github issue together with this seed {}\n".format(seed)
+    )
     random.seed(seed)
 
     def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)):
@@ -32,17 +34,17 @@ def test_solution_consistency():
 
         fs = []
         for i in range(formulas):
-            s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs])
+            s1 = sum([v * random.randint(coef[0], coef[1]) for v in vs])
             s1 += random.randint(coef[0], coef[1])
-            s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs])
+            s2 = sum([v * random.randint(coef[0], coef[1]) for v in vs])
             s2 += random.randint(coef[0], coef[1])
             op = random.choice([tir.expr.EQ, tir.expr.LE, tir.expr.LT, tir.expr.GE, tir.expr.GT])
             fs.append(op(s1, s2))
 
         vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs}
-        before = te.all(tir.const(1, 'bool'), *fs)
+        before = te.all(tir.const(1, "bool"), *fs)
         after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs)
-        after = te.all(tir.const(1, 'bool'), *after)
+        after = te.all(tir.const(1, "bool"), *after)
         testing.check_bool_expr_is_true(before == after, vranges)
 
         solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True)
@@ -109,7 +111,7 @@ def test_dual_variable():
     solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True)
     [x_new, y_new] = solution.dst.variables
     [rel] = solution.dst.relations
-    assert ir.structural_equal(rel, (y_new*2) + x_new <= 10)
+    assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10)
     assert ir.structural_equal(solution.dst.ranges[x_new].min, 0)
     assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11)
     assert ir.structural_equal(solution.dst.ranges[y_new].min, 0)
@@ -164,8 +166,9 @@ def test_multi_equal():
     # (z*y - 6) <= 0 && (6 - z*y) <= 0
     ana = tvm.arith.Analyzer()
     assert ana.simplify(solution.relations[1].a + solution.relations[2].a) == 0
-    assert ir.structural_equal(solution.relations[1].a, (z*y - 6)) or \
-        ir.structural_equal(solution.relations[2].a, (z*y - 6))
+    assert ir.structural_equal(solution.relations[1].a, (z * y - 6)) or ir.structural_equal(
+        solution.relations[2].a, (z * y - 6)
+    )
 
     solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True)
     assert solution.src_to_dst[y] == y
@@ -176,7 +179,7 @@ def test_multi_equal():
 def test_no_solution():
     x = te.var("x0")
     vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)}
-    problem = [-x - 4 <= -5*x + 2, x*4 + 5 <= x*5]
+    problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5]
 
     solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True)
     assert list(solution.dst.variables) == []
index a1c12ed..33e498e 100644 (file)
@@ -28,44 +28,51 @@ from tvm.topi.util import get_const_tuple
 
 @auto_scheduler.register_workload
 def matmul_auto_scheduler_test(N, M, K):
-    A = te.placeholder((N, K), name='A')
-    B = te.placeholder((K, M), name='B')
-    k = te.reduce_axis((0, K), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C', attrs={"layout_free_placeholders":[B]})
+    A = te.placeholder((N, K), name="A")
+    B = te.placeholder((K, M), name="B")
+    k = te.reduce_axis((0, K), name="k")
+    C = te.compute(
+        (N, M),
+        lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
+        name="C",
+        attrs={"layout_free_placeholders": [B]},
+    )
     return [A, B, C]
 
 
 # Test for register_workload with different name
 @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1")
 def matmul_auto_scheduler_test_rename_0(N, M, K):
-    A = te.placeholder((N, K), name='A')
-    B = te.placeholder((K, M), name='B')
-    k = te.reduce_axis((0, K), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+    A = te.placeholder((N, K), name="A")
+    B = te.placeholder((K, M), name="B")
+    k = te.reduce_axis((0, K), name="k")
+    C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
     return [A, B, C]
 
 
 @auto_scheduler.register_workload
-def conv2d_nchw_bn_relu_auto_scheduler_test(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1):
-    data = te.placeholder((N, CI, H, W), name='Data')
-    kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel')
-    bias = te.placeholder((CO, 1, 1), name='Bias')
-    bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale')
-    bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset')
+def conv2d_nchw_bn_relu_auto_scheduler_test(
+    N, H, W, CI, CO, kernel_size, strides, padding, dilation=1
+):
+    data = te.placeholder((N, CI, H, W), name="Data")
+    kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name="Kernel")
+    bias = te.placeholder((CO, 1, 1), name="Bias")
+    bn_scale = te.placeholder((CO, 1, 1), name="Bn_scale")
+    bn_offset = te.placeholder((CO, 1, 1), name="Bn_offset")
 
     OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
     OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
 
     conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation)
-    conv = te.compute((N, CO, OH, OW),
-                      lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0],
-                      name='Bias_add')
-    conv = te.compute((N, CO, OH, OW),
-                      lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0],
-                      name='Bn_mul')
-    conv = te.compute((N, CO, OH, OW),
-                      lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0],
-                      name='Bn_add')
+    conv = te.compute(
+        (N, CO, OH, OW), lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], name="Bias_add"
+    )
+    conv = te.compute(
+        (N, CO, OH, OW), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], name="Bn_mul"
+    )
+    conv = te.compute(
+        (N, CO, OH, OW), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], name="Bn_add"
+    )
     out = topi.nn.relu(conv)
 
     return [data, kernel, bias, bn_offset, bn_scale, out]
@@ -73,15 +80,15 @@ def conv2d_nchw_bn_relu_auto_scheduler_test(N, H, W, CI, CO, kernel_size, stride
 
 @auto_scheduler.register_workload
 def max_pool2d_auto_scheduler_test(N, H, W, CI, padding):
-    data = te.placeholder((N, CI, H, W), name='Data')
-    out = topi.nn.pool(data, [2, 2], [1, 1], [padding, padding, padding, padding], 'max')
+    data = te.placeholder((N, CI, H, W), name="Data")
+    out = topi.nn.pool(data, [2, 2], [1, 1], [padding, padding, padding, padding], "max")
 
     return [data, out]
 
 
 @auto_scheduler.register_workload
 def min_nm_auto_scheduler_test(N, M):
-    A = te.placeholder((N, M), name='A')
+    A = te.placeholder((N, M), name="A")
     B = topi.min(A, axis=-1)
 
     return [A, B]
@@ -89,7 +96,7 @@ def min_nm_auto_scheduler_test(N, M):
 
 @auto_scheduler.register_workload
 def softmax_nm_auto_scheduler_test(N, M):
-    A = te.placeholder((N, M), name='A')
+    A = te.placeholder((N, M), name="A")
     B = topi.nn.softmax(A, axis=1)
 
     return [A, B]
@@ -97,16 +104,18 @@ def softmax_nm_auto_scheduler_test(N, M):
 
 @auto_scheduler.register_workload
 def softmax_abcd_auto_scheduler_test(a, b, c, d):
-    A = te.placeholder((a, b, c, d), name='A')
+    A = te.placeholder((a, b, c, d), name="A")
     B = topi.nn.softmax(A, axis=-1)
 
     return [A, B]
 
 
 @auto_scheduler.register_workload
-def conv2d_winograd_nhwc_auto_scheduler_test(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1):
+def conv2d_winograd_nhwc_auto_scheduler_test(
+    N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1
+):
     tile_size = 4
-    inputs = te.placeholder((N, H, W, CI), name='inputs')
+    inputs = te.placeholder((N, H, W, CI), name="inputs")
     N, H, W, CI = get_const_tuple(inputs.shape)
     if isinstance(dilation, int):
         dilation_h = dilation_w = dilation
@@ -125,57 +134,74 @@ def conv2d_winograd_nhwc_auto_scheduler_test(N, H, W, CI, CO, kernel_size=3, str
     r = KW
     m = tile_size
     alpha = m + r - 1
-    A, B, G = winograd_transform_matrices(m, r, 'float32')
+    A, B, G = winograd_transform_matrices(m, r, "float32")
 
     H = (H + 2 * HPAD - KH) // HSTR + 1
     W = (W + 2 * WPAD - KW) // WSTR + 1
     nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
-    r_kh = te.reduce_axis((0, KH), name='r_kh')
-    r_kw = te.reduce_axis((0, KW), name='r_kw')
+    r_kh = te.reduce_axis((0, KH), name="r_kh")
+    r_kw = te.reduce_axis((0, KW), name="r_kw")
     kshape = (alpha, alpha, CI, CO)
     kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")
 
     idxdiv = te.indexdiv
     idxmod = te.indexmod
     # pack input tile
-    input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
-                             data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps]
-                                     [idxmod(p, nW) * m + nu][ci], name='input_tile')
+    input_tile = te.compute(
+        (alpha, alpha, P, CI),
+        lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][
+            idxmod(p, nW) * m + nu
+        ][ci],
+        name="input_tile",
+    )
 
     # transform data
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
-                            te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu],
-                                    axis=[r_a, r_b]), name='data_pack',
-                            attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]})
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    data_pack = te.compute(
+        (alpha, alpha, P, CI),
+        lambda eps, nu, p, ci: te.sum(
+            input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
+        ),
+        name="data_pack",
+        attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]},
+    )
 
     # do batch gemm
-    ci = te.reduce_axis((0, CI), name='ci')
-    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
-                        te.sum(data_pack[eps][nu][p][ci] *
-                                kernel_pack[eps][nu][ci][co],
-                                axis=[ci]), name='bgemm')
+    ci = te.reduce_axis((0, CI), name="ci")
+    bgemm = te.compute(
+        (alpha, alpha, P, CO),
+        lambda eps, nu, p, co: te.sum(
+            data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci]
+        ),
+        name="bgemm",
+    )
 
     # inverse transform
-    r_a = te.reduce_axis((0, alpha), 'r_a')
-    r_b = te.reduce_axis((0, alpha), 'r_b')
-    inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co:
-                          te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw],
-                                  axis=[r_a, r_b]), name='inverse',
-                          attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]})
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    inverse = te.compute(
+        (m, m, P, CO),
+        lambda vh, vw, p, co: te.sum(
+            bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+        ),
+        name="inverse",
+        attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]},
+    )
 
     # output
-    output = te.compute((N, H, W, CO), lambda n, h, w, co:
-                         inverse[idxmod(h, m),
-                                 idxmod(w, m),
-                                 n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
-                                 co],
-                         name='conv2d_winograd')
+    output = te.compute(
+        (N, H, W, CO),
+        lambda n, h, w, co: inverse[
+            idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co
+        ],
+        name="conv2d_winograd",
+    )
 
     return [inputs, kernel_pack, output]
 
+
 def get_tiled_matmul():
     A, B, C = matmul_auto_scheduler_test(512, 512, 512)
     dag = auto_scheduler.ComputeDAG([A, B, C])
@@ -183,8 +209,9 @@ def get_tiled_matmul():
     s0 = dag.get_init_state()
     its0 = s0.split(C, s0[C].iters[0], [4, 8, 8])
     its1 = s0.split(C, s0[C].iters[4], [8, 4, 4])
-    s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3],
-                   s0[C].iters[8]])
+    s0.reorder(
+        C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], s0[C].iters[8]]
+    )
 
     return dag, s0
 
index 6b76dc6..a58f2ca 100644 (file)
@@ -51,7 +51,7 @@ def test_estimate_flop():
     dag = auto_scheduler.ComputeDAG([A, B, E])
     assert abs(dag.flop_ct - 2 * N ** 3) < 0.5
 
-    F = te.compute((N, N), lambda i, j: E[i,j], name='F', attrs={"FLOP": 1234})
+    F = te.compute((N, N), lambda i, j: E[i, j], name="F", attrs={"FLOP": 1234})
     dag = auto_scheduler.ComputeDAG([A, B, F])
     assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5
 
index 5d58ae0..a28618c 100644 (file)
@@ -30,17 +30,18 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test
 def get_sample_records(number):
     """Generate random a list of random MeasureInput and MeasureResult pairs"""
     N = 128
-    workload_key = auto_scheduler.make_workload_key(
-        matmul_auto_scheduler_test, (N, N, N))
+    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
     dag = auto_scheduler.ComputeDAG(workload_key)
-    target = tvm.target.Target('llvm')
+    target = tvm.target.Target("llvm")
     task = auto_scheduler.SearchTask(dag, workload_key, target)
     policy = auto_scheduler.SketchPolicy(task, verbose=0)
     states = policy.sample_initial_population(number)
 
     inputs = [auto_scheduler.MeasureInput(task, s) for s in states]
-    results = [auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0)
-               for _ in range(len(inputs))]
+    results = [
+        auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0)
+        for _ in range(len(inputs))
+    ]
 
     return task, dag, inputs, results
 
@@ -65,8 +66,7 @@ def test_xgb_model():
     costs = [np.mean([x.value for x in res.costs]) for res in results]
     throughputs = np.min(costs) / costs
 
-    rmse = np.sqrt(np.mean([np.square(pred - label)
-                            for pred, label in zip(preds, throughputs)]))
+    rmse = np.sqrt(np.mean([np.square(pred - label) for pred, label in zip(preds, throughputs)]))
     assert rmse <= 0.3
 
     with tempfile.NamedTemporaryFile() as fp:
index ff43432..eb706b7 100644 (file)
@@ -29,8 +29,8 @@ class MockCostModel(PythonBasedModel):
         scores = []
         found = False
         for state in states:
-            for line in str(state).split('\n'):
-                if line.find('k.1') != -1 and line.find('(0,2)') != -1:
+            for line in str(state).split("\n"):
+                if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                     found = True
                     break
             scores.append(1 if found else 0)
@@ -44,20 +44,17 @@ def test_evo_search():
     This unit test has been tested with 1,000 runs with no failures, meaning that
     the failure rate is less than 0.1%.
     """
-    workload_key = auto_scheduler.make_workload_key(
-        matmul_auto_scheduler_test, (10, 10, 4))
+    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4))
     dag = auto_scheduler.ComputeDAG(workload_key)
-    task = auto_scheduler.SearchTask(
-        dag, workload_key, tvm.target.Target('llvm'))
-    policy = auto_scheduler.SketchPolicy(
-        task, schedule_cost_model=MockCostModel(), verbose=0)
+    task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
+    policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=MockCostModel(), verbose=0)
     states = policy.sample_initial_population(50)
     pruned_states = []
     for state in states:
         found = False
-        for line in str(state).split('\n'):
+        for line in str(state).split("\n"):
             # Remove all tile_k=2 states and expect evo search will fine them.
-            if line.find('k.1') != -1 and line.find('(0,2)') != -1:
+            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                 found = True
                 break
         if not found:
@@ -66,9 +63,9 @@ def test_evo_search():
     new_states = policy.evolutionary_search(pruned_states, 50)
     found = False
     for state in new_states:
-        for line in str(state).split('\n'):
+        for line in str(state).split("\n"):
             # Check if evo search found at least one state with tile_k=2.
-            if line.find('k.1') != -1 and line.find('(0,2)') != -1:
+            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                 found = True
                 break
         if found:
index 6d6f5cd..8cbe201 100644 (file)
@@ -44,11 +44,10 @@ def test_cpu_matmul():
     s.parallel(C, jo)
     s.unroll(C, k)
 
-    target = tvm.target.Target('llvm')
+    target = tvm.target.Target("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)
     names = auto_scheduler.feature.get_per_store_feature_names()
-    fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[
-        0]
+    fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0]
 
     stage_0 = fea[0]
     assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names))
@@ -79,11 +78,9 @@ def test_cpu_matmul():
 
     # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc.
     assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1))
-    assert fequal(fea_dict[b_name + ".unique_bytes"],
-                  math.log2(512 ** 2 * 4 + 1))
+    assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1))
     assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1))
-    assert fequal(fea_dict[c_name + ".reuse_dis_bytes"],
-                  math.log2((8 * 16 + 8 + 16) * 4 + 1))
+    assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1))
     assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1))
 
     # check annotations
@@ -91,26 +88,24 @@ def test_cpu_matmul():
     # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0)
     assert fequal(fea_dict["vec_num"], math.log2(1 + 1))
     assert fequal(fea_dict["parallel_num"], math.log2(2 + 1))
-    assert fequal(fea_dict["parallel_prod"],
-                  math.log2((512 * 512 / 16 / 8) + 1))
+    assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1))
 
 
 def test_cpu_fusion():
     def fusion_test(N, M):
-        A = te.placeholder((N, M), name='A')
-        B = te.compute((N, M), lambda i, j: A[i][j], name='B')
-        C = te.compute((N, M), lambda i, j: B[i][j], name='C')
+        A = te.placeholder((N, M), name="A")
+        B = te.compute((N, M), lambda i, j: A[i][j], name="B")
+        C = te.compute((N, M), lambda i, j: B[i][j], name="C")
         return [A, B, C]
 
     dag = auto_scheduler.ComputeDAG(fusion_test(64, 32))
     s = dag.get_init_state()
     s.compute_at(1, 2, s.stages[2].iters[1])
 
-    target = tvm.target.Target('llvm')
+    target = tvm.target.Target("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)
     names = auto_scheduler.feature.get_per_store_feature_names()
-    fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[
-        0]
+    fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0]
 
     """
     lowered IR:
@@ -128,7 +123,7 @@ def test_cpu_fusion():
     found = False
     for stage_fea in fea:
         for i, (name, value) in enumerate(zip(names, stage_fea)):
-            if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5:
+            if "reuse_type.kSerialMultipleReadWrite" in name and value > 0.5:
                 # reuse distance in #iter
                 assert fequal(stage_fea[i + 2], 1.0)
                 # reuse distance in bytes
@@ -139,12 +134,14 @@ def test_cpu_fusion():
 
 def test_gpu_feature():
     # Use records to build a complicated GPU program
-    json_records = "\n".join((
-        """{"i": [["[\\"matmul_auto_scheduler_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""",
-    ))
+    json_records = "\n".join(
+        (
+            """{"i": [["[\\"matmul_auto_scheduler_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""",
+        )
+    )
 
     # load states
-    with tempfile.NamedTemporaryFile(mode='w') as f:
+    with tempfile.NamedTemporaryFile(mode="w") as f:
         f.write(json_records)
         f.flush()
         inputs, results = auto_scheduler.RecordReader(f.name).read_lines()
@@ -152,11 +149,15 @@ def test_gpu_feature():
         inp = inputs[0]
         dag = auto_scheduler.ComputeDAG(inp.task.workload_key)
         task = auto_scheduler.SearchTask(
-            dag, inp.task.workload_key, inp.task.target, None, auto_scheduler.HardwareParams(100000, 16, 64))
+            dag,
+            inp.task.workload_key,
+            inp.task.target,
+            None,
+            auto_scheduler.HardwareParams(100000, 16, 64),
+        )
 
         state = dag.infer_bound_from_state(inputs[0].state)
-        fea = auto_scheduler.feature.get_per_store_features_from_states([state], task)[
-            0]
+        fea = auto_scheduler.feature.get_per_store_features_from_states([state], task)[0]
         names = auto_scheduler.feature.get_per_store_feature_names()
 
         # build feature dict
@@ -192,12 +193,12 @@ def test_gpu_feature():
         """
 
         # check gpu-related features
-        assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1))
-        assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1))
-        assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1))
-        assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1))
-        assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1))
-        assert fequal(fea_dicts[0]['is_gpu'], 1.0)
+        assert fequal(fea_dicts[0]["blockIdx_x_len"], math.log2(8 + 1))
+        assert fequal(fea_dicts[0]["vthread_len"], math.log2(4 + 1))
+        assert fequal(fea_dicts[1]["threadIdx_x_len"], math.log2(16 + 1))
+        assert fequal(fea_dicts[0]["threadIdx_y_len"], math.log2(1 + 1))
+        assert fequal(fea_dicts[2]["blockIdx_z_len"], math.log2(1 + 1))
+        assert fequal(fea_dicts[0]["is_gpu"], 1.0)
 
 
 if __name__ == "__main__":
index 8842362..aba2784 100644 (file)
@@ -38,6 +38,7 @@ def test_apply_steps_with_layout_rewrite():
     assert bufs[1].shape[3] == 4
     assert bufs[1].shape[4] == 512
 
+
 def test_layout_rewrite_correctness():
     N = 128
     target = "llvm"
@@ -52,8 +53,12 @@ def test_layout_rewrite_correctness():
 
         search_policy = auto_scheduler.SketchPolicy(task)
 
-        tuning_options = auto_scheduler.TuningOptions(num_measure_trials=2,
-                runner='local', verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
+        tuning_options = auto_scheduler.TuningOptions(
+            num_measure_trials=2,
+            runner="local",
+            verbose=1,
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        )
         auto_scheduler.auto_schedule(task, search_policy, tuning_options)
         inp, _ = auto_scheduler.load_best(log_file, workload_key, target)
         s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
@@ -70,12 +75,22 @@ def test_layout_rewrite_correctness():
             out_dim = weight.shape[3 + base] * weight.shape[5 + base]
             for i in range(base + 2):
                 out_dim *= weight.shape[i]
-            new_order = [2 + base, 4 + base,] + list(range(base + 2)) + [3 + base, 5 + base,]
+            new_order = (
+                [
+                    2 + base,
+                    4 + base,
+                ]
+                + list(range(base + 2))
+                + [
+                    3 + base,
+                    5 + base,
+                ]
+            )
             np_args_ref[1] = np_args_ref[1].transpose(new_order)
             np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))
 
         func = tvm.build(s, bufs, target=inp.task.target, target_host=inp.task.target_host)
-        func_ref = tvm.build(s_ref, bufs_ref, target='llvm')
+        func_ref = tvm.build(s_ref, bufs_ref, target="llvm")
 
         ctx = tvm.context(str(inp.task.target))
         ctx_ref = tvm.cpu()
@@ -91,6 +106,7 @@ def test_layout_rewrite_correctness():
         np.testing.assert_allclose(np_args[0], np_args_ref[0])
         np.testing.assert_allclose(np_args[2], np_args_ref[2])
 
+
 if __name__ == "__main__":
     test_apply_steps_with_layout_rewrite()
     test_layout_rewrite_correctness()
index aeed420..44ed1fc 100644 (file)
@@ -23,7 +23,10 @@ import tvm
 from tvm import auto_scheduler, te
 from tvm import topi
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test
+from test_auto_scheduler_common import (
+    matmul_auto_scheduler_test,
+    conv2d_nchw_bn_relu_auto_scheduler_test,
+)
 
 
 def test_split_fuse_reorder_annotation():
@@ -85,10 +88,13 @@ def test_split_fuse_reorder_annotation():
     assert res == s1[C].iters[5]
     assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
 
+
 def test_compute_at_root_inline():
-    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu_auto_scheduler_test(N=1, H=224, W=224, CI=3,
-                                                                            CO=64, kernel_size=7, strides=2,
-                                                                            padding=3))
+    dag = auto_scheduler.ComputeDAG(
+        conv2d_nchw_bn_relu_auto_scheduler_test(
+            N=1, H=224, W=224, CI=3, CO=64, kernel_size=7, strides=2, padding=3
+        )
+    )
     s0 = dag.get_init_state()
 
     # data, padding, kernel = 0, 1, 2
@@ -142,18 +148,18 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[5].range.extent == 7
     assert s0[conv].iters[6].range.extent == 7
 
+
 def test_cache_read_write():
-    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
-        1, 1), (1, 1)
-
-    data = te.placeholder((N, CI, H, W), name='Data')
-    kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data')
-    k0, k1 = te.compute(kernel_data.shape,
-                        lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2),
-                        name='Kernel_split')
-    kernel = te.compute(kernel_data.shape,
-                        lambda *i: k0(*i) + k1(*i),
-                        name='Kernel')
+    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
+
+    data = te.placeholder((N, CI, H, W), name="Data")
+    kernel_data = te.placeholder((CO, CI, KH, KW), name="Kernel_data")
+    k0, k1 = te.compute(
+        kernel_data.shape,
+        lambda *i: (kernel_data(*i) + 1, kernel_data(*i) / 2),
+        name="Kernel_split",
+    )
+    kernel = te.compute(kernel_data.shape, lambda *i: k0(*i) + k1(*i), name="Kernel")
     conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
     relu = topi.nn.relu(conv)
     add = topi.add(data, relu)
@@ -416,6 +422,7 @@ def test_cache_read_write():
     for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
         assert it0.range == it1.range
 
+
 def test_follow_split_follow_fused_split():
     A, B, C = matmul_auto_scheduler_test(512, 512, 512)
     dag = auto_scheduler.ComputeDAG([A, B, C])
@@ -428,8 +435,7 @@ def test_follow_split_follow_fused_split():
         tmp = s0.copy()
         tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level)
         for i in range(0, level):
-            assert tmp[C].iters[i].range.extent == \
-                   tmp[C_global].iters[i].range.extent
+            assert tmp[C].iters[i].range.extent == tmp[C_global].iters[i].range.extent
 
     its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8])
     split_step1 = len(s0.transform_steps) - 1
@@ -443,17 +449,17 @@ def test_follow_split_follow_fused_split():
 
     for level in range(0, 4):
         tmp = s0.copy()
-        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
-                               [split_step0, split_step1], level, False)
-        assert tmp[C].iters[level + 1].range.extent == \
-               tmp[C_global].iters[0].range.extent
+        tmp.follow_fused_split(
+            C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, False
+        )
+        assert tmp[C].iters[level + 1].range.extent == tmp[C_global].iters[0].range.extent
 
     for level in range(0, 4):
         tmp = s0.copy()
-        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
-                               [split_step0, split_step1], level, True)
-        assert tmp[C].iters[level + 1].range.extent == \
-               tmp[C_global].iters[1].range.extent
+        tmp.follow_fused_split(
+            C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, True
+        )
+        assert tmp[C].iters[level + 1].range.extent == tmp[C_global].iters[1].range.extent
 
 
 def test_rfactor():
index c12240d..5dae2a5 100644 (file)
@@ -51,11 +51,10 @@ def test_record_split_reorder_fuse_annotation():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    A = te.placeholder((512, 512), name='A')
-    B = te.placeholder((512, 512), name='B')
-    k = te.reduce_axis((0, 512), name='k')
-    C = te.compute((512, 512), lambda i, j: te.sum(
-        A[i][k] * B[k][j], axis=[k]), name='C')
+    A = te.placeholder((512, 512), name="A")
+    B = te.placeholder((512, 512), name="B")
+    k = te.reduce_axis((0, 512), name="k")
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
 
     dag = auto_scheduler.ComputeDAG([A, B, C])
     s = dag.get_init_state()
@@ -64,8 +63,9 @@ def test_record_split_reorder_fuse_annotation():
     its0 = s.split(C, s[C].iters[0], [4, 8, 8])
     its1 = s.split(C, s[C].iters[4], [8, 4, 4])
     # Reorder
-    s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8],
-                  its1[3]])
+    s.reorder(
+        C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]
+    )
     # Fuse
     s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
     # Parallel
@@ -86,12 +86,11 @@ def test_record_compute_at_root_inline_cache_read_write():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    A = te.placeholder((512, 512), name='A')
+    A = te.placeholder((512, 512), name="A")
     AA = topi.nn.relu(A)
-    B = te.placeholder((512, 512), name='B')
-    k = te.reduce_axis((0, 512), name='k')
-    C = te.compute((512, 512), lambda i, j: te.sum(
-        AA[i][k] * B[k][j], axis=[k]), name='C')
+    B = te.placeholder((512, 512), name="B")
+    k = te.reduce_axis((0, 512), name="k")
+    C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name="C")
 
     dag = auto_scheduler.ComputeDAG([A, B, C])
     s = dag.get_init_state()
@@ -115,11 +114,10 @@ def test_record_follow_split_follow_fused_split():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    A = te.placeholder((512, 512), name='A')
-    B = te.placeholder((512, 512), name='B')
-    k = te.reduce_axis((0, 512), name='k')
-    C = te.compute((512, 512), lambda i, j: te.sum(
-        A[i][k] * B[k][j], axis=[k]), name='C')
+    A = te.placeholder((512, 512), name="A")
+    B = te.placeholder((512, 512), name="B")
+    k = te.reduce_axis((0, 512), name="k")
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
     D = topi.nn.relu(C)
     E = topi.nn.relu(D)
 
@@ -150,11 +148,10 @@ def test_record_pragma_storage_align_rfactor():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    A = te.placeholder((512, 512), name='A')
-    B = te.placeholder((512, 512), name='B')
-    k = te.reduce_axis((0, 512), name='k')
-    C = te.compute((512, 512), lambda i, j: te.sum(
-        A[i][k] * B[k][j], axis=[k]), name='C')
+    A = te.placeholder((512, 512), name="A")
+    B = te.placeholder((512, 512), name="B")
+    k = te.reduce_axis((0, 512), name="k")
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
 
     dag = auto_scheduler.ComputeDAG([A, B, C])
     s = dag.get_init_state()
@@ -180,8 +177,9 @@ def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
 
     minp = auto_scheduler.MeasureInput(task, s0)
     local_builder = auto_scheduler.LocalBuilder()
-    local_runner = auto_scheduler.LocalRunner(timeout=60,
-                                              enable_cpu_cache_flush=enable_cpu_cache_flush)
+    local_runner = auto_scheduler.LocalRunner(
+        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+    )
 
     bress = local_builder.build([minp])
     assert bress[0].error_no == 0
@@ -199,8 +197,9 @@ def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
 
     minp = auto_scheduler.MeasureInput(task, s0)
     local_builder = auto_scheduler.LocalBuilder()
-    measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60,
-                                                        enable_cpu_cache_flush=enable_cpu_cache_flush)
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
+        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+    )
     rpc_runner = measure_ctx.runner
 
     bress = local_builder.build([minp])
@@ -218,4 +217,3 @@ if __name__ == "__main__":
     test_measure_local_builder_runner(enable_cpu_cache_flush=False)
     test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=True)
     test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False)
-
index b10d520..3302dd5 100644 (file)
@@ -27,10 +27,16 @@ from tvm import auto_scheduler
 from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread
 
 
-def search_common(workload=matmul_auto_scheduler_test, target="llvm",
-                  search_policy='empty', seed=random.randint(1, 1 << 30), runner='local',
-                  cost_model=auto_scheduler.RandomModel(), num_measure_trials=2,
-                  init_search_callbacks=None):
+def search_common(
+    workload=matmul_auto_scheduler_test,
+    target="llvm",
+    search_policy="empty",
+    seed=random.randint(1, 1 << 30),
+    runner="local",
+    cost_model=auto_scheduler.RandomModel(),
+    num_measure_trials=2,
+    init_search_callbacks=None,
+):
     print("Test %s schedule search with the default search policy" % (target))
 
     random.seed(seed)
@@ -44,22 +50,25 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
         log_file = fp.name
 
         init_search_callbacks = init_search_callbacks or []
-        init_search_callbacks.append(
-            auto_scheduler.PreloadMeasuredStates(log_file))
+        init_search_callbacks.append(auto_scheduler.PreloadMeasuredStates(log_file))
 
-        if search_policy == 'empty':
+        if search_policy == "empty":
             search_policy = auto_scheduler.EmptyPolicy(task)
-        elif search_policy == 'sketch':
-            search_policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=cost_model,
-                                                        init_search_callbacks=init_search_callbacks)
-
-        tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials,
-                                                      runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
-        sch, args = auto_scheduler.auto_schedule(
-            task, search_policy, tuning_options)
-        print("*"*80)
+        elif search_policy == "sketch":
+            search_policy = auto_scheduler.SketchPolicy(
+                task, schedule_cost_model=cost_model, init_search_callbacks=init_search_callbacks
+            )
+
+        tuning_options = auto_scheduler.TuningOptions(
+            num_measure_trials=num_measure_trials,
+            runner=runner,
+            verbose=1,
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        )
+        sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options)
+        print("*" * 80)
         print(target)
-        print("*"*80)
+        print("*" * 80)
         inp, res = auto_scheduler.load_best(log_file, workload_key, target)
 
         print("==== Python Code ====")
@@ -76,8 +85,7 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
             b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
             c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
             mod(a, b, c)
-            tvm.testing.assert_allclose(c.asnumpy(), np.dot(
-                a.asnumpy(), b.asnumpy()), rtol=1e-5)
+            tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
             print("==== Verification passed ====")
         except Exception:
             raise Exception("Error encountered with seed: %d" % (seed))
@@ -88,15 +96,18 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
 def test_workload_registry_search_basic():
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common, kwargs={'seed': 944563397})
+    t = PropagatingThread(target=search_common, kwargs={"seed": 944563397})
     t.start()
     t.join()
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'workload': "matmul_auto_scheduler_test"})
+    t = PropagatingThread(
+        target=search_common, kwargs={"seed": 944563397, "workload": "matmul_auto_scheduler_test"}
+    )
     t.start()
     t.join()
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'workload': "matmul_auto_scheduler_test_rename_1"})
+    t = PropagatingThread(
+        target=search_common,
+        kwargs={"seed": 944563397, "workload": "matmul_auto_scheduler_test_rename_1"},
+    )
     t.start()
     t.join()
 
@@ -105,8 +116,9 @@ def test_workload_registry_search_basic():
 def test_sketch_search_policy_basic():
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'search_policy': 'sketch'})
+    t = PropagatingThread(
+        target=search_common, kwargs={"seed": 944563397, "search_policy": "sketch"}
+    )
     t.start()
     t.join()
 
@@ -115,9 +127,14 @@ def test_sketch_search_policy_basic():
 def test_sketch_search_policy_xgbmodel():
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'search_policy': 'sketch',
-                                  'cost_model': auto_scheduler.XGBModel()})
+    t = PropagatingThread(
+        target=search_common,
+        kwargs={
+            "seed": 944563397,
+            "search_policy": "sketch",
+            "cost_model": auto_scheduler.XGBModel(),
+        },
+    )
     t.start()
     t.join()
 
@@ -127,9 +144,15 @@ def test_sketch_search_policy_cuda_rpc_runner():
     measure_ctx = auto_scheduler.LocalRPCMeasureContext()
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
-                                  'runner': measure_ctx.runner})
+    t = PropagatingThread(
+        target=search_common,
+        kwargs={
+            "seed": 944563397,
+            "search_policy": "sketch",
+            "target": "cuda",
+            "runner": measure_ctx.runner,
+        },
+    )
     t.start()
     t.join()
 
@@ -140,9 +163,16 @@ def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
     measure_ctx = auto_scheduler.LocalRPCMeasureContext()
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common,
-                          kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
-                                  'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()})
+    t = PropagatingThread(
+        target=search_common,
+        kwargs={
+            "seed": 944563397,
+            "search_policy": "sketch",
+            "target": "cuda",
+            "runner": measure_ctx.runner,
+            "cost_model": auto_scheduler.XGBModel(),
+        },
+    )
     t.start()
     t.join()
 
index 6d4c263..fa67756 100644 (file)
@@ -23,10 +23,15 @@ from tvm import te, auto_scheduler
 from tvm.auto_scheduler import _ffi_api
 from tvm.auto_scheduler.loop_state import Stage
 
-from test_auto_scheduler_common import (matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test,
-                                        max_pool2d_auto_scheduler_test, min_nm_auto_scheduler_test,
-                                        softmax_nm_auto_scheduler_test, softmax_abcd_auto_scheduler_test,
-                                        conv2d_winograd_nhwc_auto_scheduler_test)
+from test_auto_scheduler_common import (
+    matmul_auto_scheduler_test,
+    conv2d_nchw_bn_relu_auto_scheduler_test,
+    max_pool2d_auto_scheduler_test,
+    min_nm_auto_scheduler_test,
+    softmax_nm_auto_scheduler_test,
+    softmax_abcd_auto_scheduler_test,
+    conv2d_winograd_nhwc_auto_scheduler_test,
+)
 
 
 def generate_sketches(workload_func, args, target, print_for_debug=False):
@@ -36,35 +41,42 @@ def generate_sketches(workload_func, args, target, print_for_debug=False):
     policy = auto_scheduler.SketchPolicy(task, verbose=0)
     return policy.generate_sketches(print_for_debug)
 
+
 def assert_compute_at_condition(stage, condition):
     assert stage.compute_at == Stage.COMPUTE_AT_TRANS_TABLE[condition]
 
+
 def assert_is_tiled(stage):
     assert _ffi_api.SearchPolicyUtilsIsTiled(stage)
 
+
 def assert_is_not_tiled(stage):
     assert not _ffi_api.SearchPolicyUtilsIsTiled(stage)
 
+
 def assert_has_cache_write(state, stage_id):
     assert _ffi_api.SearchPolicyUtilsHasCacheWriteStage(state, stage_id)
 
+
 def assert_has_cache_read(state, stage_id):
     assert _ffi_api.SearchPolicyUtilsHasCacheReadStage(state, stage_id)
 
+
 def assert_has_rfactor(state, stage_id):
     assert _ffi_api.SearchPolicyUtilsHasRfactorStage(state, stage_id)
 
+
 def assert_has_cross_thread_reduction(state, stage_id):
     assert _ffi_api.SearchPolicyUtilsHasCrossThreadReduction(state, stage_id)
 
 
 def test_cpu_matmul_sketch():
-    sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'llvm')
-    ''' 3 multi-level tiling sketches
+    sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "llvm")
+    """ 3 multi-level tiling sketches
         0 - Multi-level tiling
         1 - Multi-level tiling with cache write on position 0
         2 - Multi-level tiling with cache write on position 1
-    '''
+    """
     assert len(sketches) == 3
     # Sketch 0
     assert_is_tiled(sketches[0].stages[2])
@@ -78,14 +90,14 @@ def test_cpu_matmul_sketch():
     assert_compute_at_condition(sketches[2].stages[2], "iter")
     assert sketches[1] != sketches[2]
 
-    sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), 'llvm')
-    ''' 2 rfactor sketches + 3 multi-level tiling sketches
+    sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), "llvm")
+    """ 2 rfactor sketches + 3 multi-level tiling sketches
         0 - Rfactor with factor position 0
         1 - Rfactor with factor position 1
         2 - Multi-level tiling
         3 - Multi-level tiling with cache write on position 0
         4 - Multi-level tiling with cache write on position 1
-    '''
+    """
     assert len(sketches) == 5
     # Sketch 0
     assert_has_rfactor(sketches[0], 2)
@@ -106,13 +118,14 @@ def test_cpu_matmul_sketch():
 
 
 def test_cpu_conv2d_bn_relu_sketch():
-    sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
-                                 (1, 56, 56, 512, 512, 3, 1, 1), 'llvm')
-    ''' 3 multi-level tiling sketches
+    sketches = generate_sketches(
+        conv2d_nchw_bn_relu_auto_scheduler_test, (1, 56, 56, 512, 512, 3, 1, 1), "llvm"
+    )
+    """ 3 multi-level tiling sketches
         0 - Conv2d multi-level tiling with fusion on position 0
         1 - Conv2d multi-level tiling with fusion on position 1
         2 - Conv2d multi-level tiling without fusion
-    '''
+    """
     assert len(sketches) == 3
     # Sketch 0
     assert_is_not_tiled(sketches[0].stages[1])
@@ -141,20 +154,20 @@ def test_cpu_conv2d_bn_relu_sketch():
 
 
 def test_cpu_max_pool2d_sketch():
-    sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 1), 'llvm')
-    ''' 1 default sketch '''
+    sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 1), "llvm")
+    """ 1 default sketch """
     assert len(sketches) == 1
     # Sketch 0
     assert len(sketches[0].transform_steps) == 0
 
 
 def test_cpu_min_sketch():
-    sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'llvm')
-    ''' 2 rfactor sketches + 1 default sketch
+    sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), "llvm")
+    """ 2 rfactor sketches + 1 default sketch
         0 - Rfactor with factor position 0
         1 - Rfactor with factor position 1
         2 - Default sketch
-    '''
+    """
     assert len(sketches) == 3
     # Sketch 0
     assert_has_rfactor(sketches[0], 1)
@@ -166,8 +179,8 @@ def test_cpu_min_sketch():
 
 
 def test_cpu_softmax_sketch():
-    sketches = generate_sketches(softmax_nm_auto_scheduler_test, (1, 1024), 'llvm')
-    ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
+    sketches = generate_sketches(softmax_nm_auto_scheduler_test, (1, 1024), "llvm")
+    """ (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) """
     assert len(sketches) == (3 * 3)
     for i in range(0, 3):
         for j in range(0, 3):
@@ -178,8 +191,8 @@ def test_cpu_softmax_sketch():
                 assert_has_rfactor(sketch, 4 if j in [0, 1] else 3)
     assert len(sketches[8].transform_steps) == 0
 
-    sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'llvm')
-    ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
+    sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), "llvm")
+    """ (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) """
     assert len(sketches) == (3 * 3)
     for i in range(0, 3):
         for j in range(0, 3):
@@ -192,13 +205,14 @@ def test_cpu_softmax_sketch():
 
 
 def test_cpu_conv2d_winograd_sketch():
-    sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
-                                 (1, 28, 28, 128, 128, 3, 1, 1), 'llvm')
-    ''' 3 multi-level tiling sketches
+    sketches = generate_sketches(
+        conv2d_winograd_nhwc_auto_scheduler_test, (1, 28, 28, 128, 128, 3, 1, 1), "llvm"
+    )
+    """ 3 multi-level tiling sketches
         0 - Bgemm multi-level tiling
         1 - Bgemm multi-level tiling with cache write on position 0
         2 - Bgemm multi-level tiling with cache write on position 1
-    '''
+    """
     assert len(sketches) == 3
     # Sketch 0
     assert_is_not_tiled(sketches[0].stages[1])
@@ -236,8 +250,8 @@ def test_cpu_conv2d_winograd_sketch():
 
 @tvm.testing.requires_cuda
 def test_cuda_matmul_sketch():
-    sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'cuda')
-    ''' 1 multi-level tiling sketch '''
+    sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda")
+    """ 1 multi-level tiling sketch """
     assert len(sketches) == 1
     assert_has_cache_read(sketches[0], 0)
     assert_compute_at_condition(sketches[0].stages[1], "iter")
@@ -248,8 +262,8 @@ def test_cuda_matmul_sketch():
     assert_compute_at_condition(sketches[0].stages[4], "iter")
     assert_is_tiled(sketches[0].stages[5])
 
-    sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 1024), 'cuda')
-    ''' 1 cross thread reuction sketch + 1 multi-level tiling sketch '''
+    sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 1024), "cuda")
+    """ 1 cross thread reuction sketch + 1 multi-level tiling sketch """
     assert len(sketches) == 2
     # Sketch 0
     assert_has_cross_thread_reduction(sketches[0], 2)
@@ -266,9 +280,10 @@ def test_cuda_matmul_sketch():
 
 @tvm.testing.requires_cuda
 def test_cuda_conv2d_bn_relu_sketch():
-    sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
-                                 (1, 56, 56, 512, 512, 3, 1, 1), 'cuda')
-    ''' 1 multi-level tiling sketch '''
+    sketches = generate_sketches(
+        conv2d_nchw_bn_relu_auto_scheduler_test, (1, 56, 56, 512, 512, 3, 1, 1), "cuda"
+    )
+    """ 1 multi-level tiling sketch """
     assert len(sketches) == 1
     assert_has_cache_read(sketches[0], 1)
     assert_compute_at_condition(sketches[0].stages[1], "inlined")
@@ -285,16 +300,16 @@ def test_cuda_conv2d_bn_relu_sketch():
 
 @tvm.testing.requires_cuda
 def test_cuda_max_pool2d_sketch():
-    sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 0), 'cuda')
-    ''' 1 default sketch '''
+    sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 0), "cuda")
+    """ 1 default sketch """
     assert len(sketches) == 1
     assert len(sketches[0].transform_steps) == 0
 
 
 @tvm.testing.requires_cuda
 def test_cuda_min_sketch():
-    sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'cuda')
-    ''' 1 cross thread reuction sketch + 1 default sketch '''
+    sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), "cuda")
+    """ 1 cross thread reuction sketch + 1 default sketch """
     assert len(sketches) == 2
     # Sketch 0
     assert_has_cross_thread_reduction(sketches[0], 1)
@@ -304,8 +319,8 @@ def test_cuda_min_sketch():
 
 @tvm.testing.requires_cuda
 def test_cuda_softmax_sketch():
-    sketches = generate_sketches(softmax_nm_auto_scheduler_test, (2, 1024), 'cuda')
-    ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+    sketches = generate_sketches(softmax_nm_auto_scheduler_test, (2, 1024), "cuda")
+    """ (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) """
     assert len(sketches) == (2 * 2)
     # Sketch 0
     assert_has_cross_thread_reduction(sketches[0], 1)
@@ -320,8 +335,8 @@ def test_cuda_softmax_sketch():
     # Sketch 3
     assert_compute_at_condition(sketches[3].stages[2], "inlined")
 
-    sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'cuda')
-    ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+    sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), "cuda")
+    """ (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) """
     assert len(sketches) == (2 * 2)
     # Sketch 0
     assert_has_cross_thread_reduction(sketches[0], 1)
@@ -339,9 +354,10 @@ def test_cuda_softmax_sketch():
 
 @tvm.testing.requires_cuda
 def test_cuda_conv2d_winograd_sketch():
-    sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
-                                 (1, 28, 28, 128, 128, 3, 1, 1), 'cuda')
-    ''' 1 multi-level tiling sketch '''
+    sketches = generate_sketches(
+        conv2d_winograd_nhwc_auto_scheduler_test, (1, 28, 28, 128, 128, 3, 1, 1), "cuda"
+    )
+    """ 1 multi-level tiling sketch """
     assert len(sketches) == 1
     assert_compute_at_condition(sketches[0].stages[1], "inlined")
     assert_compute_at_condition(sketches[0].stages[2], "inlined")
index 8c22ccb..917036f 100644 (file)
@@ -31,8 +31,10 @@ class DummyRunner(Runner):
         super(DummyRunner, self).__init__(1, 1)
 
     def run(self, measure_inputs, build_results):
-        return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
-                for _ in range(len(measure_inputs))]
+        return [
+            MeasureResult((np.random.random(),), 0, 0.2, time.time())
+            for _ in range(len(measure_inputs))
+        ]
 
     def get_build_kwargs(self):
         return {}
@@ -40,12 +42,11 @@ class DummyRunner(Runner):
 
 @autotvm.template("testing/matmul")
 def matmul(N, L, M, dtype):
-    A = te.placeholder((N, L), name='A', dtype=dtype)
-    B = te.placeholder((L, M), name='B', dtype=dtype)
+    A = te.placeholder((N, L), name="A", dtype=dtype)
+    B = te.placeholder((L, M), name="B", dtype=dtype)
 
-    k = te.reduce_axis((0, L), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(
-        A[i, k] * B[k, j], axis=k), name='C')
+    k = te.reduce_axis((0, L), name="k")
+    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
     s = te.create_schedule(C.op)
 
     # schedule
@@ -69,13 +70,12 @@ def matmul(N, L, M, dtype):
 
 @autotvm.template("testing/bad_matmul")
 def bad_matmul(N, L, M, dtype):
-    if 'bad_device' in tvm.target.Target.current().keys:
-        A = te.placeholder((N, L), name='A', dtype=dtype)
-        B = te.placeholder((L, M), name='B', dtype=dtype)
+    if "bad_device" in tvm.target.Target.current().keys:
+        A = te.placeholder((N, L), name="A", dtype=dtype)
+        B = te.placeholder((L, M), name="B", dtype=dtype)
 
-        k = te.reduce_axis((0, L-1), name='k')
-        C = te.compute((N, M), lambda i, j: te.sum(
-            A[i, k] * B[k, j], axis=k), name='C')
+        k = te.reduce_axis((0, L - 1), name="k")
+        C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
         s = te.create_schedule(C.op)
 
         # schedule
@@ -91,8 +91,7 @@ def bad_matmul(N, L, M, dtype):
 def get_sample_task(n=128):
     """return a sample task for testing"""
     target = tvm.target.Target("llvm")
-    task = autotvm.task.create(
-        "testing/matmul", args=(n, n, n, 'float32'), target=target)
+    task = autotvm.task.create("testing/matmul", args=(n, n, n, "float32"), target=target)
     return task, target
 
 
@@ -103,5 +102,5 @@ def get_sample_records(n):
     inps, ress = [], []
     for i in range(n):
         inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
-        ress.append(MeasureResult((i+1,), 0, i, time.time()))
+        ress.append(MeasureResult((i + 1,), 0, i, time.time()))
     return list(zip(inps, ress))
index 3884444..197243e 100644 (file)
@@ -23,6 +23,7 @@ from tvm.autotvm.record import encode, MeasureResult
 
 from test_autotvm_common import get_sample_records
 
+
 def test_save_load():
     logging.info("test basic db load/save ...")
     records = get_sample_records(3)
@@ -43,14 +44,16 @@ def test_save_load():
     assert load3 is None
     assert load1 != load2
 
+
 TRIAL_LIMIT = 2
 
+
 def test_db_hash():
     logging.info("test db hash check ...")
     inp1, res1 = get_sample_records(1)[0]
     inp2 = copy.deepcopy(inp1)
-    inp1.config.code_hash = 'cafecafe'
-    inp2.config.code_hash = 'dbffdbff'
+    inp1.config.code_hash = "cafecafe"
+    inp2.config.code_hash = "dbffdbff"
     res2l = list(tuple(res1))
 
     # set timestamp
@@ -67,6 +70,7 @@ def test_db_hash():
     assert load1.timestamp != -1
     assert load2.timestamp == -1
 
+
 def test_db_latest_all():
     logging.info("test db load w/ multiple results ...")
     inp1, res1 = get_sample_records(1)[0]
@@ -99,6 +103,7 @@ def test_db_latest_all():
     assert encode(inp1, load4[1]) == encode(inp1, res2)
     assert encode(inp1, load4[2]) == encode(inp1, res3)
 
+
 def test_db_filter():
     logging.info("test db filter ...")
     records = get_sample_records(5)
@@ -110,7 +115,8 @@ def test_db_filter():
     records = _db.filter(lambda inp, ress: any(r.costs[0] <= 2 for r in ress))
     assert len(records) == 2
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
     test_save_load()
     test_db_hash()
index 8b073c0..4064ede 100644 (file)
@@ -20,8 +20,8 @@ to the parameters of workload"""
 
 from tvm import autotvm
 
-def test_fallback():
 
+def test_fallback():
     @autotvm.template("testing/dispatch_fallback")
     def simple_template(a, b):
         cfg = autotvm.get_config()
index 3877db3..9757576 100644 (file)
@@ -19,14 +19,17 @@ import time
 
 from tvm.autotvm.measure import LocalExecutor, executor
 
+
 def slow(n):
     r = 0
-    for i in range(0, n+1):
+    for i in range(0, n + 1):
         r += i
     return r
 
+
 def fast(n):
-    return n*(n+1)//2
+    return n * (n + 1) // 2
+
 
 def test_local_measure_async():
     ex = LocalExecutor()
@@ -44,9 +47,11 @@ def test_local_measure_async():
     assert t2 < t1, "Expected fast async job to finish first!"
     assert f1.get() == f2.get()
 
+
 def timeout_job(n):
     time.sleep(n * 1.5)
 
+
 def test_timeout():
     timeout = 0.5
 
@@ -58,6 +63,7 @@ def test_timeout():
     res = f1.get()
     assert isinstance(res, executor.TimeoutError)
 
+
 if __name__ == "__main__":
     test_local_measure_async()
     test_timeout()
index 59ad464..26268e5 100644 (file)
@@ -22,16 +22,14 @@ import tvm
 from tvm import te
 from tvm.autotvm import feature
 
+
 def test_iter_feature_gemm():
     N = 128
 
-    k = te.reduce_axis((0, N), 'k')
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='B')
-    C = te.compute(
-        A.shape,
-        lambda y, x: te.sum(A[y, k] * B[k, x], axis=k),
-        name='C')
+    k = te.reduce_axis((0, N), "k")
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="B")
+    C = te.compute(A.shape, lambda y, x: te.sum(A[y, k] * B[k, x], axis=k), name="C")
 
     s = te.create_schedule(C.op)
 
@@ -39,20 +37,26 @@ def test_iter_feature_gemm():
 
     expected = [
         {
-            '_attr_': [128, 1, 128, 2097152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
-            'A_0': [128, -1, 16384, 128, 0, 0], 'B_0': [0, -1, 16384, 128, 0, 0],
-            'C_0': [128, -1, 16384, 128, 0, 0], 'C_1': [128, -1, 16384, 128, 0, 0],
+            "_attr_": [128, 1, 128, 2097152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+            "A_0": [128, -1, 16384, 128, 0, 0],
+            "B_0": [0, -1, 16384, 128, 0, 0],
+            "C_0": [128, -1, 16384, 128, 0, 0],
+            "C_1": [128, -1, 16384, 128, 0, 0],
         },
         {
-            '_attr_': [128, 2, 16384, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
-            'A_0': [0, -1, 128, 128, 0, 0], 'B_0': [1, -1, 16384, 1, 0, 0],
-            'C_0': [1, -1, 128, 128, 0, 0], 'C_1': [1, -1, 128, 128, 0, 0],
+            "_attr_": [128, 2, 16384, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+            "A_0": [0, -1, 128, 128, 0, 0],
+            "B_0": [1, -1, 16384, 1, 0, 0],
+            "C_0": [1, -1, 128, 128, 0, 0],
+            "C_1": [1, -1, 128, 128, 0, 0],
         },
         {
-            '_attr_': [128, 3, 2097152, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
-            'A_0': [1, -1, 128, 1, 0, 0], 'B_0': [128, -1, 128, 1, 0, 0],
-            'C_1': [0, -1, 1, 128, 0, 0], 'C_2':  [0, -1, 1, 128, 0, 0],
-        }
+            "_attr_": [128, 3, 2097152, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+            "A_0": [1, -1, 128, 1, 0, 0],
+            "B_0": [128, -1, 128, 1, 0, 0],
+            "C_1": [0, -1, 1, 128, 0, 0],
+            "C_2": [0, -1, 1, 128, 0, 0],
+        },
     ]
 
     for ans, row in zip(expected, feas):
@@ -65,13 +69,10 @@ def test_iter_feature_gemm():
 def test_curve_feature_gemm():
     N = 128
 
-    k = te.reduce_axis((0, N), 'k')
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='B')
-    C = te.compute(
-        A.shape,
-        lambda y, x: te.sum(A[y, k] * B[k, x], axis=k),
-        name='C')
+    k = te.reduce_axis((0, N), "k")
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="B")
+    C = te.compute(A.shape, lambda y, x: te.sum(A[y, k] * B[k, x], axis=k), name="C")
 
     s = te.create_schedule(C.op)
 
@@ -79,6 +80,7 @@ def test_curve_feature_gemm():
     # sample_n * #buffers * #curves * 2 numbers per curve
     assert len(feas) == 30 * 3 * 4 * 2
 
+
 def test_feature_shape():
     """test the dimensions of flatten feature are the same"""
 
@@ -86,11 +88,10 @@ def test_feature_shape():
     n_sample = 100
 
     def get_gemm_feature(target):
-        k = te.reduce_axis((0, N), 'k')
-        A = te.placeholder((N, N), name='A')
-        B = te.placeholder((N, N), name='B')
-        C = te.compute(A.shape, lambda y, x: te.sum(A[y, k] * B[k, x], axis=k),
-                        name='C')
+        k = te.reduce_axis((0, N), "k")
+        A = te.placeholder((N, N), name="A")
+        B = te.placeholder((N, N), name="B")
+        C = te.compute(A.shape, lambda y, x: te.sum(A[y, k] * B[k, x], axis=k), name="C")
 
         s = te.create_schedule(C.op)
 
@@ -124,12 +125,12 @@ def test_feature_shape():
     for target in targets:
         dim = len(get_gemm_feature(target))
         for i in range(n_sample):
-            assert dim == len(get_gemm_feature(target)), "dimensions of feature do not match" \
-                                                   " for different configurations"
+            assert dim == len(get_gemm_feature(target)), (
+                "dimensions of feature do not match" " for different configurations"
+            )
 
 
 if __name__ == "__main__":
     test_iter_feature_gemm()
     test_curve_feature_gemm()
     test_feature_shape()
-
index e06010b..e07cdac 100644 (file)
@@ -22,11 +22,13 @@ import numpy as np
 
 from tvm.autotvm.task.task import compute_flop
 
+
 def random_dtypes():
     """Return pair of (input, accumulator) dtypes"""
     candidates = [("float32", "float32"), ("float16", "float32"), ("int8", "int32")]
     return candidates[np.random.choice(len(candidates))]
 
+
 def test_conv():
     for i in range(5):
         N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
@@ -44,14 +46,19 @@ def test_conv():
         OH = (H - KH) + 1
         OW = (W - KW) + 1
 
-        C = te.compute((N, CO, OH, OW), lambda n, co, h, w:
-        te.sum(D[n][ci][h][w].astype(acc_dtype) * K[co][ci][h][w].astype(acc_dtype),
-                axis=[ci, kh, kw]))
+        C = te.compute(
+            (N, CO, OH, OW),
+            lambda n, co, h, w: te.sum(
+                D[n][ci][h][w].astype(acc_dtype) * K[co][ci][h][w].astype(acc_dtype),
+                axis=[ci, kh, kw],
+            ),
+        )
 
         s = te.create_schedule([C.op])
 
         assert compute_flop(s) == 2 * N * CO * OH * OW * CI * KH * KW
 
+
 def test_pack_gemm():
     for i in range(5):
         N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)]
@@ -66,13 +73,20 @@ def test_pack_gemm():
 
         A_pack = te.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
         B_pack = te.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
-        C_pack = te.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
-        te.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
-        C = te.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])
+        C_pack = te.compute(
+            (N // bn, M // bn, bn, bn),
+            lambda i, j, ii, jj: te.sum(
+                A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]
+            ),
+        )
+        C = te.compute(
+            (N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)]
+        )
 
         s = te.create_schedule([C.op])
         assert compute_flop(s) == 2 * N * L * M
 
+
 def test_outer_dot():
     for i in range(5):
         N, M = [np.random.randint(10, 128) * 4 for _ in range(2)]
@@ -85,6 +99,7 @@ def test_outer_dot():
         s = te.create_schedule([C.op])
         assert compute_flop(s) == N * M
 
+
 def test_max_pool():
     for i in range(5):
         N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
@@ -101,13 +116,14 @@ def test_max_pool():
         OW = (W - KW) + 1
 
         C = te.compute(
-            (N, CO, OH, OW),
-            lambda n, co, h, w: tvm.te.max(D[n][co][h + kh][w + kw], axis=[kh, kw]))
+            (N, CO, OH, OW), lambda n, co, h, w: tvm.te.max(D[n][co][h + kh][w + kw], axis=[kh, kw])
+        )
 
         s = te.create_schedule([C.op])
 
         assert compute_flop(s) == N * CO * OH * OW * KH * KW
 
+
 def test_average_pool():
     for i in range(5):
         N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
@@ -123,16 +139,18 @@ def test_average_pool():
         OH = (H - KH) + 1
         OW = (W - KW) + 1
 
-
         C = te.compute(
             (N, CO, OH, OW),
             lambda n, co, h, w: te.sum(
-                te.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]))
+                te.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]
+            ),
+        )
 
         s = te.create_schedule([C.op])
 
         assert compute_flop(s) == 2 * N * CO * OH * OW * KH * KW
 
+
 def test_move():
     """No float number operation in simple move. So the estimator should raise an error """
     N = 1024
@@ -147,7 +165,8 @@ def test_move():
     except RuntimeError:
         pass
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_conv()
     test_pack_gemm()
     test_outer_dot()
index f577b63..f594761 100644 (file)
@@ -34,12 +34,13 @@ from tvm.autotvm.measure import MeasureResult, MeasureInput
 from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
 
 
-def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout,
-                 dtype, out_dtype):
+def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout, dtype, out_dtype):
     data = tvm.te.placeholder(dshape, dtype=dtype)
     kernel = tvm.te.placeholder(kshape, dtype=dtype)
-    return autotvm.task.serialize_args([data, kernel, strides, padding, dilation,
-                                        layout, layout, out_dtype])
+    return autotvm.task.serialize_args(
+        [data, kernel, strides, padding, dilation, layout, layout, out_dtype]
+    )
+
 
 def _create_data(target, dshape, dtype, layout):
     data = relay.var("data", shape=dshape, dtype=dtype)
@@ -52,38 +53,71 @@ def _create_data(target, dshape, dtype, layout):
     out = relay.add(conv1, conv2)
     net = relay.Function(relay.analysis.free_vars(out), out)
     mod, params = relay.testing.create_workload(net)
-    tasks = autotvm.task.extract_from_program(mod["main"],
-                                              target=target,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
+    )
     new_args = [
-        _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
+        _create_args(
+            (1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
+        _create_args(
+            (1, 16, 8, 8),
+            (32, 16, 1, 1),
+            (1, 1),
+            (0, 0, 0, 0),
+            (1, 1),
+            layout,
+            layout,
+            dtype,
+            dtype,
+        ),
+        _create_args(
+            (1, 32, 8, 8),
+            (32, 32, 3, 3),
+            (1, 1),
+            (1, 1, 1, 1),
+            (1, 1),
+            layout,
+            layout,
+            dtype,
+            dtype,
+        ),
     ]
 
     costs = [0.04, 0.012, 0.03]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [3, 1]],
-                           ["tile_oc", "sp", [4, 4]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [3, 1]],
+            ["tile_oc", "sp", [4, 4]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [2, 8]],
-                           ["tile_oc", "sp", [1, 32]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [2, 8]],
+            ["tile_oc", "sp", [1, 32]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [8, 4]],
-                           ["tile_oc", "sp", [4, 8]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [8, 4]],
+            ["tile_oc", "sp", [4, 8]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
@@ -95,20 +129,20 @@ def _create_data(target, dshape, dtype, layout):
 
     ltf_records = []
     ltf_arg = [te.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
+    ltf_task = autotvm.task.create("layout_transform", ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
-    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
     ltf_keys = []
     ltf_arg = [te.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
-    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, "layout_transform")
     ltf_keys.append(ltf_wkl)
     ltf_arg = [te.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
-    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, "layout_transform")
     ltf_keys.append(ltf_wkl)
     ltf_arg = [te.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
-    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, "layout_transform")
     ltf_keys.append(ltf_wkl)
 
     return net, records, ltf_records, ltf_keys, tasks
@@ -145,9 +179,13 @@ def test_graph_tuner_layout_transform():
             flops *= i
         expected_time = flops * avg_time
         out_time = out[ltf_workload][1].costs[0]
-        assert expected_time == out_time, "Inferred layout transformation time mismatch for %s: " \
-                                          "expecting %f but got %f" % (str(ltf_workload), expected_time,
-                                                                       out_time)
+        assert (
+            expected_time == out_time
+        ), "Inferred layout transformation time mismatch for %s: " "expecting %f but got %f" % (
+            str(ltf_workload),
+            expected_time,
+            out_time,
+        )
 
 
 def test_DPTuner_run():
@@ -164,26 +202,38 @@ def test_DPTuner_run():
     mod["main"] = g
     costs = [0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 3]],
-                           ["tile_oc", "sp", [2, 8]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 3]],
+            ["tile_oc", "sp", [2, 8]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [4, 4]],
-                           ["tile_oc", "sp", [2, 16]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [4, 4]],
+            ["tile_oc", "sp", [2, 16]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [16, 2]],
-                           ["tile_oc", "sp", [8, 4]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [16, 2]],
+            ["tile_oc", "sp", [8, 4]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
     for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
@@ -195,8 +245,10 @@ def test_DPTuner_run():
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
     assert os.path.isfile(log_file), "No log file with name %s exists." % log_file
 
 
@@ -211,26 +263,38 @@ def test_PBQPTuner_run():
     g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
     costs = [0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 3]],
-                           ["tile_oc", "sp", [2, 8]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 3]],
+            ["tile_oc", "sp", [2, 8]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [4, 4]],
-                           ["tile_oc", "sp", [2, 16]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [4, 4]],
+            ["tile_oc", "sp", [2, 16]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [16, 2]],
-                           ["tile_oc", "sp", [8, 4]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [16, 2]],
+            ["tile_oc", "sp", [8, 4]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
     for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
@@ -242,8 +306,10 @@ def test_PBQPTuner_run():
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                           % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
 
 def test_many_sub_graphs():
@@ -271,59 +337,104 @@ def test_many_sub_graphs():
     net = relay.Function(relay.analysis.free_vars(out), out)
     net, params = relay.testing.create_workload(net)
 
-    tasks = autotvm.task.extract_from_program(net["main"],
-                                              target=target,
-                                              params=params,
-                                              ops=(conv2d,))
+    tasks = autotvm.task.extract_from_program(
+        net["main"], target=target, params=params, ops=(conv2d,)
+    )
     new_args = [
-        _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
+        _create_args(
+            (1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
+        _create_args(
+            (1, 16, 8, 8),
+            (32, 16, 1, 1),
+            (1, 1),
+            (0, 0, 0, 0),
+            (1, 1),
+            layout,
+            layout,
+            dtype,
+            dtype,
+        ),
+        _create_args(
+            (1, 32, 8, 8),
+            (32, 32, 3, 3),
+            (1, 1),
+            (1, 1, 1, 1),
+            (1, 1),
+            layout,
+            layout,
+            dtype,
+            dtype,
+        ),
     ]
 
     costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [3, 1]],
-                           ["tile_oc", "sp", [4, 4]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [3, 1]],
+            ["tile_oc", "sp", [4, 4]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [2, 8]],
-                           ["tile_oc", "sp", [1, 32]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [2, 8]],
+            ["tile_oc", "sp", [1, 32]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [8, 4]],
-                           ["tile_oc", "sp", [4, 8]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [8, 4]],
+            ["tile_oc", "sp", [4, 8]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 3]],
-                           ["tile_oc", "sp", [2, 8]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 3]],
+            ["tile_oc", "sp", [2, 8]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [4, 4]],
-                           ["tile_oc", "sp", [2, 16]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [4, 4]],
+            ["tile_oc", "sp", [2, 16]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [16, 2]],
-                           ["tile_oc", "sp", [8, 4]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [16, 2]],
+            ["tile_oc", "sp", [8, 4]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
@@ -337,9 +448,9 @@ def test_many_sub_graphs():
 
     ltf_records = []
     ltf_arg = [te.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
+    ltf_task = autotvm.task.create("layout_transform", ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
-    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
     executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
@@ -347,16 +458,20 @@ def test_many_sub_graphs():
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
     executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
     executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
 
 def test_tuple():
@@ -376,43 +491,62 @@ def test_tuple():
     net = relay.Function(relay.analysis.free_vars(out), out)
     net, params = relay.testing.create_workload(net)
 
-    tasks = autotvm.task.extract_from_program(net["main"],
-                                              target=target,
-                                              params=params,
-                                              ops=(conv2d,))
+    tasks = autotvm.task.extract_from_program(
+        net["main"], target=target, params=params, ops=(conv2d,)
+    )
     new_args = [
-        _create_args((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
+        _create_args(
+            (1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
+        _create_args(
+            (1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
     ]
     costs = [0.01, 0.012, 0.03, 0.04]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 5]],
-                           ["tile_oc", "sp", [1, 2]],
-                           ["tile_ow", "sp", [4, 8]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 5]],
+            ["tile_oc", "sp", [1, 2]],
+            ["tile_ow", "sp", [4, 8]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 5]],
-                           ["tile_oc", "sp", [1, 3]],
-                           ["tile_ow", "sp", [2, 16]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 5]],
+            ["tile_oc", "sp", [1, 3]],
+            ["tile_ow", "sp", [2, 16]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 5]],
-                           ["tile_oc", "sp", [2, 1]],
-                           ["tile_ow", "sp", [4, 8]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 5]],
+            ["tile_oc", "sp", [2, 1]],
+            ["tile_ow", "sp", [4, 8]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 5]],
-                           ["tile_oc", "sp", [3, 1]],
-                           ["tile_ow", "sp", [2, 16]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 5]],
+            ["tile_oc", "sp", [3, 1]],
+            ["tile_ow", "sp", [2, 16]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
@@ -426,9 +560,9 @@ def test_tuple():
 
     ltf_records = []
     ltf_arg = [te.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
+    ltf_task = autotvm.task.create("layout_transform", ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
-    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
     executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
@@ -436,16 +570,20 @@ def test_tuple():
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[2][0].config, records[1][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
     executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
     executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[2][0].config, records[1][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
 
 def test_triangle_block():
@@ -467,58 +605,95 @@ def test_triangle_block():
     net = relay.Function(relay.analysis.free_vars(out), out)
     net, params = relay.testing.create_workload(net)
 
-    tasks = autotvm.task.extract_from_program(net["main"],
-                                              target=target,
-                                              params=params,
-                                              ops=(conv2d,))
+    tasks = autotvm.task.extract_from_program(
+        net["main"], target=target, params=params, ops=(conv2d,)
+    )
     new_args = [
-        _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        _create_args((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
+        _create_args(
+            (1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
+        _create_args(
+            (1, 16, 8, 8),
+            (32, 16, 1, 1),
+            (1, 1),
+            (0, 0, 0, 0),
+            (1, 1),
+            layout,
+            layout,
+            dtype,
+            dtype,
+        ),
+        _create_args(
+            (1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype
+        ),
     ]
     costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [3, 1]],
-                           ["tile_oc", "sp", [4, 4]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [3, 1]],
+            ["tile_oc", "sp", [4, 4]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [2, 8]],
-                           ["tile_oc", "sp", [1, 32]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [2, 8]],
+            ["tile_oc", "sp", [1, 32]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [8, 4]],
-                           ["tile_oc", "sp", [4, 8]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [8, 4]],
+            ["tile_oc", "sp", [4, 8]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [1, 3]],
-                           ["tile_oc", "sp", [2, 8]],
-                           ["tile_ow", "sp", [4, 2]],
-                           ["unroll_kw", "ot", True]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [1, 3]],
+            ["tile_oc", "sp", [2, 8]],
+            ["tile_ow", "sp", [4, 2]],
+            ["unroll_kw", "ot", True],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [4, 4]],
-                           ["tile_oc", "sp", [2, 16]],
-                           ["tile_oh", "ot", 1],
-                           ["tile_ow", "sp", [4, 2]]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [4, 4]],
+            ["tile_oc", "sp", [2, 16]],
+            ["tile_oh", "ot", 1],
+            ["tile_ow", "sp", [4, 2]],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"index": -1,
-                "code_hash": None,
-                "entity": [["tile_ic", "sp", [16, 2]],
-                           ["tile_oc", "sp", [8, 4]],
-                           ["tile_ow", "sp", [2, 4]],
-                           ["unroll_kw", "ot", False]]}
+    cfg_dict = {
+        "index": -1,
+        "code_hash": None,
+        "entity": [
+            ["tile_ic", "sp", [16, 2]],
+            ["tile_oc", "sp", [8, 4]],
+            ["tile_ow", "sp", [2, 4]],
+            ["unroll_kw", "ot", False],
+        ],
+    }
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
@@ -532,9 +707,9 @@ def test_triangle_block():
 
     ltf_records = []
     ltf_arg = [te.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
+    ltf_task = autotvm.task.create("layout_transform", ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
-    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
     executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
@@ -542,19 +717,23 @@ def test_triangle_block():
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
     executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
     executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
     expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
-    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
-                                % (str(expected_out), str(out))
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" % (
+        str(expected_out),
+        str(out),
+    )
 
 
-if __name__=="__main__":
+if __name__ == "__main__":
     test_graph_tuner_layout_transform()
     test_DPTuner_run()
     test_PBQPTuner_run()
index b675df7..9fc415c 100644 (file)
@@ -25,16 +25,25 @@ from tvm import te
 
 from tvm import autotvm, relay
 from tvm.relay.testing import synthetic
-from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
-    get_out_nodes, expr2graph, bind_inputs
+from tvm.autotvm.graph_tuner.utils import (
+    has_multiple_inputs,
+    get_direct_ancestor,
+    get_in_nodes,
+    get_out_nodes,
+    expr2graph,
+    bind_inputs,
+)
 from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
 from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
 
 
 def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
     out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
-    assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
-                                   % (node_list[node_idx]["op"], str(expected_result), str(out))
+    assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." % (
+        node_list[node_idx]["op"],
+        str(expected_result),
+        str(out),
+    )
 
 
 def test_has_multiple_inputs():
@@ -61,19 +70,24 @@ def test_expr2graph():
     node_list = []
     target_ops = [relay.op.get("nn.conv2d")]
     op_name_list = []
+
     def _count_node(node):
         if isinstance(node, Call):
             op_name_list.append(node.op)
         elif isinstance(node, (Var, TupleGetItem, Tuple)):
             op_name_list.append(None)
+
     relay.analysis.post_order_visit(mod["main"], _count_node)
 
     expr2graph(mod["main"], target_ops, node_dict, node_list)
     assert len(node_list) == len(op_name_list)
     for i, item in enumerate(zip(op_name_list, node_list)):
         op_name, node = item
-        assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
-                                      % (i, str(op_name), str(node["op"]))
+        assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" % (
+            i,
+            str(op_name),
+            str(node["op"]),
+        )
 
 
 def test_get_direct_ancestor():
@@ -115,7 +129,9 @@ def test_get_in_nodes():
     expected_out = {3: [0], 4: [3, 0], 7: [4]}
     diff_set = set(out) ^ set(expected_out)
     if len(diff_set) != 0:
-        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
+        raise RuntimeError(
+            "Output mismatch: expecting %s but got %s." % (str(expected_out), str(out))
+        )
 
 
 def test_get_out_nodes():
@@ -124,8 +140,9 @@ def test_get_out_nodes():
     out = get_out_nodes(in_nodes_dict)
     diff_set = set(out) ^ set(expected_out)
     if len(diff_set) != 0:
-        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
-
+        raise RuntimeError(
+            "Output mismatch: expecting %s but got %s." % (str(expected_out), str(out))
+        )
 
 
 if __name__ == "__main__":
index 2875fd7..05f1211 100644 (file)
@@ -63,6 +63,6 @@ def test_random_tuner():
         assert 8 <= idx <= 15
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_gridsearch_tuner()
     test_random_tuner()
index 11c95eb..c8760d2 100644 (file)
@@ -31,17 +31,16 @@ def test_task_tuner_without_measurement():
     """test task and tuner without measurement"""
     task, _ = get_sample_task()
 
-    measure_option = autotvm.measure_option(
-        builder=autotvm.LocalBuilder(),
-        runner=DummyRunner()
-    )
+    measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())
 
     logging.info("%s", task.config_space)
 
-    for tuner_class in [autotvm.tuner.RandomTuner,
-                        autotvm.tuner.GridSearchTuner,
-                        autotvm.tuner.GATuner,
-                        autotvm.tuner.XGBTuner]:
+    for tuner_class in [
+        autotvm.tuner.RandomTuner,
+        autotvm.tuner.GridSearchTuner,
+        autotvm.tuner.GATuner,
+        autotvm.tuner.XGBTuner,
+    ]:
         tuner = tuner_class(task)
         tuner.tune(n_trial=10, measure_option=measure_option)
         assert tuner.best_flops > 1
@@ -51,8 +50,7 @@ def test_check_correctness():
     task, target = get_sample_task()
 
     measure_option = autotvm.measure_option(
-        builder=autotvm.LocalBuilder(),
-        runner=autotvm.LocalRunner(check_correctness=True)
+        builder=autotvm.LocalBuilder(), runner=autotvm.LocalRunner(check_correctness=True)
     )
 
     def _callback_correct(tuner, measure_inputs, measure_results):
@@ -60,25 +58,22 @@ def test_check_correctness():
             assert res.error_no == 0
 
     tuner = autotvm.tuner.RandomTuner(task)
-    tuner.tune(n_trial=2, measure_option=measure_option,
-               callbacks=[_callback_correct])
+    tuner.tune(n_trial=2, measure_option=measure_option, callbacks=[_callback_correct])
 
     # a bad template
     n = 128
     target = tvm.target.Target("llvm -device=bad_device")
-    task = autotvm.task.create(
-        "testing/bad_matmul", args=(n, n, n, 'float32'), target=target)
+    task = autotvm.task.create("testing/bad_matmul", args=(n, n, n, "float32"), target=target)
 
     def _callback_wrong(tuner, measure_inputs, measure_results):
         for _, res in zip(measure_inputs, measure_results):
             assert res.error_no == MeasureErrorNo.WRONG_ANSWER
 
     tuner = autotvm.tuner.RandomTuner(task)
-    tuner.tune(n_trial=2, measure_option=measure_option,
-               callbacks=[_callback_wrong])
+    tuner.tune(n_trial=2, measure_option=measure_option, callbacks=[_callback_wrong])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
 
     test_task_tuner_without_measurement()
index bcc9a93..c9d2c49 100644 (file)
@@ -27,19 +27,23 @@ from tvm.autotvm.record import encode, decode, ApplyHistoryBest, measure_str_key
 
 from test_autotvm_common import get_sample_task
 
+
 def test_load_dump():
     task, target = get_sample_task()
 
     inp = MeasureInput(target, task, task.config_space.get(0))
-    result = MeasureResult((2.0, 2.23, 0.23, 0.123, 0.234, 0.123), MeasureErrorNo.NO_ERROR,
-                           2.3, time.time())
+    result = MeasureResult(
+        (2.0, 2.23, 0.23, 0.123, 0.234, 0.123), MeasureErrorNo.NO_ERROR, 2.3, time.time()
+    )
 
-    for protocol in ['json', 'pickle']:
+    for protocol in ["json", "pickle"]:
         row = encode(inp, result, protocol=protocol)
         inp_2, result_2 = decode(row, protocol=protocol)
 
-        assert measure_str_key(inp) == measure_str_key(inp_2), \
-            "%s vs %s" % (measure_str_key(inp), measure_str_key(inp_2))
+        assert measure_str_key(inp) == measure_str_key(inp_2), "%s vs %s" % (
+            measure_str_key(inp),
+            measure_str_key(inp_2),
+        )
         assert result.costs == result_2.costs
         assert result.error_no == result_2.error_no
         assert result.timestamp == result_2.timestamp
@@ -51,10 +55,10 @@ def test_file_io():
 
     tsk, target = get_sample_task()
     inputs = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(0, 10)]
-    results = [MeasureResult((i, ), 0, 0, 0) for i in range(0, 10)]
+    results = [MeasureResult((i,), 0, 0, 0) for i in range(0, 10)]
 
     invalid_inp = MeasureInput(target, tsk, tsk.config_space.get(10))
-    invalid_res = MeasureResult((10, ), 0, 0, 0)
+    invalid_res = MeasureResult((10,), 0, 0, 0)
 
     # Erase the entity map to test if it will be ignored when loading back.
     invalid_inp.config._entity_map = {}
@@ -76,7 +80,7 @@ def test_apply_history_best():
         (MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)),
         (MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)),
         (MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)),
-        (MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0))
+        (MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0)),
     ]
     hist_best = ApplyHistoryBest(records)
     x = hist_best.query(target, tsk.workload)
index 2694c49..2d40371 100644 (file)
@@ -20,48 +20,50 @@ import tvm
 from tvm import te
 from tvm.autotvm.task.space import ConfigSpace, FallbackConfigEntity
 
+
 def gemm_func(cfg, N):
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='B')
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="B")
 
-    k = te.reduce_axis((0, N), name='k')
-    C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=[k]), name='C')
+    k = te.reduce_axis((0, N), name="k")
+    C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=[k]), name="C")
 
     s = te.create_schedule([C.op])
 
     y, x = s[C].op.axis
 
-    cfg.define_split('tile_y', cfg.axis(y), num_outputs=2)
-    cfg.define_split('tile_x', cfg.axis(x), num_outputs=2)
+    cfg.define_split("tile_y", cfg.axis(y), num_outputs=2)
+    cfg.define_split("tile_x", cfg.axis(x), num_outputs=2)
 
     return s, [A, B, C]
 
+
 def test_split():
     cfg = ConfigSpace()
 
     gemm_func(cfg, 128)
     assert len(cfg) == 64
-    assert len(cfg.space_map['tile_y']) == 8
+    assert len(cfg.space_map["tile_y"]) == 8
 
     # test policy
     cfg = ConfigSpace()
-    cfg.define_split('tile_x', cfg.axis(256), policy='factors', num_outputs=3)
-    assert len(cfg.space_map['tile_x']) == 45
+    cfg.define_split("tile_x", cfg.axis(256), policy="factors", num_outputs=3)
+    assert len(cfg.space_map["tile_x"]) == 45
 
-    cfg.define_split('tile_y', cfg.axis(256), policy='power2', num_outputs=3)
-    assert len(cfg.space_map['tile_y']) == 45
+    cfg.define_split("tile_y", cfg.axis(256), policy="power2", num_outputs=3)
+    assert len(cfg.space_map["tile_y"]) == 45
 
-    cfg.define_split('tile_z', cfg.axis(256), policy='verbose', num_outputs=3)
-    assert len(cfg.space_map['tile_z']) == 45
+    cfg.define_split("tile_z", cfg.axis(256), policy="verbose", num_outputs=3)
+    assert len(cfg.space_map["tile_z"]) == 45
 
-    cfg.define_split('tile_a', cfg.axis(224), policy='factors', num_outputs=3)
-    assert len(cfg.space_map['tile_a']) == 63
+    cfg.define_split("tile_a", cfg.axis(224), policy="factors", num_outputs=3)
+    assert len(cfg.space_map["tile_a"]) == 63
 
-    cfg.define_split('tile_b', cfg.axis(224), policy='power2', num_outputs=3)
-    assert len(cfg.space_map['tile_b']) == 36
+    cfg.define_split("tile_b", cfg.axis(224), policy="power2", num_outputs=3)
+    assert len(cfg.space_map["tile_b"]) == 36
 
-    cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
-    assert len(cfg.space_map['tile_c']) == 84
+    cfg.define_split("tile_c", cfg.axis(224), policy="verbose", num_outputs=3)
+    assert len(cfg.space_map["tile_c"]) == 84
 
     # Count the number of non-negative integer solutions of a + b + c + d = n
     def count4(n):
@@ -74,29 +76,29 @@ def test_split():
     # test overflow
     n = 25
     cfg = ConfigSpace()
-    cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4)
+    cfg.define_split("x", cfg.axis(2 ** n), policy="factors", num_outputs=4)
     # count4(25) is 3276.
-    assert len(cfg.space_map['x']) == count4(n)
+    assert len(cfg.space_map["x"]) == count4(n)
 
     # test fallback
     cfg = FallbackConfigEntity()
-    cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
-    cfg.fallback_split('tile_n', [-1, 8, 4])
-    assert cfg['tile_n'].size == [4, 8, 4]
+    cfg.define_split("tile_n", cfg.axis(128), num_outputs=3)
+    cfg.fallback_split("tile_n", [-1, 8, 4])
+    assert cfg["tile_n"].size == [4, 8, 4]
 
     cfg = FallbackConfigEntity()
-    cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
-    cfg.fallback_split('tile_n', [-1, 8, 4])
-    assert cfg['tile_n'].size == [7, 7, 1]
+    cfg.define_split("tile_n", cfg.axis(49), num_outputs=3)
+    cfg.fallback_split("tile_n", [-1, 8, 4])
+    assert cfg["tile_n"].size == [7, 7, 1]
 
     cfg = FallbackConfigEntity()
-    cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
+    cfg.define_split("tile_n", cfg.axis(49), num_outputs=3)
     try:
-        cfg.fallback_split('tile_n', [-1, 1, 0])
+        cfg.fallback_split("tile_n", [-1, 1, 0])
         assert False
     except RuntimeError:
         pass
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_split()
index 214a600..5789a9f 100644 (file)
@@ -31,10 +31,10 @@ def test_fit():
     task, target = get_sample_task()
     records = get_sample_records(n=500)
 
-    base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
+    base_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank")
     base_model.fit_log(records, plan_size=32)
 
-    upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
+    upper_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank")
     upper_model.load_basemodel(base_model)
 
     xs = np.arange(10)
@@ -54,4 +54,3 @@ def test_tuner():
 if __name__ == "__main__":
     test_fit()
     test_tuner()
-
index 73b0eef..6080a41 100644 (file)
@@ -23,124 +23,143 @@ import tempfile
 
 
 def setup_git_repo(worktree=False):
-  git_repo_dir = tempfile.mkdtemp()
-  to_rm = [git_repo_dir]
-  try:
-      subprocess.check_output(['git', 'init', '.'], cwd=git_repo_dir)
+    git_repo_dir = tempfile.mkdtemp()
+    to_rm = [git_repo_dir]
+    try:
+        subprocess.check_output(["git", "init", "."], cwd=git_repo_dir)
 
-      with open(f'{git_repo_dir}/committed', 'w') as committed_f:
-          committed_f.write('normal committed file\n')
+        with open(f"{git_repo_dir}/committed", "w") as committed_f:
+            committed_f.write("normal committed file\n")
 
-      subprocess.check_output(['git', 'add', 'committed'], cwd=git_repo_dir)
+        subprocess.check_output(["git", "add", "committed"], cwd=git_repo_dir)
 
-      with open(f'{git_repo_dir}/committed-ignored', 'w') as gitignore_f:
-          gitignore_f.write('this file is gitignored, but committed already')
+        with open(f"{git_repo_dir}/committed-ignored", "w") as gitignore_f:
+            gitignore_f.write("this file is gitignored, but committed already")
 
-      subprocess.check_output(['git', 'add', 'committed-ignored'], cwd=git_repo_dir)
+        subprocess.check_output(["git", "add", "committed-ignored"], cwd=git_repo_dir)
 
-      with open(f'{git_repo_dir}/.gitignore', 'w') as gitignore_f:
-          gitignore_f.write('ignored\n'
-                            'committed-ignored\n')
+        with open(f"{git_repo_dir}/.gitignore", "w") as gitignore_f:
+            gitignore_f.write("ignored\n" "committed-ignored\n")
 
-      subprocess.check_output(['git', 'add', '.gitignore'], cwd=git_repo_dir)
+        subprocess.check_output(["git", "add", ".gitignore"], cwd=git_repo_dir)
 
-      # NOTE: explicitly set the author so this test passes in the CI.
-      subprocess.check_output(['git',
-                               '-c', 'user.name=Unit Test',
-                               '-c', 'user.email=unit.test@testing.tvm.ai',
-                               'commit', '-m', 'initial commit'],
-                              cwd=git_repo_dir)
+        # NOTE: explicitly set the author so this test passes in the CI.
+        subprocess.check_output(
+            [
+                "git",
+                "-c",
+                "user.name=Unit Test",
+                "-c",
+                "user.email=unit.test@testing.tvm.ai",
+                "commit",
+                "-m",
+                "initial commit",
+            ],
+            cwd=git_repo_dir,
+        )
 
-      if worktree:
-        worktree_dir = tempfile.mkdtemp()
-        to_rm.append(worktree_dir)
-        subprocess.check_output(['git', 'worktree', 'add', worktree_dir], cwd=git_repo_dir)
-        git_repo_dir = worktree_dir
+        if worktree:
+            worktree_dir = tempfile.mkdtemp()
+            to_rm.append(worktree_dir)
+            subprocess.check_output(["git", "worktree", "add", worktree_dir], cwd=git_repo_dir)
+            git_repo_dir = worktree_dir
 
-      with open(f'{git_repo_dir}/ignored', 'w') as gitignore_f:
-          gitignore_f.write('this file is gitignored')
+        with open(f"{git_repo_dir}/ignored", "w") as gitignore_f:
+            gitignore_f.write("this file is gitignored")
 
-      with open(f'{git_repo_dir}/added-to-index', 'w') as added_f:
-          added_f.write('only added to git index\n')
+        with open(f"{git_repo_dir}/added-to-index", "w") as added_f:
+            added_f.write("only added to git index\n")
 
-      subprocess.check_output(['git', 'add', 'added-to-index'], cwd=git_repo_dir)
+        subprocess.check_output(["git", "add", "added-to-index"], cwd=git_repo_dir)
 
-      with open(f'{git_repo_dir}/ignored-added-to-index', 'w') as ignored_f:
-          ignored_f.write('this file is gitignored but in the index already\n')
+        with open(f"{git_repo_dir}/ignored-added-to-index", "w") as ignored_f:
+            ignored_f.write("this file is gitignored but in the index already\n")
 
-      subprocess.check_output(['git', 'add', '-f', 'ignored-added-to-index'], cwd=git_repo_dir)
+        subprocess.check_output(["git", "add", "-f", "ignored-added-to-index"], cwd=git_repo_dir)
 
-      with open(f'{git_repo_dir}/untracked', 'w') as untracked_f:
-          untracked_f.write('this file is untracked\n')
+        with open(f"{git_repo_dir}/untracked", "w") as untracked_f:
+            untracked_f.write("this file is untracked\n")
 
-      os.mkdir(f'{git_repo_dir}/subdir')
-      with open(f'{git_repo_dir}/subdir/untracked', 'w') as untracked_f:
-          untracked_f.write('this file is untracked\n')
+        os.mkdir(f"{git_repo_dir}/subdir")
+        with open(f"{git_repo_dir}/subdir/untracked", "w") as untracked_f:
+            untracked_f.write("this file is untracked\n")
 
-      with open(f'{git_repo_dir}/subdir/untracked2', 'w') as untracked_f:
-          untracked_f.write('this file is also untracked\n')
+        with open(f"{git_repo_dir}/subdir/untracked2", "w") as untracked_f:
+            untracked_f.write("this file is also untracked\n")
 
-      return git_repo_dir, to_rm
+        return git_repo_dir, to_rm
 
-  except Exception:
-      for rm_dir in to_rm:
-          shutil.rmtree(rm_dir)
-      raise
+    except Exception:
+        for rm_dir in to_rm:
+            shutil.rmtree(rm_dir)
+        raise
 
 
 def run_test(repo_path, passed_files, filtered_files):
-    test_input = '\n'.join(
-        passed_files +
-        filtered_files +
-        [f'./{f}' for f in passed_files] +
-        [f'./{f}' for f in filtered_files]) + '\n'
-
-    test_script_dir = f'{repo_path}/test-script-dir'
+    test_input = (
+        "\n".join(
+            passed_files
+            + filtered_files
+            + [f"./{f}" for f in passed_files]
+            + [f"./{f}" for f in filtered_files]
+        )
+        + "\n"
+    )
+
+    test_script_dir = f"{repo_path}/test-script-dir"
     os.mkdir(test_script_dir)
 
-    filter_script_path = f'{test_script_dir}/filter_untracked.py'
+    filter_script_path = f"{test_script_dir}/filter_untracked.py"
     test_script_dirname = os.path.dirname(__file__) or os.getcwd()
-    shutil.copy(os.path.realpath(f'{test_script_dirname}/../../lint/filter_untracked.py'),
-                filter_script_path)
+    shutil.copy(
+        os.path.realpath(f"{test_script_dirname}/../../lint/filter_untracked.py"),
+        filter_script_path,
+    )
     filter_proc = subprocess.Popen(
         [sys.executable, filter_script_path],
         cwd=repo_path,
         stdin=subprocess.PIPE,
         stdout=subprocess.PIPE,
-        encoding='utf-8')
+        encoding="utf-8",
+    )
     filter_output, _ = filter_proc.communicate(test_input)
-    filter_output_lines = [l for l in filter_output.split('\n') if l]
+    filter_output_lines = [l for l in filter_output.split("\n") if l]
 
     for pass_f in passed_files:
-        assert pass_f in filter_output_lines, (
-            f'expected in filter output: {pass_f}\filter output: {filter_output}')
-        assert f'./{pass_f}' in filter_output_lines, (
-            f'expected in filter output: ./{pass_f}\filter output: {filter_output}')
+        assert (
+            pass_f in filter_output_lines
+        ), f"expected in filter output: {pass_f}\filter output: {filter_output}"
+        assert (
+            f"./{pass_f}" in filter_output_lines
+        ), f"expected in filter output: ./{pass_f}\filter output: {filter_output}"
 
     for filter_f in filtered_files:
-        assert filter_f not in filter_output_lines, (
-            f'expected not in filter output: {filter_f}\nfilter_output: {filter_output}')
-        assert f'./{filter_f}' not in filter_output_lines, (
-            f'expected not in filter output: ./{filter_f}\nfilter_output: {filter_output}')
+        assert (
+            filter_f not in filter_output_lines
+        ), f"expected not in filter output: {filter_f}\nfilter_output: {filter_output}"
+        assert (
+            f"./{filter_f}" not in filter_output_lines
+        ), f"expected not in filter output: ./{filter_f}\nfilter_output: {filter_output}"
 
-    assert len(filter_output_lines) == 2 * len(passed_files), (
-        f'expected {len(filter_output_lines)} == 2 * {len(passed_files)}')
+    assert len(filter_output_lines) == 2 * len(
+        passed_files
+    ), f"expected {len(filter_output_lines)} == 2 * {len(passed_files)}"
 
 
 def test_filter_untracked():
     repo_path, to_rm = setup_git_repo()
     try:
         passed_files = [
-            'committed',
-            'committed-ignored',
-            'added-to-index',
-            'ignored-added-to-index',
+            "committed",
+            "committed-ignored",
+            "added-to-index",
+            "ignored-added-to-index",
         ]
         filtered_files = [
-            'ignored',
-            'untracked',
-            'subdir/untracked',
-            'subdir/untracked2',
+            "ignored",
+            "untracked",
+            "subdir/untracked",
+            "subdir/untracked2",
         ]
         run_test(repo_path, passed_files, filtered_files)
 
@@ -153,17 +172,17 @@ def test_worktree():
     repo_path, to_rm = setup_git_repo(worktree=True)
     try:
         passed_files = [
-            'committed',
-            'committed-ignored',
-            'added-to-index',
-            'ignored-added-to-index',
+            "committed",
+            "committed-ignored",
+            "added-to-index",
+            "ignored-added-to-index",
         ]
         filtered_files = [
-            'ignored',
-            'untracked',
-            'subdir/untracked',
-            'subdir/untracked2',
-            '.git',
+            "ignored",
+            "untracked",
+            "subdir/untracked",
+            "subdir/untracked2",
+            ".git",
         ]
         run_test(repo_path, passed_files, filtered_files)
 
@@ -172,6 +191,6 @@ def test_worktree():
             shutil.rmtree(rm_dir)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_filter_untracked()
     test_worktree()
index 69be62a..b655654 100644 (file)
@@ -20,22 +20,22 @@ import random
 from tvm.autotvm import util
 
 
-SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY'
+SI_PREFIXES = "yzafpn\xb5m kMGTPEZY"
 
 
 def test_format_si_prefix():
-  # test float conversion
-  assert util.format_si_prefix(1024, 'k') == 1.024
+    # test float conversion
+    assert util.format_si_prefix(1024, "k") == 1.024
 
-  for i, prefix in enumerate(SI_PREFIXES):
-    integer, decimal = random.randint(0, 1000), random.randint(0, 1000)
-    exp = -24 + 3 * i   # 0th prefix (yocto) is 10^-24
-    number = integer * (10 ** exp) + decimal * (10 ** (exp - 3))
-    expected = (integer + decimal / 1000)
-    assert isclose(util.format_si_prefix(number, prefix), expected)
+    for i, prefix in enumerate(SI_PREFIXES):
+        integer, decimal = random.randint(0, 1000), random.randint(0, 1000)
+        exp = -24 + 3 * i  # 0th prefix (yocto) is 10^-24
+        number = integer * (10 ** exp) + decimal * (10 ** (exp - 3))
+        expected = integer + decimal / 1000
+        assert isclose(util.format_si_prefix(number, prefix), expected)
 
-  assert util.format_si_prefix(0, 'y') == 0
+    assert util.format_si_prefix(0, "y") == 0
 
 
-if __name__ == '__main__':
-  test_format_si_prefix()
+if __name__ == "__main__":
+    test_format_si_prefix()
index 0dfdbbd..39b8bfc 100644 (file)
@@ -91,11 +91,11 @@ def wrap_error(module, lineno):
     assert error is not None
     e = error.value
     print(e)
-    msg = str(e).split('\n')[-1].split(':', maxsplit=1)[0].strip().split(' ')[-1].strip()
+    msg = str(e).split("\n")[-1].split(":", maxsplit=1)[0].strip().split(" ")[-1].strip()
     assert int(msg) == lineno
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     wrap_error(Module1, 29)
     wrap_error(Module2, 39)
     wrap_error(Module3, 50)
index 7b706bd..90d76a2 100644 (file)
@@ -37,24 +37,42 @@ class Module1:
         for x in tir.range(0, 32, "parallel"):
             for y in tir.range(0, 1024):
                 for z in tir.range(0, 32, "vectorized"):
-                    packedB[x, y, z] = B_1[y, ((x*32) + z)]
+                    packedB[x, y, z] = B_1[y, ((x * 32) + z)]
         tir.attr(C_1, "realize_scope", "")
         tir.realize(C_1[0:1024, 0:1024])
         for x_outer in tir.range(0, 32, "parallel"):
             for y_outer in tir.range(0, 32):
                 tir.attr(C_global, "realize_scope", "global")
-                tir.realize(C_global[(x_outer*32):((x_outer*32) + 32), (y_outer*32):((y_outer*32) + 32)])
+                tir.realize(
+                    C_global[
+                        (x_outer * 32) : ((x_outer * 32) + 32),
+                        (y_outer * 32) : ((y_outer * 32) + 32),
+                    ]
+                )
                 for x_c_init in tir.range(0, 32):
                     for y_c_init in tir.range(0, 32, "vectorized"):
-                        C_global[(x_c_init + (x_outer*32)), (y_c_init + (y_outer*32))] = tir.float32(0)
+                        C_global[
+                            (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32))
+                        ] = tir.float32(0)
                 for k_outer in tir.range(0, 256):
                     for x_c in tir.range(0, 32):
                         for k_inner in tir.range(0, 4, "unroll"):
                             for y_c in tir.range(0, 32, "vectorized"):
-                                C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] = (C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] + (A_1[(x_c + (x_outer*32)), (k_inner + (k_outer*4))]*packedB[tir.floordiv((y_c + (y_outer*32)), 32), (k_inner + (k_outer*4)), tir.floormod((y_c + (y_outer*32)), 32)]))
+                                C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] = C_global[
+                                    (x_c + (x_outer * 32)), (y_c + (y_outer * 32))
+                                ] + (
+                                    A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))]
+                                    * packedB[
+                                        tir.floordiv((y_c + (y_outer * 32)), 32),
+                                        (k_inner + (k_outer * 4)),
+                                        tir.floormod((y_c + (y_outer * 32)), 32),
+                                    ]
+                                )
                 for x_inner in tir.range(0, 32):
                     for y_inner in tir.range(0, 32):
-                        C_1[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))] = C_global[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))]
+                        C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[
+                            (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))
+                        ]
 
 
 def test_opt_gemm_normalize():
@@ -81,20 +99,164 @@ class Module2:
         tir.allocate(C_global, "float32", [1024])
         for x in tir.range(0, 32, "parallel"):
             for y in tir.range(0, 1024):
-                tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B_1.data, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32))
+                tir.store(
+                    packedB,
+                    tir.ramp(((x * 32768) + (y * 32)), 1, 32),
+                    tir.load(
+                        "float32x32",
+                        B_1.data,
+                        tir.ramp(((y * 1024) + (x * 32)), 1, 32),
+                        tir.broadcast(True, 32),
+                    ),
+                    tir.broadcast(True, 32),
+                )
         for x_outer in tir.range(0, 32):
             for y_outer in tir.range(0, 32):
                 for x_c_init in tir.range(0, 32):
-                    tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32))
+                    tir.store(
+                        C_global,
+                        tir.ramp((x_c_init * 32), 1, 32),
+                        tir.broadcast(tir.float32(0), 32),
+                        tir.broadcast(True, 32),
+                    )
                 for k_outer in tir.range(0, 256):
                     for x_c in tir.range(0, 32):
-                        tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32)*tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32))
-                        tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32))
-                        tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32))
-                        tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32))
+                        tir.store(
+                            C_global,
+                            tir.ramp((x_c * 32), 1, 32),
+                            (
+                                tir.load(
+                                    "float32x32",
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.broadcast(True, 32),
+                                )
+                                + (
+                                    tir.broadcast(
+                                        tir.load(
+                                            "float32",
+                                            A_1.data,
+                                            (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)),
+                                        ),
+                                        32,
+                                    )
+                                    * tir.load(
+                                        "float32x32",
+                                        packedB,
+                                        tir.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32),
+                                        tir.broadcast(True, 32),
+                                    )
+                                )
+                            ),
+                            tir.broadcast(True, 32),
+                        )
+                        tir.store(
+                            C_global,
+                            tir.ramp((x_c * 32), 1, 32),
+                            (
+                                tir.load(
+                                    "float32x32",
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.broadcast(True, 32),
+                                )
+                                + (
+                                    tir.broadcast(
+                                        tir.load(
+                                            "float32",
+                                            A_1.data,
+                                            (
+                                                (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4))
+                                                + 1
+                                            ),
+                                        ),
+                                        32,
+                                    )
+                                    * tir.load(
+                                        "float32x32",
+                                        packedB,
+                                        tir.ramp(
+                                            (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32
+                                        ),
+                                        tir.broadcast(True, 32),
+                                    )
+                                )
+                            ),
+                            tir.broadcast(True, 32),
+                        )
+                        tir.store(
+                            C_global,
+                            tir.ramp((x_c * 32), 1, 32),
+                            (
+                                tir.load(
+                                    "float32x32",
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.broadcast(True, 32),
+                                )
+                                + (
+                                    tir.broadcast(
+                                        tir.load(
+                                            "float32",
+                                            A_1.data,
+                                            (
+                                                (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4))
+                                                + 2
+                                            ),
+                                        ),
+                                        32,
+                                    )
+                                    * tir.load(
+                                        "float32x32",
+                                        packedB,
+                                        tir.ramp(
+                                            (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32
+                                        ),
+                                        tir.broadcast(True, 32),
+                                    )
+                                )
+                            ),
+                            tir.broadcast(True, 32),
+                        )
+                        tir.store(
+                            C_global,
+                            tir.ramp((x_c * 32), 1, 32),
+                            (
+                                tir.load(
+                                    "float32x32",
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.broadcast(True, 32),
+                                )
+                                + (
+                                    tir.broadcast(
+                                        tir.load(
+                                            "float32",
+                                            A_1.data,
+                                            (
+                                                (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4))
+                                                + 3
+                                            ),
+                                        ),
+                                        32,
+                                    )
+                                    * tir.load(
+                                        "float32x32",
+                                        packedB,
+                                        tir.ramp(
+                                            (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32
+                                        ),
+                                        tir.broadcast(True, 32),
+                                    )
+                                )
+                            ),
+                            tir.broadcast(True, 32),
+                        )
                 for x_inner in tir.range(0, 32):
                     for y_inner in tir.range(0, 32):
-                        C_1.data[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner))
+                        C_1.data[
+                            ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner)
+                        ] = tir.load("float32", C_global, ((x_inner * 32) + y_inner))
 
 
 def test_opt_gemm_lower():
@@ -105,14 +267,27 @@ def test_opt_gemm_lower():
 
 @tvm.hybrid.script
 class Module3:
-    def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32:
+    def mmult(
+        args: ty.handle,
+        arg_type_ids: ty.handle,
+        num_args: ty.int32,
+        out_ret_value: ty.handle,
+        out_ret_tcode: ty.handle,
+    ) -> ty.int32:
         # function attr dict
-        tir.func_attr({"tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1})
+        tir.func_attr(
+            {
+                "tir.noalias": True,
+                "global_symbol": "mmult",
+                "tir.is_entry_func": True,
+                "calling_conv": 1,
+            }
+        )
         # var definition
         C_global = tir.var("handle")
         packedB = tir.var("handle")
         # body
-        assert (num_args == 3), "mmult: num_args should be 3"
+        assert num_args == 3, "mmult: num_args should be 3"
         arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
         arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)
         arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
@@ -132,71 +307,309 @@ class Module3:
         tir.attr(C, "storage_alignment", 128)
         arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
         arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
-        assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer"
-        assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer"
-        assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer"
-        assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
-        assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
-        assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint"
+        assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (
+            arg0_code == 4
+        ), "mmult: Expect arg[0] to be pointer"
+        assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (
+            arg1_code == 4
+        ), "mmult: Expect arg[1] to be pointer"
+        assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (
+            arg2_code == 4
+        ), "mmult: Expect arg[2] to be pointer"
+        assert 2 == tir.tvm_struct_get(
+            arg0, 0, 4, dtype="int32"
+        ), "arg0.ndim is expected to equal 2"
+        assert 2 == tir.tvm_struct_get(
+            arg0, 0, 4, dtype="int32"
+        ), "arg0.ndim is expected to equal 2"
+        assert (
+            (tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2))
+            and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))
+        ) and (
+            tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1)
+        ), "arg0.dtype is expected to be float32"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg0_shape, 0)
+        ), "Argument arg0.shape[0] has an unsatisfied constraint"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg0_shape, 1)
+        ), "Argument arg0.shape[1] has an unsatisfied constraint"
         if not (tir.isnullptr(arg0_strides, dtype="bool")):
-            assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array"
+            assert (1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (
+                1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0))
+            ), "arg0.strides: expected to be compact array"
             tir.evaluate(0)
-        assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint"
-        assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint"
-        assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
-        assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
-        assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint"
+        assert tir.uint64(0) == tir.tvm_struct_get(
+            arg0, 0, 8, dtype="uint64"
+        ), "Argument arg0.byte_offset has an unsatisfied constraint"
+        assert 1 == tir.tvm_struct_get(
+            arg0, 0, 10, dtype="int32"
+        ), "Argument arg0.device_type has an unsatisfied constraint"
+        assert 2 == tir.tvm_struct_get(
+            arg1, 0, 4, dtype="int32"
+        ), "arg1.ndim is expected to equal 2"
+        assert 2 == tir.tvm_struct_get(
+            arg1, 0, 4, dtype="int32"
+        ), "arg1.ndim is expected to equal 2"
+        assert (
+            (tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2))
+            and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))
+        ) and (
+            tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1)
+        ), "arg1.dtype is expected to be float32"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg1_shape, 0)
+        ), "Argument arg1.shape[0] has an unsatisfied constraint"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg1_shape, 1)
+        ), "Argument arg1.shape[1] has an unsatisfied constraint"
         if not (tir.isnullptr(arg1_strides, dtype="bool")):
-            assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array"
+            assert (1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (
+                1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0))
+            ), "arg1.strides: expected to be compact array"
             tir.evaluate(0)
-        assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint"
-        assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint"
-        assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint"
-        assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
-        assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
-        assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint"
-        assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint"
+        assert tir.uint64(0) == tir.tvm_struct_get(
+            arg1, 0, 8, dtype="uint64"
+        ), "Argument arg1.byte_offset has an unsatisfied constraint"
+        assert 1 == tir.tvm_struct_get(
+            arg1, 0, 10, dtype="int32"
+        ), "Argument arg1.device_type has an unsatisfied constraint"
+        assert dev_id == tir.tvm_struct_get(
+            arg1, 0, 9, dtype="int32"
+        ), "Argument arg1.device_id has an unsatisfied constraint"
+        assert 2 == tir.tvm_struct_get(
+            arg2, 0, 4, dtype="int32"
+        ), "arg2.ndim is expected to equal 2"
+        assert 2 == tir.tvm_struct_get(
+            arg2, 0, 4, dtype="int32"
+        ), "arg2.ndim is expected to equal 2"
+        assert (
+            (tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2))
+            and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))
+        ) and (
+            tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1)
+        ), "arg2.dtype is expected to be float32"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg2_shape, 0)
+        ), "Argument arg2.shape[0] has an unsatisfied constraint"
+        assert 1024 == tir.cast(
+            "int32", tir.load("int64", arg2_shape, 1)
+        ), "Argument arg2.shape[1] has an unsatisfied constraint"
         if not (tir.isnullptr(arg2_strides, dtype="bool")):
-            assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array"
+            assert (1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (
+                1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0))
+            ), "arg2.strides: expected to be compact array"
             tir.evaluate(0)
-        assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint"
-        assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint"
-        assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint"
+        assert tir.uint64(0) == tir.tvm_struct_get(
+            arg2, 0, 8, dtype="uint64"
+        ), "Argument arg2.byte_offset has an unsatisfied constraint"
+        assert 1 == tir.tvm_struct_get(
+            arg2, 0, 10, dtype="int32"
+        ), "Argument arg2.device_type has an unsatisfied constraint"
+        assert dev_id == tir.tvm_struct_get(
+            arg2, 0, 9, dtype="int32"
+        ), "Argument arg2.device_id has an unsatisfied constraint"
         tir.attr(0, "compute_scope", "mmult_compute_")
         tir.attr(packedB, "storage_scope", "global")
         tir.attr(packedB, "storage_alignment", 128)
-        with tir.let(packedB, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4194304), 2, 32, dtype="handle")):
+        with tir.let(
+            packedB,
+            tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4194304), 2, 32, dtype="handle"),
+        ):
             if tir.isnullptr(packedB, dtype="bool"):
                 tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
             for x in tir.range(0, 32, "parallel"):
                 for y in tir.range(0, 1024):
-                    tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32))
+                    tir.store(
+                        packedB,
+                        tir.ramp(((x * 32768) + (y * 32)), 1, 32),
+                        tir.load(
+                            "float32x32",
+                            B,
+                            tir.ramp(((y * 1024) + (x * 32)), 1, 32),
+                            tir.broadcast(True, 32),
+                        ),
+                        tir.broadcast(True, 32),
+                    )
             for x_outer in tir.range(0, 32, "parallel"):
                 tir.attr(C_global, "storage_scope", "global")
                 tir.attr(C_global, "storage_alignment", 128)
-                with tir.let(C_global, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4096), 2, 32, dtype="handle")):
+                with tir.let(
+                    C_global,
+                    tir.TVMBackendAllocWorkspace(
+                        1, dev_id, tir.uint64(4096), 2, 32, dtype="handle"
+                    ),
+                ):
                     if tir.isnullptr(C_global, dtype="bool"):
                         tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
                     for y_outer in tir.range(0, 32):
                         for x_c_init in tir.range(0, 32):
-                            tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32))
+                            tir.store(
+                                C_global,
+                                tir.ramp((x_c_init * 32), 1, 32),
+                                tir.broadcast(tir.float32(0), 32),
+                                tir.broadcast(True, 32),
+                            )
                         for k_outer in tir.range(0, 256):
                             for x_c in tir.range(0, 32):
-                                tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32), tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32))
-                                tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32))
-                                tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32))
-                                tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32))
+                                tir.store(
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.call_llvm_pure_intrin(
+                                        tir.uint32(97),
+                                        tir.uint32(3),
+                                        tir.broadcast(
+                                            tir.load(
+                                                "float32",
+                                                A,
+                                                (
+                                                    ((x_outer * 32768) + (x_c * 1024))
+                                                    + (k_outer * 4)
+                                                ),
+                                            ),
+                                            32,
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            packedB,
+                                            tir.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            C_global,
+                                            tir.ramp((x_c * 32), 1, 32),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        dtype="float32x32",
+                                    ),
+                                    tir.broadcast(True, 32),
+                                )
+                                tir.store(
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.call_llvm_pure_intrin(
+                                        tir.uint32(97),
+                                        tir.uint32(3),
+                                        tir.broadcast(
+                                            tir.load(
+                                                "float32",
+                                                A,
+                                                (
+                                                    (
+                                                        ((x_outer * 32768) + (x_c * 1024))
+                                                        + (k_outer * 4)
+                                                    )
+                                                    + 1
+                                                ),
+                                            ),
+                                            32,
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            packedB,
+                                            tir.ramp(
+                                                (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32
+                                            ),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            C_global,
+                                            tir.ramp((x_c * 32), 1, 32),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        dtype="float32x32",
+                                    ),
+                                    tir.broadcast(True, 32),
+                                )
+                                tir.store(
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.call_llvm_pure_intrin(
+                                        tir.uint32(97),
+                                        tir.uint32(3),
+                                        tir.broadcast(
+                                            tir.load(
+                                                "float32",
+                                                A,
+                                                (
+                                                    (
+                                                        ((x_outer * 32768) + (x_c * 1024))
+                                                        + (k_outer * 4)
+                                                    )
+                                                    + 2
+                                                ),
+                                            ),
+                                            32,
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            packedB,
+                                            tir.ramp(
+                                                (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32
+                                            ),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            C_global,
+                                            tir.ramp((x_c * 32), 1, 32),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        dtype="float32x32",
+                                    ),
+                                    tir.broadcast(True, 32),
+                                )
+                                tir.store(
+                                    C_global,
+                                    tir.ramp((x_c * 32), 1, 32),
+                                    tir.call_llvm_pure_intrin(
+                                        tir.uint32(97),
+                                        tir.uint32(3),
+                                        tir.broadcast(
+                                            tir.load(
+                                                "float32",
+                                                A,
+                                                (
+                                                    (
+                                                        ((x_outer * 32768) + (x_c * 1024))
+                                                        + (k_outer * 4)
+                                                    )
+                                                    + 3
+                                                ),
+                                            ),
+                                            32,
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            packedB,
+                                            tir.ramp(
+                                                (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32
+                                            ),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        tir.load(
+                                            "float32x32",
+                                            C_global,
+                                            tir.ramp((x_c * 32), 1, 32),
+                                            tir.broadcast(True, 32),
+                                        ),
+                                        dtype="float32x32",
+                                    ),
+                                    tir.broadcast(True, 32),
+                                )
                         for x_inner in tir.range(0, 32):
                             for y_inner in tir.range(0, 32):
-                                C[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner))
-                if (tir.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0):
+                                C[
+                                    (
+                                        (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32))
+                                        + y_inner
+                                    )
+                                ] = tir.load("float32", C_global, ((x_inner * 32) + y_inner))
+                if tir.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0:
                     tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
-        if (tir.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0):
+        if tir.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0:
             tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
 
 
@@ -218,23 +631,49 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) -
     threadIdx_y = tir.var("int32")
     threadIdx_z = tir.var("int32")
     # buffer definition
-    Apad_shared = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    Apad_shared_wmma_matrix_a = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    BA = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256)
-    BB = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256)
+    Apad_shared = tir.buffer_decl(
+        [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    Apad_shared_wmma_matrix_a = tir.buffer_decl(
+        [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    BA = tir.buffer_decl(
+        [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256
+    )
+    BB = tir.buffer_decl(
+        [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256
+    )
     BC = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256)
-    Conv_wmma_accumulator = tir.buffer_decl([16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1)
-    W_shared = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    W_shared_wmma_matrix_b = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
+    Conv_wmma_accumulator = tir.buffer_decl(
+        [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1
+    )
+    W_shared = tir.buffer_decl(
+        [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    W_shared_wmma_matrix_b = tir.buffer_decl(
+        [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
     buffer = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256)
-    buffer_1 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256)
-    buffer_2 = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256)
-    buffer_3 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256)
+    buffer_1 = tir.buffer_decl(
+        [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256
+    )
+    buffer_2 = tir.buffer_decl(
+        [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256
+    )
+    buffer_3 = tir.buffer_decl(
+        [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256
+    )
     buffer_4 = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256)
     buffer_5 = tir.buffer_decl([16, 16], align=32, offset_factor=256)
-    A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1)
+    A_1 = tir.buffer_bind(
+        A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    W_1 = tir.buffer_bind(
+        W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    Conv_1 = tir.buffer_bind(
+        Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1
+    )
     # body
     tir.attr(Conv_1, "realize_scope", "")
     tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16])
@@ -244,52 +683,414 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) -
     tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4)
     tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2)
     tir.attr(Conv_wmma_accumulator, "realize_scope", "wmma.accumulator")
-    tir.realize(Conv_wmma_accumulator[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), tir.floordiv(blockIdx_z, 14):(tir.floordiv(blockIdx_z, 14) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16])
+    tir.realize(
+        Conv_wmma_accumulator[
+            ((blockIdx_x * 8) + (threadIdx_y * 2)) : (((blockIdx_x * 8) + (threadIdx_y * 2)) + 2),
+            tir.floordiv(blockIdx_z, 14) : (tir.floordiv(blockIdx_z, 14) + 1),
+            tir.floormod(blockIdx_z, 14) : (tir.floormod(blockIdx_z, 14) + 1),
+            ((blockIdx_y * 8) + (threadIdx_z * 4)) : (((blockIdx_y * 8) + (threadIdx_z * 4)) + 4),
+            0:16,
+            0:16,
+        ]
+    )
     for n_c_init in tir.range(0, 2):
         for o_c_init in tir.range(0, 4):
-            tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c_init + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c_init + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle"))
-            tir.evaluate(tir.tvm_fill_fragment(BC.data, 16, 16, 16, tir.floordiv(BC.elem_offset, 256), tir.float32(0), dtype="handle"))
+            tir.attr(
+                [BC, Conv_wmma_accumulator],
+                "buffer_bind_scope",
+                tir.tvm_tuple(
+                    (n_c_init + ((blockIdx_x * 8) + (threadIdx_y * 2))),
+                    1,
+                    tir.floordiv(blockIdx_z, 14),
+                    1,
+                    tir.floormod(blockIdx_z, 14),
+                    1,
+                    (o_c_init + ((blockIdx_y * 8) + (threadIdx_z * 4))),
+                    1,
+                    0,
+                    16,
+                    0,
+                    16,
+                    dtype="handle",
+                ),
+            )
+            tir.evaluate(
+                tir.tvm_fill_fragment(
+                    BC.data,
+                    16,
+                    16,
+                    16,
+                    tir.floordiv(BC.elem_offset, 256),
+                    tir.float32(0),
+                    dtype="handle",
+                )
+            )
     for ic_outer in tir.range(0, 8):
         for kh in tir.range(0, 3):
             tir.attr(Apad_shared, "realize_scope", "shared")
-            tir.realize(Apad_shared[(blockIdx_x*8):((blockIdx_x*8) + 8), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 3), (ic_outer*2):((ic_outer*2) + 2), 0:16, 0:16])
+            tir.realize(
+                Apad_shared[
+                    (blockIdx_x * 8) : ((blockIdx_x * 8) + 8),
+                    (tir.floordiv(blockIdx_z, 14) + kh) : ((tir.floordiv(blockIdx_z, 14) + kh) + 1),
+                    tir.floormod(blockIdx_z, 14) : (tir.floormod(blockIdx_z, 14) + 3),
+                    (ic_outer * 2) : ((ic_outer * 2) + 2),
+                    0:16,
+                    0:16,
+                ]
+            )
             for ax2 in tir.range(0, 3):
                 for ax3 in tir.range(0, 2):
                     for ax4_ax5_fused_outer in tir.range(0, 8):
-                        tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32)
-                        Apad_shared[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), (tir.floordiv(blockIdx_z, 14) + kh), (ax2 + tir.floormod(blockIdx_z, 14)), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)] = tir.if_then_else((((((tir.floordiv(blockIdx_z, 14) + kh) >= 1) and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14)) and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1)) and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14)), A_1[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), ((tir.floordiv(blockIdx_z, 14) + kh) - 1), ((ax2 + tir.floormod(blockIdx_z, 14)) - 1), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)], tir.float16(0), dtype="float16")
+                        tir.attr(
+                            tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                            "thread_extent",
+                            32,
+                        )
+                        Apad_shared[
+                            ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_x * 8)),
+                            (tir.floordiv(blockIdx_z, 14) + kh),
+                            (ax2 + tir.floormod(blockIdx_z, 14)),
+                            (ax3 + (ic_outer * 2)),
+                            tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16),
+                            tir.floormod((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16),
+                        ] = tir.if_then_else(
+                            (
+                                (
+                                    (
+                                        ((tir.floordiv(blockIdx_z, 14) + kh) >= 1)
+                                        and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14)
+                                    )
+                                    and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1)
+                                )
+                                and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14)
+                            ),
+                            A_1[
+                                ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_x * 8)),
+                                ((tir.floordiv(blockIdx_z, 14) + kh) - 1),
+                                ((ax2 + tir.floormod(blockIdx_z, 14)) - 1),
+                                (ax3 + (ic_outer * 2)),
+                                tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16),
+                                tir.floormod((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16),
+                            ],
+                            tir.float16(0),
+                            dtype="float16",
+                        )
             tir.attr(W_shared, "realize_scope", "shared")
-            tir.realize(W_shared[kh:(kh + 1), 0:3, (ic_outer*2):((ic_outer*2) + 2), (blockIdx_y*8):((blockIdx_y*8) + 8), 0:16, 0:16])
+            tir.realize(
+                W_shared[
+                    kh : (kh + 1),
+                    0:3,
+                    (ic_outer * 2) : ((ic_outer * 2) + 2),
+                    (blockIdx_y * 8) : ((blockIdx_y * 8) + 8),
+                    0:16,
+                    0:16,
+                ]
+            )
             for ax1 in tir.range(0, 3):
                 for ax2_1 in tir.range(0, 2):
-                    tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32)
+                    tir.attr(
+                        tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                        "thread_extent",
+                        32,
+                    )
                     for ax4_ax5_fused_inner in tir.range(0, 8, "vectorized"):
-                        W_shared[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] = W_1[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)]
+                        W_shared[
+                            kh,
+                            ax1,
+                            (ax2_1 + (ic_outer * 2)),
+                            ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_y * 8)),
+                            tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16),
+                            tir.floormod((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16),
+                        ] = W_1[
+                            kh,
+                            ax1,
+                            (ax2_1 + (ic_outer * 2)),
+                            ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_y * 8)),
+                            tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16),
+                            tir.floormod((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16),
+                        ]
             for ic_inner in tir.range(0, 2):
                 for kw in tir.range(0, 3):
                     tir.attr(Apad_shared_wmma_matrix_a, "realize_scope", "wmma.matrix_a")
-                    tir.realize(Apad_shared_wmma_matrix_a[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), (kw + tir.floormod(blockIdx_z, 14)):((kw + tir.floormod(blockIdx_z, 14)) + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), 0:16, 0:16])
+                    tir.realize(
+                        Apad_shared_wmma_matrix_a[
+                            ((blockIdx_x * 8) + (threadIdx_y * 2)) : (
+                                ((blockIdx_x * 8) + (threadIdx_y * 2)) + 2
+                            ),
+                            (tir.floordiv(blockIdx_z, 14) + kh) : (
+                                (tir.floordiv(blockIdx_z, 14) + kh) + 1
+                            ),
+                            (kw + tir.floormod(blockIdx_z, 14)) : (
+                                (kw + tir.floormod(blockIdx_z, 14)) + 1
+                            ),
+                            ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1),
+                            0:16,
+                            0:16,
+                        ]
+                    )
                     for ax0 in tir.range(0, 2):
-                        tir.attr([buffer, Apad_shared], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle"))
-                        tir.attr([buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle"))
-                        tir.evaluate(tir.tvm_load_matrix_sync(buffer_1.data, 16, 16, 16, tir.floordiv(buffer_1.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer.data, buffer.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
+                        tir.attr(
+                            [buffer, Apad_shared],
+                            "buffer_bind_scope",
+                            tir.tvm_tuple(
+                                (ax0 + ((blockIdx_x * 8) + (threadIdx_y * 2))),
+                                1,
+                                (tir.floordiv(blockIdx_z, 14) + kh),
+                                1,
+                                (kw + tir.floormod(blockIdx_z, 14)),
+                                1,
+                                ((ic_outer * 2) + ic_inner),
+                                1,
+                                0,
+                                16,
+                                0,
+                                16,
+                                dtype="handle",
+                            ),
+                        )
+                        tir.attr(
+                            [buffer_1, Apad_shared_wmma_matrix_a],
+                            "buffer_bind_scope",
+                            tir.tvm_tuple(
+                                (ax0 + ((blockIdx_x * 8) + (threadIdx_y * 2))),
+                                1,
+                                (tir.floordiv(blockIdx_z, 14) + kh),
+                                1,
+                                (kw + tir.floormod(blockIdx_z, 14)),
+                                1,
+                                ((ic_outer * 2) + ic_inner),
+                                1,
+                                0,
+                                16,
+                                0,
+                                16,
+                                dtype="handle",
+                            ),
+                        )
+                        tir.evaluate(
+                            tir.tvm_load_matrix_sync(
+                                buffer_1.data,
+                                16,
+                                16,
+                                16,
+                                tir.floordiv(buffer_1.elem_offset, 256),
+                                tir.tvm_access_ptr(
+                                    tir.type_annotation(dtype="float16"),
+                                    buffer.data,
+                                    buffer.elem_offset,
+                                    256,
+                                    1,
+                                    dtype="handle",
+                                ),
+                                16,
+                                "row_major",
+                                dtype="handle",
+                            )
+                        )
                     tir.attr(W_shared_wmma_matrix_b, "realize_scope", "wmma.matrix_b")
-                    tir.realize(W_shared_wmma_matrix_b[kh:(kh + 1), kw:(kw + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16])
+                    tir.realize(
+                        W_shared_wmma_matrix_b[
+                            kh : (kh + 1),
+                            kw : (kw + 1),
+                            ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1),
+                            ((blockIdx_y * 8) + (threadIdx_z * 4)) : (
+                                ((blockIdx_y * 8) + (threadIdx_z * 4)) + 4
+                            ),
+                            0:16,
+                            0:16,
+                        ]
+                    )
                     for ax3_1 in tir.range(0, 4):
-                        tir.attr([buffer_2, W_shared], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle"))
-                        tir.attr([buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle"))
-                        tir.evaluate(tir.tvm_load_matrix_sync(buffer_3.data, 16, 16, 16, tir.floordiv(buffer_3.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer_2.data, buffer_2.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
+                        tir.attr(
+                            [buffer_2, W_shared],
+                            "buffer_bind_scope",
+                            tir.tvm_tuple(
+                                kh,
+                                1,
+                                kw,
+                                1,
+                                ((ic_outer * 2) + ic_inner),
+                                1,
+                                (ax3_1 + ((blockIdx_y * 8) + (threadIdx_z * 4))),
+                                1,
+                                0,
+                                16,
+                                0,
+                                16,
+                                dtype="handle",
+                            ),
+                        )
+                        tir.attr(
+                            [buffer_3, W_shared_wmma_matrix_b],
+                            "buffer_bind_scope",
+                            tir.tvm_tuple(
+                                kh,
+                                1,
+                                kw,
+                                1,
+                                ((ic_outer * 2) + ic_inner),
+                                1,
+                                (ax3_1 + ((blockIdx_y * 8) + (threadIdx_z * 4))),
+                                1,
+                                0,
+                                16,
+                                0,
+                                16,
+                                dtype="handle",
+                            ),
+                        )
+                        tir.evaluate(
+                            tir.tvm_load_matrix_sync(
+                                buffer_3.data,
+                                16,
+                                16,
+                                16,
+                                tir.floordiv(buffer_3.elem_offset, 256),
+                                tir.tvm_access_ptr(
+                                    tir.type_annotation(dtype="float16"),
+                                    buffer_2.data,
+                                    buffer_2.elem_offset,
+                                    256,
+                                    1,
+                                    dtype="handle",
+                                ),
+                                16,
+                                "row_major",
+                                dtype="handle",
+                            )
+                        )
                     for n_c in tir.range(0, 2):
                         for o_c in tir.range(0, 4):
-                            tir.attr([BA, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (tir.floormod(blockIdx_z, 14) + kw), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle"))
-                            tir.attr([BB, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle"))
-                            tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle"))
-                            tir.evaluate(tir.tvm_mma_sync(BC.data, tir.floordiv(BC.elem_offset, 256), BA.data, tir.floordiv(BA.elem_offset, 256), BB.data, tir.floordiv(BB.elem_offset, 256), BC.data, tir.floordiv(BC.elem_offset, 256), dtype="handle"))
+                            tir.attr(
+                                [BA, Apad_shared_wmma_matrix_a],
+                                "buffer_bind_scope",
+                                tir.tvm_tuple(
+                                    (n_c + ((blockIdx_x * 8) + (threadIdx_y * 2))),
+                                    1,
+                                    (tir.floordiv(blockIdx_z, 14) + kh),
+                                    1,
+                                    (tir.floormod(blockIdx_z, 14) + kw),
+                                    1,
+                                    ((ic_outer * 2) + ic_inner),
+                                    1,
+                                    0,
+                                    16,
+                                    0,
+                                    16,
+                                    dtype="handle",
+                                ),
+                            )
+                            tir.attr(
+                                [BB, W_shared_wmma_matrix_b],
+                                "buffer_bind_scope",
+                                tir.tvm_tuple(
+                                    kh,
+                                    1,
+                                    kw,
+                                    1,
+                                    ((ic_outer * 2) + ic_inner),
+                                    1,
+                                    (o_c + ((blockIdx_y * 8) + (threadIdx_z * 4))),
+                                    1,
+                                    0,
+                                    16,
+                                    0,
+                                    16,
+                                    dtype="handle",
+                                ),
+                            )
+                            tir.attr(
+                                [BC, Conv_wmma_accumulator],
+                                "buffer_bind_scope",
+                                tir.tvm_tuple(
+                                    (n_c + ((blockIdx_x * 8) + (threadIdx_y * 2))),
+                                    1,
+                                    tir.floordiv(blockIdx_z, 14),
+                                    1,
+                                    tir.floormod(blockIdx_z, 14),
+                                    1,
+                                    (o_c + ((blockIdx_y * 8) + (threadIdx_z * 4))),
+                                    1,
+                                    0,
+                                    16,
+                                    0,
+                                    16,
+                                    dtype="handle",
+                                ),
+                            )
+                            tir.evaluate(
+                                tir.tvm_mma_sync(
+                                    BC.data,
+                                    tir.floordiv(BC.elem_offset, 256),
+                                    BA.data,
+                                    tir.floordiv(BA.elem_offset, 256),
+                                    BB.data,
+                                    tir.floordiv(BB.elem_offset, 256),
+                                    BC.data,
+                                    tir.floordiv(BC.elem_offset, 256),
+                                    dtype="handle",
+                                )
+                            )
     for n_inner in tir.range(0, 2):
         for o_inner in tir.range(0, 4):
-            tir.attr([buffer_4, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle"))
-            tir.attr([buffer_5, Conv_1], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle"))
-            tir.evaluate(tir.tvm_store_matrix_sync(buffer_4.data, 16, 16, 16, tir.floordiv(buffer_4.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), buffer_5.data, buffer_5.elem_offset, 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
+            tir.attr(
+                [buffer_4, Conv_wmma_accumulator],
+                "buffer_bind_scope",
+                tir.tvm_tuple(
+                    ((((blockIdx_x * 4) + threadIdx_y) * 2) + n_inner),
+                    1,
+                    tir.floordiv(blockIdx_z, 14),
+                    1,
+                    tir.floormod(blockIdx_z, 14),
+                    1,
+                    ((((blockIdx_y * 2) + threadIdx_z) * 4) + o_inner),
+                    1,
+                    0,
+                    16,
+                    0,
+                    16,
+                    dtype="handle",
+                ),
+            )
+            tir.attr(
+                [buffer_5, Conv_1],
+                "buffer_bind_scope",
+                tir.tvm_tuple(
+                    ((((blockIdx_x * 4) + threadIdx_y) * 2) + n_inner),
+                    1,
+                    tir.floordiv(blockIdx_z, 14),
+                    1,
+                    tir.floormod(blockIdx_z, 14),
+                    1,
+                    ((((blockIdx_y * 2) + threadIdx_z) * 4) + o_inner),
+                    1,
+                    0,
+                    16,
+                    0,
+                    16,
+                    dtype="handle",
+                ),
+            )
+            tir.evaluate(
+                tir.tvm_store_matrix_sync(
+                    buffer_4.data,
+                    16,
+                    16,
+                    16,
+                    tir.floordiv(buffer_4.elem_offset, 256),
+                    tir.tvm_access_ptr(
+                        tir.type_annotation(dtype="float32"),
+                        buffer_5.data,
+                        buffer_5.elem_offset,
+                        256,
+                        2,
+                        dtype="handle",
+                    ),
+                    16,
+                    "row_major",
+                    dtype="handle",
+                )
+            )
 
 
 def test_opt_conv_tensorcore_normalize():
@@ -314,9 +1115,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No
     threadIdx_x = tir.var("int32")
     threadIdx_y = tir.var("int32")
     threadIdx_z = tir.var("int32")
-    A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
-    Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1)
+    A_1 = tir.buffer_bind(
+        A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    W_1 = tir.buffer_bind(
+        W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1
+    )
+    Conv_1 = tir.buffer_bind(
+        Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1
+    )
     # body
     tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196)
     tir.attr(Conv_wmma_accumulator, "storage_scope", "wmma.accumulator")
@@ -333,85 +1140,1616 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No
     tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4)
     tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4)
     tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2)
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, tir.float32(0), dtype="handle"))
-    tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle"))
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, tir.float32(0), dtype="handle")
+    )
+    tir.evaluate(
+        tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle")
+    )
     for ic_outer in tir.range(0, 8):
         for kh in tir.range(0, 3):
             for ax2 in tir.range(0, 3):
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61440)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 32)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61408)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 64)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61376)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 96)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61344)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 128)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61312)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 160)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61280)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 192)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61248)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 224)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61216)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 256)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61184)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 288)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61152)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 320)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61120)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 352)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61088)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 384)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61056)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 416)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61024)), tir.float16(0), dtype="float16")
-                with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                    Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 448)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60992)), tir.float16(0), dtype="float16")
-                tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32)
-                Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 480)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60960)), tir.float16(0), dtype="float16")
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.load("float16x8", W_1.data, tir.ramp(((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 2048), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 4096), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 131072), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 6144), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 139264), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 262144), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
-            with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32):
-                tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 10240), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 270336), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8))
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                            + threadIdx_x
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61440
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 32
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61408
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 64
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61376
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 96
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61344
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 128
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61312
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 160
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61280
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 192
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61248
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 224
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61216
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 256
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61184
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 288
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61152
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 320
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61120
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 352
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61088
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 384
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61056
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 416
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 61024
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                with tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                ):
+                    Apad_shared[
+                        (
+                            (
+                                (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                                + threadIdx_x
+                            )
+                            + 448
+                        )
+                    ] = tir.if_then_else(
+                        (
+                            (
+                                (
+                                    (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                    and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                                )
+                                and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                            )
+                            and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                        ),
+                        tir.load(
+                            "float16",
+                            A_1.data,
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (
+                                                            (blockIdx_x * 6422528)
+                                                            + (threadIdx_y * 1605632)
+                                                        )
+                                                        + (threadIdx_z * 802816)
+                                                    )
+                                                    + (kh * 57344)
+                                                )
+                                                + (blockIdx_z * 4096)
+                                            )
+                                            + (ax2 * 4096)
+                                        )
+                                        + (ic_outer * 512)
+                                    )
+                                    + threadIdx_x
+                                )
+                                - 60992
+                            ),
+                        ),
+                        tir.float16(0),
+                        dtype="float16",
+                    )
+                tir.attr(
+                    tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"),
+                    "thread_extent",
+                    32,
+                )
+                Apad_shared[
+                    (
+                        (
+                            (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512))
+                            + threadIdx_x
+                        )
+                        + 480
+                    )
+                ] = tir.if_then_else(
+                    (
+                        (
+                            (
+                                (1 <= (tir.floordiv(blockIdx_z, 14) + kh))
+                                and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)
+                            )
+                            and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))
+                        )
+                        and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)
+                    ),
+                    tir.load(
+                        "float16",
+                        A_1.data,
+                        (
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                (
+                                                    (
+                                                        (blockIdx_x * 6422528)
+                                                        + (threadIdx_y * 1605632)
+                                                    )
+                                                    + (threadIdx_z * 802816)
+                                                )
+                                                + (kh * 57344)
+                                            )
+                                            + (blockIdx_z * 4096)
+                                        )
+                                        + (ax2 * 4096)
+                                    )
+                                    + (ic_outer * 512)
+                                )
+                                + threadIdx_x
+                            )
+                            - 60960
+                        ),
+                    ),
+                    tir.float16(0),
+                    dtype="float16",
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        (((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)), 1, 8
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (((kh * 393216) + (ic_outer * 16384)) + (blockIdx_y * 2048))
+                                        + (threadIdx_y * 512)
+                                    )
+                                    + (threadIdx_z * 256)
+                                )
+                                + (threadIdx_x * 8)
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 2048),
+                        1,
+                        8,
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                ((kh * 393216) + (ic_outer * 16384))
+                                                + (blockIdx_y * 2048)
+                                            )
+                                            + (threadIdx_y * 512)
+                                        )
+                                        + (threadIdx_z * 256)
+                                    )
+                                    + (threadIdx_x * 8)
+                                )
+                                + 8192
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 4096),
+                        1,
+                        8,
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                ((kh * 393216) + (ic_outer * 16384))
+                                                + (blockIdx_y * 2048)
+                                            )
+                                            + (threadIdx_y * 512)
+                                        )
+                                        + (threadIdx_z * 256)
+                                    )
+                                    + (threadIdx_x * 8)
+                                )
+                                + 131072
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 6144),
+                        1,
+                        8,
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                ((kh * 393216) + (ic_outer * 16384))
+                                                + (blockIdx_y * 2048)
+                                            )
+                                            + (threadIdx_y * 512)
+                                        )
+                                        + (threadIdx_z * 256)
+                                    )
+                                    + (threadIdx_x * 8)
+                                )
+                                + 139264
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 8192),
+                        1,
+                        8,
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                ((kh * 393216) + (ic_outer * 16384))
+                                                + (blockIdx_y * 2048)
+                                            )
+                                            + (threadIdx_y * 512)
+                                        )
+                                        + (threadIdx_z * 256)
+                                    )
+                                    + (threadIdx_x * 8)
+                                )
+                                + 262144
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
+            with tir.attr(
+                tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32
+            ):
+                tir.store(
+                    W_shared,
+                    tir.ramp(
+                        ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 10240),
+                        1,
+                        8,
+                    ),
+                    tir.load(
+                        "float16x8",
+                        W_1.data,
+                        tir.ramp(
+                            (
+                                (
+                                    (
+                                        (
+                                            (
+                                                ((kh * 393216) + (ic_outer * 16384))
+                                                + (blockIdx_y * 2048)
+                                            )
+                                            + (threadIdx_y * 512)
+                                        )
+                                        + (threadIdx_z * 256)
+                                    )
+                                    + (threadIdx_x * 8)
+                                )
+                                + 270336
+                            ),
+                            1,
+                            8,
+                        ),
+                        tir.broadcast(True, 8),
+                    ),
+                    tir.broadcast(True, 8),
+                )
             for ic_inner in tir.range(0, 2):
                 for kw in tir.range(0, 3):
-                    tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, (((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, ((((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)) + 1536), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, (((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 256), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 512), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 768), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 0, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 0, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 1, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 1, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 2, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 2, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 3, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 3, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 4, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 4, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 5, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 5, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 6, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 6, dtype="handle"))
-                    tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 7, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 7, dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, (((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 256), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 512), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 768), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 4, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605632), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 5, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605888), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 6, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606144), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
-    tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 7, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606400), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            Apad_shared_wmma_matrix_a,
+                            16,
+                            16,
+                            16,
+                            0,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                Apad_shared,
+                                (((threadIdx_y * 3072) + (kw * 512)) + (ic_inner * 256)),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            Apad_shared_wmma_matrix_a,
+                            16,
+                            16,
+                            16,
+                            1,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                Apad_shared,
+                                ((((threadIdx_y * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            W_shared_wmma_matrix_b,
+                            16,
+                            16,
+                            16,
+                            0,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                W_shared,
+                                (((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            W_shared_wmma_matrix_b,
+                            16,
+                            16,
+                            16,
+                            1,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                W_shared,
+                                ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 256),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            W_shared_wmma_matrix_b,
+                            16,
+                            16,
+                            16,
+                            2,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                W_shared,
+                                ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 512),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_load_matrix_sync(
+                            W_shared_wmma_matrix_b,
+                            16,
+                            16,
+                            16,
+                            3,
+                            tir.tvm_access_ptr(
+                                tir.type_annotation(dtype="float16"),
+                                W_shared,
+                                ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 768),
+                                256,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            0,
+                            Apad_shared_wmma_matrix_a,
+                            0,
+                            W_shared_wmma_matrix_b,
+                            0,
+                            Conv_wmma_accumulator,
+                            0,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            1,
+                            Apad_shared_wmma_matrix_a,
+                            0,
+                            W_shared_wmma_matrix_b,
+                            1,
+                            Conv_wmma_accumulator,
+                            1,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            2,
+                            Apad_shared_wmma_matrix_a,
+                            0,
+                            W_shared_wmma_matrix_b,
+                            2,
+                            Conv_wmma_accumulator,
+                            2,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            3,
+                            Apad_shared_wmma_matrix_a,
+                            0,
+                            W_shared_wmma_matrix_b,
+                            3,
+                            Conv_wmma_accumulator,
+                            3,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            4,
+                            Apad_shared_wmma_matrix_a,
+                            1,
+                            W_shared_wmma_matrix_b,
+                            0,
+                            Conv_wmma_accumulator,
+                            4,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            5,
+                            Apad_shared_wmma_matrix_a,
+                            1,
+                            W_shared_wmma_matrix_b,
+                            1,
+                            Conv_wmma_accumulator,
+                            5,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            6,
+                            Apad_shared_wmma_matrix_a,
+                            1,
+                            W_shared_wmma_matrix_b,
+                            2,
+                            Conv_wmma_accumulator,
+                            6,
+                            dtype="handle",
+                        )
+                    )
+                    tir.evaluate(
+                        tir.tvm_mma_sync(
+                            Conv_wmma_accumulator,
+                            7,
+                            Apad_shared_wmma_matrix_a,
+                            1,
+                            W_shared_wmma_matrix_b,
+                            3,
+                            Conv_wmma_accumulator,
+                            7,
+                            dtype="handle",
+                        )
+                    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            0,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) + (blockIdx_z * 8192))
+                        + (blockIdx_y * 2048)
+                    )
+                    + (threadIdx_z * 1024)
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            1,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 256
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            2,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 512
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            3,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 768
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            4,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 1605632
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            5,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 1605888
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            6,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 1606144
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
+    tir.evaluate(
+        tir.tvm_store_matrix_sync(
+            Conv_wmma_accumulator,
+            16,
+            16,
+            16,
+            7,
+            tir.tvm_access_ptr(
+                tir.type_annotation(dtype="float32"),
+                Conv_1.data,
+                (
+                    (
+                        (
+                            (
+                                ((blockIdx_x * 12845056) + (threadIdx_y * 3211264))
+                                + (blockIdx_z * 8192)
+                            )
+                            + (blockIdx_y * 2048)
+                        )
+                        + (threadIdx_z * 1024)
+                    )
+                    + 1606400
+                ),
+                256,
+                2,
+                dtype="handle",
+            ),
+            16,
+            "row_major",
+            dtype="handle",
+        )
+    )
 
 
 def test_opt_conv_tensorcore_lower():
@@ -421,13 +2759,27 @@ def test_opt_conv_tensorcore_lower():
 
 
 @tvm.hybrid.script
-def opt_conv_tensorcore_mod_host(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle, resource_handle: ty.handle) -> ty.int32:
+def opt_conv_tensorcore_mod_host(
+    args: ty.handle,
+    arg_type_ids: ty.handle,
+    num_args: ty.int32,
+    out_ret_value: ty.handle,
+    out_ret_tcode: ty.handle,
+    resource_handle: ty.handle,
+) -> ty.int32:
     # function attr dict
-    tir.func_attr({"tir.noalias": True, "global_symbol": "default_function", "tir.is_entry_func": True, "calling_conv": 1})
+    tir.func_attr(
+        {
+            "tir.noalias": True,
+            "global_symbol": "default_function",
+            "tir.is_entry_func": True,
+            "calling_conv": 1,
+        }
+    )
     # body
     stack_tcode: ty.handle = tir.tvm_stack_alloca("arg_tcode", 10, dtype="handle")
     stack_value: ty.handle = tir.tvm_stack_alloca("arg_value", 10, dtype="handle")
-    assert (num_args == 3), "default_function: num_args should be 3"
+    assert num_args == 3, "default_function: num_args should be 3"
     arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
     arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)
     arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
@@ -447,58 +2799,177 @@ def opt_conv_tensorcore_mod_host(args: ty.handle, arg_type_ids: ty.handle, num_a
     tir.attr(Conv, "storage_alignment", 128)
     arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
     arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
-    assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "default_function: Expect arg[0] to be pointer"
-    assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "default_function: Expect arg[1] to be pointer"
-    assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "default_function: Expect arg[2] to be pointer"
-    assert (6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 6"
-    assert (6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 6"
-    assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(16))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float16"
-    assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint"
-    assert (14 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint"
-    assert (14 == tir.cast("int32", tir.load("int64", arg0_shape, 2))), "Argument arg0.shape[2] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 3))), "Argument arg0.shape[3] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 4))), "Argument arg0.shape[4] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 5))), "Argument arg0.shape[5] has an unsatisfied constraint"
+    assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (
+        arg0_code == 4
+    ), "default_function: Expect arg[0] to be pointer"
+    assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (
+        arg1_code == 4
+    ), "default_function: Expect arg[1] to be pointer"
+    assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (
+        arg2_code == 4
+    ), "default_function: Expect arg[2] to be pointer"
+    assert 6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6"
+    assert 6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6"
+    assert (
+        (tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2))
+        and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(16))
+    ) and (
+        tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1)
+    ), "arg0.dtype is expected to be float16"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 0)
+    ), "Argument arg0.shape[0] has an unsatisfied constraint"
+    assert 14 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 1)
+    ), "Argument arg0.shape[1] has an unsatisfied constraint"
+    assert 14 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 2)
+    ), "Argument arg0.shape[2] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 3)
+    ), "Argument arg0.shape[3] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 4)
+    ), "Argument arg0.shape[4] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg0_shape, 5)
+    ), "Argument arg0.shape[5] has an unsatisfied constraint"
     if not (tir.isnullptr(arg0_strides, dtype="bool")):
-        assert ((((((1 == tir.cast("int32", tir.load("int64", arg0_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg0_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg0_strides, 3)))) and (4096 == tir.cast("int32", tir.load("int64", arg0_strides, 2)))) and (57344 == tir.cast("int32", tir.load("int64", arg0_strides, 1)))) and (802816 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array"
+        assert (
+            (
+                (
+                    (
+                        (1 == tir.cast("int32", tir.load("int64", arg0_strides, 5)))
+                        and (16 == tir.cast("int32", tir.load("int64", arg0_strides, 4)))
+                    )
+                    and (256 == tir.cast("int32", tir.load("int64", arg0_strides, 3)))
+                )
+                and (4096 == tir.cast("int32", tir.load("int64", arg0_strides, 2)))
+            )
+            and (57344 == tir.cast("int32", tir.load("int64", arg0_strides, 1)))
+        ) and (
+            802816 == tir.cast("int32", tir.load("int64", arg0_strides, 0))
+        ), "arg0.strides: expected to be compact array"
         tir.evaluate(0)
-    assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint"
-    assert (2 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint"
-    assert (6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 6"
-    assert (6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 6"
-    assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(16))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float16"
-    assert (3 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint"
-    assert (3 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 2))), "Argument arg1.shape[2] has an unsatisfied constraint"
-    assert (32 == tir.cast("int32", tir.load("int64", arg1_shape, 3))), "Argument arg1.shape[3] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 4))), "Argument arg1.shape[4] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 5))), "Argument arg1.shape[5] has an unsatisfied constraint"
+    assert tir.uint64(0) == tir.tvm_struct_get(
+        arg0, 0, 8, dtype="uint64"
+    ), "Argument arg0.byte_offset has an unsatisfied constraint"
+    assert 2 == tir.tvm_struct_get(
+        arg0, 0, 10, dtype="int32"
+    ), "Argument arg0.device_type has an unsatisfied constraint"
+    assert 6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6"
+    assert 6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6"
+    assert (
+        (tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2))
+        and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(16))
+    ) and (
+        tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1)
+    ), "arg1.dtype is expected to be float16"
+    assert 3 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 0)
+    ), "Argument arg1.shape[0] has an unsatisfied constraint"
+    assert 3 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 1)
+    ), "Argument arg1.shape[1] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 2)
+    ), "Argument arg1.shape[2] has an unsatisfied constraint"
+    assert 32 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 3)
+    ), "Argument arg1.shape[3] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 4)
+    ), "Argument arg1.shape[4] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg1_shape, 5)
+    ), "Argument arg1.shape[5] has an unsatisfied constraint"
     if not (tir.isnullptr(arg1_strides, dtype="bool")):
-        assert ((((((1 == tir.cast("int32", tir.load("int64", arg1_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg1_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg1_strides, 3)))) and (8192 == tir.cast("int32", tir.load("int64", arg1_strides, 2)))) and (131072 == tir.cast("int32", tir.load("int64", arg1_strides, 1)))) and (393216 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array"
+        assert (
+            (
+                (
+                    (
+                        (1 == tir.cast("int32", tir.load("int64", arg1_strides, 5)))
+                        and (16 == tir.cast("int32", tir.load("int64", arg1_strides, 4)))
+                    )
+                    and (256 == tir.cast("int32", tir.load("int64", arg1_strides, 3)))
+                )
+                and (8192 == tir.cast("int32", tir.load("int64", arg1_strides, 2)))
+            )
+            and (131072 == tir.cast("int32", tir.load("int64", arg1_strides, 1)))
+        ) and (
+            393216 == tir.cast("int32", tir.load("int64", arg1_strides, 0))
+        ), "arg1.strides: expected to be compact array"
         tir.evaluate(0)
-    assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint"
-    assert (2 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint"
-    assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint"
-    assert (6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 6"
-    assert (6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 6"
-    assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32"
-    assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint"
-    assert (14 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint"
-    assert (14 == tir.cast("int32", tir.load("int64", arg2_shape, 2))), "Argument arg2.shape[2] has an unsatisfied constraint"
-    assert (32 == tir.cast("int32", tir.load("int64", arg2_shape, 3))), "Argument arg2.shape[3] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 4))), "Argument arg2.shape[4] has an unsatisfied constraint"
-    assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 5))), "Argument arg2.shape[5] has an unsatisfied constraint"
+    assert tir.uint64(0) == tir.tvm_struct_get(
+        arg1, 0, 8, dtype="uint64"
+    ), "Argument arg1.byte_offset has an unsatisfied constraint"
+    assert 2 == tir.tvm_struct_get(
+        arg1, 0, 10, dtype="int32"
+    ), "Argument arg1.device_type has an unsatisfied constraint"
+    assert dev_id == tir.tvm_struct_get(
+        arg1, 0, 9, dtype="int32"
+    ), "Argument arg1.device_id has an unsatisfied constraint"
+    assert 6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6"
+    assert 6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6"
+    assert (
+        (tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2))
+        and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))
+    ) and (
+        tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1)
+    ), "arg2.dtype is expected to be float32"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 0)
+    ), "Argument arg2.shape[0] has an unsatisfied constraint"
+    assert 14 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 1)
+    ), "Argument arg2.shape[1] has an unsatisfied constraint"
+    assert 14 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 2)
+    ), "Argument arg2.shape[2] has an unsatisfied constraint"
+    assert 32 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 3)
+    ), "Argument arg2.shape[3] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 4)
+    ), "Argument arg2.shape[4] has an unsatisfied constraint"
+    assert 16 == tir.cast(
+        "int32", tir.load("int64", arg2_shape, 5)
+    ), "Argument arg2.shape[5] has an unsatisfied constraint"
     if not (tir.isnullptr(arg2_strides, dtype="bool")):
-        assert ((((((1 == tir.cast("int32", tir.load("int64", arg2_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg2_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg2_strides, 3)))) and (8192 == tir.cast("int32", tir.load("int64", arg2_strides, 2)))) and (114688 == tir.cast("int32", tir.load("int64", arg2_strides, 1)))) and (1605632 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array"
+        assert (
+            (
+                (
+                    (
+                        (1 == tir.cast("int32", tir.load("int64", arg2_strides, 5)))
+                        and (16 == tir.cast("int32", tir.load("int64", arg2_strides, 4)))
+                    )
+                    and (256 == tir.cast("int32", tir.load("int64", arg2_strides, 3)))
+                )
+                and (8192 == tir.cast("int32", tir.load("int64", arg2_strides, 2)))
+            )
+            and (114688 == tir.cast("int32", tir.load("int64", arg2_strides, 1)))
+        ) and (
+            1605632 == tir.cast("int32", tir.load("int64", arg2_strides, 0))
+        ), "arg2.strides: expected to be compact array"
         tir.evaluate(0)
-    assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint"
-    assert (2 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint"
-    assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint"
+    assert tir.uint64(0) == tir.tvm_struct_get(
+        arg2, 0, 8, dtype="uint64"
+    ), "Argument arg2.byte_offset has an unsatisfied constraint"
+    assert 2 == tir.tvm_struct_get(
+        arg2, 0, 10, dtype="int32"
+    ), "Argument arg2.device_type has an unsatisfied constraint"
+    assert dev_id == tir.tvm_struct_get(
+        arg2, 0, 9, dtype="int32"
+    ), "Argument arg2.device_id has an unsatisfied constraint"
     tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, tir.cast("int64", 2), dtype="int32"))
     stack_tcode[0] = 0
     tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, tir.cast("int64", dev_id), dtype="int32"))
     stack_tcode[1] = 0
-    tir.evaluate(tir.tvm_call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32"))
+    tir.evaluate(
+        tir.tvm_call_packed_lowered(
+            "__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32"
+        )
+    )
     tir.attr(0, "compute_scope", "default_function_compute_")
     tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, A, dtype="int32"))
     stack_tcode[0] = 3
@@ -518,7 +2989,11 @@ def opt_conv_tensorcore_mod_host(args: ty.handle, arg_type_ids: ty.handle, num_a
     stack_tcode[7] = 0
     tir.evaluate(tir.tvm_struct_set(stack_value, 8, 12, tir.cast("int64", 32), dtype="int32"))
     stack_tcode[8] = 0
-    tir.evaluate(tir.tvm_call_packed_lowered("default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32"))
+    tir.evaluate(
+        tir.tvm_call_packed_lowered(
+            "default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32"
+        )
+    )
 
 
 def test_opt_conv_tensorcore_mod_host():
@@ -527,7 +3002,7 @@ def test_opt_conv_tensorcore_mod_host():
     tvm.ir.assert_structural_equal(mod, rt_mod, True)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_opt_gemm_normalize()
     test_opt_gemm_mod_host()
     test_opt_gemm_lower()
index 233e59b..9ac0648 100644 (file)
@@ -18,6 +18,7 @@ import tvm
 import pytest
 import tvm.ir._ffi_api
 
+
 def test_make_attrs():
     with pytest.raises(AttributeError):
         x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
@@ -25,7 +26,7 @@ def test_make_attrs():
     with pytest.raises(AttributeError):
         x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
 
-    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     assert x.name == "xx"
     assert x.padding[0].value == 3
     assert x.padding[1].value == 4
@@ -33,7 +34,7 @@ def test_make_attrs():
 
 
 def test_dict_attrs():
-    dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
+    dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0, 0))
     assert dattr.x.value == 1
     datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
     assert dattr.name == "xyz"
@@ -55,7 +56,6 @@ def test_attrs_equal():
     assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
 
 
-
 if __name__ == "__main__":
     test_make_attrs()
     test_dict_attrs()
index c2d3aba..46388b3 100644 (file)
@@ -18,25 +18,26 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 def test_array():
-    a = tvm.runtime.convert([1,2,3])
+    a = tvm.runtime.convert([1, 2, 3])
     assert len(a) == 3
     assert a[-1].value == 3
     a_slice = a[-3:-1]
     assert (a_slice[0].value, a_slice[1].value) == (1, 2)
 
+
 def test_array_save_load_json():
-    a = tvm.runtime.convert([1,2,3])
+    a = tvm.runtime.convert([1, 2, 3])
     json_str = tvm.ir.save_json(a)
     a_loaded = tvm.ir.load_json(json_str)
-    assert(a_loaded[1].value == 2)
+    assert a_loaded[1].value == 2
 
 
 def test_map():
-    a = te.var('a')
-    b = te.var('b')
-    amap = tvm.runtime.convert({a: 2,
-                        b: 3})
+    a = te.var("a")
+    b = te.var("b")
+    amap = tvm.runtime.convert({a: 2, b: 3})
     assert a in amap
     assert len(amap) == 2
     dd = dict(amap.items())
@@ -46,35 +47,35 @@ def test_map():
 
 
 def test_str_map():
-    amap = tvm.runtime.convert({'a': 2, 'b': 3})
-    assert 'a' in amap
+    amap = tvm.runtime.convert({"a": 2, "b": 3})
+    assert "a" in amap
     assert len(amap) == 2
     dd = dict(amap.items())
-    assert amap['a'].value == 2
-    assert 'a' in dd
-    assert 'b' in dd
+    assert amap["a"].value == 2
+    assert "a" in dd
+    assert "b" in dd
 
 
 def test_map_save_load_json():
-    a = te.var('a')
-    b = te.var('b')
-    amap = tvm.runtime.convert({a: 2,
-                        b: 3})
+    a = te.var("a")
+    b = te.var("b")
+    amap = tvm.runtime.convert({a: 2, b: 3})
     json_str = tvm.ir.save_json(amap)
     amap = tvm.ir.load_json(json_str)
     assert len(amap) == 2
-    dd = {kv[0].name : kv[1].value for kv in amap.items()}
-    assert(dd == {"a": 2, "b": 3})
+    dd = {kv[0].name: kv[1].value for kv in amap.items()}
+    assert dd == {"a": 2, "b": 3}
 
 
 def test_in_container():
-    arr = tvm.runtime.convert(['a', 'b', 'c'])
-    assert 'a' in arr
-    assert tvm.tir.StringImm('a') in arr
-    assert 'd' not in arr
+    arr = tvm.runtime.convert(["a", "b", "c"])
+    assert "a" in arr
+    assert tvm.tir.StringImm("a") in arr
+    assert "d" not in arr
+
 
 def test_ndarray_container():
-    x = tvm.nd.array([1,2,3])
+    x = tvm.nd.array([1, 2, 3])
     arr = tvm.runtime.convert([x, x])
     assert arr[0].same_as(x)
     assert arr[1].same_as(x)
index 1072efb..986e48d 100644 (file)
@@ -17,6 +17,7 @@
 """Test type nodes in the IR"""
 import tvm
 
+
 def check_json_roundtrip(node):
     json_str = tvm.ir.save_json(node)
     back = tvm.ir.load_json(json_str)
@@ -35,9 +36,10 @@ def test_tensor_type_bad_constructor():
     except tvm.error.TVMError:
         pass
 
+
 def test_tensor_type():
     shape = tvm.runtime.convert([1, 2, 3])
-    dtype = 'float32'
+    dtype = "float32"
     tt = tvm.ir.TensorType(shape, dtype)
     assert tt.dtype == dtype
     assert tt.shape == shape
@@ -47,7 +49,7 @@ def test_tensor_type():
 
 
 def test_type_param():
-    tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type)
+    tp = tvm.ir.TypeVar("name", tvm.ir.TypeKind.Type)
     assert tp.kind == tvm.ir.TypeKind.Type
     # assert tp.span  # TODO allow us to set span
     str(tp)
@@ -58,7 +60,7 @@ def test_func_type():
     type_params = tvm.runtime.convert([])
     type_constraints = tvm.runtime.convert([])  # TODO: fill me in
     arg_types = tvm.runtime.convert([])
-    ret_type = tvm.ir.TensorType((1, 2, 3), 'float32')
+    ret_type = tvm.ir.TensorType((1, 2, 3), "float32")
     tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
     assert tf.type_params == type_params
     assert tf.type_constraints == type_constraints
@@ -71,9 +73,9 @@ def test_func_type():
 
 
 def test_tuple_type():
-    tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+    tp = tvm.ir.TypeVar("tp", tvm.ir.TypeKind.Type)
     tf = tvm.ir.FuncType([], tvm.ir.TupleType([]), [], [])
-    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
+    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
     fields = tvm.runtime.convert([tp, tf, tt])
 
     tup_ty = tvm.ir.TupleType(fields)
@@ -81,16 +83,16 @@ def test_tuple_type():
     str(tup_ty)
     check_json_roundtrip(tup_ty)
 
+
 def test_type_relation():
-    tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+    tp = tvm.ir.TypeVar("tp", tvm.ir.TypeKind.Type)
     tf = tvm.ir.FuncType([], None, [], [])
-    tt = tvm.ir.TensorType(
-        tvm.runtime.convert([1, 2, 3]), 'float32')
+    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), "float32")
     args = tvm.runtime.convert([tp, tf, tt])
 
     num_inputs = 2
     func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
-    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4))
 
     tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
     assert tr.args == args
@@ -98,6 +100,7 @@ def test_type_relation():
     str(tr)
     check_json_roundtrip(tr)
 
+
 if __name__ == "__main__":
     test_tensor_type_bad_constructor()
     test_tensor_type()
index edf8b42..0db25f7 100644 (file)
@@ -18,6 +18,7 @@ import tvm
 import pytest
 from tvm import te
 
+
 def test_const_saveload_json():
     # save load json
     x = tvm.tir.const(1, "int32")
@@ -28,16 +29,19 @@ def test_const_saveload_json():
     zz = tvm.ir.load_json(json_str)
     tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
 
+
 def _test_infinity_value(value, dtype):
     x = tvm.tir.const(value, dtype)
     json_str = tvm.ir.save_json(x)
     tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str))
 
+
 def test_infinity_value():
-    _test_infinity_value(float("inf"), 'float64')
-    _test_infinity_value(float("-inf"), 'float64')
-    _test_infinity_value(float("inf"), 'float32')
-    _test_infinity_value(float("-inf"), 'float32')
+    _test_infinity_value(float("inf"), "float64")
+    _test_infinity_value(float("-inf"), "float64")
+    _test_infinity_value(float("inf"), "float32")
+    _test_infinity_value(float("-inf"), "float32")
+
 
 def test_make_smap():
     # save load json
@@ -56,12 +60,10 @@ def test_make_node():
     x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
     assert isinstance(x, tvm.tir.IntImm)
     assert x.value == 10
-    A = te.placeholder((10, ), name='A')
-    AA = tvm.ir.make_node("Tensor",
-                       shape=A.shape,
-                       dtype=A.dtype,
-                       op=A.op,
-                       value_index=A.value_index)
+    A = te.placeholder((10,), name="A")
+    AA = tvm.ir.make_node(
+        "Tensor", shape=A.shape, dtype=A.dtype, op=A.op, value_index=A.value_index
+    )
     assert AA.op == A.op
     assert AA.value_index == A.value_index
 
@@ -69,8 +71,8 @@ def test_make_node():
 
 
 def test_make_sum():
-    A = te.placeholder((2, 10), name='A')
-    k = te.reduce_axis((0,10), "k")
+    A = te.placeholder((2, 10), name="A")
+    k = te.reduce_axis((0, 10), "k")
     B = te.compute((2,), lambda i: te.sum(A[i, k], axis=k), name="B")
     json_str = tvm.ir.save_json(B)
     BB = tvm.ir.load_json(json_str)
@@ -92,7 +94,7 @@ def test_env_func():
     assert y(1) == 2
     assert y.func(1) == 2
 
-    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4), func=y)
     assert x.name == "xx"
     assert x.padding[0].value == 3
     assert x.padding[1].value == 4
@@ -115,12 +117,15 @@ def test_string():
 
 
 def test_pass_config():
-    cfg = tvm.transform.PassContext(opt_level=1, config={
-        "tir.UnrollLoop": {
-            "auto_max_step": 10,
-        }
-    })
-    cfg.opt_level  == 1
+    cfg = tvm.transform.PassContext(
+        opt_level=1,
+        config={
+            "tir.UnrollLoop": {
+                "auto_max_step": 10,
+            }
+        },
+    )
+    cfg.opt_level == 1
 
     assert cfg.config["tir.UnrollLoop"].auto_max_step == 10
     # default option
@@ -128,22 +133,19 @@ def test_pass_config():
 
     # schema checking for specific config key
     with pytest.raises(AttributeError):
-        cfg = tvm.transform.PassContext(config={
-            "tir.UnrollLoop": { "invalid": 1 }
-        })
+        cfg = tvm.transform.PassContext(config={"tir.UnrollLoop": {"invalid": 1}})
 
     # schema check for un-registered config
     with pytest.raises(AttributeError):
-        cfg = tvm.transform.PassContext(config={ "inavlid-opt": True })
+        cfg = tvm.transform.PassContext(config={"inavlid-opt": True})
 
     # schema check for wrong type
     with pytest.raises(AttributeError):
-        cfg = tvm.transform.PassContext(config={
-            "tir.UnrollLoop": 1
-        })
+        cfg = tvm.transform.PassContext(config={"tir.UnrollLoop": 1})
+
 
 def test_dict():
-    x = tvm.tir.const(1) # a class that has Python-defined methods
+    x = tvm.tir.const(1)  # a class that has Python-defined methods
     # instances should see the full class dict
     assert set(dir(x.__class__)) <= set(dir(x))
 
index 5ecc21e..a56a65b 100644 (file)
@@ -37,21 +37,18 @@ def test_adt_constructor():
 
 def test_tuple_object():
     x = relay.var(
-        'x',
-        type_annotation=relay.ty.TupleType([
-            relay.ty.TensorType((), 'int32'),
-            relay.ty.TensorType((), 'int32')
-        ]))
+        "x",
+        type_annotation=relay.ty.TupleType(
+            [relay.ty.TensorType((), "int32"), relay.ty.TensorType((), "int32")]
+        ),
+    )
 
     fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
     mod = tvm.IRModule.from_expr(fn)
 
-    exe = relay.create_executor(
-        kind="vm", mod=mod, ctx=nd.cpu(), target="llvm")
+    exe = relay.create_executor(kind="vm", mod=mod, ctx=nd.cpu(), target="llvm")
     f = exe.evaluate()
-    value_tuple = _container.tuple_object(
-        [nd.array(np.array(11)),
-         nd.array(np.array(12))])
+    value_tuple = _container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))])
     # pass an ADT object to evaluate
     out = f(value_tuple)
     tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
index 70166b3..3d7a218 100644 (file)
@@ -19,9 +19,9 @@ import tvm
 from tvm import te
 import tvm.testing
 
+
 def test_op_translation():
-    ferror = tvm.testing.test_raise_error_callback(
-        "OpNotImplemented: myop")
+    ferror = tvm.testing.test_raise_error_callback("OpNotImplemented: myop")
     try:
         ferror()
         assert False
@@ -30,8 +30,7 @@ def test_op_translation():
         assert isinstance(e, NotImplementedError)
         assert msg.find("ffi_testing.cc") != -1
 
-    fchk_eq = tvm.testing.test_check_eq_callback(
-        "InternalError: myop")
+    fchk_eq = tvm.testing.test_check_eq_callback("InternalError: myop")
     try:
         fchk_eq(0, 1)
         assert False
@@ -50,12 +49,17 @@ def test_op_translation():
 def test_deep_callback():
     def error_callback():
         raise ValueError("callback error")
+
     wrap1 = tvm.testing.test_wrap_callback(error_callback)
+
     def flevel2():
         wrap1()
+
     wrap2 = tvm.testing.test_wrap_callback(flevel2)
+
     def flevel3():
         wrap2()
+
     wrap3 = tvm.testing.test_wrap_callback(flevel3)
 
     try:
index 2207eb3..14f128c 100644 (file)
@@ -22,6 +22,7 @@ import numpy as np
 @tvm.register_extension
 class MyTensorView(object):
     _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE
+
     def __init__(self, arr):
         self.arr = arr
 
@@ -29,24 +30,25 @@ class MyTensorView(object):
     def _tvm_handle(self):
         return self.arr._tvm_handle
 
+
 def test_dltensor_compatible():
-    dtype = 'int64'
-    n = te.var('n')
+    dtype = "int64"
+    n = te.var("n")
     Ab = tvm.tir.decl_buffer((n,), dtype)
-    i = te.var('i')
+    i = te.var("i")
     ib = tvm.tir.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     with ib.for_range(0, n - 1, "i") as i:
         A[i + 1] = A[i] + 1
     stmt = ib.get()
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
     f = tvm.build(mod, target="stackvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     aview = MyTensorView(a)
     f(aview)
     np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
 
+
 if __name__ == "__main__":
     test_dltensor_compatible()
index d718f20..4e6f1e9 100644 (file)
@@ -21,35 +21,38 @@ import json
 from tvm import rpc
 from tvm.contrib import util, graph_runtime
 
+
 @tvm.testing.requires_llvm
 def test_graph_simple():
     n = 4
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     node0 = {"op": "null", "name": "x", "inputs": []}
-    node1 = {"op": "tvm_op", "name": "add",
-             "inputs": [[0, 0, 0]],
-             "attrs": {"func_name": "myadd",
-                       "flatten_data": "1",
-                       "num_inputs" : "1",
-                    "num_outputs" : "1"}}
+    node1 = {
+        "op": "tvm_op",
+        "name": "add",
+        "inputs": [[0, 0, 0]],
+        "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"},
+    }
     nodes = [node0, node1]
     arg_nodes = [0]
     node_row_ptr = [0, 1, 2]
     outputs = [[1, 0, 0]]
     shape = (4,)
     attrs = {
-        "shape" : ["list_shape", [shape, shape]],
-        "dltype" : ["list_str", ["float32", "float32"]],
-        "storage_id" : ["list_int", [0, 1]],
+        "shape": ["list_shape", [shape, shape]],
+        "dltype": ["list_str", ["float32", "float32"]],
+        "storage_id": ["list_int", [0, 1]],
+    }
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": outputs,
+        "attrs": attrs,
     }
-    graph = {"nodes": nodes,
-             "arg_nodes": arg_nodes,
-             "node_row_ptr": node_row_ptr,
-             "heads": outputs,
-             "attrs": attrs}
     graph = json.dumps(graph)
 
     def check_verify():
@@ -79,20 +82,20 @@ def test_graph_simple():
 
     def check_sharing():
         from tvm import relay
-        x = relay.var('x', shape=(1, 10))
-        y = relay.var('y', shape=(1, 10))
+
+        x = relay.var("x", shape=(1, 10))
+        y = relay.var("y", shape=(1, 10))
         z = relay.add(x, y)
         func = relay.Function([x, y], z)
 
         x_in = np.ones((1, 10)).astype("float32")
-        params = {'x': x_in}
+        params = {"x": x_in}
         graph, lib, params = relay.build(func, target="llvm", params=params)
 
         mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
         mod_shared.load_params(relay.save_param_dict(params))
         num_mods = 10
-        mods = [graph_runtime.create(graph, lib, tvm.cpu(0))
-                for _ in range(num_mods)]
+        mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) for _ in range(num_mods)]
 
         for mod in mods:
             mod.share_params(mod_shared, relay.save_param_dict(params))
@@ -115,5 +118,6 @@ def test_graph_simple():
     check_remote()
     check_sharing()
 
+
 if __name__ == "__main__":
     test_graph_simple()
index f284ba6..db9b0ce 100644 (file)
@@ -23,35 +23,38 @@ from tvm import rpc
 from tvm.contrib import util
 from tvm.contrib.debugger import debug_runtime as graph_runtime
 
+
 @tvm.testing.requires_llvm
 def test_graph_simple():
     n = 4
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     node0 = {"op": "null", "name": "x", "inputs": []}
-    node1 = {"op": "tvm_op", "name": "add",
-             "inputs": [[0, 0, 0]],
-             "attrs": {"func_name": "myadd",
-                       "flatten_data": "1",
-                       "num_inputs" : "1",
-                    "num_outputs" : "1"}}
+    node1 = {
+        "op": "tvm_op",
+        "name": "add",
+        "inputs": [[0, 0, 0]],
+        "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"},
+    }
     nodes = [node0, node1]
     arg_nodes = [0]
     node_row_ptr = [0, 1, 2]
     outputs = [[1, 0, 0]]
     shape = (4,)
     attrs = {
-        "shape" : ["list_shape", [shape, shape]],
-        "dltype" : ["list_str", ["float32", "float32"]],
-        "storage_id" : ["list_int", [0, 1]],
+        "shape": ["list_shape", [shape, shape]],
+        "dltype": ["list_str", ["float32", "float32"]],
+        "storage_id": ["list_int", [0, 1]],
+    }
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": outputs,
+        "attrs": attrs,
     }
-    graph = {"nodes": nodes,
-             "arg_nodes": arg_nodes,
-             "node_row_ptr": node_row_ptr,
-             "heads": outputs,
-             "attrs": attrs}
     graph = json.dumps(graph)
 
     def check_verify():
@@ -64,17 +67,17 @@ def test_graph_simple():
         a = np.random.uniform(size=(n,)).astype(A.dtype)
         mod.set_input(x=a)
 
-        #verify dumproot created
+        # verify dumproot created
         directory = mod._dump_path
-        assert(os.path.exists(directory))
+        assert os.path.exists(directory)
 
-        #verify graph is there
-        GRAPH_DUMP_FILE_NAME = '_tvmdbg_graph_dump.json'
-        assert(len(os.listdir(directory)) == 1)
+        # verify graph is there
+        GRAPH_DUMP_FILE_NAME = "_tvmdbg_graph_dump.json"
+        assert len(os.listdir(directory)) == 1
 
-        #verify the file name is proper
+        # verify the file name is proper
         graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME)
-        assert(os.path.exists(graph_dump_path))
+        assert os.path.exists(graph_dump_path)
 
         # verify the graph contains some expected keys
         with open(graph_dump_path) as graph_f:
@@ -85,32 +88,32 @@ def test_graph_simple():
             assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}"
 
         mod.run()
-        #Verify the tensors are dumped
-        assert(len(os.listdir(directory)) > 1)
+        # Verify the tensors are dumped
+        assert len(os.listdir(directory)) > 1
 
-        CHROME_TRACE_FILE_NAME = '_tvmdbg_execution_trace.json'
-        assert(os.path.exists(os.path.join(directory, CHROME_TRACE_FILE_NAME)))
+        CHROME_TRACE_FILE_NAME = "_tvmdbg_execution_trace.json"
+        assert os.path.exists(os.path.join(directory, CHROME_TRACE_FILE_NAME))
 
         with open(os.path.join(directory, CHROME_TRACE_FILE_NAME)) as f:
             trace = json.load(f)
         assert trace["displayTimeUnit"] == "ns"
         events = trace["traceEvents"]
         assert len(events) == 4
-        assert all(event["ph"] in ('B', 'E') for event in events)
+        assert all(event["ph"] in ("B", "E") for event in events)
         assert all(event["pid"] == 1 for event in events)
         assert all(event["tid"] == 1 for event in events)
-        assert all(event["name"] == 'x' for event in events[:2])
-        assert all(event["name"] == 'add' for event in events[2:])
+        assert all(event["name"] == "x" for event in events[:2])
+        assert all(event["name"] == "add" for event in events[2:])
         assert events[0]["ts"] == 0
-        assert events[0]["ph"] == 'B'
+        assert events[0]["ph"] == "B"
 
-        #verify the output is correct
+        # verify the output is correct
         out = mod.get_output(0, tvm.nd.empty((n,)))
         np.testing.assert_equal(out.asnumpy(), a + 1)
 
         mod.exit()
-        #verify dump root delete after cleanup
-        assert(not os.path.exists(directory))
+        # verify dump root delete after cleanup
+        assert not os.path.exists(directory)
 
     def check_remote():
         mlib = tvm.build(s, [A, B], "llvm", name="myadd")
@@ -136,5 +139,6 @@ def test_graph_simple():
     check_verify()
     check_remote()
 
+
 if __name__ == "__main__":
     test_graph_simple()
index 4bf8651..80b330c 100644 (file)
@@ -53,14 +53,15 @@ def get_simplex_graph(host_dev_type, device_dev_type):
     var_a = {"op": "null", "name": "A", "inputs": []}
     var_b = {"op": "null", "name": "B", "inputs": []}
     elemwise_add = {
-        "op": "tvm_op", "name": "elemwise_add",
+        "op": "tvm_op",
+        "name": "elemwise_add",
         "attrs": {
             "flatten_data": "1",
             "func_name": "elemwise_add",
             "num_inputs": "2",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[0, 0, 0], [1, 0, 0]]
+        "inputs": [[0, 0, 0], [1, 0, 0]],
     }
     copy = {
         "op": "tvm_op",
@@ -69,20 +70,21 @@ def get_simplex_graph(host_dev_type, device_dev_type):
             "flatten_data": "0",
             "func_name": "__copy",
             "num_inputs": "1",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[2, 0, 0]]
+        "inputs": [[2, 0, 0]],
     }
     var_c = {"op": "null", "name": "C", "inputs": []}
     elemwise_sub = {
-        "op": "tvm_op", "name": "elemwise_sub",
+        "op": "tvm_op",
+        "name": "elemwise_sub",
         "attrs": {
             "flatten_data": "0",
             "func_name": "elemwise_sub",
             "num_inputs": "2",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[3, 0, 0], [4, 0, 0]]
+        "inputs": [[3, 0, 0], [4, 0, 0]],
     }
 
     # Group the nodes.
@@ -94,20 +96,29 @@ def get_simplex_graph(host_dev_type, device_dev_type):
     attrs = {
         "storage_id": ["list_int", [3, 4, 0, 1, 5, 2]],
         "shape": ["list_shape", [shape, shape, shape, shape, shape, shape]],
-        "device_index": ["list_int", [device_dev_type, device_dev_type,
-                                      device_dev_type, host_dev_type,
-                                      host_dev_type, host_dev_type]],
+        "device_index": [
+            "list_int",
+            [
+                device_dev_type,
+                device_dev_type,
+                device_dev_type,
+                host_dev_type,
+                host_dev_type,
+                host_dev_type,
+            ],
+        ],
         "dtype": ["list_int", [0, 0, 0, 0, 0, 0]],
-        "dltype": ["list_str", ["float32", "float32", "float32",
-                                "float32", "float32", "float32"]]
+        "dltype": ["list_str", ["float32", "float32", "float32", "float32", "float32", "float32"]],
     }
 
     # Construct the graph.
-    graph = {"nodes": nodes,
-             "arg_nodes": arg_nodes,
-             "node_row_ptr": node_row_ptr,
-             "heads": heads,
-             "attrs": attrs}
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": heads,
+        "attrs": attrs,
+    }
     return json.dumps(graph)
 
 
@@ -136,12 +147,12 @@ def test_simplex_data_transferring():
         # Create module for add whose target is the device.
         tensor_a = te.placeholder(shape, name="A")
         tensor_b = te.placeholder(shape, name="B")
-        elemwise_add = te.compute(shape, lambda *i: tensor_a(*i)
-                                  + tensor_b(*i), name="elemwise_add")
+        elemwise_add = te.compute(
+            shape, lambda *i: tensor_a(*i) + tensor_b(*i), name="elemwise_add"
+        )
         target = topi.cpp.TEST_create_target(device)
         schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add])
-        lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add],
-                              name="elemwise_add")
+        lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add")
 
         # Insert copy. Neither compute nor schedule is required for the copy
         # node. The compute will be performed at runtime which is just data
@@ -150,29 +161,26 @@ def test_simplex_data_transferring():
 
         # Create module for sub whose target is the host.
         tensor_c = te.placeholder(shape, name="C")
-        elemwise_sub = te.compute(shape, lambda *i: tensor_copy(*i)
-                                  - tensor_c(*i), name="elemwise_sub")
+        elemwise_sub = te.compute(
+            shape, lambda *i: tensor_copy(*i) - tensor_c(*i), name="elemwise_sub"
+        )
         schedule_sub = te.create_schedule(elemwise_sub.op)
-        lower_sub = tvm.lower(schedule_sub, [tensor_copy, tensor_c,
-                                             elemwise_sub],
-                              name="elemwise_sub")
+        lower_sub = tvm.lower(
+            schedule_sub, [tensor_copy, tensor_c, elemwise_sub], name="elemwise_sub"
+        )
 
         target_flist = {target_device: lower_add, target_host: lower_sub}
         mhost = tvm.build(target_flist, target_host=target_host)
         ctx = [host_ctx, device_ctx]
         mod = graph_runtime.create(graph, mhost, ctx)
         params = {}
-        params["A"] = tensor_a = np.random.uniform(
-            size=shape).astype(tensor_a.dtype)
-        params["B"] = tensor_b = np.random.uniform(
-            size=shape).astype(tensor_b.dtype)
-        params["C"] = tensor_c = np.random.uniform(
-            size=shape).astype(tensor_c.dtype)
+        params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype)
+        params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype)
+        params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype)
         mod.set_input(**params)
         mod.run()
         out = mod.get_output(0, tvm.nd.empty(shape))
-        np.testing.assert_equal(
-            out.asnumpy(), (tensor_a + tensor_b) - tensor_c)
+        np.testing.assert_equal(out.asnumpy(), (tensor_a + tensor_b) - tensor_c)
 
     dev_tar = {"cuda": "cuda", "opencl": "opencl"}
     for device, target in dev_tar.items():
@@ -212,14 +220,15 @@ def get_duplex_graph(host_dev_type, device_dev_type):
     var_a = {"op": "null", "name": "A", "inputs": []}
     var_b = {"op": "null", "name": "B", "inputs": []}
     elemwise_add0 = {
-        "op": "tvm_op", "name": "elemwise_add0",
+        "op": "tvm_op",
+        "name": "elemwise_add0",
         "attrs": {
             "flatten_data": "1",
             "func_name": "elemwise_add0",
             "num_inputs": "2",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[0, 0, 0], [1, 0, 0]]
+        "inputs": [[0, 0, 0], [1, 0, 0]],
     }
     copy_add_sub = {
         "op": "tvm_op",
@@ -228,20 +237,21 @@ def get_duplex_graph(host_dev_type, device_dev_type):
             "flatten_data": "0",
             "func_name": "__copy",
             "num_inputs": "1",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[2, 0, 0]]
+        "inputs": [[2, 0, 0]],
     }
     var_c = {"op": "null", "name": "C", "inputs": []}
     elemwise_sub = {
-        "op": "tvm_op", "name": "elemwise_sub",
+        "op": "tvm_op",
+        "name": "elemwise_sub",
         "attrs": {
             "flatten_data": "0",
             "func_name": "elemwise_sub",
             "num_inputs": "2",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[3, 0, 0], [4, 0, 0]]
+        "inputs": [[3, 0, 0], [4, 0, 0]],
     }
     copy_sub_add = {
         "op": "tvm_op",
@@ -250,50 +260,81 @@ def get_duplex_graph(host_dev_type, device_dev_type):
             "flatten_data": "0",
             "func_name": "__copy",
             "num_inputs": "1",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[5, 0, 0]]
+        "inputs": [[5, 0, 0]],
     }
     var_d = {"op": "null", "name": "D", "inputs": []}
     elemwise_add1 = {
-        "op": "tvm_op", "name": "elemwise_add1",
+        "op": "tvm_op",
+        "name": "elemwise_add1",
         "attrs": {
             "flatten_data": "0",
             "func_name": "elemwise_add1",
             "num_inputs": "2",
-            "num_outputs": "1"
+            "num_outputs": "1",
         },
-        "inputs": [[6, 0, 0], [7, 0, 0]]
+        "inputs": [[6, 0, 0], [7, 0, 0]],
     }
 
     # Group the nodes.
-    nodes = [var_a, var_b, elemwise_add0, copy_add_sub, var_c, elemwise_sub,
-             copy_sub_add, var_d, elemwise_add1]
+    nodes = [
+        var_a,
+        var_b,
+        elemwise_add0,
+        copy_add_sub,
+        var_c,
+        elemwise_sub,
+        copy_sub_add,
+        var_d,
+        elemwise_add1,
+    ]
     arg_nodes = [0, 1, 4, 7]
     node_row_ptr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
     heads = [[8, 0, 0]]
     shape = (4,)
     attrs = {
         "storage_id": ["list_int", [4, 5, 0, 1, 6, 2, 0, 7, 3]],
-        "shape": ["list_shape", [shape, shape, shape, shape, shape, shape,
-                                 shape, shape, shape]],
-        "device_index": ["list_int", [device_dev_type, device_dev_type,
-                                      device_dev_type,
-                                      host_dev_type, host_dev_type, host_dev_type,
-                                      device_dev_type, device_dev_type,
-                                      device_dev_type]],
+        "shape": ["list_shape", [shape, shape, shape, shape, shape, shape, shape, shape, shape]],
+        "device_index": [
+            "list_int",
+            [
+                device_dev_type,
+                device_dev_type,
+                device_dev_type,
+                host_dev_type,
+                host_dev_type,
+                host_dev_type,
+                device_dev_type,
+                device_dev_type,
+                device_dev_type,
+            ],
+        ],
         "dtype": ["list_int", [0, 0, 0, 0, 0, 0, 0, 0, 0]],
-        "dltype": ["list_str", ["float32", "float32", "float32",
-                                "float32", "float32", "float32",
-                                "float32", "float32", "float32"]]
+        "dltype": [
+            "list_str",
+            [
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+                "float32",
+            ],
+        ],
     }
 
     # Construct the graph.
-    graph = {"nodes": nodes,
-             "arg_nodes": arg_nodes,
-             "node_row_ptr": node_row_ptr,
-             "heads": heads,
-             "attrs": attrs}
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": heads,
+        "attrs": attrs,
+    }
     return json.dumps(graph)
 
 
@@ -331,52 +372,47 @@ def test_duplex_data_transferring():
         tensor_a = te.placeholder(shape, name="A")
         tensor_b = te.placeholder(shape, name="B")
         tensor_d = te.placeholder(shape, name="D")
-        elemwise_add0 = te.compute(shape, lambda *i: tensor_a(*i)
-                                   + tensor_b(*i), name="elemwise_add0")
-        elemwise_add1 = te.compute(shape, lambda *i: copy_sub_add(*i)
-                                   + tensor_d(*i), name="elemwise_add1")
+        elemwise_add0 = te.compute(
+            shape, lambda *i: tensor_a(*i) + tensor_b(*i), name="elemwise_add0"
+        )
+        elemwise_add1 = te.compute(
+            shape, lambda *i: copy_sub_add(*i) + tensor_d(*i), name="elemwise_add1"
+        )
         target = topi.cpp.TEST_create_target(device)
-        add_schedule0 = topi.cpp.cuda.schedule_injective(
-            target, [elemwise_add0])
+        add_schedule0 = topi.cpp.cuda.schedule_injective(target, [elemwise_add0])
         lower_add0 = tvm.lower(
-            add_schedule0, [tensor_a, tensor_b, elemwise_add0],
-            name="elemwise_add0")
-        add_schedule1 = topi.cpp.cuda.schedule_injective(
-            target, [elemwise_add1])
+            add_schedule0, [tensor_a, tensor_b, elemwise_add0], name="elemwise_add0"
+        )
+        add_schedule1 = topi.cpp.cuda.schedule_injective(target, [elemwise_add1])
         lower_add1 = tvm.lower(
-            add_schedule1, [tensor_d, copy_sub_add, elemwise_add1],
-            name="elemwise_add1")
+            add_schedule1, [tensor_d, copy_sub_add, elemwise_add1], name="elemwise_add1"
+        )
         # Create module for sub whose target is the host.
         tensor_c = te.placeholder(shape, name="C")
-        elemwise_sub = te.compute(shape, lambda *i: copy_add_sub(*i)
-                                  - tensor_c(*i), name="elemwise_sub")
+        elemwise_sub = te.compute(
+            shape, lambda *i: copy_add_sub(*i) - tensor_c(*i), name="elemwise_sub"
+        )
         sub_schedule = te.create_schedule(elemwise_sub.op)
-        lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c,
-                                             elemwise_sub],
-                              name="elemwise_sub")
+        lower_sub = tvm.lower(
+            sub_schedule, [copy_add_sub, tensor_c, elemwise_sub], name="elemwise_sub"
+        )
 
         lower_add0.update(lower_add1)
-        target_flist = {target_device: lower_add0, target_host:
-                        lower_sub}
+        target_flist = {target_device: lower_add0, target_host: lower_sub}
         mhost = tvm.build(target_flist, target_host=target_host)
         ctx = [host_ctx, device_ctx]
         params = {}
-        params["A"] = tensor_a = np.random.uniform(
-            size=shape).astype(tensor_a.dtype)
-        params["B"] = tensor_b = np.random.uniform(
-            size=shape).astype(tensor_b.dtype)
-        params["C"] = tensor_c = np.random.uniform(
-            size=shape).astype(tensor_c.dtype)
-        params["D"] = tensor_d = np.random.uniform(
-            size=shape).astype(tensor_d.dtype)
+        params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype)
+        params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype)
+        params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype)
+        params["D"] = tensor_d = np.random.uniform(size=shape).astype(tensor_d.dtype)
 
         def check_verify():
             mod = graph_runtime.create(graph, mhost, ctx)
             mod.set_input(**params)
             mod.run()
             out = mod.get_output(0, tvm.nd.empty(shape))
-            np.testing.assert_equal(
-                out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
+            np.testing.assert_equal(out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
 
         def check_load_module():
             temp = util.tempdir()
@@ -390,8 +426,7 @@ def test_duplex_data_transferring():
             mod.set_input(**params)
             mod.run()
             out = mod.get_output(0, tvm.nd.empty(shape))
-            np.testing.assert_equal(
-                out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
+            np.testing.assert_equal(out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
 
         check_verify()
         check_load_module()
index 25361a1..77f32a0 100644 (file)
@@ -33,13 +33,12 @@ def test_min_repeat_ms():
         with open(filename, "a") as fout:
             fout.write("c")
 
-    X = te.compute((), lambda : tvm.tir.call_packed("my_debug", filename))
+    X = te.compute((), lambda: tvm.tir.call_packed("my_debug", filename))
     s = te.create_schedule(X.op)
     func = tvm.build(s, [X])
 
     x = tvm.nd.empty((), dtype="int32")
-    ftimer = func.time_evaluator(func.entry_name, tvm.cpu(),
-                                 number=1, repeat=1)
+    ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1)
     ftimer(x)
 
     with open(filename, "r") as fin:
@@ -47,9 +46,7 @@ def test_min_repeat_ms():
 
     assert ct == 2
 
-
-    ftimer = func.time_evaluator(func.entry_name, tvm.cpu(),
-                                 number=1, repeat=1, min_repeat_ms=1000)
+    ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1, min_repeat_ms=1000)
     ftimer(x)
 
     # make sure we get more than 10 calls
@@ -61,4 +58,3 @@ def test_min_repeat_ms():
 
 if __name__ == "__main__":
     test_min_repeat_ms()
-
index 841bffb..45ec9bc 100644 (file)
@@ -27,7 +27,8 @@ from tvm.micro import create_micro_mod
 # # Use the host emulated micro device.
 DEV_CONFIG_A = micro.device.host.generate_config()
 DEV_CONFIG_B = micro.device.host.generate_config()
-TARGET = 'c --runtime=c'
+TARGET = "c --runtime=c"
+
 
 def relay_micro_build(func, dev_config, params=None):
     """Create a graph runtime module with a micro device context from a Relay function.
@@ -48,9 +49,9 @@ def relay_micro_build(func, dev_config, params=None):
     mod : tvm.runtime.Module
         graph runtime module for the target device
     """
-    with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={
-        "tir.disable_vectorize": True
-    }):
+    with tvm.transform.PassContext(
+        disabled_pass={"FuseOps"}, config={"tir.disable_vectorize": True}
+    ):
         graph, c_mod, params = relay.build(func, target=TARGET, params=params)
     micro_mod = micro.create_micro_mod(c_mod, dev_config)
     ctx = tvm.micro_dev(0)
@@ -68,11 +69,11 @@ break UTVMDone
 
 
 def reset_gdbinit():
-    if 'server_port' not in DEV_CONFIG_A:
+    if "server_port" not in DEV_CONFIG_A:
         return
-    gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR']
-    with open(f'{gdb_init_dir}/.gdbinit', 'w') as f:
-        gdb_port = DEV_CONFIG_A['server_port'] - 3333
+    gdb_init_dir = os.environ["MICRO_GDB_INIT_DIR"]
+    with open(f"{gdb_init_dir}/.gdbinit", "w") as f:
+        gdb_port = DEV_CONFIG_A["server_port"] - 3333
         f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port))
 
 
@@ -121,13 +122,10 @@ def test_add():
         micro_func(a, b, c)
 
         # ensure inputs weren't corrupted
-        tvm.testing.assert_allclose(
-                a.asnumpy(), a_np)
-        tvm.testing.assert_allclose(
-                b.asnumpy(), b_np)
+        tvm.testing.assert_allclose(a.asnumpy(), a_np)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np)
         # ensure output is correct
-        tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
 
 def test_workspace_add():
@@ -160,11 +158,9 @@ def test_workspace_add():
         micro_func(a, c)
 
         # ensure input wasn't corrupted
-        tvm.testing.assert_allclose(
-                a.asnumpy(), a_np)
+        tvm.testing.assert_allclose(a.asnumpy(), a_np)
         # ensure output is correct
-        tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + 2.0)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 2.0)
 
 
 def test_graph_runtime():
@@ -187,10 +183,8 @@ def test_graph_runtime():
         mod.run(x=x_in)
         result = mod.get_output(0).asnumpy()
 
-        tvm.testing.assert_allclose(
-                mod.get_input(0).asnumpy(), x_in)
-        tvm.testing.assert_allclose(
-                result, x_in * x_in + 1.0)
+        tvm.testing.assert_allclose(mod.get_input(0).asnumpy(), x_in)
+        tvm.testing.assert_allclose(result, x_in * x_in + 1.0)
 
 
 def test_conv2d():
@@ -201,29 +195,23 @@ def test_conv2d():
     from tvm.relay import transform
 
     dshape = (1, 4, 16, 16)
-    dtype = 'int8'
-    func_name = 'fused_nn_conv2d'
+    dtype = "int8"
+    func_name = "fused_nn_conv2d"
 
     reset_gdbinit()
 
     # Construct Relay program.
     x = relay.var("x", shape=dshape, dtype=dtype)
-    conv_expr = relay.nn.conv2d(
-            x, relay.var("w"),
-            kernel_size=(3, 3),
-            padding=(1, 1),
-            channels=4)
+    conv_expr = relay.nn.conv2d(x, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=4)
     func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr)
     mod = tvm.IRModule.from_expr(func)
     mod = transform.InferType()(mod)
 
-    x_shape = list(map(lambda x: x.value, mod['main'].params[0].checked_type.shape))
-    w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape))
-    out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape))
+    x_shape = list(map(lambda x: x.value, mod["main"].params[0].checked_type.shape))
+    w_shape = list(map(lambda x: x.value, mod["main"].params[1].checked_type.shape))
+    out_shape = list(map(lambda x: x.value, mod["main"].ret_type.shape))
 
-    with tvm.transform.PassContext(config={
-        "tir.disable_vectorize": True
-    }):
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
         graph, c_mod, params = relay.build(mod, target="c")
 
     with micro.Session(DEV_CONFIG_A):
@@ -234,7 +222,7 @@ def test_conv2d():
                 micro_func = micro_mod[candidate_func_name]
                 break
             except tvm.TVMError as e:
-                candidate_func_name = f'{func_name}_{i}'
+                candidate_func_name = f"{func_name}_{i}"
         else:
             assert False
         ctx = tvm.micro_dev(0)
@@ -245,9 +233,9 @@ def test_conv2d():
         micro_func(x_data, w_data, result)
 
         out_data = np.zeros(out_shape, dtype=dtype)
-        params = { 'x': x_data.asnumpy(), 'w': w_data.asnumpy() }
-        intrp = create_executor('debug')
-        expected_result = intrp.evaluate(mod['main'])(x_data, w_data)
+        params = {"x": x_data.asnumpy(), "w": w_data.asnumpy()}
+        intrp = create_executor("debug")
+        expected_result = intrp.evaluate(mod["main"])(x_data, w_data)
 
         tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy())
 
@@ -276,14 +264,12 @@ def test_interleave_sessions():
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
     with sess_b:
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B)
         add_const_mod.run(x=micro_tensor_b)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_b + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_b + 1.0)
 
 
 def test_nested_sessions():
@@ -309,8 +295,7 @@ def test_nested_sessions():
         add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
 
 
 def test_inactive_session_use():
@@ -336,8 +321,7 @@ def test_inactive_session_use():
         # These objects belong to `sess_a`.
         add_const_mod.run(x=micro_tensor_a)
         add_result = add_const_mod.get_output(0).asnumpy()
-        tvm.testing.assert_allclose(
-                add_result, np_tensor_a + 1.0)
+        tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
 
 
 # TODO add workspace alloc/free stress test
@@ -345,33 +329,33 @@ def test_inactive_session_use():
 if __name__ == "__main__":
     test_alloc()
     print()
-    print('finished alloc test')
-    input('[press enter to continue]')
+    print("finished alloc test")
+    input("[press enter to continue]")
     test_add()
     print()
-    print('finished add test')
-    input('[press enter to continue]')
+    print("finished add test")
+    input("[press enter to continue]")
     test_workspace_add()
     print()
-    print('finished workspace add test')
-    input('[press enter to continue]')
+    print("finished workspace add test")
+    input("[press enter to continue]")
     test_graph_runtime()
     print()
-    print('finished graph runtime test')
-    input('[press enter to continue]')
+    print("finished graph runtime test")
+    input("[press enter to continue]")
     test_conv2d()
     print()
-    print('finished conv2d test')
-    input('[press enter to continue]')
+    print("finished conv2d test")
+    input("[press enter to continue]")
     test_interleave_sessions()
     print()
-    print('finished interleaved sessions test')
-    input('[press enter to continue]')
+    print("finished interleaved sessions test")
+    input("[press enter to continue]")
     test_nested_sessions()
     print()
-    print('finished nested sessions test')
-    input('[press enter to continue]')
+    print("finished nested sessions test")
+    input("[press enter to continue]")
     test_inactive_session_use()
     print()
-    print('finished use inactive session test')
-    input('[press enter to continue]')
+    print("finished use inactive session test")
+    input("[press enter to continue]")
index 512fefd..1d682d2 100644 (file)
@@ -22,9 +22,11 @@ from tvm.contrib import graph_runtime
 from tvm.contrib.debugger import debug_runtime
 import tvm.testing
 
+
 def input_shape(mod):
     return [int(x) for x in mod["main"].checked_type.arg_types[0].shape]
 
+
 def verify(data):
     if not tvm.runtime.enabled("llvm"):
         print("Skip because llvm is not enabled")
@@ -42,6 +44,7 @@ def verify(data):
 
     return out
 
+
 def test_legacy_compatibility():
     if not tvm.testing.device_enabled("llvm"):
         print("Skip because llvm is not enabled")
@@ -58,6 +61,7 @@ def test_legacy_compatibility():
     out = module.get_output(0).asnumpy()
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
+
 def test_cpu():
     if not tvm.testing.device_enabled("llvm"):
         print("Skip because llvm is not enabled")
@@ -68,7 +72,7 @@ def test_cpu():
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
     # raw api
     ctx = tvm.cpu()
-    gmod = complied_graph_lib['default'](ctx)
+    gmod = complied_graph_lib["default"](ctx)
     set_input = gmod["set_input"]
     run = gmod["run"]
     get_output = gmod["get_output"]
@@ -78,12 +82,13 @@ def test_cpu():
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
     # graph runtime wrapper
-    gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx))
+    gmod = graph_runtime.GraphModule(complied_graph_lib["default"](ctx))
     gmod.set_input("data", data)
     gmod.run()
     out = gmod.get_output(0).asnumpy()
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
+
 @tvm.testing.requires_cuda
 @tvm.testing.requires_gpu
 def test_gpu():
@@ -94,7 +99,7 @@ def test_gpu():
     ctx = tvm.gpu()
 
     # raw api
-    gmod = complied_graph_lib['default'](ctx)
+    gmod = complied_graph_lib["default"](ctx)
     set_input = gmod["set_input"]
     run = gmod["run"]
     get_output = gmod["get_output"]
@@ -104,12 +109,13 @@ def test_gpu():
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
     # graph runtime wrapper
-    gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx))
+    gmod = graph_runtime.GraphModule(complied_graph_lib["default"](ctx))
     gmod.set_input("data", data)
     gmod.run()
     out = gmod.get_output(0).asnumpy()
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
+
 @tvm.testing.uses_gpu
 def test_mod_export():
     def verify_cpu_export(obj_format):
@@ -121,6 +127,7 @@ def test_mod_export():
             complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -131,7 +138,7 @@ def test_mod_export():
         complied_graph_lib.export_library(path_lib)
         loaded_lib = tvm.runtime.load_module(path_lib)
         ctx = tvm.cpu(0)
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
 
         # raw api
         set_input = gmod["set_input"]
@@ -144,7 +151,7 @@ def test_mod_export():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         gmod.set_input("data", data)
         gmod.run()
         out = gmod.get_output(0).asnumpy()
@@ -159,6 +166,7 @@ def test_mod_export():
             complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -172,7 +180,7 @@ def test_mod_export():
         ctx = tvm.gpu()
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -182,7 +190,7 @@ def test_mod_export():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         gmod.set_input("data", data)
         gmod.run()
         out = gmod.get_output(0).asnumpy()
@@ -197,6 +205,7 @@ def test_mod_export():
             complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -207,6 +216,7 @@ def test_mod_export():
         complied_graph_lib.export_library(path_lib)
 
         from tvm import rpc
+
         remote = rpc.LocalSession()
         remote.upload(path_lib)
         loaded_lib = remote.load_module(path_lib)
@@ -214,7 +224,7 @@ def test_mod_export():
         ctx = remote.cpu()
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -224,7 +234,7 @@ def test_mod_export():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         gmod.set_input("data", data)
         gmod.run()
         out = gmod.get_output(0).asnumpy()
@@ -239,6 +249,7 @@ def test_mod_export():
             complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -249,6 +260,7 @@ def test_mod_export():
         complied_graph_lib.export_library(path_lib)
 
         from tvm import rpc
+
         server = rpc.Server("localhost", use_popen=True, port=9094)
         remote = rpc.connect(server.host, server.port)
         remote.upload(path_lib)
@@ -257,7 +269,7 @@ def test_mod_export():
         ctx = remote.gpu()
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -267,7 +279,7 @@ def test_mod_export():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         gmod.set_input("data", data)
         gmod.run()
         out = gmod.get_output(0).asnumpy()
@@ -279,6 +291,7 @@ def test_mod_export():
         verify_rpc_cpu_export(obj_format)
         verify_rpc_gpu_export(obj_format)
 
+
 @tvm.testing.uses_gpu
 def test_remove_package_params():
     def verify_cpu_remove_package_params(obj_format):
@@ -290,6 +303,7 @@ def test_remove_package_params():
             complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -306,7 +320,7 @@ def test_remove_package_params():
         ctx = tvm.cpu(0)
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -319,7 +333,7 @@ def test_remove_package_params():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
         gmod.set_input("data", data)
         gmod.load_params(loaded_params)
@@ -336,6 +350,7 @@ def test_remove_package_params():
             complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -352,7 +367,7 @@ def test_remove_package_params():
         ctx = tvm.gpu(0)
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -365,7 +380,7 @@ def test_remove_package_params():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
         gmod.set_input("data", data)
         gmod.load_params(loaded_params)
@@ -382,6 +397,7 @@ def test_remove_package_params():
             complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -396,6 +412,7 @@ def test_remove_package_params():
             fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
 
         from tvm import rpc
+
         remote = rpc.LocalSession()
         remote.upload(path_lib)
         loaded_lib = remote.load_module(path_lib)
@@ -403,7 +420,7 @@ def test_remove_package_params():
         ctx = remote.cpu()
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -416,7 +433,7 @@ def test_remove_package_params():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         loaded_params = bytearray(open(path_params, "rb").read())
         gmod.set_input("data", data)
         gmod.load_params(loaded_params)
@@ -433,6 +450,7 @@ def test_remove_package_params():
             complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -447,6 +465,7 @@ def test_remove_package_params():
             fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
 
         from tvm import rpc
+
         remote = rpc.LocalSession()
         remote.upload(path_lib)
         loaded_lib = remote.load_module(path_lib)
@@ -454,7 +473,7 @@ def test_remove_package_params():
         ctx = remote.gpu()
 
         # raw api
-        gmod = loaded_lib['default'](ctx)
+        gmod = loaded_lib["default"](ctx)
         set_input = gmod["set_input"]
         run = gmod["run"]
         get_output = gmod["get_output"]
@@ -467,7 +486,7 @@ def test_remove_package_params():
         tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
         # graph runtime wrapper
-        gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+        gmod = graph_runtime.GraphModule(loaded_lib["default"](ctx))
         loaded_params = bytearray(open(path_params, "rb").read())
         gmod.set_input("data", data)
         gmod.load_params(loaded_params)
@@ -481,6 +500,7 @@ def test_remove_package_params():
         verify_rpc_cpu_remove_package_params(obj_format)
         verify_rpc_gpu_remove_package_params(obj_format)
 
+
 def test_debug_graph_runtime():
     if not tvm.testing.device_enabled("llvm"):
         print("Skip because llvm is not enabled")
@@ -493,7 +513,7 @@ def test_debug_graph_runtime():
     # raw api
     ctx = tvm.cpu()
     try:
-        gmod = complied_graph_lib['debug_create']('default', ctx)
+        gmod = complied_graph_lib["debug_create"]("default", ctx)
     except:
         print("Skip because debug graph_runtime not enabled")
         return
@@ -506,13 +526,18 @@ def test_debug_graph_runtime():
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
     # debug graph runtime wrapper
-    debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx],
-                                                 complied_graph_lib.get_json(), None)
+    debug_g_mod = debug_runtime.GraphModuleDebug(
+        complied_graph_lib["debug_create"]("default", ctx),
+        [ctx],
+        complied_graph_lib.get_json(),
+        None,
+    )
     debug_g_mod.set_input("data", data)
     debug_g_mod.run()
     out = debug_g_mod.get_output(0).asnumpy()
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
+
 if __name__ == "__main__":
     test_legacy_compatibility()
     test_cpu()
index bc5e7fb..fcdd906 100644 (file)
@@ -22,11 +22,12 @@ from tvm import te
 import tvm.testing
 
 from tvm.contrib import util
+
 header_file_dir_path = util.tempdir()
 
 
 def gen_engine_header():
-    code = r'''
+    code = r"""
         #ifndef _ENGINE_H_
         #define _ENGINE_H_
         #include <cstdint>
@@ -37,14 +38,14 @@ def gen_engine_header():
         };
 
         #endif
-        '''
+        """
     header_file = header_file_dir_path.relpath("gcc_engine.h")
-    with open(header_file, 'w') as f:
+    with open(header_file, "w") as f:
         f.write(code)
 
 
 def generate_engine_module():
-    code = r'''
+    code = r"""
         #include <tvm/runtime/c_runtime_api.h>
         #include <dlpack/dlpack.h>
         #include "gcc_engine.h"
@@ -53,11 +54,11 @@ def generate_engine_module():
                 float* gcc_input6, float* gcc_input7, float* out) {
             Engine engine;
         }
-        '''
+        """
     import tvm.runtime._ffi_api
+
     gen_engine_header()
-    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "",
-                                                              None)
+    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", None)
     return csource_module
 
 
@@ -72,10 +73,15 @@ def test_mod_export():
         synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload()
         synthetic_llvm_mod, synthetic_llvm_params = relay.testing.synthetic.get_workload()
         with tvm.transform.PassContext(opt_level=3):
-            _, synthetic_gpu_lib, _ = relay.build_module.build(synthetic_mod, "cuda", params=synthetic_params)
-            _, synthetic_llvm_cpu_lib, _ = relay.build_module.build(synthetic_llvm_mod, "llvm", params=synthetic_llvm_params)
+            _, synthetic_gpu_lib, _ = relay.build_module.build(
+                synthetic_mod, "cuda", params=synthetic_params
+            )
+            _, synthetic_llvm_cpu_lib, _ = relay.build_module.build(
+                synthetic_llvm_mod, "llvm", params=synthetic_llvm_params
+            )
 
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -98,13 +104,16 @@ def test_mod_export():
 
         synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload()
         with tvm.transform.PassContext(opt_level=3):
-            _, synthetic_cpu_lib, _ = relay.build_module.build(synthetic_mod, "llvm", params=synthetic_params)
+            _, synthetic_cpu_lib, _ = relay.build_module.build(
+                synthetic_mod, "llvm", params=synthetic_params
+            )
 
-        A = te.placeholder((1024,), name='A')
-        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        A = te.placeholder((1024,), name="A")
+        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm", name="myadd")
         from tvm.contrib import util
+
         temp = util.tempdir()
         if obj_format == ".so":
             file_name = "deploy_lib.so"
@@ -125,32 +134,35 @@ def test_mod_export():
                 return
 
         # Get subgraph Json.
-        subgraph_json = ("json_rt_0\n" +
-                         "input 0 10 10\n" +
-                         "input 1 10 10\n" +
-                         "input 2 10 10\n" +
-                         "input 3 10 10\n" +
-                         "add 4 inputs: 0 1 shape: 10 10\n" +
-                         "sub 5 inputs: 4 2 shape: 10 10\n" +
-                         "mul 6 inputs: 5 3 shape: 10 10\n" +
-                         "json_rt_1\n" +
-                         "input 0 10 10\n" +
-                         "input 1 10 10\n" +
-                         "input 2 10 10\n" +
-                         "input 3 10 10\n" +
-                         "add 4 inputs: 0 1 shape: 10 10\n" +
-                         "sub 5 inputs: 4 2 shape: 10 10\n" +
-                         "mul 6 inputs: 5 3 shape: 10 10")
+        subgraph_json = (
+            "json_rt_0\n"
+            + "input 0 10 10\n"
+            + "input 1 10 10\n"
+            + "input 2 10 10\n"
+            + "input 3 10 10\n"
+            + "add 4 inputs: 0 1 shape: 10 10\n"
+            + "sub 5 inputs: 4 2 shape: 10 10\n"
+            + "mul 6 inputs: 5 3 shape: 10 10\n"
+            + "json_rt_1\n"
+            + "input 0 10 10\n"
+            + "input 1 10 10\n"
+            + "input 2 10 10\n"
+            + "input 3 10 10\n"
+            + "add 4 inputs: 0 1 shape: 10 10\n"
+            + "sub 5 inputs: 4 2 shape: 10 10\n"
+            + "mul 6 inputs: 5 3 shape: 10 10"
+        )
 
         from tvm.contrib import util
+
         temp = util.tempdir()
-        subgraph_path = temp.relpath('subgraph.examplejson')
-        with open(subgraph_path, 'w') as f:
+        subgraph_path = temp.relpath("subgraph.examplejson")
+        with open(subgraph_path, "w") as f:
             f.write(subgraph_json)
 
         # Get Json and module.
-        A = te.placeholder((1024,), name='A')
-        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        A = te.placeholder((1024,), name="A")
+        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm", name="myadd")
         try:
@@ -172,6 +184,7 @@ def test_mod_export():
 
     def verify_multi_c_mod_export():
         from shutil import which
+
         if which("gcc") is None:
             print("Skip test because gcc is not available.")
 
@@ -182,14 +195,17 @@ def test_mod_export():
 
         synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload()
         with tvm.transform.PassContext(opt_level=3):
-            _, synthetic_cpu_lib, _ = relay.build_module.build(synthetic_mod, "llvm", params=synthetic_params)
+            _, synthetic_cpu_lib, _ = relay.build_module.build(
+                synthetic_mod, "llvm", params=synthetic_params
+            )
 
-        A = te.placeholder((1024,), name='A')
-        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        A = te.placeholder((1024,), name="A")
+        B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "c", name="myadd")
         engine_module = generate_engine_module()
         from tvm.contrib import util
+
         temp = util.tempdir()
         file_name = "deploy_lib.so"
         path_lib = temp.relpath(file_name)
index 6e7df06..81aa2ba 100644 (file)
@@ -42,25 +42,23 @@ np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
 print("Finish runtime checking...")
 """
 
+
 def test_dso_module_load():
     if not tvm.testing.device_enabled("llvm"):
         return
-    dtype = 'int64'
+    dtype = "int64"
     temp = util.tempdir()
 
     def save_object(names):
-        n = te.size_var('n')
-        Ab = tvm.tir.decl_buffer((n, ), dtype)
-        i = te.var('i')
+        n = te.size_var("n")
+        Ab = tvm.tir.decl_buffer((n,), dtype)
+        i = te.var("i")
         # for i in 0 to n-1:
         stmt = tvm.tir.For(
-            i, 0, n - 1, 0, 0,
-            tvm.tir.Store(Ab.data,
-                           tvm.tir.Load(dtype, Ab.data, i) + 1,
-                           i + 1))
+            i, 0, n - 1, 0, 0, tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)
+        )
         mod = tvm.IRModule.from_expr(
-            tvm.tir.PrimFunc([Ab], stmt).with_attr(
-                "global_symbol", "main")
+            tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main")
         )
         m = tvm.driver.build(mod, target="llvm")
         for name in names:
@@ -86,17 +84,15 @@ def test_dso_module_load():
     with open(path_runtime_py, "w") as fo:
         fo.write(runtime_py)
 
-    subprocess.check_call(
-        "python3 %s %s %s" % (path_runtime_py, path_dso, dtype),
-        shell=True)
+    subprocess.check_call("python3 %s %s %s" % (path_runtime_py, path_dso, dtype), shell=True)
 
 
 @tvm.testing.requires_gpu
 def test_device_module_dump():
     # graph
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
     # create iter var and assign them tags.
     num_thread = 8
@@ -111,7 +107,7 @@ def test_device_module_dump():
             return
         temp = util.tempdir()
         name = "myadd_%s" % device
-        if sys.platform == "darwin" or sys.platform.startswith('linux'):
+        if sys.platform == "darwin" or sys.platform.startswith("linux"):
             f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
         elif sys.platform == "win32":
             f = tvm.build(s, [A, B], device, "llvm", name=name)
@@ -152,19 +148,20 @@ def test_device_module_dump():
         check_device(device)
         check_stackvm(device)
 
+
 def test_combine_module_llvm():
     """Test combine multiple module into one shared lib."""
     # graph
     nn = 12
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     def check_llvm():
         ctx = tvm.cpu(0)
         if not tvm.testing.device_enabled("llvm"):
-            print("Skip because llvm is not enabled" )
+            print("Skip because llvm is not enabled")
             return
         temp = util.tempdir()
         fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
@@ -177,8 +174,8 @@ def test_combine_module_llvm():
         # create shared library with multiple functions
         cc.create_shared(path_dso, [path1, path2])
         m = tvm.runtime.load_module(path_dso)
-        fadd1 = m['myadd1']
-        fadd2 = m['myadd2']
+        fadd1 = m["myadd1"]
+        fadd2 = m["myadd2"]
         a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
         fadd1(a, b)
@@ -189,7 +186,7 @@ def test_combine_module_llvm():
     def check_system_lib():
         ctx = tvm.cpu(0)
         if not tvm.testing.device_enabled("llvm"):
-            print("Skip because llvm is not enabled" )
+            print("Skip because llvm is not enabled")
             return
         temp = util.tempdir()
         fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
@@ -206,9 +203,9 @@ def test_combine_module_llvm():
         mm = tvm.runtime.system_lib()
         a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
-        mm['myadd1'](a, b)
+        mm["myadd1"](a, b)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
-        mm['myadd2'](a, b)
+        mm["myadd2"](a, b)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
     if sys.platform != "win32":
@@ -216,7 +213,6 @@ def test_combine_module_llvm():
     check_llvm()
 
 
-
 if __name__ == "__main__":
     test_combine_module_llvm()
     test_device_module_dump()
index bda987d..0183ecd 100644 (file)
@@ -23,8 +23,7 @@ import tvm.testing
 @tvm.testing.uses_gpu
 def test_nd_create():
     for target, ctx in tvm.testing.enabled_targets():
-        for dtype in ["uint8", "int8", "uint16", "int16", "uint32", "int32",
-                      "float32"]:
+        for dtype in ["uint8", "int8", "uint16", "int16", "uint32", "int32", "float32"]:
             x = np.random.randint(0, 10, size=(3, 4))
             x = np.array(x, dtype=dtype)
             y = tvm.nd.array(x, ctx=ctx)
@@ -41,12 +40,12 @@ def test_nd_create():
 def test_fp16_conversion():
     n = 100
 
-    for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
+    for (src, dst) in [("float32", "float16"), ("float16", "float32")]:
         A = te.placeholder((n,), dtype=src)
         B = te.compute((n,), lambda i: A[i].astype(dst))
 
         s = te.create_schedule([B.op])
-        func = tvm.build(s, [A, B], 'llvm')
+        func = tvm.build(s, [A, B], "llvm")
 
         x_tvm = tvm.nd.array(100 * np.random.randn(n).astype(src) - 50)
         y_tvm = tvm.nd.array(100 * np.random.randn(n).astype(dst) - 50)
index aab84a1..718fe03 100644 (file)
@@ -19,21 +19,25 @@ from tvm import te
 import tvm.testing
 import numpy as np
 
+
 def test_get_global():
     targs = (10, 10.0, "hello")
     # register into global function table
     @tvm.register_func
     def my_packed_func(*args):
-        assert(tuple(args) == targs)
+        assert tuple(args) == targs
         return 10
+
     # get it out from global function table
     f = tvm.get_global_func("my_packed_func")
     assert isinstance(f, tvm.runtime.PackedFunc)
     y = f(*targs)
     assert y == 10
 
+
 def test_get_callback_with_node():
     x = tvm.runtime.convert(10)
+
     def test(y):
         assert y.handle != x.handle
         return y
@@ -49,14 +53,16 @@ def test_get_callback_with_node():
     f = tvm.get_global_func("my_callback_with_node")
     assert isinstance(f, tvm.runtime.PackedFunc)
     y = f(x, f2)
-    assert(y.value == 10)
+    assert y.value == 10
 
 
 def test_return_func():
     def addy(y):
         def add(x):
             return tvm.runtime.convert(x + y)
+
         return add
+
     myf = tvm.runtime.convert(addy)
     f = myf(10)
     assert f(11).value == 21
@@ -65,18 +71,21 @@ def test_return_func():
 def test_convert():
     # convert a function to tvm function
     targs = (10, 10.0, "hello", 10)
+
     def myfunc(*args):
-        assert(tuple(args) == targs)
+        assert tuple(args) == targs
 
     f = tvm.runtime.convert(myfunc)
     assert isinstance(f, tvm.runtime.PackedFunc)
 
+
 def test_byte_array():
     s = "hello"
     a = bytearray(s, encoding="ascii")
 
     def myfunc(ss):
         assert ss == a
+
     f = tvm.runtime.convert(myfunc)
     f(a)
 
@@ -84,6 +93,7 @@ def test_byte_array():
 def test_empty_array():
     def myfunc(ss):
         assert tuple(ss) == ()
+
     x = tvm.runtime.convert(())
     tvm.runtime.convert(myfunc)(x)
 
@@ -92,6 +102,7 @@ def test_ctx():
     def test_ctx_func(ctx):
         assert tvm.gpu(7) == ctx
         return tvm.cpu(0)
+
     x = test_ctx_func(tvm.gpu(7))
     assert x == tvm.cpu(0)
     x = tvm.opencl(10)
@@ -127,14 +138,15 @@ def test_rvalue_ref():
 
 def test_trace_default_action():
     n = 2
-    x = te.placeholder((n,n,n), name="X", dtype="float32")
+    x = te.placeholder((n, n, n), name="X", dtype="float32")
     y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([i, j, k, x[i][j][k]]))
     s = te.create_schedule(y.op)
     f = tvm.build(s, [x, y], target="llvm")
-    xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
-    ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
+    xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype))
+    ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype))
     f(xnd, ynd)
 
+
 def test_trace_expr_assign():
     @tvm.register_func("tvm.tir.trace_callback2")
     def trace_buffer(x):
@@ -142,24 +154,29 @@ def test_trace_expr_assign():
 
     def check_assign(dtype):
         n = 4
-        x = te.placeholder((n,n,n), name="X", dtype=dtype)
-        y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([x[i][j][k]], "tvm.tir.trace_callback2"))
-        z = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([y[i][j][k]], "tvm.tir.trace_callback2"))
+        x = te.placeholder((n, n, n), name="X", dtype=dtype)
+        y = te.compute(
+            x.shape, lambda i, j, k: tvm.tir.trace([x[i][j][k]], "tvm.tir.trace_callback2")
+        )
+        z = te.compute(
+            x.shape, lambda i, j, k: tvm.tir.trace([y[i][j][k]], "tvm.tir.trace_callback2")
+        )
         s = te.create_schedule(z.op)
         f = tvm.build(s, [x, y, z], "llvm")
 
-        xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
-        ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
-        znd = tvm.nd.array(np.zeros((n,n,n), dtype=z.dtype))
+        xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype))
+        ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype))
+        znd = tvm.nd.array(np.zeros((n, n, n), dtype=z.dtype))
         f(xnd, ynd, znd)
 
-        assert(np.array_equal(xnd.asnumpy(), np.ones((n,n,n))))
-        assert(np.array_equal(ynd.asnumpy(), np.ones((n,n,n))))
-        assert(np.array_equal(znd.asnumpy(), np.ones((n,n,n))))
+        assert np.array_equal(xnd.asnumpy(), np.ones((n, n, n)))
+        assert np.array_equal(ynd.asnumpy(), np.ones((n, n, n)))
+        assert np.array_equal(znd.asnumpy(), np.ones((n, n, n)))
 
     for t in ["float64", "float32", "int64", "int32"]:
         check_assign(t)
 
+
 def test_trace_expr_sum_generated():
     @tvm.register_func("tvm.tir.trace_callback3")
     def trace_buffer(x):
@@ -167,53 +184,60 @@ def test_trace_expr_sum_generated():
 
     def check_expr_sum(dtype):
         n = 4
-        a = te.placeholder((n,n,n), name="a", dtype=dtype)
-        b = te.placeholder((n,n,n), name="b", dtype=dtype)
-        c = te.compute(a.shape, lambda i, j, k: tvm.tir.trace([a[i][j][k]],"tvm.tir.trace_callback3")
-                                         + tvm.tir.trace([b[i][j][k]],"tvm.tir.trace_callback3"))
+        a = te.placeholder((n, n, n), name="a", dtype=dtype)
+        b = te.placeholder((n, n, n), name="b", dtype=dtype)
+        c = te.compute(
+            a.shape,
+            lambda i, j, k: tvm.tir.trace([a[i][j][k]], "tvm.tir.trace_callback3")
+            + tvm.tir.trace([b[i][j][k]], "tvm.tir.trace_callback3"),
+        )
         s = te.create_schedule(c.op)
         f = tvm.build(s, [a, b, c])
-        xnd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
-        ynd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
-        znd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
+        xnd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype)))
+        ynd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype)))
+        znd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype))
         f(xnd, ynd, znd)
-        assert(np.array_equal(znd.asnumpy(), xnd.asnumpy() + ynd.asnumpy()))
+        assert np.array_equal(znd.asnumpy(), xnd.asnumpy() + ynd.asnumpy())
 
     for t in ["float64", "float32", "int64", "int32"]:
         check_expr_sum(t)
 
+
 def test_trace_expr_sum_args():
     @tvm.register_func("tvm.tir.trace_silent")
     def silent(*args):
-      return
+        return
 
     def check_expr_sum(dtype):
         n = 4
-        a = te.placeholder((n,n,n), name="a", dtype=dtype)
-        b = te.placeholder((n,n,n), name="b", dtype=dtype)
-        e = te.placeholder((n,n,n), name="e", dtype=dtype)
-        d = te.placeholder((n,n,n), name="d", dtype=dtype)
-
-        c = te.compute(a.shape, lambda i, j, k: tvm.tir.trace([i, j, k, a[i][j][k]], "tvm.tir.trace_silent")
-                                               + tvm.tir.trace([i, j, k, b[i][j][k]], "tvm.tir.trace_silent")
-                                               + tvm.tir.trace([i, j, k, d[i][j][k]], "tvm.tir.trace_silent")
-                                               + tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent"))
+        a = te.placeholder((n, n, n), name="a", dtype=dtype)
+        b = te.placeholder((n, n, n), name="b", dtype=dtype)
+        e = te.placeholder((n, n, n), name="e", dtype=dtype)
+        d = te.placeholder((n, n, n), name="d", dtype=dtype)
+
+        c = te.compute(
+            a.shape,
+            lambda i, j, k: tvm.tir.trace([i, j, k, a[i][j][k]], "tvm.tir.trace_silent")
+            + tvm.tir.trace([i, j, k, b[i][j][k]], "tvm.tir.trace_silent")
+            + tvm.tir.trace([i, j, k, d[i][j][k]], "tvm.tir.trace_silent")
+            + tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent"),
+        )
         s = te.create_schedule(c.op)
         f = tvm.build(s, [a, b, d, e, c])
-        a_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
-        b_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
-        d_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=d.dtype)))
-        e_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=e.dtype)))
-        c_nd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
+        a_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype)))
+        b_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype)))
+        d_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=d.dtype)))
+        e_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=e.dtype)))
+        c_nd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype))
         f(a_nd, b_nd, d_nd, e_nd, c_nd)
-        assert(np.array_equal(c_nd.asnumpy(), a_nd.asnumpy()
-                                            + b_nd.asnumpy()
-                                            + d_nd.asnumpy()
-                                            + e_nd.asnumpy()))
+        assert np.array_equal(
+            c_nd.asnumpy(), a_nd.asnumpy() + b_nd.asnumpy() + d_nd.asnumpy() + e_nd.asnumpy()
+        )
 
     for t in ["float64", "float32", "int64", "int32"]:
         check_expr_sum(t)
 
+
 def test_trace_expr_sum_custom():
     @tvm.register_func("tvm.tir.trace_callback4")
     def trace_buffer(x):
@@ -221,23 +245,27 @@ def test_trace_expr_sum_custom():
 
     def check_expr_sum_custom(dtype):
         n = 4
-        a = te.placeholder((n,n), name="a", dtype=dtype)
-        b = te.placeholder((n,n), name="b", dtype=dtype)
-        c = te.compute(a.shape, lambda i,j: tvm.tir.trace([a[i][j]], "tvm.tir.trace_callback4")
-                                         + tvm.tir.trace([b[i][j]], "tvm.tir.trace_callback4"))
+        a = te.placeholder((n, n), name="a", dtype=dtype)
+        b = te.placeholder((n, n), name="b", dtype=dtype)
+        c = te.compute(
+            a.shape,
+            lambda i, j: tvm.tir.trace([a[i][j]], "tvm.tir.trace_callback4")
+            + tvm.tir.trace([b[i][j]], "tvm.tir.trace_callback4"),
+        )
         s = te.create_schedule(c.op)
         f = tvm.build(s, [a, b, c])
-        npa = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
-        npb = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
+        npa = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype)
+        npb = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype)
         xnd = tvm.nd.array(npa)
         ynd = tvm.nd.array(npb)
-        znd = tvm.nd.array(np.zeros((n,n), dtype=c.dtype))
+        znd = tvm.nd.array(np.zeros((n, n), dtype=c.dtype))
         f(xnd, ynd, znd)
-        assert(np.array_equal(znd.asnumpy(), npa + npb))
+        assert np.array_equal(znd.asnumpy(), npa + npb)
 
     for t in ["float64", "float32", "int64", "int32"]:
         check_expr_sum_custom(t)
 
+
 def test_trace_can_change_traced_value_int():
     @tvm.register_func("tvm.tir.trace_change_int_first")
     def trace_buffer(x):
@@ -261,12 +289,13 @@ def test_trace_can_change_traced_value_int():
         f(xnd, ynd, znd)
         check_array_first = np.array([13, 13, 13, 13])
         check_array_second = np.array([14, 14, 14, 14])
-        assert(np.array_equal(ynd.asnumpy(), check_array_first))
-        assert(np.array_equal(znd.asnumpy(), check_array_second))
+        assert np.array_equal(ynd.asnumpy(), check_array_first)
+        assert np.array_equal(znd.asnumpy(), check_array_second)
 
     for t in ["int64", "int32"]:
         check_assign(t)
 
+
 def test_trace_can_change_traced_value_float():
     @tvm.register_func("tvm.tir.trace_change_float_first")
     def trace_buffer(x):
@@ -280,7 +309,9 @@ def test_trace_can_change_traced_value_float():
         n = 4
         x = te.placeholder((n,), name="X", dtype=dtype)
         y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_float_first"))
-        z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_float_second"))
+        z = te.compute(
+            x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_float_second")
+        )
         s = te.create_schedule(z.op)
         f = tvm.build(s, [x, y, z], "llvm")
 
@@ -290,14 +321,15 @@ def test_trace_can_change_traced_value_float():
         f(xnd, ynd, znd)
         check_array_first = np.array([13.0, 13.0, 13.0, 13.0])
         check_array_second = np.array([14.0, 14.0, 14.0, 14.0])
-        assert(np.array_equal(ynd.asnumpy(), check_array_first))
-        assert(np.array_equal(znd.asnumpy(), check_array_second))
+        assert np.array_equal(ynd.asnumpy(), check_array_first)
+        assert np.array_equal(znd.asnumpy(), check_array_second)
 
     for t in ["float64", "float32"]:
         check_assign(t)
 
+
 def test_numpy_scalar():
-    maxint = (1<<63) - 1
+    maxint = (1 << 63) - 1
     assert tvm.testing.echo(np.int64(maxint)) == maxint
 
 
index 50c753f..e106ead 100644 (file)
@@ -36,9 +36,10 @@ def test_bigendian_rpc():
     port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090)
     if host is None:
         return
+
     def verify_rpc(remote, target, shape, dtype):
         A = te.placeholder(shape, dtype=dtype)
-        B = te.compute(A.shape, lambda i: A[i]+tvm.tir.const(1, A.dtype))
+        B = te.compute(A.shape, lambda i: A[i] + tvm.tir.const(1, A.dtype))
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], target, name="myadd")
 
@@ -63,9 +64,11 @@ def test_bigendian_rpc():
 def test_rpc_simple():
     if not tvm.runtime.enabled("rpc"):
         return
+
     @tvm.register_func("rpc.test.addone")
     def addone(x):
         return x + 1
+
     @tvm.register_func("rpc.test.strcat")
     def strcat(name, x):
         return "%s:%d" % (name, x)
@@ -90,6 +93,7 @@ def test_rpc_simple():
 def test_rpc_runtime_string():
     if not tvm.runtime.enabled("rpc"):
         return
+
     @tvm.register_func("rpc.test.runtime_str_concat")
     def strcat(x, y):
         return x + y
@@ -106,9 +110,11 @@ def test_rpc_array():
     if not tvm.runtime.enabled("rpc"):
         return
     x = np.random.randint(0, 10, size=(3, 4))
+
     @tvm.register_func("rpc.test.remote_array_func")
     def remote_array_func(y):
         np.testing.assert_equal(y.asnumpy(), x)
+
     server = rpc.Server("localhost")
     remote = rpc.connect(server.host, server.port)
     r_cpu = tvm.nd.array(x, remote.cpu(0))
@@ -123,8 +129,8 @@ def test_rpc_large_array():
     server = rpc.Server("localhost")
     remote = rpc.connect(server.host, server.port)
     ctx = remote.cpu(0)
-    a_np = np.ones((5041, 720)).astype('float32')
-    b_np = np.ones((720, 192)).astype('float32')
+    a_np = np.ones((5041, 720)).astype("float32")
+    b_np = np.ones((720, 192)).astype("float32")
     a = tvm.nd.array(a_np, ctx)
     b = tvm.nd.array(b_np, ctx)
     np.testing.assert_equal(a.asnumpy(), a_np)
@@ -134,21 +140,19 @@ def test_rpc_large_array():
 def test_rpc_echo():
     def check(remote):
         fecho = remote.get_function("testing.echo")
-        assert(fecho(1, 2, 3) == 1)
-        assert(fecho(100, 2, 3) == 100)
-        assert(fecho("xyz") == "xyz")
-        assert(bytes(fecho(bytearray(b"123"))) == b"123")
+        assert fecho(1, 2, 3) == 1
+        assert fecho(100, 2, 3) == 100
+        assert fecho("xyz") == "xyz"
+        assert bytes(fecho(bytearray(b"123"))) == b"123"
 
         with pytest.raises(RuntimeError):
-            raise_err = remote.get_function(
-                "testing.test_raise_error_callback")("RuntimeError")
+            raise_err = remote.get_function("testing.test_raise_error_callback")("RuntimeError")
             raise_err()
 
         remote.cpu().sync()
         with pytest.raises(AttributeError):
             f3 = remote.system_lib()["notexist"]
 
-
     temp = rpc.server._server_env([])
     server = rpc.Server("localhost")
     client = rpc.connect(server.host, server.port)
@@ -163,9 +167,10 @@ def test_rpc_echo():
     # minrpc on the remote
     server = rpc.Server("localhost")
     client = rpc.connect(
-        server.host, server.port,
-        session_constructor_args=["rpc.PopenSession",
-                             open(minrpc_exec, "rb").read()])
+        server.host,
+        server.port,
+        session_constructor_args=["rpc.PopenSession", open(minrpc_exec, "rb").read()],
+    )
     check(client)
 
 
@@ -177,7 +182,8 @@ def test_rpc_file_exchange():
     blob = bytearray(np.random.randint(0, 10, size=(10)))
     remote.upload(blob, "dat.bin")
     rev = remote.download("dat.bin")
-    assert(rev == blob)
+    assert rev == blob
+
 
 @tvm.testing.requires_llvm
 def test_rpc_remote_module():
@@ -185,17 +191,19 @@ def test_rpc_remote_module():
         return
     # graph
     n = tvm.runtime.convert(102)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     server0 = rpc.Server("localhost", key="x0")
     server1 = rpc.Server("localhost", key="x1")
 
     client = rpc.connect(
-        server0.host, server0.port, key="x0",
-        session_constructor_args=[
-        "rpc.Connect", server1.host, server1.port, "x1"])
+        server0.host,
+        server0.port,
+        key="x0",
+        session_constructor_args=["rpc.Connect", server1.host, server1.port, "x1"],
+    )
 
     def check_remote(remote):
         temp = util.tempdir()
@@ -209,7 +217,7 @@ def test_rpc_remote_module():
         b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
         time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
         cost = time_f(a, b).mean
-        print('%g secs/op' % cost)
+        print("%g secs/op" % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
     def check_minrpc():
@@ -240,7 +248,6 @@ def test_rpc_remote_module():
         with pytest.raises(RuntimeError):
             rpc.PopenSession(path_minrpc)
 
-
     def check_remote_link_cl(remote):
         """Test function to run remote code such as cl
 
@@ -291,11 +298,10 @@ def test_rpc_remote_module():
     check_minrpc()
 
 
-
 def test_rpc_return_func():
     @tvm.register_func("rpc.test.remote_func")
     def addone(x):
-        return lambda y: x+y
+        return lambda y: x + y
 
     server = rpc.Server("localhost", key="x1")
     client = rpc.connect(server.host, server.port, key="x1")
@@ -312,24 +318,29 @@ def test_rpc_session_constructor_args():
     def check_multi_hop():
         # use server0 as proxy to connect to server1
         client = rpc.connect(
-            server0.host, server0.port, key="x0",
-            session_constructor_args=[
-                "rpc.Connect", server1.host, server1.port, "x1"])
+            server0.host,
+            server0.port,
+            key="x0",
+            session_constructor_args=["rpc.Connect", server1.host, server1.port, "x1"],
+        )
 
         fecho = client.get_function("testing.echo")
-        assert(fecho(1, 2, 3) == 1)
-        assert(fecho(100, 2, 3) == 100)
-        assert(fecho("xyz") == "xyz")
-        assert(bytes(fecho(bytearray(b"123"))) == b"123")
+        assert fecho(1, 2, 3) == 1
+        assert fecho(100, 2, 3) == 100
+        assert fecho("xyz") == "xyz"
+        assert bytes(fecho(bytearray(b"123"))) == b"123"
 
-        nd = tvm.nd.array([1,2,3], ctx=client.cpu(0))
-        assert(nd.asnumpy()[1] == 2)
+        nd = tvm.nd.array([1, 2, 3], ctx=client.cpu(0))
+        assert nd.asnumpy()[1] == 2
 
     def check_error_handling():
         with pytest.raises(tvm.error.RPCError):
             client = rpc.connect(
-                server0.host, server0.port, key="x0",
-                session_constructor_args=["rpc.NonExistingConstructor"])
+                server0.host,
+                server0.port,
+                key="x0",
+                session_constructor_args=["rpc.NonExistingConstructor"],
+            )
 
     check_multi_hop()
     check_error_handling()
@@ -338,12 +349,13 @@ def test_rpc_session_constructor_args():
 def test_rpc_return_ndarray():
     # Use closure to check the ref counter correctness
     nd = tvm.nd.array(np.zeros(10).astype("float32"))
+
     @tvm.register_func("rpc.test.remote_return_nd")
     def my_module(name):
         if name == "get_arr":
-            return lambda : nd
+            return lambda: nd
         elif name == "ref_count":
-            return lambda : tvm.testing.object_use_count(nd)
+            return lambda: tvm.testing.object_use_count(nd)
         elif name == "get_elem":
             return lambda idx: nd.asnumpy()[idx]
         elif name == "get_arr_elem":
@@ -379,7 +391,8 @@ def test_rpc_return_ndarray():
 def test_local_func():
     @tvm.register_func("rpc.test.remote_func2")
     def addone(x):
-        return lambda y: x+y
+        return lambda y: x + y
+
     client = rpc.LocalSession()
     f1 = client.get_function("rpc.test.remote_func2")
     fadd = f1(10)
@@ -390,44 +403,54 @@ def test_local_func():
     rev = client.download("dat.bin")
     assert rev == blob
 
+
 def test_rpc_tracker_register():
     # test registration
-    tracker = Tracker('localhost', port=9000, port_end=10000)
-    device_key = 'test_device'
-    server = rpc.Server('localhost', port=9000, port_end=10000,
-                        key=device_key,
-                        tracker_addr=(tracker.host, tracker.port))
+    tracker = Tracker("localhost", port=9000, port_end=10000)
+    device_key = "test_device"
+    server = rpc.Server(
+        "localhost",
+        port=9000,
+        port_end=10000,
+        key=device_key,
+        tracker_addr=(tracker.host, tracker.port),
+    )
     time.sleep(1)
     client = rpc.connect_tracker(tracker.host, tracker.port)
 
     summary = client.summary()
-    assert summary['queue_info'][device_key]['free'] == 1
+    assert summary["queue_info"][device_key]["free"] == 1
 
     remote = client.request(device_key)
     summary = client.summary()
-    assert summary['queue_info'][device_key]['free'] == 0
+    assert summary["queue_info"][device_key]["free"] == 0
 
     del remote
     time.sleep(1)
 
     summary = client.summary()
-    assert summary['queue_info'][device_key]['free'] == 1
+    assert summary["queue_info"][device_key]["free"] == 1
 
     server.terminate()
     time.sleep(1)
 
     summary = client.summary()
-    assert summary['queue_info'][device_key]['free'] == 0
+    assert summary["queue_info"][device_key]["free"] == 0
 
     tracker.terminate()
 
+
 def test_rpc_tracker_request():
     # test concurrent request
-    tracker = Tracker('localhost', port=9000, port_end=10000)
-    device_key = 'test_device'
-    server = rpc.Server('localhost', port=9000, port_end=10000,
-                        key=device_key,
-                        tracker_addr=(tracker.host, tracker.port))
+    tracker = Tracker("localhost", port=9000, port_end=10000)
+    device_key = "test_device"
+    server = rpc.Server(
+        "localhost",
+        port=9000,
+        port_end=10000,
+        key=device_key,
+        tracker_addr=(tracker.host, tracker.port),
+    )
     client = rpc.connect_tracker(tracker.host, tracker.port)
 
     def target(host, port, device_key, timeout):
@@ -437,10 +460,10 @@ def test_rpc_tracker_request():
             pass
         remote.cpu()
 
-    proc1 = multiprocessing.Process(target=target,
-                                    args=(tracker.host, tracker.port, device_key, 4))
-    proc2 = multiprocessing.Process(target=target,
-                                    args=(tracker.host, tracker.port, device_key, 200))
+    proc1 = multiprocessing.Process(target=target, args=(tracker.host, tracker.port, device_key, 4))
+    proc2 = multiprocessing.Process(
+        target=target, args=(tracker.host, tracker.port, device_key, 200)
+    )
     proc1.start()
     time.sleep(0.5)
     proc2.start()
@@ -448,16 +471,16 @@ def test_rpc_tracker_request():
 
     summary = client.summary()
 
-    assert summary['queue_info'][device_key]['free'] == 0
-    assert summary['queue_info'][device_key]['pending'] == 1
+    assert summary["queue_info"][device_key]["free"] == 0
+    assert summary["queue_info"][device_key]["pending"] == 1
 
     proc1.terminate()
     proc1.join()
     time.sleep(0.5)
 
     summary = client.summary()
-    assert summary['queue_info'][device_key]['free'] == 0
-    assert summary['queue_info'][device_key]['pending'] == 0
+    assert summary["queue_info"][device_key]["free"] == 0
+    assert summary["queue_info"][device_key]["pending"] == 0
 
     proc2.terminate()
     proc2.join()
index 9e48435..2971522 100644 (file)
@@ -20,6 +20,7 @@ from tvm.runtime import profiler_vm
 from tvm import relay
 from tvm.relay.testing import resnet, enabled_targets
 
+
 def test_basic():
     mod, params = resnet.get_workload()
     if not profiler_vm.enabled():
@@ -29,10 +30,11 @@ def test_basic():
         exe = relay.vm.compile(mod, target, params=params)
         vm = profiler_vm.VirtualMachineProfiler(exe, ctx)
 
-        data = np.random.rand(1, 3, 224, 224).astype('float32')
+        data = np.random.rand(1, 3, 224, 224).astype("float32")
         res = vm.invoke("main", [data])
         print("\n{}".format(vm.get_stat()))
         print("\n{}".format(vm.get_stat(False)))
 
+
 if __name__ == "__main__":
     test_basic()
index ace41d0..b5c69d6 100644 (file)
@@ -20,48 +20,54 @@ import re
 import os
 import ctypes
 
+
 def test_popcount():
-    target = 'llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
+    target = "llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon"
 
     def check_correct_assembly(type, elements, counts):
         n = tvm.runtime.convert(elements)
-        A = te.placeholder(n, dtype=type, name='A')
-        B = te.compute(A.shape, lambda i: tvm.tir.popcount(A[i]), name='B')
+        A = te.placeholder(n, dtype=type, name="A")
+        B = te.compute(A.shape, lambda i: tvm.tir.popcount(A[i]), name="B")
         s = te.create_schedule(B.op)
         s[B].vectorize(s[B].op.axis[0])
         f = tvm.build(s, [A, B], target)
 
         # Verify we see the correct number of vpaddl and vcnt instructions in the assembly
-        assembly = f.get_source('asm')
+        assembly = f.get_source("asm")
         matches = re.findall("vpaddl", assembly)
-        assert (len(matches) == counts)
+        assert len(matches) == counts
         matches = re.findall("vcnt", assembly)
-        assert (len(matches) == 1)
-    check_correct_assembly('uint16', 8, 1)
-    check_correct_assembly('uint16', 4, 1)
-    check_correct_assembly('uint32', 4, 2)
-    check_correct_assembly('uint32', 2, 2)
-    check_correct_assembly('uint64', 2, 3)
+        assert len(matches) == 1
+
+    check_correct_assembly("uint16", 8, 1)
+    check_correct_assembly("uint16", 4, 1)
+    check_correct_assembly("uint32", 4, 2)
+    check_correct_assembly("uint32", 2, 2)
+    check_correct_assembly("uint64", 2, 3)
 
 
 def test_vmlal_s16():
-    target = 'llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
+    target = "llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon"
 
     def check_correct_assembly(N):
         K = te.size_var("K")
-        A = te.placeholder((K, N), dtype="int8", name='A')
-        B = te.placeholder((K, N), dtype="int8", name='B')
+        A = te.placeholder((K, N), dtype="int8", name="A")
+        B = te.placeholder((K, N), dtype="int8", name="B")
         k = te.reduce_axis((0, K))
-        C = te.compute((N, ), lambda n: te.sum(
-            A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
+        C = te.compute(
+            (N,),
+            lambda n: te.sum(A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]),
+            name="C",
+        )
         s = te.create_schedule(C.op)
         s[C].vectorize(s[C].op.axis[0])
         f = tvm.build(s, [A, B, C], target)
 
         # Verify we see the correct number of vmlal.s16 instructions
-        assembly = f.get_source('asm')
+        assembly = f.get_source("asm")
         matches = re.findall("vmlal.s16", assembly)
-        assert (len(matches) == N // 4)
+        assert len(matches) == N // 4
+
     check_correct_assembly(8)
     check_correct_assembly(16)
     check_correct_assembly(32)
@@ -69,20 +75,23 @@ def test_vmlal_s16():
 
     def check_broadcast_correct_assembly(N):
         K = te.size_var("K")
-        A = te.placeholder((K, N), dtype="int8", name='A')
-        B = te.placeholder((K,), dtype="int8", name='B')
+        A = te.placeholder((K, N), dtype="int8", name="A")
+        B = te.placeholder((K,), dtype="int8", name="B")
         k = te.reduce_axis((0, K))
-        C = te.compute((N, ), lambda n: te.sum(
-            A[k, n].astype("int32") * B[k].astype("int32"),
-            axis=[k]), name='C')
+        C = te.compute(
+            (N,),
+            lambda n: te.sum(A[k, n].astype("int32") * B[k].astype("int32"), axis=[k]),
+            name="C",
+        )
         s = te.create_schedule(C.op)
         s[C].vectorize(s[C].op.axis[0])
         f = tvm.build(s, [A, B, C], target)
 
         # Verify we see the correct number of vmlal.s16 instructions
-        assembly = f.get_source('asm')
+        assembly = f.get_source("asm")
         matches = re.findall("vmlal.s16", assembly)
         assert len(matches) == N // 4
+
     check_broadcast_correct_assembly(8)
     check_broadcast_correct_assembly(16)
     check_broadcast_correct_assembly(32)
index 758643d..dedabf4 100644 (file)
@@ -24,6 +24,7 @@ from tvm import te
 import ctypes
 import tvm.testing
 
+
 @tvm.testing.uses_gpu
 def test_synthetic():
     for device in ["llvm", "cuda"]:
@@ -47,9 +48,12 @@ def test_synthetic():
 
     synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload(input_shape=input_shape)
     with tvm.transform.PassContext(opt_level=3):
-        graph, synthetic_gpu_lib, graph_params = relay.build_module.build(synthetic_mod, "cuda", params=synthetic_params)
+        graph, synthetic_gpu_lib, graph_params = relay.build_module.build(
+            synthetic_mod, "cuda", params=synthetic_params
+        )
 
     from tvm.contrib import util
+
     temp = util.tempdir()
     path_lib = temp.relpath("deploy_lib.so")
     synthetic_gpu_lib.export_library(path_lib)
@@ -81,14 +85,15 @@ def test_cuda_lib():
             return
     nn = 12
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
     bx, tx = s[B].split(B.op.axis[0], factor=4)
     s[B].bind(bx, te.thread_axis("blockIdx.x"))
     s[B].bind(tx, te.thread_axis("threadIdx.x"))
 
     from tvm.contrib import util
+
     temp = util.tempdir()
     fn_add = tvm.build(s, [A, B], target="cuda", target_host="llvm", name="add")
     path_lib = temp.relpath("deploy_lib.so")
@@ -96,7 +101,7 @@ def test_cuda_lib():
     m = tvm.runtime.load_module(path_lib)
     a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
     b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
-    m['add'](a, b)
+    m["add"](a, b)
     np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
 
index f8d6e32..f4b5f90 100644 (file)
@@ -21,15 +21,14 @@ from tvm import te
 import numpy as np
 import tvm.testing
 
+
 @tvm.testing.uses_gpu
 def test_cmp_load_store():
     n = 32
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
-    D = te.compute(C.shape, lambda *i: tvm.tir.all(C(*i),
-                                                A(*i) > 1).astype('float32'), name="D")
-
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) > B(*i), name="C")
+    D = te.compute(C.shape, lambda *i: tvm.tir.all(C(*i), A(*i) > 1).astype("float32"), name="D")
 
     def check_llvm():
         if not tvm.testing.device_enabled("llvm"):
@@ -47,7 +46,9 @@ def test_cmp_load_store():
         d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
         f(a, b, d)
         np.testing.assert_equal(
-            d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))
+            d.asnumpy(),
+            np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype("float32"),
+        )
 
     def check_device(device):
         if not tvm.testing.device_enabled(device):
@@ -65,14 +66,14 @@ def test_cmp_load_store():
         d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
         f(a, b, d)
         np.testing.assert_equal(
-            d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))
-
+            d.asnumpy(),
+            np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype("float32"),
+        )
 
     check_llvm()
     for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]:
         check_device(device)
 
 
-
 if __name__ == "__main__":
     test_cmp_load_store()
index 31353ef..6afbfc7 100644 (file)
@@ -19,12 +19,13 @@ from tvm import te
 import numpy as np
 from tvm.contrib import util
 
+
 def test_add():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
 
     def check_c():
@@ -33,7 +34,7 @@ def test_add():
         path_dso = temp.relpath("temp.so")
         mhost.export_library(path_dso)
         m = tvm.runtime.load_module(path_dso)
-        fadd = m['fadd']
+        fadd = m["fadd"]
         ctx = tvm.cpu(0)
         # launch the kernel.
         n = nn
@@ -41,20 +42,20 @@ def test_add():
         b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         fadd(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
     check_c()
 
 
 def test_add_pipeline():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    AA = te.compute((n,), lambda *i: A(*i), name='A')
-    BB = te.compute((n,), lambda *i: B(*i), name='B')
-    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
-    C = te.compute(A.shape, lambda *i: T(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    AA = te.compute((n,), lambda *i: A(*i), name="A")
+    BB = te.compute((n,), lambda *i: B(*i), name="B")
+    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name="T")
+    C = te.compute(A.shape, lambda *i: T(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     xo1, xo2 = s[C].split(xo, factor=13)
@@ -67,13 +68,11 @@ def test_add_pipeline():
     def check_c():
         # Specifically allow offset to test codepath when offset is available
         Ab = tvm.tir.decl_buffer(
-            A.shape, A.dtype,
-            elem_offset=te.size_var('Aoffset'),
-            offset_factor=8,
-            name='A')
-        binds = {A : Ab}
+            A.shape, A.dtype, elem_offset=te.size_var("Aoffset"), offset_factor=8, name="A"
+        )
+        binds = {A: Ab}
         # BUILD and invoke the kernel.
-        f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
+        f1 = tvm.lower(s, [A, B, C], name="fadd_pipeline")
         mhost = tvm.build(f1, target="c")
 
         temp = util.tempdir()
@@ -88,8 +87,7 @@ def test_add_pipeline():
         b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         fadd(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
     check_c()
 
@@ -97,8 +95,10 @@ def test_add_pipeline():
 def test_reinterpret():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A', dtype="int32")
-    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name='B')
+    A = te.placeholder((n,), name="A", dtype="int32")
+    B = te.compute(
+        A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name="B"
+    )
     s = te.create_schedule(B.op)
 
     def check_c():
@@ -107,14 +107,14 @@ def test_reinterpret():
         path_dso = temp.relpath("temp.so")
         mhost.export_library(path_dso)
         m = tvm.runtime.load_module(path_dso)
-        fadd = m['reinterpret']
+        fadd = m["reinterpret"]
         ctx = tvm.cpu(0)
         n = nn
-        a = tvm.nd.array(np.random.randint(-2 ** 30, 2 ** 30, size=n).astype(A.dtype), ctx)
+        a = tvm.nd.array(np.random.randint(-(2 ** 30), 2 ** 30, size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         fadd(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), (2 + a.asnumpy()).view('float32'))
+        tvm.testing.assert_allclose(b.asnumpy(), (2 + a.asnumpy()).view("float32"))
+
     check_c()
 
 
index 64a10d8..2c6c519 100644 (file)
@@ -23,13 +23,14 @@ from tvm import rpc
 from tvm.contrib import util, cc
 import numpy as np
 
+
 @tvm.testing.requires_llvm
 def test_llvm_add_pipeline():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     s[C].parallel(xo)
@@ -38,10 +39,10 @@ def test_llvm_add_pipeline():
     def verify_elf(path, e_machine):
         with open(path, "rb") as fi:
             arr = fi.read(20)
-            assert struct.unpack('ccc', arr[1:4]) == (b'E',b'L',b'F')
-            endian = struct.unpack('b', arr[0x5:0x6])[0]
-            endian = '<' if endian == 1 else '>'
-            assert struct.unpack(endian + 'h', arr[0x12:0x14])[0] == e_machine
+            assert struct.unpack("ccc", arr[1:4]) == (b"E", b"L", b"F")
+            endian = struct.unpack("b", arr[0x5:0x6])[0]
+            endian = "<" if endian == 1 else ">"
+            assert struct.unpack(endian + "h", arr[0x12:0x14])[0] == e_machine
 
     def build_i386():
         temp = util.tempdir()
@@ -64,10 +65,10 @@ def test_llvm_add_pipeline():
         asm_path = temp.relpath("myadd.asm")
         f.save(asm_path)
         # Do a RPC verification, launch kernel on Arm Board if available.
-        host = os.environ.get('TVM_RPC_ARM_HOST', None)
+        host = os.environ.get("TVM_RPC_ARM_HOST", None)
         remote = None
         if host:
-            port = int(os.environ['TVM_RPC_ARM_PORT'])
+            port = int(os.environ["TVM_RPC_ARM_PORT"])
             try:
                 remote = rpc.connect(host, port)
             except tvm.error.TVMError as e:
@@ -82,12 +83,12 @@ def test_llvm_add_pipeline():
             b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
             c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
             farm(a, b, c)
-            tvm.testing.assert_allclose(
-                c.asnumpy(), a.asnumpy() + b.asnumpy())
+            tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
             print("Verification finish on remote..")
 
     build_i386()
     build_arm()
 
+
 if __name__ == "__main__":
     test_llvm_add_pipeline()
index c046874..e877674 100644 (file)
@@ -1,4 +1,3 @@
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -40,17 +39,15 @@ def test_cuda_vectorize_add():
         if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
             print("skip because gpu does not support int8")
             return
-        A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = te.compute((n,), lambda i: A[i] +
-                       tvm.tir.const(1, A.dtype), name='B')
+        A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
+        B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
         s = te.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
         s[B].bind(xo, bx)
         s[B].bind(xi, tx)
         fun = tvm.build(s, [A, B], "cuda")
         ctx = tvm.gpu(0)
-        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
-            np.random.uniform(size=(n, lanes)))
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np.random.uniform(size=(n, lanes)))
         c = tvm.nd.empty((n,), B.dtype, ctx)
         fun(a, c)
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
@@ -58,12 +55,12 @@ def test_cuda_vectorize_add():
     check_cuda("float32", 64, 2)
     check_cuda("float32", 64, 3)
     check_cuda("float32", 64, 4)
-    check_cuda("int8",    64, 2)
-    check_cuda("int8",    64, 3)
-    check_cuda("int8",    64, 4)
-    check_cuda("uint8",   64, 2)
-    check_cuda("uint8",   64, 3)
-    check_cuda("uint8",   64, 4)
+    check_cuda("int8", 64, 2)
+    check_cuda("int8", 64, 3)
+    check_cuda("int8", 64, 4)
+    check_cuda("uint8", 64, 2)
+    check_cuda("uint8", 64, 3)
+    check_cuda("uint8", 64, 4)
     check_cuda("float16", 64, 2)
     check_cuda("float16", 64, 4)
     check_cuda("float16", 64, 6)
@@ -79,11 +76,12 @@ def test_cuda_multiply_add():
         if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
             print("skip because gpu does not support int8")
             return
-        A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = te.placeholder((n,), name='B', dtype="%sx%d" % (dtype, lanes))
-        C = te.placeholder((n,), name='C', dtype="int32")
-        D = te.compute((n,),
-                       lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name='D')
+        A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
+        B = te.placeholder((n,), name="B", dtype="%sx%d" % (dtype, lanes))
+        C = te.placeholder((n,), name="C", dtype="int32")
+        D = te.compute(
+            (n,), lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name="D"
+        )
         s = te.create_schedule(D.op)
         xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
         s[D].bind(xo, bx)
@@ -100,6 +98,7 @@ def test_cuda_multiply_add():
         d = tvm.nd.empty((n,), D.dtype, ctx)
         fun(a, b, c, d)
         tvm.testing.assert_allclose(d.asnumpy(), np_d)
+
     check_cuda("int8", 64, 4)
 
 
@@ -110,8 +109,8 @@ def test_cuda_vectorize_load():
 
     def check_cuda(dtype, n, lanes):
         ctx = tvm.gpu(0)
-        A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = te.compute((n,), lambda i: A[i], name='B')
+        A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
+        B = te.compute((n,), lambda i: A[i], name="B")
         s = te.create_schedule(B.op)
         block, thread = s[B].split(B.op.axis[0], factor=num_thread)
         s[B].bind(block, bx)
@@ -122,6 +121,7 @@ def test_cuda_vectorize_load():
         b = tvm.nd.empty((n,), B.dtype, ctx)
         fun(a, b)
         tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy())
+
     check_cuda("int8", 64, 2)
     check_cuda("int8", 64, 3)
     check_cuda("int8", 64, 4)
@@ -133,10 +133,9 @@ def test_cuda_vectorize_load():
 @tvm.testing.requires_cuda
 def test_cuda_make_int8():
     def check_cuda(n, value, lanes):
-        dtype = 'int8'
+        dtype = "int8"
         ctx = tvm.gpu(0)
-        A = te.compute((n, lanes), lambda i,
-                       j: tvm.tir.const(value, dtype=dtype))
+        A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
         s = te.create_schedule(A.op)
         y, x = s[A].op.axis
         s[A].vectorize(x)
@@ -146,6 +145,7 @@ def test_cuda_make_int8():
         a = tvm.nd.empty(np_a.shape, dtype, ctx)
         fun(a)
         np.testing.assert_equal(a.asnumpy(), np_a)
+
     check_cuda(64, 0xAB, 4)
     check_cuda(64, 0, 4)
     check_cuda(64, -3, 4)
@@ -160,12 +160,12 @@ def test_cuda_make_int8():
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_inf_nan():
-    target = 'cuda'
+    target = "cuda"
 
     def check_inf_nan(ctx, n, value, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         inf_value = tvm.tir.const(value, dtype=dtype)
-        C = te.compute((n,), lambda i: inf_value, name='C')
+        C = te.compute((n,), lambda i: inf_value, name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], tx)
         fun = tvm.build(s, [A, C], target)
@@ -176,22 +176,21 @@ def test_cuda_inf_nan():
 
     ctx = tvm.context(target, 0)
 
-    check_inf_nan(ctx, 1, -float('inf'), 'float32')
-    check_inf_nan(ctx, 1, -float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('inf'), 'float32')
-    check_inf_nan(ctx, 1, float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('nan'), 'float32')
-    check_inf_nan(ctx, 1, float('nan'), 'float64')
+    check_inf_nan(ctx, 1, -float("inf"), "float32")
+    check_inf_nan(ctx, 1, -float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("inf"), "float32")
+    check_inf_nan(ctx, 1, float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("nan"), "float32")
+    check_inf_nan(ctx, 1, float("nan"), "float64")
 
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_shuffle():
     idxm = tvm.tir.indexmod
-    a = te.placeholder((64, ), 'int32')
-    b = te.placeholder((64, ), 'int32')
-    c = te.compute((64, ), lambda x: a[x] +
-                   b[x - idxm(x, 4) + (3 - idxm(x, 4))])
+    a = te.placeholder((64,), "int32")
+    b = te.placeholder((64,), "int32")
+    c = te.compute((64,), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
     sch = te.create_schedule(c.op)
     x = c.op.axis[0]
     xo, xi = sch[c].split(x, 4)
@@ -202,34 +201,37 @@ def test_cuda_shuffle():
     def MyVectorize():
         def vectorizer(op):
             if op.for_type == tvm.tir.For.Vectorized:
-                four = tvm.tir.const(4, 'int32')
-                idx = tvm.tir.Ramp(
-                    thrx.var * four, tvm.tir.const(1, 'int32'), 4)
-                all_ones = tvm.tir.const(1, 'int32x4')
+                four = tvm.tir.const(4, "int32")
+                idx = tvm.tir.Ramp(thrx.var * four, tvm.tir.const(1, "int32"), 4)
+                all_ones = tvm.tir.const(1, "int32x4")
                 store = op.body
                 value = store.value
-                new_a = tvm.tir.Load(
-                    'int32x4', value.a.buffer_var, idx, all_ones)
+                new_a = tvm.tir.Load("int32x4", value.a.buffer_var, idx, all_ones)
                 bs, ids = [], []
                 for i in range(4):
-                    bs.append(tvm.tir.Load('int32', value.b.buffer_var,
-                                           thrx.var * four + tvm.tir.const(i, 'int32')))
-                    ids.append(tvm.tir.const(3 - i, 'int32'))
+                    bs.append(
+                        tvm.tir.Load(
+                            "int32", value.b.buffer_var, thrx.var * four + tvm.tir.const(i, "int32")
+                        )
+                    )
+                    ids.append(tvm.tir.const(3 - i, "int32"))
                 new_b = tvm.tir.Shuffle(bs, ids)
                 return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
             return None
 
         def _transform(f, *_):
             return f.with_body(
-                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
+                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"])
+            )
+
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
 
     with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}):
-        module = tvm.build(sch, [a, b, c], target='cuda')
-        a_ = np.array(list(range(64)), dtype='int32')
-        b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
-        c_ = np.zeros((64, ), dtype='int32')
-        ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
+        module = tvm.build(sch, [a, b, c], target="cuda")
+        a_ = np.array(list(range(64)), dtype="int32")
+        b_ = np.array((list(range(4))[::-1]) * 16, dtype="int32")
+        c_ = np.zeros((64,), dtype="int32")
+        ref = a_ + np.array((list(range(4))) * 16, dtype="int32")
         nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
         module(nda, ndb, ndc)
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
@@ -239,7 +241,7 @@ def test_cuda_shuffle():
 def test_crossthread_reduction1(target, ctx):
     n = te.var("n")
     m = te.var("m")
-    A = te.placeholder((n, m), name='A')
+    A = te.placeholder((n, m), name="A")
     k = te.reduce_axis((0, m), "m")
     B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
 
@@ -255,14 +257,13 @@ def test_crossthread_reduction1(target, ctx):
         func = sched(nthd)
         nn = 3
         # checks three typical cases
-        vals = [nthd-1, nthd, nthd+1]
+        vals = [nthd - 1, nthd, nthd + 1]
         for kk in [x for x in vals]:
             size = (nn, kk)
             a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
             b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
             func(a, b)
-            tvm.testing.assert_allclose(b.asnumpy(),
-                                        np.sum(a.asnumpy(), axis=1), rtol=1e-3)
+            tvm.testing.assert_allclose(b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-3)
 
     verify(16)
     verify(32)
@@ -274,11 +275,10 @@ def test_crossthread_reduction2(target, ctx):
     n = te.var("n")
     k0 = te.var("k0")
     k1 = te.var("k1")
-    A = te.placeholder((n, k0, k1), name='A')
+    A = te.placeholder((n, k0, k1), name="A")
     k0 = te.reduce_axis((0, k0), "k0")
     k1 = te.reduce_axis((0, k1), "k1")
-    B = te.compute((n,), lambda i: te.sum(
-        A[i, k0, k1], axis=(k0, k1)), name="B")
+    B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
 
     def sched(nthdx, nthdy):
         s = te.create_schedule(B.op)
@@ -294,15 +294,14 @@ def test_crossthread_reduction2(target, ctx):
         func = sched(nthdx, nthdy)
         nn = 3
         # checks three typical cases
-        vx = [nthdx-1, nthdx, nthdx+1]
-        vy = [nthdy-1, nthdy, nthdy+1]
+        vx = [nthdx - 1, nthdx, nthdx + 1]
+        vy = [nthdy - 1, nthdy, nthdy + 1]
         for kk0, kk1 in [(x, y) for x in vx for y in vy]:
             size = (nn, kk0, kk1)
             a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
             b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
             func(a, b)
-            tvm.testing.assert_allclose(b.asnumpy(),
-                                        np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
+            tvm.testing.assert_allclose(b.asnumpy(), np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
 
     verify(16, 16)
     verify(32, 32)
@@ -313,11 +312,9 @@ def test_crossthread_reduction2(target, ctx):
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_reduction_binding():
-    k = te.reduce_axis((0, 32), 'k')
-    A = te.placeholder((96, 32), name='A')
-    B = te.compute((96,), lambda m:
-                   te.sum(A[m, k], axis=k),
-                   name='B')
+    k = te.reduce_axis((0, 32), "k")
+    A = te.placeholder((96, 32), name="A")
+    B = te.compute((96,), lambda m: te.sum(A[m, k], axis=k), name="B")
     s = te.create_schedule(B.op)
 
     s[B].reorder(B.op.reduce_axis[0], B.op.axis[0])
@@ -330,13 +327,9 @@ def test_cuda_reduction_binding():
 
 @tvm.testing.parametrize_targets("cuda", "rocm")
 def test_rfactor_predicates(target, ctx):
-    n = te.reduce_axis((0, 129), 'n')
-    A = te.placeholder((129,), name='A')
-    B = te.compute((1, ), lambda b:
-                   te.sum(A[n],
-                          axis=n),
-                   name='B'
-                   )
+    n = te.reduce_axis((0, 129), "n")
+    A = te.placeholder((129,), name="A")
+    B = te.compute((1,), lambda b: te.sum(A[n], axis=n), name="B")
 
     s = te.create_schedule(B.op)
 
@@ -366,18 +359,19 @@ def test_cuda_const_float_to_half():
     # This import is required to use nvcc to perform code gen;
     # otherwise it is found that the code gen is done by nvrtc.
     from tvm import autotvm
+
     shape = (2, 3, 4)
-    a = te.placeholder(shape, dtype='float16', name='a')
-    b = tvm.tir.const(0.5, dtype='float16')
-    c = te.compute(shape, lambda i, j, k: a[i, j, k] > b, name='c')
+    a = te.placeholder(shape, dtype="float16", name="a")
+    b = tvm.tir.const(0.5, dtype="float16")
+    c = te.compute(shape, lambda i, j, k: a[i, j, k] > b, name="c")
     s = te.create_schedule(c.op)
     axes = [axis for axis in c.op.axis]
     fused = s[c].fuse(*axes)
     bx, tx = s[c].split(fused, factor=64)
-    s[c].bind(bx, te.thread_axis('blockIdx.x'))
-    s[c].bind(tx, te.thread_axis('threadIdx.x'))
+    s[c].bind(bx, te.thread_axis("blockIdx.x"))
+    s[c].bind(tx, te.thread_axis("threadIdx.x"))
 
-    func = tvm.build(s, [a, c], 'cuda')
+    func = tvm.build(s, [a, c], "cuda")
     ctx = tvm.gpu(0)
     a_np = np.random.uniform(size=shape).astype(a.dtype)
     c_np = np.zeros(shape=shape, dtype=c.dtype)
@@ -456,19 +450,19 @@ def test_cuda_floordiv_with_vectorization():
         # B[i] = A[floordiv(i, k)]
         n = 256
         k = 37
-        A = te.placeholder((n,), name='A')
-        B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name="B")
         s = te.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], nparts=1)
         xio, xii = s[B].split(xi, factor=4)
         s[B].vectorize(xii)
         s[B].bind(xo, bx)
         s[B].bind(xio, tx)
-        func = tvm.build(s, [A, B], 'cuda')
+        func = tvm.build(s, [A, B], "cuda")
 
         ctx = tvm.gpu(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
-        b_np = np.array([a_np[i//k] for i in range(0, n)])
+        b_np = np.array([a_np[i // k] for i in range(0, n)])
         a_nd = tvm.nd.array(a_np, ctx)
         b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
         func(a_nd, b_nd)
@@ -482,15 +476,15 @@ def test_cuda_floormod_with_vectorization():
         # B[i] = A[floormod(i, k)]
         n = 256
         k = 37
-        A = te.placeholder((n,), name='A')
-        B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name="B")
         s = te.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], nparts=1)
         xio, xii = s[B].split(xi, factor=4)
         s[B].vectorize(xii)
         s[B].bind(xo, bx)
         s[B].bind(xio, tx)
-        func = tvm.build(s, [A, B], 'cuda')
+        func = tvm.build(s, [A, B], "cuda")
 
         ctx = tvm.gpu(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
@@ -511,10 +505,9 @@ def test_vectorized_casts():
 
         # compute
         n = 128
-        A = te.placeholder((n,), dtype=t0, name='A')
-        B = te.placeholder((n,), dtype=t1, name='B')
-        C = te.compute((n,), lambda i: A[i] +
-                       topi.cast(B[i], A.dtype), name='C')
+        A = te.placeholder((n,), dtype=t0, name="A")
+        B = te.placeholder((n,), dtype=t1, name="B")
+        C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name="C")
 
         # schedule
         s = tvm.te.create_schedule(C.op)
@@ -526,8 +519,7 @@ def test_vectorized_casts():
 
         # correctness
         ctx = tvm.gpu(0)
-        low, high = (0, 20) if t0.startswith(
-            'u') or t1.startswith('u') else (-10, 10)
+        low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10)
         a_np = np.random.randint(low, high, size=n).astype(A.dtype)
         b_np = np.random.randint(low, high, size=n).astype(B.dtype)
         c_np = (a_np + b_np).astype(A.dtype)
@@ -546,8 +538,7 @@ def test_vectorized_casts():
             return True
         return False
 
-    types = ["float16", "float32", "int8", "uint8",
-             "int16", "uint16", "int32", "uint32"]
+    types = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
     for t0, t1 in [(x, y) for x in types for y in types if not skip(x, y)]:
         check(t0, t1)
 
@@ -593,29 +584,29 @@ def test_vectorized_intrin1():
             print("Skip because gpu does not have fp16 support")
             return
         # set of intrinsics does not support fp16 yet.
-        skip_set = {tvm.tir.abs,
-                    tvm.tir.round,
-                    tvm.tir.tan,
-                    tvm.tir.atan,
-                    tvm.tir.tanh,
-                    tvm.tir.cosh,
-                    tvm.tir.sinh}
+        skip_set = {
+            tvm.tir.abs,
+            tvm.tir.round,
+            tvm.tir.tan,
+            tvm.tir.atan,
+            tvm.tir.tanh,
+            tvm.tir.cosh,
+            tvm.tir.sinh,
+        }
         if dtype == "float16" and tvm_intrin in skip_set:
-            print("Skip because '{0}' does not support fp16 yet".format(
-                tvm_intrin.__name__))
+            print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__))
             return
 
         n = 128
-        A = te.placeholder((n,), dtype=dtype, name='A')
-        B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name='B')
+        A = te.placeholder((n,), dtype=dtype, name="A")
+        B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
         ctx = tvm.gpu(0)
         a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), ctx)
         f(a, b)
-        tvm.testing.assert_allclose(b.asnumpy(), np_func(
-            a.asnumpy()), atol=1e-3, rtol=1e-3)
+        tvm.testing.assert_allclose(b.asnumpy(), np_func(a.asnumpy()), atol=1e-3, rtol=1e-3)
 
     for func in test_funcs:
         run_test(*func, "float32")
@@ -628,21 +619,20 @@ def test_vectorized_intrin2(dtype="float32"):
     c2 = tvm.tir.const(2, dtype=dtype)
     test_funcs = [
         (tvm.tir.power, lambda x: np.power(x, 2.0)),
-        (tvm.tir.fmod, lambda x: np.fmod(x, 2.0))
+        (tvm.tir.fmod, lambda x: np.fmod(x, 2.0)),
     ]
 
     def run_test(tvm_intrin, np_func):
         n = 128
-        A = te.placeholder((n,), dtype=dtype, name='A')
-        B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name='B')
+        A = te.placeholder((n,), dtype=dtype, name="A")
+        B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
         ctx = tvm.gpu(0)
         a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), ctx)
         f(a, b)
-        tvm.testing.assert_allclose(b.asnumpy(), np_func(
-            a.asnumpy()), atol=1e-3, rtol=1e-3)
+        tvm.testing.assert_allclose(b.asnumpy(), np_func(a.asnumpy()), atol=1e-3, rtol=1e-3)
 
     for func in test_funcs:
         run_test(*func)
@@ -660,13 +650,12 @@ def test_vectorized_popcount():
 
     def run_test(dtype):
         n = 128
-        A = te.placeholder((n,), dtype=dtype, name='A')
-        B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name='B')
+        A = te.placeholder((n,), dtype=dtype, name="A")
+        B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
         ctx = tvm.gpu(0)
-        a = tvm.nd.array(np.random.randint(
-            0, 100000, size=n).astype(A.dtype), ctx)
+        a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), ctx)
         f(a, b)
         ref = np.vectorize(ref_popcount)(a.asnumpy())
@@ -685,26 +674,30 @@ def test_cuda_vectorize_load_permute_pad():
             return
 
         ctx = tvm.gpu(0)
-        A = tvm.te.placeholder((n, l), name='A', dtype=dtype)
-        B = tvm.te.compute((n // lanes, l + 2 * padding, lanes),
-                           lambda i, j, k: tvm.te.if_then_else(
-            tvm.te.any(j < padding, j >= l + padding),
-            tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding]),
-            name='B')
+        A = tvm.te.placeholder((n, l), name="A", dtype=dtype)
+        B = tvm.te.compute(
+            (n // lanes, l + 2 * padding, lanes),
+            lambda i, j, k: tvm.te.if_then_else(
+                tvm.te.any(j < padding, j >= l + padding),
+                tvm.runtime.convert(0).astype(dtype),
+                A[i * lanes + k, j - padding],
+            ),
+            name="B",
+        )
         s = te.create_schedule(B.op)
         block, thread, vectorize = s[B].op.axis
         s[B].bind(block, bx)
         s[B].bind(thread, tx)
         s[B].vectorize(vectorize)
         fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad")
-        np_a = np.random.randint(
-            low=-128, high=127, size=(n, l)).astype(A.dtype)
+        np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype)
         a = tvm.nd.empty((n, l), A.dtype, ctx).copyfrom(np_a)
         b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, ctx)
         fun(a, b)
         np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1)
-        ref = np.pad(np_a_reshape, ((0, 0), (padding, padding),
-                                    (0, 0)), mode='constant', constant_values=0)
+        ref = np.pad(
+            np_a_reshape, ((0, 0), (padding, padding), (0, 0)), mode="constant", constant_values=0
+        )
         tvm.testing.assert_allclose(b.asnumpy(), ref)
 
     check_cuda("int8", 64, 16, 3, 2)
@@ -733,8 +726,7 @@ def vcf_check_common(s, args):
         if isinstance(stmt, tvm.tir.Broadcast):
             inside_broadcast[0] = True
             # Check Broadcast[Imm numbers] or Broadcast[Load] patterns
-            assert isinstance(stmt.value, (tvm.tir.IntImm,
-                                           tvm.tir.FloatImm, tvm.tir.Load))
+            assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load))
         if isinstance(stmt, tvm.tir.Store):
             # Check Store[Ramp] pattern
             assert isinstance(stmt.index, tvm.tir.Ramp)
@@ -750,7 +742,7 @@ def vcf_check_common(s, args):
             inside_broadcast[0] = False
         return None
 
-    tvm.tir.stmt_functor.ir_transform(stmt['main'].body, pre_visit, post_visit)
+    tvm.tir.stmt_functor.ir_transform(stmt["main"].body, pre_visit, post_visit)
 
     tgt = tvm.target.cuda()
     mod = tvm.build(s, args, tgt)
@@ -762,17 +754,16 @@ def vcf_check_common(s, args):
     b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx)
     c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx)
     mod(a, b, c)
-    tvm.testing.assert_allclose(c.asnumpy(), np.dot(
-        a.asnumpy(), b.asnumpy()), rtol=1e-5)
+    tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
 
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_vectorized_cooperative_fetching_x():
     N = 512
-    A = te.placeholder((N, N), name='A', dtype='float32')
-    B = te.placeholder((N, N), name='B', dtype='float32')
-    k = te.reduce_axis((0, N), name='k')
+    A = te.placeholder((N, N), name="A", dtype="float32")
+    B = te.placeholder((N, N), name="B", dtype="float32")
+    k = te.reduce_axis((0, N), name="k")
     C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
     s = te.create_schedule(C.op)
     i, j = s[C].op.axis
@@ -822,9 +813,9 @@ def test_vectorized_cooperative_fetching_x():
 @tvm.testing.requires_cuda
 def test_vectorized_cooperative_fetching_xy():
     N = 512
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='B')
-    k = te.reduce_axis((0, N), name='k')
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="B")
+    k = te.reduce_axis((0, N), name="k")
     C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
     s = te.create_schedule(C.op)
     i, j = s[C].op.axis
@@ -877,16 +868,15 @@ def test_vectorized_cooperative_fetching_xy():
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_unrolled_vectorization():
-    dtype = 'float32'
-    target = 'cuda'
+    dtype = "float32"
+    target = "cuda"
 
     # Compute declaration
     N = 128
-    A = te.placeholder((N, N), name='A')
-    B = te.placeholder((N, N), name='B')
-    k = te.reduce_axis((0, N), name='k')
-    C = te.compute((N, N), lambda i, j: te.sum(
-        A[i][k] * B[k][j], axis=[k]), name='C')
+    A = te.placeholder((N, N), name="A")
+    B = te.placeholder((N, N), name="B")
+    k = te.reduce_axis((0, N), name="k")
+    C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
 
     # Schedule
     s = te.create_schedule([C.op])
index 3289e38..3cdcb2d 100644 (file)
@@ -20,14 +20,15 @@ from tvm.contrib import util
 import numpy as np
 import tvm.testing
 
+
 @tvm.testing.requires_gpu
 def test_large_uint_imm():
-    value =  (1 << 63) + 123
+    value = (1 << 63) + 123
     other = tvm.tir.const(3, "uint64")
     n = 12
     num_thread = 2
 
-    A = te.compute((n,), lambda *i: tvm.tir.const(value, "uint64") + other, name='A')
+    A = te.compute((n,), lambda *i: tvm.tir.const(value, "uint64") + other, name="A")
     s = te.create_schedule(A.op)
     xo, xi = s[A].split(A.op.axis[0], factor=num_thread)
     s[A].bind(xi, te.thread_axis("threadIdx.x"))
@@ -39,7 +40,7 @@ def test_large_uint_imm():
         ctx = tvm.context(device, 0)
         f = tvm.build(s, [A], device)
         # launch the kernel.
-        a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx)
+        a = tvm.nd.empty((n,), dtype=A.dtype, ctx=ctx)
         f(a)
         assert a.asnumpy()[0] == value + 3
 
@@ -49,11 +50,11 @@ def test_large_uint_imm():
 
 @tvm.testing.requires_gpu
 def test_add_pipeline():
-    n = te.size_var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(), name='C')
-    D = te.compute(A.shape, lambda *i: C(*i) + 1, name='D')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(), name="C")
+    D = te.compute(A.shape, lambda *i: C(*i) + 1, name="D")
     s = te.create_schedule(D.op)
 
     # GPU schedule have to split by gridIdx and threadIdx
@@ -78,8 +79,7 @@ def test_add_pipeline():
         b = tvm.nd.array(np.random.uniform(size=()).astype(B.dtype), ctx)
         d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
         f(a, b, d)
-        tvm.testing.assert_allclose(
-            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
+        tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
 
     check_target("cuda", host="llvm")
     check_target("nvptx", host="llvm")
index ef98816..032b105 100644 (file)
@@ -19,18 +19,23 @@ from tvm import te
 import numpy as np
 import tvm.testing
 
+
 @tvm.testing.uses_gpu
 def test_add_pipeline():
     nn = 64
     max_threads = 4
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
 
     def extern_generator(ins, outs):
         """Manually write the IR for the extern function, add pipeline"""
         ib = tvm.tir.ir_builder.create()
-        with ib.for_range(0, (n+1) // 2) as i:
-            ib.emit(outs[0].vstore(i*2, ins[0].vload(i*2, "float32x2") + tvm.tir.const(1, "float32x2")))
+        with ib.for_range(0, (n + 1) // 2) as i:
+            ib.emit(
+                outs[0].vstore(
+                    i * 2, ins[0].vload(i * 2, "float32x2") + tvm.tir.const(1, "float32x2")
+                )
+            )
         return ib.get()
 
     def extern_generator_gpu(ins, outs):
@@ -38,15 +43,19 @@ def test_add_pipeline():
         ib = tvm.tir.ir_builder.create()
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(bx, "thread_extent", (nn+max_threads-1) // max_threads)
+        ib.scope_attr(bx, "thread_extent", (nn + max_threads - 1) // max_threads)
         ib.scope_attr(tx, "thread_extent", max_threads)
         idx = bx.var * max_threads + tx.var
         with ib.if_scope(ib.likely(idx < n)):
-            ib.emit(outs[0].vstore(idx*2, ins[0].vload(idx*2, "float32x2") + tvm.tir.const(1, "float32x2")))
+            ib.emit(
+                outs[0].vstore(
+                    idx * 2, ins[0].vload(idx * 2, "float32x2") + tvm.tir.const(1, "float32x2")
+                )
+            )
         return ib.get()
 
-    C_cpu = te.extern(A.shape, [A], extern_generator, name='C')
-    C_gpu = te.extern(A.shape, [A], extern_generator_gpu, name='C')
+    C_cpu = te.extern(A.shape, [A], extern_generator, name="C")
+    C_gpu = te.extern(A.shape, [A], extern_generator_gpu, name="C")
     s_cpu = te.create_schedule(C_cpu.op)
     s_gpu = te.create_schedule(C_gpu.op)
     print(tvm.lower(s_cpu, [A, C_cpu], simple_mode=True))
@@ -55,8 +64,8 @@ def test_add_pipeline():
     def check_target(target):
         if not tvm.testing.device_enabled(target):
             return
-        s = s_gpu if target in ['opencl', 'cuda'] else s_cpu
-        C = C_gpu if target in ['opencl', 'cuda'] else C_cpu
+        s = s_gpu if target in ["opencl", "cuda"] else s_cpu
+        C = C_gpu if target in ["opencl", "cuda"] else C_cpu
         # build and invoke the kernel.
         f = tvm.build(s, [A, C], target)
         ctx = tvm.context(target, 0)
@@ -71,22 +80,23 @@ def test_add_pipeline():
     check_target("opencl")
     check_target("cuda")
 
+
 def test_pack_buffer_simple():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
+
     def extern_generator(ins, outs):
         """Manually write the IR for the extern function, add pipeline."""
         return tvm.tir.call_packed("my_extern_array_func1", ins[0], outs[0])
 
-    C = te.extern(A.shape, [A], extern_generator, name='C')
+    C = te.extern(A.shape, [A], extern_generator, name="C")
     s = te.create_schedule(C.op)
 
     @tvm.register_func
     def my_extern_array_func1(aa, bb):
         aa.copyto(bb)
 
-
     def check_target(target):
         if not tvm.testing.device_enabled(target):
             return
@@ -99,8 +109,8 @@ def test_pack_buffer_simple():
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
 
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy())
+
     check_target("stackvm")
     check_target("llvm")
 
@@ -108,13 +118,14 @@ def test_pack_buffer_simple():
 def test_pack_buffer_intermediate():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     B = te.compute((n,), lambda i: A[i] + 1, name="B")
+
     def extern_generator(ins, outs):
         """Manually write the IR for the extern function, add pipeline."""
         return tvm.tir.call_packed("my_extern_array_func2", ins[0], outs[0])
 
-    C = te.extern(B.shape, [B], extern_generator, name='C')
+    C = te.extern(B.shape, [B], extern_generator, name="C")
     s = te.create_schedule(C.op)
 
     def check_target(target):
@@ -131,13 +142,11 @@ def test_pack_buffer_intermediate():
         @tvm.register_func
         def my_extern_array_func2(aa, bb):
             assert aa.shape == a.shape
-            tvm.testing.assert_allclose(
-                aa.asnumpy(), a.asnumpy() + 1)
+            tvm.testing.assert_allclose(aa.asnumpy(), a.asnumpy() + 1)
             aa.copyto(bb)
 
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
 
     check_target("llvm")
 
index 1478e2e..d42693e 100644 (file)
@@ -23,40 +23,40 @@ import tvm.contrib.hexagon as hexagon
 
 def check_prereq_and_setup():
     if tvm.target.codegen.llvm_version_major() <= 7:
-        print('Skipping test: need LLVM 7 or later for codegen')
+        print("Skipping test: need LLVM 7 or later for codegen")
         return False
-    if os.name != 'posix':
-        print('Skipping test on non-POSIX platforms')
+    if os.name != "posix":
+        print("Skipping test on non-POSIX platforms")
         return False
-    if not tvm.runtime.enabled('hexagon'):
-        print('Hexagon runtime not enabled')
+    if not tvm.runtime.enabled("hexagon"):
+        print("Hexagon runtime not enabled")
         return False
     # Register a phony linker, so that we can test codegen without a Hexagon toolchain.
-    hexagon.register_linker(lambda: '/bin/true')
+    hexagon.register_linker(lambda: "/bin/true")
     return True
 
 
 def test_basic():
     if not check_prereq_and_setup():
         return
-    target = tvm.target.hexagon('v66', hvx=128)
+    target = tvm.target.hexagon("v66", hvx=128)
 
     def check_add(offload):
-        A = tvm.te.placeholder((128,), dtype='uint8', name='A')
-        B = tvm.te.placeholder((128,), dtype='uint8', name='A')
-        C = tvm.te.compute((128,), lambda i: A[i] + B[i], name='C')
+        A = tvm.te.placeholder((128,), dtype="uint8", name="A")
+        B = tvm.te.placeholder((128,), dtype="uint8", name="A")
+        C = tvm.te.compute((128,), lambda i: A[i] + B[i], name="C")
         s = tvm.te.create_schedule(C.op)
 
         if offload:
             xo, xi = s[C].split(s[C].op.axis[0], nparts=1)
-            s[C].bind(xo, tvm.te.thread_axis('pipeline'))
-            m = tvm.build(s, [C, A, B], target=target, name='offload_add')
+            s[C].bind(xo, tvm.te.thread_axis("pipeline"))
+            m = tvm.build(s, [C, A, B], target=target, name="offload_add")
             hexm = m.imported_modules[0]
         else:
-            hexm = tvm.build(s, [C, A, B], target=target, target_host=target, name='native_add')
+            hexm = tvm.build(s, [C, A, B], target=target, target_host=target, name="native_add")
 
-        asm = hexm.get_source('s')
-        vadds = re.findall(r'v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)', asm)
+        asm = hexm.get_source("s")
+        vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm)
         assert vadds  # Check that it's non-empty
 
     check_add(True)
@@ -66,30 +66,30 @@ def test_basic():
 def test_alloc_vtcm():
     if not check_prereq_and_setup():
         return
-    target = tvm.target.hexagon('v66')
+    target = tvm.target.hexagon("v66")
 
     buf_len = 2048
-    A = tvm.te.placeholder((buf_len,), name='A', dtype='int8')
-    B = tvm.te.placeholder((buf_len,), name='B', dtype='int8')
+    A = tvm.te.placeholder((buf_len,), name="A", dtype="int8")
+    B = tvm.te.placeholder((buf_len,), name="B", dtype="int8")
 
-    A_buf = tvm.te.compute((buf_len,), lambda *i: A(*i), 'A_buf')
-    B_buf = tvm.te.compute((buf_len,), lambda *i: B(*i), 'B_buf')
-    C = tvm.te.compute((buf_len,), lambda *i: A_buf(*i) + B_buf(*i), name='C')
+    A_buf = tvm.te.compute((buf_len,), lambda *i: A(*i), "A_buf")
+    B_buf = tvm.te.compute((buf_len,), lambda *i: B(*i), "B_buf")
+    C = tvm.te.compute((buf_len,), lambda *i: A_buf(*i) + B_buf(*i), name="C")
     s = tvm.te.create_schedule(C.op)
 
     # Use VTCM for each buffer.
     s[A_buf].set_scope("local.vtcm")
     s[B_buf].set_scope("local.vtcm")
 
-    config = {'tir.add_lower_pass': hexagon.ir_lower_vtcm_pass()}
-    with tvm.transform.PassContext(config = config):
-        irmod = tvm.lower(s, [A, B, C], name = 'alloc_vtcm')
+    config = {"tir.add_lower_pass": hexagon.ir_lower_vtcm_pass()}
+    with tvm.transform.PassContext(config=config):
+        irmod = tvm.lower(s, [A, B, C], name="alloc_vtcm")
 
-    calls = re.findall('HexagonBackend[A-Za-z]*VTCM', str(irmod['alloc_vtcm']))
-    assert 'HexagonBackendAllocateVTCM' in calls
-    assert 'HexagonBackendFreeVTCM' in calls
+    calls = re.findall("HexagonBackend[A-Za-z]*VTCM", str(irmod["alloc_vtcm"]))
+    assert "HexagonBackendAllocateVTCM" in calls
+    assert "HexagonBackendFreeVTCM" in calls
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_basic()
     test_alloc_vtcm()
index 9aa843e..98190ac 100644 (file)
@@ -29,19 +29,11 @@ def test_llvm_intrin():
     ib = tvm.tir.ir_builder.create()
     n = tvm.runtime.convert(4)
     A = ib.pointer("float32", name="A")
-    args = [
-        tvm.tir.call_intrin("handle", "tir.address_of", A[0]),
-        0, 3, 1
-    ]
-    ib.emit(tvm.tir.Evaluate(
-        tvm.tir.Call(
-            "int32", "tir.prefetch", args)))
+    args = [tvm.tir.call_intrin("handle", "tir.address_of", A[0]), 0, 3, 1]
+    ib.emit(tvm.tir.Evaluate(tvm.tir.Call("int32", "tir.prefetch", args)))
     body = ib.get()
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A], body).with_attr(
-            "global_symbol", "prefetch")
-    )
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "prefetch"))
     fcode = tvm.build(mod, None, "llvm")
 
 
@@ -50,12 +42,10 @@ def test_llvm_void_intrin():
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("uint8", name="A")
     # Create an intrinsic that returns void.
-    x = tvm.tir.call_llvm_intrin(
-        '', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A)
+    x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A)
     ib.emit(x)
     body = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
     fcode = tvm.build(mod, None, "llvm")
 
 
@@ -69,19 +59,20 @@ def test_llvm_overloaded_intrin():
     def use_llvm_intrinsic(A, C):
         ib = tvm.tir.ir_builder.create()
         L = A.vload((0, 0))
-        I = tvm.tir.call_llvm_pure_intrin('int32', 'llvm.ctlz',
-                                          tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
+        I = tvm.tir.call_llvm_pure_intrin(
+            "int32", "llvm.ctlz", tvm.tir.const(2, "uint32"), L, tvm.tir.const(0, "int1")
+        )
         S = C.vstore((0, 0), I)
         ib.emit(S)
         return ib.get()
 
-    A = tvm.te.placeholder((1, 1), dtype='int32', name='A')
-    C = tvm.te.extern((1, 1), [A],
-                      lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]),
-                      name='C', dtype='int32')
+    A = tvm.te.placeholder((1, 1), dtype="int32", name="A")
+    C = tvm.te.extern(
+        (1, 1), [A], lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), name="C", dtype="int32"
+    )
 
     s = tvm.te.create_schedule(C.op)
-    f = tvm.build(s, [A, C], target='llvm')
+    f = tvm.build(s, [A, C], target="llvm")
 
 
 @tvm.testing.requires_llvm
@@ -93,10 +84,10 @@ def test_llvm_import():
     }
     """
     n = 10
-    A = te.placeholder((n,), name='A')
-    B = te.compute((n,), lambda *i:
-                   tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0),
-                   name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(
+        (n,), lambda *i: tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0), name="B"
+    )
 
     def check_llvm(use_file):
         if not clang.find_clang(required=False):
@@ -117,8 +108,8 @@ def test_llvm_import():
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
         f(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), a.asnumpy() + 1.0)
+        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1.0)
+
     check_llvm(use_file=True)
     check_llvm(use_file=False)
 
@@ -127,13 +118,13 @@ def test_llvm_import():
 def test_llvm_lookup_intrin():
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("uint8x8", name="A")
-    z = tvm.tir.const(0, 'int32')
+    z = tvm.tir.const(0, "int32")
     x = tvm.tir.call_llvm_pure_intrin(
-        "uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
+        "uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, "uint32"), A[z]
+    )
     ib.emit(x)
     body = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
     fcode = tvm.build(mod, None, "llvm")
 
 
@@ -141,8 +132,7 @@ def test_llvm_lookup_intrin():
 def test_llvm_large_uintimm():
     value = (1 << 63) + 123
     other = tvm.tir.const(3, "uint64")
-    A = te.compute((), lambda: tvm.tir.const(
-        value, "uint64") + other, name='A')
+    A = te.compute((), lambda: tvm.tir.const(value, "uint64") + other, name="A")
     s = te.create_schedule(A.op)
 
     def check_llvm():
@@ -160,12 +150,12 @@ def test_llvm_large_uintimm():
 def test_llvm_add_pipeline():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    AA = te.compute((n,), lambda *i: A(*i), name='A')
-    BB = te.compute((n,), lambda *i: B(*i), name='B')
-    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
-    C = te.compute(A.shape, lambda *i: T(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    AA = te.compute((n,), lambda *i: A(*i), name="A")
+    BB = te.compute((n,), lambda *i: B(*i), name="B")
+    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name="T")
+    C = te.compute(A.shape, lambda *i: T(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     xo1, xo2 = s[C].split(xo, factor=13)
@@ -178,10 +168,8 @@ def test_llvm_add_pipeline():
     def check_llvm():
         # Specifically allow offset to test codepath when offset is available
         Ab = tvm.tir.decl_buffer(
-            A.shape, A.dtype,
-            elem_offset=te.size_var('Aoffset'),
-            offset_factor=8,
-            name='A')
+            A.shape, A.dtype, elem_offset=te.size_var("Aoffset"), offset_factor=8, name="A"
+        )
         binds = {A: Ab}
         # BUILD and invoke the kernel.
         f = tvm.build(s, [A, B, C], "llvm", binds=binds)
@@ -192,8 +180,7 @@ def test_llvm_add_pipeline():
         b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         f(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
     check_llvm()
 
@@ -201,9 +188,9 @@ def test_llvm_add_pipeline():
 @tvm.testing.requires_llvm
 def test_llvm_persist_parallel():
     n = 128
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1, name='B')
-    C = te.compute(A.shape, lambda *i: te.sqrt(B(*i)) * 2 + 2, name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B")
+    C = te.compute(A.shape, lambda *i: te.sqrt(B(*i)) * 2 + 2, name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=8)
     xo1, xo2 = s[C].split(xo, nparts=1)
@@ -222,9 +209,7 @@ def test_llvm_persist_parallel():
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         f(a, c)
-        tvm.testing.assert_allclose(c.asnumpy(),
-                                    np.sqrt(a.asnumpy() + 1) * 2 + 2,
-                                    rtol=1e-5)
+        tvm.testing.assert_allclose(c.asnumpy(), np.sqrt(a.asnumpy() + 1) * 2 + 2, rtol=1e-5)
 
     check_llvm()
 
@@ -233,8 +218,8 @@ def test_llvm_persist_parallel():
 def test_llvm_flip_pipeline():
     def check_llvm(nn, base):
         n = tvm.runtime.convert(nn)
-        A = te.placeholder((n + base), name='A')
-        C = te.compute((n,), lambda i: A(nn + base - i - 1), name='C')
+        A = te.placeholder((n + base), name="A")
+        C = te.compute((n,), lambda i: A(nn + base - i - 1), name="C")
         s = te.create_schedule(C.op)
         xo, xi = s[C].split(C.op.axis[0], factor=4)
         s[C].parallel(xo)
@@ -244,12 +229,11 @@ def test_llvm_flip_pipeline():
         ctx = tvm.cpu(0)
         # launch the kernel.
         n = nn
-        a = tvm.nd.array(np.random.uniform(
-            size=(n + base)).astype(A.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy()[::-1][:n])
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy()[::-1][:n])
+
     check_llvm(4, 0)
     check_llvm(128, 8)
     check_llvm(3, 0)
@@ -259,10 +243,9 @@ def test_llvm_flip_pipeline():
 @tvm.testing.requires_llvm
 def test_llvm_vadd_pipeline():
     def check_llvm(n, lanes):
-        A = te.placeholder((n,), name='A', dtype="float32x%d" % lanes)
-        B = te.compute((n,), lambda i: A[i], name='B')
-        C = te.compute((n,), lambda i: B[i] +
-                       tvm.tir.const(1, A.dtype), name='C')
+        A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes)
+        B = te.compute((n,), lambda i: A[i], name="B")
+        C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C")
         s = te.create_schedule(C.op)
         xo, xi = s[C].split(C.op.axis[0], nparts=2)
         _, xi = s[C].split(xi, factor=2)
@@ -275,12 +258,11 @@ def test_llvm_vadd_pipeline():
         f = tvm.build(s, [A, C], "llvm")
         ctx = tvm.cpu(0)
         # launch the kernel.
-        a = tvm.nd.empty((n,), A.dtype).copyfrom(
-            np.random.uniform(size=(n, lanes)))
+        a = tvm.nd.empty((n,), A.dtype).copyfrom(np.random.uniform(size=(n, lanes)))
         c = tvm.nd.empty((n,), C.dtype, ctx)
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
+
     check_llvm(64, 2)
     check_llvm(512, 2)
 
@@ -289,8 +271,8 @@ def test_llvm_vadd_pipeline():
 def test_llvm_madd_pipeline():
     def check_llvm(nn, base, stride):
         n = tvm.runtime.convert(nn)
-        A = te.placeholder((n + base, stride), name='A')
-        C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C')
+        A = te.placeholder((n + base, stride), name="A")
+        C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name="C")
         s = te.create_schedule(C.op)
         xo, xi = s[C].split(C.op.axis[0], factor=4)
         s[C].parallel(xo)
@@ -300,12 +282,11 @@ def test_llvm_madd_pipeline():
         ctx = tvm.cpu(0)
         # launch the kernel.
         n = nn
-        a = tvm.nd.array(np.random.uniform(
-            size=(n + base, stride)).astype(A.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx)
         c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx)
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy()[base:] + 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy()[base:] + 1)
+
     check_llvm(64, 0, 2)
     check_llvm(4, 0, 1)
 
@@ -317,9 +298,9 @@ def test_llvm_madd_pipeline():
 def test_llvm_temp_space():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda i: A(i) + 1, name='B')
-    C = te.compute(A.shape, lambda i: B(i) + 1, name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda i: A(i) + 1, name="B")
+    C = te.compute(A.shape, lambda i: B(i) + 1, name="C")
     s = te.create_schedule(C.op)
 
     def check_llvm():
@@ -331,8 +312,8 @@ def test_llvm_temp_space():
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         f(a, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + 1 + 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1 + 1)
+
     check_llvm()
 
 
@@ -340,9 +321,9 @@ def test_llvm_temp_space():
 def test_multiple_func():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     s[C].parallel(xo)
@@ -353,8 +334,8 @@ def test_multiple_func():
         f2 = tvm.lower(s, [A, B, C], name="fadd1")
         f1 = tvm.lower(s, [A, B, C], name="fadd2")
         m = tvm.build([f1, f2], "llvm")
-        fadd2 = m['fadd2']
-        fadd1 = m['fadd1']
+        fadd2 = m["fadd2"]
+        fadd1 = m["fadd1"]
 
         ctx = tvm.cpu(0)
         # launch the kernel.
@@ -363,20 +344,18 @@ def test_multiple_func():
         b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         fadd1(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
         fadd2(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
     check_llvm()
 
 
 @tvm.testing.requires_llvm
 def test_llvm_condition():
     def check_llvm(n, offset):
-        A = te.placeholder((n, ), name='A')
-        C = te.compute((n,), lambda i: tvm.tir.if_then_else(
-            i >= offset, A[i], 0.0), name='C')
+        A = te.placeholder((n,), name="A")
+        C = te.compute((n,), lambda i: tvm.tir.if_then_else(i >= offset, A[i], 0.0), name="C")
         s = te.create_schedule(C.op)
         # build and invoke the kernel.
         f = tvm.build(s, [A, C], "llvm")
@@ -388,33 +367,34 @@ def test_llvm_condition():
         c_np = a.asnumpy()
         c_np[:offset] = 0
         tvm.testing.assert_allclose(c.asnumpy(), c_np)
+
     check_llvm(64, 8)
 
 
 @tvm.testing.requires_llvm
 def test_llvm_bool():
     def check_llvm(n):
-        A = te.placeholder((n, ), name='A', dtype="int32")
-        C = te.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C')
+        A = te.placeholder((n,), name="A", dtype="int32")
+        C = te.compute((n,), lambda i: A[i].equal(1).astype("float"), name="C")
         s = te.create_schedule(C.op)
         # build and invoke the kernel.
         f = tvm.build(s, [A, C], "llvm")
         ctx = tvm.cpu(0)
         # launch the kernel.
-        a = tvm.nd.array(np.random.randint(
-            0, 2, size=(n,)).astype(A.dtype), ctx)
+        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
         c = tvm.nd.empty((n,), C.dtype, ctx)
         f(a, c)
         c_np = a.asnumpy() == 1
         tvm.testing.assert_allclose(c.asnumpy(), c_np)
+
     check_llvm(64)
 
 
 @tvm.testing.requires_llvm
 def test_rank_zero():
     def check_llvm(n):
-        A = te.placeholder((n, ), name='A')
-        scale = te.placeholder((), name='scale')
+        A = te.placeholder((n,), name="A")
+        scale = te.placeholder((), name="scale")
         k = te.reduce_axis((0, n), name="k")
         C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C")
         D = te.compute((), lambda: C() + 1)
@@ -423,14 +403,13 @@ def test_rank_zero():
         f = tvm.build(s, [A, scale, D], "llvm")
         ctx = tvm.cpu(0)
         # launch the kernel.
-        a = tvm.nd.array(np.random.randint(
-            0, 2, size=(n,)).astype(A.dtype), ctx)
-        sc = tvm.nd.array(
-            np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
+        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
+        sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
         d = tvm.nd.empty((), D.dtype, ctx)
         f(a, sc, d)
         d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
         tvm.testing.assert_allclose(d.asnumpy(), d_np)
+
     check_llvm(64)
 
 
@@ -438,33 +417,31 @@ def test_rank_zero():
 def test_rank_zero_bound_checkers():
     def check_llvm(n):
         with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}):
-            A = te.placeholder((n, ), name='A')
-            scale = te.placeholder((), name='scale')
+            A = te.placeholder((n,), name="A")
+            scale = te.placeholder((), name="scale")
             k = te.reduce_axis((0, n), name="k")
-            C = te.compute((), lambda: te.sum(
-                A[k] * scale(), axis=k), name="C")
+            C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C")
             D = te.compute((), lambda: C() + 1)
             s = te.create_schedule(D.op)
             # build and invoke the kernel.
             f = tvm.build(s, [A, scale, D], "llvm")
             ctx = tvm.cpu(0)
             # launch the kernel.
-            a = tvm.nd.array(np.random.randint(
-                0, 2, size=(n,)).astype(A.dtype), ctx)
-            sc = tvm.nd.array(
-                np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
+            a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
+            sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
             d = tvm.nd.empty((), D.dtype, ctx)
             f(a, sc, d)
             d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
             tvm.testing.assert_allclose(d.asnumpy(), d_np)
+
     check_llvm(64)
 
 
 @tvm.testing.requires_llvm
 def test_alignment():
     n = tvm.runtime.convert(1024)
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda i: A[i] * 3, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda i: A[i] * 3, name="B")
     s = te.create_schedule(B.op)
     bx, tx = s[B].split(B.op.axis[0], factor=8)
     s[B].vectorize(tx)
@@ -482,7 +459,7 @@ def test_alignment():
     # listed there.
     def has_param_alignment():
         for l in lines:
-            if re.search(r'test_alignment_compute_\([^(]*align [0-9]', l):
+            if re.search(r"test_alignment_compute_\([^(]*align [0-9]", l):
                 return True
         return False
 
@@ -494,7 +471,7 @@ def test_alignment():
     # a much more detailed analysis of the LLVM IR.
     def has_call_to_assume():
         for l in lines:
-            if re.search(r'call.*llvm.assume', l):
+            if re.search(r"call.*llvm.assume", l):
                 return True
         return False
 
@@ -504,6 +481,7 @@ def test_alignment():
 @tvm.testing.requires_llvm
 def test_llvm_div():
     """Check that the semantics of div and mod is correct"""
+
     def check(start, end, dstart, dend, dtype, floor_div=False):
         div = tvm.te.floordiv if floor_div else tvm.tir.truncdiv
         mod = tvm.te.floormod if floor_div else tvm.tir.truncmod
@@ -513,19 +491,30 @@ def test_llvm_div():
         B = te.placeholder((dend - dstart + 1,), name="B", dtype=dtype)
         # We clip values with min and max so that simplifiers know the ranges of values
 
-        def clipa(x): return tvm.te.min(tvm.tir.const(end, dtype),
-                                        tvm.te.max(tvm.tir.const(start, dtype), x))
-        def clipb(x): return tvm.te.min(tvm.tir.const(dend, dtype),
-                                        tvm.te.max(tvm.tir.const(dstart, dtype), x))
+        def clipa(x):
+            return tvm.te.min(tvm.tir.const(end, dtype), tvm.te.max(tvm.tir.const(start, dtype), x))
+
+        def clipb(x):
+            return tvm.te.min(
+                tvm.tir.const(dend, dtype), tvm.te.max(tvm.tir.const(dstart, dtype), x)
+            )
+
         # If the range is just a single point, use the constant itself
         if start == end:
-            def clipa(x): return tvm.tir.const(start, dtype)
+
+            def clipa(x):
+                return tvm.tir.const(start, dtype)
+
         if dstart == dend:
-            def clipb(x): return tvm.tir.const(dstart, dtype)
+
+            def clipb(x):
+                return tvm.tir.const(dstart, dtype)
+
         # D are division results and M are modulo results
-        [D, M] = te.compute((end - start + 1, dend - dstart + 1),
-                            lambda i, j: (div(clipa(A[i]), clipb(B[j])),
-                                          mod(clipa(A[i]), clipb(B[j]))))
+        [D, M] = te.compute(
+            (end - start + 1, dend - dstart + 1),
+            lambda i, j: (div(clipa(A[i]), clipb(B[j])), mod(clipa(A[i]), clipb(B[j]))),
+        )
 
         s = te.create_schedule([D.op, M.op])
         f = tvm.build(s, [A, B, D, M], "llvm")
@@ -571,71 +560,90 @@ def test_llvm_div():
 
                 if D_arr[i - start, j - dstart] != dref:
                     _show_info()
-                    raise AssertionError("Incorrect division result: {}({}, {}) is {} "
-                                         "but should be {}".format(div.__name__, i, j,
-                                                                   D_arr[i - start,
-                                                                         j - dstart],
-                                                                   dref))
+                    raise AssertionError(
+                        "Incorrect division result: {}({}, {}) is {} "
+                        "but should be {}".format(
+                            div.__name__, i, j, D_arr[i - start, j - dstart], dref
+                        )
+                    )
                 if M_arr[i - start, j - dstart] != mref:
                     _show_info()
-                    raise AssertionError("Incorrect modulo result: {}({}, {}) is {} "
-                                         "but should be {}".format(mod.__name__, i, j,
-                                                                   M_arr[i - start,
-                                                                         j - dstart],
-                                                                   mref))
+                    raise AssertionError(
+                        "Incorrect modulo result: {}({}, {}) is {} "
+                        "but should be {}".format(
+                            mod.__name__, i, j, M_arr[i - start, j - dstart], mref
+                        )
+                    )
 
     # Try different ranges to cover different cases
-    for start, end in [(-12, -12), (-11, -1), (-11,  0), (0, 0),
-                       (12,  12), (1, 11), (0, 11), (-11, 11)]:
-        for dstart, dend in [(-11, -1), (-11,  0), (-4, -4), (-2, -2),
-                             (1, 11), (0, 11), (4,  4), (2,  2), (-11, 11)]:
+    for start, end in [
+        (-12, -12),
+        (-11, -1),
+        (-11, 0),
+        (0, 0),
+        (12, 12),
+        (1, 11),
+        (0, 11),
+        (-11, 11),
+    ]:
+        for dstart, dend in [
+            (-11, -1),
+            (-11, 0),
+            (-4, -4),
+            (-2, -2),
+            (1, 11),
+            (0, 11),
+            (4, 4),
+            (2, 2),
+            (-11, 11),
+        ]:
             if end < start or dend < dstart or (dend == 0 and dstart == 0):
                 continue
-            check(start, end, dstart, dend, 'int32', floor_div=False)
-            check(start, end, dstart, dend, 'int32', floor_div=True)
-            check(start, end, dstart, dend, 'int8', floor_div=False)
-            check(start, end, dstart, dend, 'int8', floor_div=True)
+            check(start, end, dstart, dend, "int32", floor_div=False)
+            check(start, end, dstart, dend, "int32", floor_div=True)
+            check(start, end, dstart, dend, "int8", floor_div=False)
+            check(start, end, dstart, dend, "int8", floor_div=True)
             if start >= 0 and dstart >= 0:
-                check(start, end, dstart, dend, 'uint32', floor_div=False)
-                check(start, end, dstart, dend, 'uint32', floor_div=True)
+                check(start, end, dstart, dend, "uint32", floor_div=False)
+                check(start, end, dstart, dend, "uint32", floor_div=True)
 
     # Additional tests for uint8
     for dstart, dend in [(0, 11), (1, 11), (2, 2), (4, 4)]:
-        check(123, 133, dstart, dend, 'uint8', floor_div=False)
-        check(123, 133, dstart, dend, 'uint8', floor_div=True)
-        check(0, 255, dstart, dend, 'uint8', floor_div=False)
-        check(0, 255, dstart, dend, 'uint8', floor_div=True)
+        check(123, 133, dstart, dend, "uint8", floor_div=False)
+        check(123, 133, dstart, dend, "uint8", floor_div=True)
+        check(0, 255, dstart, dend, "uint8", floor_div=False)
+        check(0, 255, dstart, dend, "uint8", floor_div=True)
 
 
 @tvm.testing.requires_llvm
 def test_llvm_fp_math():
     def check_llvm_reciprocal(n):
-        A = te.placeholder((n,), name='A')
-        B = te.compute((n,), lambda i: te.div(1.0, (1e+37*A[i])), name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.compute((n,), lambda i: te.div(1.0, (1e37 * A[i])), name="B")
 
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm")
 
-        a = tvm.nd.array(np.full((n,), 100, 'float32'))
-        b = tvm.nd.empty((n,), 'float32')
+        a = tvm.nd.array(np.full((n,), 100, "float32"))
+        b = tvm.nd.empty((n,), "float32")
         f(a, b)
-        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))
+        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), "float32"))
 
     check_llvm_reciprocal(4)
     check_llvm_reciprocal(8)
     check_llvm_reciprocal(16)
 
     def check_llvm_sigmoid(n):
-        A = te.placeholder((n,), name='A')
-        B = te.compute((n,), lambda i: te.sigmoid(A[i]), name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.compute((n,), lambda i: te.sigmoid(A[i]), name="B")
 
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm")
 
-        a = tvm.nd.array(np.full((n,), -1000, 'float32'))
-        b = tvm.nd.empty((n,), 'float32')
+        a = tvm.nd.array(np.full((n,), -1000, "float32"))
+        b = tvm.nd.empty((n,), "float32")
         f(a, b)
-        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))
+        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), "float32"))
 
     check_llvm_sigmoid(4)
     check_llvm_sigmoid(8)
@@ -646,9 +654,9 @@ def test_llvm_fp_math():
 def test_dwarf_debug_information():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     s[C].parallel(xo)
@@ -683,7 +691,7 @@ def test_dwarf_debug_information():
             assert re.search(r"""DW_AT_name.*fadd2""", str(output))
 
         # Try objdump (Linux) - Darwin objdump has different DWARF syntax.
-        if shutil.which("objdump") and sys.platform != 'darwin':
+        if shutil.which("objdump") and sys.platform != "darwin":
             output = subprocess.check_output(["objdump", "--dwarf", o_path])
             assert re.search(r"""DW_AT_name.*fadd1""", str(output))
             assert re.search(r"""DW_AT_name.*fadd2""", str(output))
@@ -701,12 +709,12 @@ def test_dwarf_debug_information():
 
         # On non-Darwin OS, don't explicitly specify DWARF version.
         import re
-        assert not re.search(r""""Dwarf Version""""", ll)
+
+        assert not re.search(r""""Dwarf Version""" "", ll)
         assert re.search(r"""llvm.dbg.value""", ll)
 
         # Try Darwin, require DWARF-2
-        m = tvm.build([f1, f2],
-                      target="llvm -mtriple=x86_64-apple-darwin-macho")
+        m = tvm.build([f1, f2], target="llvm -mtriple=x86_64-apple-darwin-macho")
         ll = m.get_source("ll")
         assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll)
         assert re.search(r"""llvm.dbg.value""", ll)
@@ -717,67 +725,64 @@ def test_dwarf_debug_information():
 
 @tvm.testing.requires_llvm
 def test_llvm_shuffle():
-    a = te.placeholder((8, ), 'int32')
-    b = te.placeholder((8, ), 'int32')
-    c = te.compute((8, ), lambda x: a[x] + b[7-x])
+    a = te.placeholder((8,), "int32")
+    b = te.placeholder((8,), "int32")
+    c = te.compute((8,), lambda x: a[x] + b[7 - x])
     sch = te.create_schedule(c.op)
 
     def my_vectorize():
         def vectorizer(op):
             store = op.body
-            idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'),
-                               tvm.tir.const(1, 'int32'), 8)
-            all_ones = tvm.tir.const(1, 'int32x8')
+            idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8)
+            all_ones = tvm.tir.const(1, "int32x8")
             value = store.value
-            b_idx = tvm.tir.Shuffle(
-                [idx], [tvm.tir.const(i, 'int32') for i in range(7, -1, -1)])
-            new_a = tvm.tir.Load('int32x8', value.a.buffer_var, idx, all_ones)
-            new_b = tvm.tir.Load(
-                'int32x8', value.b.buffer_var, b_idx, all_ones)
+            b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)])
+            new_a = tvm.tir.Load("int32x8", value.a.buffer_var, idx, all_ones)
+            new_b = tvm.tir.Load("int32x8", value.b.buffer_var, b_idx, all_ones)
             value = new_a + new_b
             return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
 
         def _transform(f, *_):
             return f.with_body(
-                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
+                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"])
+            )
 
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
 
     with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}):
         ir = tvm.lower(sch, [a, b, c], simple_mode=True)
         module = tvm.build(sch, [a, b, c])
-        a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
-        b_ = tvm.nd.array(np.arange(8, 0, -1, dtype='int32'))
-        c_ = tvm.nd.array(np.zeros((8, ), dtype='int32'))
+        a_ = tvm.nd.array(np.arange(1, 9, dtype="int32"))
+        b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32"))
+        c_ = tvm.nd.array(np.zeros((8,), dtype="int32"))
         module(a_, b_, c_)
-        tvm.testing.assert_allclose(
-            c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
+        tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype("int32"))
 
 
 def np_float2np_bf16(arr):
-    ''' Convert a numpy array of float to a numpy array
-    of bf16 in uint16'''
-    orig = arr.view('<u4')
+    """Convert a numpy array of float to a numpy array
+    of bf16 in uint16"""
+    orig = arr.view("<u4")
     bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
-    return np.right_shift(orig + bias, 16).astype('uint16')
+    return np.right_shift(orig + bias, 16).astype("uint16")
 
 
 def np_float2tvm_bf16(arr):
-    ''' Convert a numpy array of float to a TVM array
-    of bf16'''
+    """Convert a numpy array of float to a TVM array
+    of bf16"""
     nparr = np_float2np_bf16(arr)
-    return tvm.nd.empty(nparr.shape, 'uint16').copyfrom(nparr)
+    return tvm.nd.empty(nparr.shape, "uint16").copyfrom(nparr)
 
 
 def np_bf162np_float(arr):
-    ''' Convert a numpy array of bf16 (uint16) to a numpy array
-    of float'''
-    u32 = np.left_shift(arr.astype('uint32'), 16)
-    return u32.view('<f4')
+    """Convert a numpy array of bf16 (uint16) to a numpy array
+    of float"""
+    u32 = np.left_shift(arr.astype("uint32"), 16)
+    return u32.view("<f4")
 
 
 def np_bf16_cast_and_cast_back(arr):
-    ''' Convert a numpy array of float to bf16 and cast back'''
+    """ Convert a numpy array of float to bf16 and cast back"""
     return np_bf162np_float(np_float2np_bf16(arr))
 
 
@@ -785,38 +790,38 @@ def np_bf16_cast_and_cast_back(arr):
 def test_llvm_bf16():
     def dotest(do_vectorize):
         np.random.seed(122)
-        A = te.placeholder((32, ), dtype='bfloat16')
-        B = te.placeholder((32, ), dtype='bfloat16')
-        d = te.compute((32, ), lambda x: A[x] + B[x])
+        A = te.placeholder((32,), dtype="bfloat16")
+        B = te.placeholder((32,), dtype="bfloat16")
+        d = te.compute((32,), lambda x: A[x] + B[x])
         sch = te.create_schedule(d.op)
         print(tvm.lower(sch, [A, B, d]))
         if do_vectorize:
             sch[d].vectorize(d.op.axis[0])
         module = tvm.build(sch, [A, B, d])
-        npa = np.random.rand(32).astype('float32')
-        npb = np.random.rand(32).astype('float32')
+        npa = np.random.rand(32).astype("float32")
+        npb = np.random.rand(32).astype("float32")
         va = np_bf16_cast_and_cast_back(npa)
         vb = np_bf16_cast_and_cast_back(npb)
         res = np_bf16_cast_and_cast_back(va + vb)
         a_ = np_float2tvm_bf16(npa)
         b_ = np_float2tvm_bf16(npb)
-        c_ = tvm.nd.empty((32,), 'uint16')
+        c_ = tvm.nd.empty((32,), "uint16")
         module(a_, b_, c_)
         tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res)
+
     dotest(True)
     dotest(False)
 
 
 @tvm.testing.requires_llvm
 def test_llvm_crt_static_lib():
-    A = te.placeholder((32, ), dtype='bfloat16')
-    B = te.placeholder((32, ), dtype='bfloat16')
-    d = te.compute((32, ), lambda x: A[x] + B[x])
+    A = te.placeholder((32,), dtype="bfloat16")
+    B = te.placeholder((32,), dtype="bfloat16")
+    d = te.compute((32,), lambda x: A[x] + B[x])
     sch = te.create_schedule(d.op)
-    module = tvm.build(sch, [A, B, d], target=tvm.target.Target(
-        'llvm --system-lib --runtime=c'))
+    module = tvm.build(sch, [A, B, d], target=tvm.target.Target("llvm --system-lib --runtime=c"))
     print(module.get_source())
-    module.save('test.o')
+    module.save("test.o")
 
 
 if __name__ == "__main__":
index 9a03a79..8a070da 100644 (file)
@@ -18,18 +18,19 @@ import tvm
 from tvm import te
 import tvm.testing
 
-target = 'opencl'
+target = "opencl"
+
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_opencl
 def test_opencl_ternary_expression():
     def check_if_then_else(ctx, n, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         true_value = tvm.tir.const(1, dtype=dtype)
         false_value = tvm.tir.const(3, dtype=dtype)
         max_lhs = tvm.tir.const(2, dtype=dtype)
         max_rhs = tvm.tir.if_then_else(A[0] > 0, true_value, false_value)
-        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C')
+        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
         fun = tvm.build(s, [A, C], target)
@@ -40,12 +41,12 @@ def test_opencl_ternary_expression():
         fun(a, c)
 
     def check_select(ctx, n, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         true_value = tvm.tir.const(1, dtype=dtype)
         false_value = tvm.tir.const(3, dtype=dtype)
         max_lhs = tvm.tir.const(2, dtype=dtype)
         max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
-        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C')
+        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
         fun = tvm.build(s, [A, C], target)
@@ -57,22 +58,23 @@ def test_opencl_ternary_expression():
 
     ctx = tvm.context(target, 0)
 
-    check_if_then_else(ctx, 1, 'int8')
-    check_if_then_else(ctx, 1, 'uint8')
-    check_if_then_else(ctx, 1, 'int16')
-    check_if_then_else(ctx, 1, 'uint16')
-    check_select(ctx, 1, 'int8')
-    check_select(ctx, 1, 'uint8')
-    check_select(ctx, 1, 'int16')
-    check_select(ctx, 1, 'uint16')
+    check_if_then_else(ctx, 1, "int8")
+    check_if_then_else(ctx, 1, "uint8")
+    check_if_then_else(ctx, 1, "int16")
+    check_if_then_else(ctx, 1, "uint16")
+    check_select(ctx, 1, "int8")
+    check_select(ctx, 1, "uint8")
+    check_select(ctx, 1, "int16")
+    check_select(ctx, 1, "uint16")
+
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_opencl
 def test_opencl_inf_nan():
     def check_inf_nan(ctx, n, value, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         inf_value = tvm.tir.const(value, dtype=dtype)
-        C = te.compute((n,), lambda i: inf_value, name='C')
+        C = te.compute((n,), lambda i: inf_value, name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
         fun = tvm.build(s, [A, C], target)
@@ -83,22 +85,22 @@ def test_opencl_inf_nan():
 
     ctx = tvm.context(target, 0)
 
-    check_inf_nan(ctx, 1, -float('inf'), 'float32')
-    check_inf_nan(ctx, 1, -float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('inf'), 'float32')
-    check_inf_nan(ctx, 1, float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('nan'), 'float32')
-    check_inf_nan(ctx, 1, float('nan'), 'float64')
+    check_inf_nan(ctx, 1, -float("inf"), "float32")
+    check_inf_nan(ctx, 1, -float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("inf"), "float32")
+    check_inf_nan(ctx, 1, float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("nan"), "float32")
+    check_inf_nan(ctx, 1, float("nan"), "float64")
 
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_opencl
 def test_opencl_max():
     def check_max(ctx, n, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         max_lhs = A[0] + tvm.tir.const(1, dtype=dtype)
         max_rhs = tvm.tir.const(0, dtype=dtype)
-        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C')
+        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
         fun = tvm.build(s, [A, C], target)
@@ -110,12 +112,12 @@ def test_opencl_max():
 
     ctx = tvm.context(target, 0)
 
-    check_max(ctx, 1, 'int8')
-    check_max(ctx, 1, 'uint8')
-    check_max(ctx, 1, 'int16')
-    check_max(ctx, 1, 'uint16')
-    check_max(ctx, 1, 'float32')
-    check_max(ctx, 1, 'float64')
+    check_max(ctx, 1, "int8")
+    check_max(ctx, 1, "uint8")
+    check_max(ctx, 1, "int16")
+    check_max(ctx, 1, "uint16")
+    check_max(ctx, 1, "float32")
+    check_max(ctx, 1, "float64")
 
 
 if __name__ == "__main__":
index 2adc1c8..1d5cbd8 100644 (file)
@@ -24,12 +24,13 @@ ty = te.thread_axis("threadIdx.y")
 bx = te.thread_axis("blockIdx.x")
 by = te.thread_axis("blockIdx.y")
 
+
 @tvm.testing.requires_rocm
 def test_rocm_cross_thread_reduction():
     # based on the reduction tutorial
     n = te.size_var("n")
     m = te.size_var("m")
-    A = te.placeholder((n, m), name='A')
+    A = te.placeholder((n, m), name="A")
     k = te.reduce_axis((0, m), "k")
     B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
     s = te.create_schedule(B.op)
@@ -48,16 +49,15 @@ def test_rocm_cross_thread_reduction():
     a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx)
     b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
     frocm(a, b)
-    tvm.testing.assert_allclose(
-      b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)
+    tvm.testing.assert_allclose(b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
 
 
 @tvm.testing.requires_rocm
 def test_rocm_inf_nan():
     def check_inf_nan(ctx, n, value, dtype):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         inf_value = tvm.tir.const(value, dtype=dtype)
-        C = te.compute((n,), lambda i: inf_value, name='C')
+        C = te.compute((n,), lambda i: inf_value, name="C")
         s = te.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], tx)
         fun = tvm.build(s, [A, C], "rocm")
@@ -68,20 +68,19 @@ def test_rocm_inf_nan():
 
     ctx = tvm.rocm(0)
 
-    check_inf_nan(ctx, 1, -float('inf'), 'float32')
-    check_inf_nan(ctx, 1, -float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('inf'), 'float32')
-    check_inf_nan(ctx, 1, float('inf'), 'float64')
-    check_inf_nan(ctx, 1, float('nan'), 'float32')
-    check_inf_nan(ctx, 1, float('nan'), 'float64')
+    check_inf_nan(ctx, 1, -float("inf"), "float32")
+    check_inf_nan(ctx, 1, -float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("inf"), "float32")
+    check_inf_nan(ctx, 1, float("inf"), "float64")
+    check_inf_nan(ctx, 1, float("nan"), "float32")
+    check_inf_nan(ctx, 1, float("nan"), "float64")
+
 
 @tvm.testing.requires_rocm
 def test_rocm_reduction_binding():
-    k = te.reduce_axis((0, 32), 'k')
-    A = te.placeholder((96, 32), name='A')
-    B = te.compute( (96,), lambda m:
-                     te.sum(A[m, k], axis=k),
-                     name='B')
+    k = te.reduce_axis((0, 32), "k")
+    A = te.placeholder((96, 32), name="A")
+    B = te.compute((96,), lambda m: te.sum(A[m, k], axis=k), name="B")
     s = te.create_schedule(B.op)
 
     s[B].reorder(B.op.reduce_axis[0], B.op.axis[0])
@@ -89,11 +88,11 @@ def test_rocm_reduction_binding():
     mo, _ = s[B].split(B.op.axis[0], 32)
     s[B].bind(mo, bx)
 
+
 @tvm.testing.requires_rocm
 def test_rocm_copy():
-
     def check_rocm(dtype, n):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         ctx = tvm.rocm(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
         a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
@@ -107,21 +106,21 @@ def test_rocm_copy():
         peturb = np.random.uniform(low=0.5, high=1.5)
         check_rocm(dtype, int(peturb * (2 ** logN)))
 
+
 @tvm.testing.requires_rocm
 def test_rocm_vectorize_add():
     num_thread = 8
 
     def check_rocm(dtype, n, lanes):
-        A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = te.compute((n,), lambda i: A[i]+tvm.tir.const(1, A.dtype), name='B')
+        A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
+        B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
         s = te.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
         s[B].bind(xo, bx)
         s[B].bind(xi, tx)
         fun = tvm.build(s, [A, B], "rocm")
         ctx = tvm.rocm(0)
-        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
-            np.random.uniform(size=(n, lanes)))
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np.random.uniform(size=(n, lanes)))
         c = tvm.nd.empty((n,), B.dtype, ctx)
         fun(a, c)
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
@@ -129,6 +128,7 @@ def test_rocm_vectorize_add():
     check_rocm("float32", 64, 2)
     check_rocm("float16", 64, 2)
 
+
 if __name__ == "__main__":
     test_rocm_cross_thread_reduction()
     test_rocm_inf_nan()
index 97a0fc3..179e302 100644 (file)
@@ -21,10 +21,10 @@ import numpy as np
 
 
 def test_static_callback():
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
-    i = te.size_var('i')
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
+    i = te.size_var("i")
     ib = tvm.tir.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     cp = te.thread_axis((0, 1), "cop")
@@ -34,24 +34,22 @@ def test_static_callback():
         A[i] = A[i] + 1
     stmt = ib.get()
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")
-    )
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
     f = tvm.driver.build(mod, target="llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
     f(a)
     np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
 
+
 def test_static_init():
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
-    i = te.size_var('i')
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
+    i = te.size_var("i")
     ib = tvm.tir.ir_builder.create()
     handle = tvm.tir.call_intrin("handle", "tir.tvm_static_handle")
-    ib.emit(
-        tvm.tir.call_packed("test_static_callback", handle, Ab))
+    ib.emit(tvm.tir.call_packed("test_static_callback", handle, Ab))
 
     @tvm.register_func("test_static_callback")
     def test_cb(sh, A):
@@ -59,8 +57,7 @@ def test_static_init():
         return sh
 
     stmt = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
     f = tvm.driver.build(mod, target="llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
index 55c7c31..26f1493 100644 (file)
@@ -19,6 +19,7 @@ import tvm.testing
 from tvm import te
 import numpy as np
 
+
 def run_jit(fapi, check):
     for target in ["llvm", "stackvm"]:
         if not tvm.testing.device_enabled(target):
@@ -27,19 +28,22 @@ def run_jit(fapi, check):
         s = f.get_source()
         check(f)
 
+
 def test_stack_vm_basic():
-    a = tvm.nd.array(np.zeros(10, dtype='float32'))
+    a = tvm.nd.array(np.zeros(10, dtype="float32"))
+
     @tvm.register_func
     def tvm_call_back_get_shape(shape0):
         print(shape0)
         assert shape0 == a.shape[0]
 
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), "float32")
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), "float32")
     stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
 
     mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape"))
+        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape")
+    )
 
     run_jit(mod, lambda f: f(a))
 
@@ -48,11 +52,12 @@ def test_stack_vm_basic():
 def tvm_stack_vm_print(*x):
     print(x)
 
+
 def test_stack_vm_loop():
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
-    i = te.size_var('i')
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
+    i = te.size_var("i")
 
     ib = tvm.tir.ir_builder.create()
     A = ib.buffer_ptr(Ab)
@@ -61,55 +66,59 @@ def test_stack_vm_loop():
         ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
 
     stmt = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
+
     def check(f):
         f(a)
         np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
+
     run_jit(mod, check)
 
 
 def test_stack_vm_cond():
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
 
     ib = tvm.tir.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     with ib.for_range(0, n - 1, "i") as i:
-        with ib.if_scope(tvm.tir.EQ(i,  4)):
+        with ib.if_scope(tvm.tir.EQ(i, 4)):
             A[i + 1] = A[i] + 1
         with ib.else_scope():
             A[i + 1] = A[i] + 2
 
     stmt = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
+
     def check(f):
         a = tvm.nd.array(np.zeros(10, dtype=dtype))
         f(a)
         y = np.arange(a.shape[0]) * 2
         y[5:] -= 1
         np.testing.assert_equal(a.asnumpy(), y)
+
     run_jit(mod, check)
 
+
 def test_vm_parallel():
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
-    i = te.size_var('i')
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
+    i = te.size_var("i")
     ib = tvm.tir.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     with ib.for_range(0, n, "i", for_type="parallel") as i:
         A[i] = A[i] + 1
     stmt = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
+
     def check(f):
         a = tvm.nd.array(np.zeros(10, dtype=dtype))
         f(a)
         np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
+
     run_jit(mod, check)
 
 
index a036cd8..61ac36e 100644 (file)
@@ -22,16 +22,18 @@ import numpy as np
 
 @tvm.testing.requires_vulkan
 def test_vector_comparison():
-    target = 'vulkan'
+    target = "vulkan"
 
     def check_correct_assembly(dtype):
         n = (1024,)
-        A = te.placeholder(n, dtype=dtype, name='A')
+        A = te.placeholder(n, dtype=dtype, name="A")
         B = te.compute(
             A.shape,
             lambda i: tvm.tir.Select(
-                A[i] >= 0, A[i] + tvm.tir.const(1, dtype),
-                tvm.tir.const(0, dtype)), name='B')
+                A[i] >= 0, A[i] + tvm.tir.const(1, dtype), tvm.tir.const(0, dtype)
+            ),
+            name="B",
+        )
         s = te.create_schedule(B.op)
 
         (bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
@@ -48,9 +50,10 @@ def test_vector_comparison():
         assert len(matches) == 1
         matches = re.findall("OpSelect %v4.*", assembly)
         assert len(matches) == 1
-    check_correct_assembly('float32')
-    check_correct_assembly('int32')
-    check_correct_assembly('float16')
+
+    check_correct_assembly("float32")
+    check_correct_assembly("int32")
+    check_correct_assembly("float16")
 
 
 tx = te.thread_axis("threadIdx.x")
@@ -59,9 +62,8 @@ bx = te.thread_axis("blockIdx.x")
 
 @tvm.testing.requires_vulkan
 def test_vulkan_copy():
-
     def check_vulkan(dtype, n):
-        A = te.placeholder((n,), name='A', dtype=dtype)
+        A = te.placeholder((n,), name="A", dtype=dtype)
         ctx = tvm.vulkan(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
         a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
@@ -81,16 +83,15 @@ def test_vulkan_vectorize_add():
     num_thread = 8
 
     def check_vulkan(dtype, n, lanes):
-        A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = te.compute((n,), lambda i: A[i]+tvm.tir.const(1, A.dtype), name='B')
+        A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
+        B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
         s = te.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
         s[B].bind(xo, bx)
         s[B].bind(xi, tx)
         fun = tvm.build(s, [A, B], "vulkan")
         ctx = tvm.vulkan(0)
-        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
-            np.random.uniform(size=(n, lanes)))
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np.random.uniform(size=(n, lanes)))
         c = tvm.nd.empty((n,), B.dtype, ctx)
         fun(a, c)
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
@@ -107,20 +108,21 @@ def test_vulkan_stress():
     """
     import random
     import threading
+
     n = 1024
     num_thread = 64
 
     def run_stress():
         def worker():
-            A = te.placeholder((n,), name='A', dtype="float32")
-            B = te.placeholder((n,), name='B', dtype="float32")
+            A = te.placeholder((n,), name="A", dtype="float32")
+            B = te.placeholder((n,), name="B", dtype="float32")
             functions = [
-                (lambda: te.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
-                 lambda a, b: 2 * a + 3 * b),
-                (lambda: te.compute((n,), lambda i: A[i]+B[i]),
-                 lambda a, b: a + b),
-                (lambda: te.compute((n,), lambda i: A[i]+2 * B[i]),
-                 lambda a, b: a + 2 * b),
+                (
+                    lambda: te.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
+                    lambda a, b: 2 * a + 3 * b,
+                ),
+                (lambda: te.compute((n,), lambda i: A[i] + B[i]), lambda a, b: a + b),
+                (lambda: te.compute((n,), lambda i: A[i] + 2 * B[i]), lambda a, b: a + 2 * b),
             ]
 
             def build_f(f_ref):
@@ -133,23 +135,20 @@ def test_vulkan_stress():
                 fun = tvm.build(s, [A, B, C], "vulkan")
                 return (fun, ref)
 
-            fs = [build_f(random.choice(functions))
-                  for _ in range(np.random.randint(low=1, high=10))]
+            fs = [
+                build_f(random.choice(functions)) for _ in range(np.random.randint(low=1, high=10))
+            ]
             ctx = tvm.vulkan(0)
-            a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
-                np.random.uniform(size=(n,)))
-            b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(
-                np.random.uniform(size=(n,)))
+            a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np.random.uniform(size=(n,)))
+            b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(np.random.uniform(size=(n,)))
             cs = [tvm.nd.empty((n,), A.dtype, ctx) for _ in fs]
             for ((f, _), c) in zip(fs, cs):
                 f(a, b, c)
 
             for ((_, ref), c) in zip(fs, cs):
-                tvm.testing.assert_allclose(
-                    c.asnumpy(), ref(a.asnumpy(), b.asnumpy()))
+                tvm.testing.assert_allclose(c.asnumpy(), ref(a.asnumpy(), b.asnumpy()))
 
-        ts = [threading.Thread(target=worker)
-              for _ in range(np.random.randint(1, 10))]
+        ts = [threading.Thread(target=worker) for _ in range(np.random.randint(1, 10))]
         for t in ts:
             t.start()
         for t in ts:
index cdba774..ec11d26 100644 (file)
@@ -21,20 +21,23 @@ import re
 
 def test_fp16_to_fp32():
     if tvm.target.codegen.llvm_version_major() < 6:
-        print("Skipping due to LLVM version being {} < 6".format(
-            tvm.target.codegen.llvm_version_major()))
+        print(
+            "Skipping due to LLVM version being {} < 6".format(
+                tvm.target.codegen.llvm_version_major()
+            )
+        )
         return
 
     def fp16_to_fp32(target, width, match=None, not_match=None):
         elements = 64
         n = tvm.runtime.convert(elements)
-        A = te.placeholder((n, width), dtype="float16", name='A')
-        B = te.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
+        A = te.placeholder((n, width), dtype="float16", name="A")
+        B = te.compute(A.shape, lambda *i: A(*i).astype("float32"), name="B")
         s = te.create_schedule(B.op)
         s[B].vectorize(s[B].op.axis[1])
         f = tvm.build(s, [A, B], target)
 
-        assembly = f.get_source('asm').splitlines()
+        assembly = f.get_source("asm").splitlines()
         if match:
             matches = [l for l in assembly if re.search(match, l)]
             assert matches
@@ -42,35 +45,22 @@ def test_fp16_to_fp32():
             not_matches = [l for l in assembly if re.search(not_match, l)]
             assert not not_matches
 
-
-    fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512', 15,
-        match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm")
-    fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512', 16,
-        match="vcvtph2ps.*zmm")
-    fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512', 17,
-        match="vcvtph2ps.*zmm")
     fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512', 49,
-        match="vcvtph2ps.*zmm")
+        "llvm -mcpu=skylake-avx512", 15, match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm"
+    )
+    fp16_to_fp32("llvm -mcpu=skylake-avx512", 16, match="vcvtph2ps.*zmm")
+    fp16_to_fp32("llvm -mcpu=skylake-avx512", 17, match="vcvtph2ps.*zmm")
+    fp16_to_fp32("llvm -mcpu=skylake-avx512", 49, match="vcvtph2ps.*zmm")
     fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512 -mattr=-avx512f', 49,
+        "llvm -mcpu=skylake-avx512 -mattr=-avx512f",
+        49,
         match="vcvtph2ps.*ymm",
-        not_match="vcvtph2ps.*zmm")
-    fp16_to_fp32(
-        'llvm -mcpu=skylake-avx512 -mattr=-f16c,-avx512f', 49,
-        not_match="vcvtph2ps")
-    fp16_to_fp32(
-        'llvm -mcpu=core-avx2', 8,
-        match="vcvtph2ps.*ymm")
-    fp16_to_fp32(
-        'llvm -mcpu=core-avx2', 9,
-        match="vcvtph2ps.*ymm")
-    fp16_to_fp32(
-        'llvm', 9,
-        not_match="vcvtph2ps")
+        not_match="vcvtph2ps.*zmm",
+    )
+    fp16_to_fp32("llvm -mcpu=skylake-avx512 -mattr=-f16c,-avx512f", 49, not_match="vcvtph2ps")
+    fp16_to_fp32("llvm -mcpu=core-avx2", 8, match="vcvtph2ps.*ymm")
+    fp16_to_fp32("llvm -mcpu=core-avx2", 9, match="vcvtph2ps.*ymm")
+    fp16_to_fp32("llvm", 9, not_match="vcvtph2ps")
 
 
 if __name__ == "__main__":
index 4c32066..9b2b85d 100644 (file)
@@ -32,17 +32,29 @@ def setup_module():
     tvm.target.datatype.register("bfloat", 129)
 
     tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast",
-        "llvm", "bfloat", "float")
+        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
+        "Cast",
+        "llvm",
+        "bfloat",
+        "float",
+    )
     tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast",
-        "llvm", "float", "bfloat")
+        tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"),
+        "Cast",
+        "llvm",
+        "float",
+        "bfloat",
+    )
     tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm",
-        "bfloat")
+        tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", "bfloat"
+    )
     tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm",
-        "llvm", "bfloat")
+        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
+        "FloatImm",
+        "llvm",
+        "bfloat",
+    )
+
 
 def lower_datatypes_and_build(schedule, args):
     """Create schedule and lower, manually lowering datatypes.
@@ -58,15 +70,15 @@ def lower_datatypes_and_build(schedule, args):
 
 
 def test_bfloat_add_and_cast_1():
-    X = te.placeholder((3, ), name="X")
-    Y = te.placeholder((3, ), name="Y")
+    X = te.placeholder((3,), name="X")
+    Y = te.placeholder((3,), name="Y")
     Z = topi.cast(
-        topi.cast(X, dtype="custom[bfloat]16") +
-        topi.cast(Y, dtype="custom[bfloat]16"),
-        dtype="float")
+        topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y, dtype="custom[bfloat]16"),
+        dtype="float",
+    )
 
     s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X,Y,Z])
+    built_cast = lower_datatypes_and_build(s, [X, Y, Z])
 
     ctx = tvm.context(tgt, 0)
 
@@ -74,13 +86,9 @@ def test_bfloat_add_and_cast_1():
     # with at most 7-bit mantissas which, when added, produce a result with at
     # most 7-bit mantissas. This is to ensure there are no errors due to
     # float32->bfloat16 conversions.
-    x = tvm.nd.array(
-        np.array([4.4103796E-32, 14942208.0, 1.78125]).astype("float32"),
-        ctx=ctx)
-    y = tvm.nd.array(
-        np.array([-3.330669E-14, 19660800.0, 2.25]).astype("float32"), ctx=ctx)
-    z_expected = np.array([-3.330669E-14, 34603008.0,
-                           4.03125]).astype("float32")
+    x = tvm.nd.array(np.array([4.4103796e-32, 14942208.0, 1.78125]).astype("float32"), ctx=ctx)
+    y = tvm.nd.array(np.array([-3.330669e-14, 19660800.0, 2.25]).astype("float32"), ctx=ctx)
+    z_expected = np.array([-3.330669e-14, 34603008.0, 4.03125]).astype("float32")
     z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
 
     built_cast(x, y, z)
@@ -89,15 +97,15 @@ def test_bfloat_add_and_cast_1():
 
 
 def test_bfloat_add_and_cast_2():
-    X = te.placeholder((3, ), name="X")
-    Y = te.placeholder((3, ), name="Y")
+    X = te.placeholder((3,), name="X")
+    Y = te.placeholder((3,), name="Y")
     Z = topi.cast(
-        topi.cast(X, dtype="custom[bfloat]16") +
-        topi.cast(Y, dtype="custom[bfloat]16"),
-        dtype="float")
+        topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y, dtype="custom[bfloat]16"),
+        dtype="float",
+    )
 
     s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X,Y,Z])
+    built_cast = lower_datatypes_and_build(s, [X, Y, Z])
 
     ctx = tvm.context(tgt, 0)
 
@@ -108,14 +116,9 @@ def test_bfloat_add_and_cast_2():
     # numbers. To simulate bfloat16 add implemented in mybfloat, I cut off all
     # but 7 bits of the result's mantissa. I then copied that value into
     # z_expected.
-    x = tvm.nd.array(
-        np.array([1.2348297, -1.0298302E25, 1.2034023E-30]).astype("float32"),
-        ctx=ctx)
-    y = tvm.nd.array(
-        np.array([-2.4992788, -9.888288E19, 9.342338E-29]).astype("float32"),
-        ctx=ctx)
-    z_expected = np.array([-1.25, -1.027587E25,
-                           9.426888E-29]).astype("float32")
+    x = tvm.nd.array(np.array([1.2348297, -1.0298302e25, 1.2034023e-30]).astype("float32"), ctx=ctx)
+    y = tvm.nd.array(np.array([-2.4992788, -9.888288e19, 9.342338e-29]).astype("float32"), ctx=ctx)
+    z_expected = np.array([-1.25, -1.027587e25, 9.426888e-29]).astype("float32")
     z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
 
     built_cast(x, y, z)
@@ -124,15 +127,14 @@ def test_bfloat_add_and_cast_2():
 
 
 def test_bfloat_add_and_cast_FloatImm():
-    X = te.placeholder((3, ), name="X")
+    X = te.placeholder((3,), name="X")
     Z = topi.cast(
-        topi.add(
-            topi.cast(X, dtype="custom[bfloat]16"),
-            tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
-        dtype="float")
+        topi.add(topi.cast(X, dtype="custom[bfloat]16"), tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
+        dtype="float",
+    )
 
     s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X,Z])
+    built_cast = lower_datatypes_and_build(s, [X, Z])
 
     ctx = tvm.context(tgt, 0)
 
index b34dcad..643043f 100644 (file)
@@ -65,8 +65,8 @@ def test_target_string_parse():
 
     assert target.kind.name == "cuda"
     assert target.model == "unknown"
-    assert set(target.keys) == set(['cuda', 'gpu'])
-    assert set(target.libs) == set(['cublas', 'cudnn'])
+    assert set(target.keys) == set(["cuda", "gpu"])
+    assert set(target.libs) == set(["cublas", "cudnn"])
     assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn"))
 
     assert tvm.target.intel_graphics().device_name == "intel_graphics"
@@ -75,8 +75,7 @@ def test_target_string_parse():
 
 
 def test_target_create():
-    targets = [cuda(), rocm(), mali(), intel_graphics(),
-               arm_cpu('rk3399'), vta(), bifrost()]
+    targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()]
     for tgt in targets:
         assert tgt is not None
 
@@ -86,27 +85,26 @@ def test_target_config():
     Test that constructing a target from a dictionary works.
     """
     target_config = {
-        'kind': 'llvm',
-        'keys': ['arm_cpu', 'cpu'],
-        'device': 'arm_cpu',
-        'libs': ['cblas'],
-        'system-lib': True,
-        'mfloat-abi': 'hard',
-        'mattr': ['+neon', '-avx512f'],
+        "kind": "llvm",
+        "keys": ["arm_cpu", "cpu"],
+        "device": "arm_cpu",
+        "libs": ["cblas"],
+        "system-lib": True,
+        "mfloat-abi": "hard",
+        "mattr": ["+neon", "-avx512f"],
     }
     # Convert config dictionary to json string.
     target_config_str = json.dumps(target_config)
     # Test both dictionary input and json string.
     for config in [target_config, target_config_str]:
         target = tvm.target.Target(config)
-        assert target.kind.name == 'llvm'
-        assert all([key in target.keys for key in ['arm_cpu', 'cpu']])
-        assert target.device_name == 'arm_cpu'
-        assert target.libs == ['cblas']
-        assert 'system-lib' in str(target)
-        assert target.attrs['mfloat-abi'] == 'hard'
-        assert all([attr in target.attrs['mattr']
-                    for attr in ['+neon', '-avx512f']])
+        assert target.kind.name == "llvm"
+        assert all([key in target.keys for key in ["arm_cpu", "cpu"]])
+        assert target.device_name == "arm_cpu"
+        assert target.libs == ["cblas"]
+        assert "system-lib" in str(target)
+        assert target.attrs["mfloat-abi"] == "hard"
+        assert all([attr in target.attrs["mattr"] for attr in ["+neon", "-avx512f"]])
 
 
 def test_config_map():
@@ -114,10 +112,7 @@ def test_config_map():
     Confirm that constructing a target with invalid
     attributes fails as expected.
     """
-    target_config = {
-        'kind': 'llvm',
-        'libs': {'a': 'b', 'c': 'd'}
-    }
+    target_config = {"kind": "llvm", "libs": {"a": "b", "c": "d"}}
     failed = False
     try:
         tvm.target.Target(target_config)
@@ -127,8 +122,7 @@ def test_config_map():
 
 
 def test_composite_target():
-    tgt = tvm.target.Target(
-        "composite --target_host=llvm --devices=cuda,opencl")
+    tgt = tvm.target.Target("composite --target_host=llvm --devices=cuda,opencl")
     assert tgt.kind.name == "composite"
     assert tgt.attrs["target_host"].kind.name == "llvm"
     assert len(tgt.attrs["devices"]) == 2
index 5bebf3d..7b591dc 100644 (file)
@@ -25,7 +25,9 @@ import pytest
 import numpy as np
 
 
-def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, assert_no_jacobian=True):
+def check_grad(
+    out, inputs, args=[], data_range=(-10, 10), desired_grads=None, assert_no_jacobian=True
+):
     inputs = inputs if isinstance(inputs, list) else [inputs]
 
     def check_device(device, host="llvm"):
@@ -38,12 +40,16 @@ def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, a
         out_shape = get_const_tuple(out.shape)
 
         l, h = data_range
-        input_data = [tvm.nd.array(
-            np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype))
-            for input in inputs]
-        arg_vals = [tvm.nd.array(
-            np.random.uniform(l, h, size=get_const_tuple(arg.shape)).astype(arg.dtype))
-            for arg in args]
+        input_data = [
+            tvm.nd.array(
+                np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype)
+            )
+            for input in inputs
+        ]
+        arg_vals = [
+            tvm.nd.array(np.random.uniform(l, h, size=get_const_tuple(arg.shape)).astype(arg.dtype))
+            for arg in args
+        ]
 
         ones = topi.full_like(out, 1.0)
         # we provide head to sum and reduce the output dimension,
@@ -56,8 +62,7 @@ def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, a
             lowered_ir = str(tvm.lower(grad_sched, list(grads) + inputs + args, simple_mode=True))
             assert "jacobian" not in lowered_ir, lowered_ir
 
-        grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype)
-                     for i, g in zip(inputs, grads)]
+        grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype) for i, g in zip(inputs, grads)]
 
         mgrad(*grad_data, *input_data, *arg_vals)
         g_res = [g.asnumpy() for g in grad_data]
@@ -67,11 +72,15 @@ def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, a
             for actual, desired in zip(g_res, desired_grads):
                 assert_allclose(actual, desired, rtol=0.1, atol=1e-2)
         else:
+
             def forward(*in_data):
                 out_data = tvm.nd.empty(out_shape, out.dtype)
                 mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
                 return out_data.asnumpy().sum()
-            tvm.testing.check_numerical_grads(forward, [d.asnumpy() for d in input_data + arg_vals], g_res)
+
+            tvm.testing.check_numerical_grads(
+                forward, [d.asnumpy() for d in input_data + arg_vals], g_res
+            )
 
     check_device("cpu")
 
@@ -79,83 +88,83 @@ def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, a
 def test_basic_operation():
     np.random.seed(0)
     shape = (10, 10)
-    x = te.var("x", dtype='float32')
+    x = te.var("x", dtype="float32")
     k = te.reduce_axis((0, 10), name="k")
     l = te.reduce_axis((0, 10), name="l")
-    A0 = te.placeholder(shape, name='A0')
-    A1 = te.placeholder(shape, name='A1')
+    A0 = te.placeholder(shape, name="A0")
+    A1 = te.placeholder(shape, name="A1")
     zeros = np.zeros(shape)
 
-    B = te.compute(shape, lambda i, j: A0[i, j], name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j], name="B")
     check_grad(B, [A0])
 
-    B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name="B")
     check_grad(B, [A0, A1])
 
-    B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B')
+    B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name="B")
     check_grad(B, A0, desired_grads=[zeros])
 
-    B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B')
+    B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name="B")
     check_grad(B, A0, desired_grads=[zeros])
 
-    B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B')
+    B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name="B")
     check_grad(B, A0, desired_grads=[zeros])
 
-    B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B')
+    B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name="B")
     check_grad(B, A0, desired_grads=[zeros])
 
-    B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))), name='B')
+    B = te.compute(shape, lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: te.sigmoid(A0[i, j] * A0[i, j] * A0[j, i]), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: te.tanh(A0[i, j] * A0[i, j] * A0[j, i]), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.sqrt(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: te.sqrt(A0[i, j] * A0[i, j] * A0[j, i]), name="B")
     check_grad(B, A0, data_range=(0.1, 10))
 
-    B = te.compute(shape, lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]), name="B")
     check_grad(B, A0, data_range=(-4, 4))
 
-    B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name="B")
     check_grad(B, A0)
 
-    B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k), name='B')
+    B = te.compute((10,), lambda i: te.sum(A0[i, k] * A0[k, i], axis=k), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B')
+    B = te.compute(shape, lambda i, j: te.sum(A0[i, k] * A0[k, i] + 5, axis=k), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: te.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B')
+    B = te.compute(shape, lambda i, j: te.max(A0[i, k] * A0[k, j] + 5, axis=k), name="B")
     check_grad(B, A0)
 
-    B = te.compute(shape, lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B')
+    B = te.compute(shape, lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name="B")
     check_grad(B, [A0, A1])
 
-    B = te.compute(shape, lambda i, j: te.sum(A0[k, k] -
-                                              A0[te.min(j + k, 9), j]*A0[i, k],
-                                              axis=k), name='B')
+    B = te.compute(
+        shape, lambda i, j: te.sum(A0[k, k] - A0[te.min(j + k, 9), j] * A0[i, k], axis=k), name="B"
+    )
     check_grad(B, A0)
 
     def fcombine(x, y):
-        return x*y
+        return x * y
 
     def fidentity(t0):
         return tvm.tir.const(1, t0)
 
-    prod = te.comm_reducer(fcombine, fidentity, name='prod')
-    B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B')
+    prod = te.comm_reducer(fcombine, fidentity, name="prod")
+    B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name="B")
     check_grad(B, A0)
 
-    X = te.placeholder((10,), name='X')
+    X = te.placeholder((10,), name="X")
     A = te.compute((10,), lambda i: X[i] + X[9 - i])
     B = te.compute((10,), lambda i: X[i] * X[9 - i])
     Y = topi.tensordot(A, B, 1)
@@ -163,10 +172,10 @@ def test_basic_operation():
 
 
 def test_topi():
-    X = te.placeholder((1, 2, 4, 4), name='X')
-    W = te.placeholder((5, 2, 3, 3), name='W')
-    W1 = te.placeholder((2, 5, 3, 3), name='W1')
-    W2 = te.placeholder((1,), name='W2')
+    X = te.placeholder((1, 2, 4, 4), name="X")
+    W = te.placeholder((5, 2, 3, 3), name="W")
+    W1 = te.placeholder((2, 5, 3, 3), name="W1")
+    W2 = te.placeholder((1,), name="W2")
 
     R = topi.nn.conv2d(X, W, 1, 1, 1)
     check_grad(R, [X, W])
@@ -180,18 +189,18 @@ def test_topi():
     R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1)
     check_grad(R, [X, W2])
 
-    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg')
+    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "avg")
     check_grad(R, X)
 
-    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max')
+    R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "max")
     check_grad(R, X)
 
-    X = te.placeholder((1, 2, 5, 5), name='X')
+    X = te.placeholder((1, 2, 5, 5), name="X")
     R = topi.reshape(X, (1, 32))
     check_grad(R, [X])
 
-    X = te.placeholder((1, 2, 5, 5), name='X')
-    W = te.placeholder((2, 2, 3, 3), name='W')
+    X = te.placeholder((1, 2, 5, 5), name="X")
+    W = te.placeholder((2, 2, 3, 3), name="W")
 
     S = topi.reshape(X, (1, 50))
     check_grad(S, [X])
@@ -212,12 +221,12 @@ def test_topi():
     check_grad(S, [X, W])
     check_grad(S, [W], [X])
 
-    X = te.placeholder((1, 2, 3, 5), name='X')
-    Y = te.placeholder((1, 2, 7, 5), name='Y')
+    X = te.placeholder((1, 2, 3, 5), name="X")
+    Y = te.placeholder((1, 2, 7, 5), name="Y")
     S = topi.concatenate((X, Y), 2)
     check_grad(S, [X, Y])
 
-    X = te.placeholder((1, 2, 6, 5), name='X')
+    X = te.placeholder((1, 2, 6, 5), name="X")
     (S, R) = topi.split(X, 2, 2)
     check_grad(S, [X])
     check_grad(R, [X])
@@ -226,21 +235,21 @@ def test_topi():
     R2 = topi.concatenate((R, S), 2)
     check_grad(R2, [X])
 
-    X = te.placeholder((4, 5), name='X')
-    I = te.placeholder((100,), name='I', dtype='int32')
+    X = te.placeholder((4, 5), name="X")
+    I = te.placeholder((100,), name="I", dtype="int32")
     R = topi.take(X, topi.abs(I))
     check_grad(R, [X], [I])
 
-    W = te.placeholder((5, 5), name='W')
+    W = te.placeholder((5, 5), name="W")
     exps = topi.exp(topi.nn.dense(X, W))
     sumexps = topi.sum(exps, axis=-1, keepdims=True)
-    R = exps/sumexps
+    R = exps / sumexps
     check_grad(R, [X, W], data_range=(-1, 1))
 
 
 def test_stride_dilation():
-    X = te.placeholder((1, 2, 10, 10), name='X')
-    W = te.placeholder((2, 2, 1, 1), name='W')
+    X = te.placeholder((1, 2, 10, 10), name="X")
+    W = te.placeholder((2, 2, 1, 1), name="W")
 
     Y = topi.nn.conv2d(X, W, 1, 0, 1)
     check_grad(Y, [X, W])
@@ -261,7 +270,7 @@ def test_stride_dilation():
     Y = topi.nn.conv2d(X, W, 3, 0, 3)
     check_grad(Y, [X, W])
 
-    W = te.placeholder((2, 2, 2, 2), name='W')
+    W = te.placeholder((2, 2, 2, 2), name="W")
 
     Y = topi.nn.conv2d(X, W, 1, 0, 1)
     check_grad(Y, [X, W])
@@ -282,7 +291,7 @@ def test_stride_dilation():
     Y = topi.nn.conv2d(X, W, 3, 0, 3)
     check_grad(Y, [X, W])
 
-    W = te.placeholder((2, 2, 3, 3), name='W')
+    W = te.placeholder((2, 2, 3, 3), name="W")
 
     Y = topi.nn.conv2d(X, W, 1, 0, 1)
     check_grad(Y, [X, W])
@@ -303,35 +312,37 @@ def test_stride_dilation():
     Y = topi.nn.conv2d(X, W, 3, 0, 3)
     check_grad(Y, [X, W])
 
-    Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
-    Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max')
+    Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], "max")
     check_grad(Y, [X])
 
+
 @pytest.mark.xfail
 def test_reduction_init():
     np.random.seed(0)
     shape = (10, 10)
     k = te.reduce_axis((0, 10), name="k")
-    A0 = te.placeholder(shape, name='A0')
+    A0 = te.placeholder(shape, name="A0")
 
-    B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k, init=0.0), name='B')
+    B = te.compute((10,), lambda i: te.sum(A0[i, k] * A0[k, i], axis=k, init=0.0), name="B")
     check_grad(B, A0)
 
+
 if __name__ == "__main__":
     test_basic_operation()
     test_topi()
index 1fc2fcd..50d5119 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_lower_rfactor():
     n = te.size_var("n")
     m = te.size_var("m")
-    A = te.placeholder((n, m), name='A')
+    A = te.placeholder((n, m), name="A")
     k = te.reduce_axis((0, m), "k")
     B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
     s = te.create_schedule(B.op)
@@ -33,16 +34,22 @@ def test_lower_rfactor():
     s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
     fapi = tvm.lower(s, [A, B])
 
+
 def test_dependent_output_shape():
-    n, m, x = te.size_var('n'), te.size_var('m'), te.size_var('x')
+    n, m, x = te.size_var("n"), te.size_var("m"), te.size_var("x")
     A = te.placeholder((n, m))
-    B = te.compute((m, n//x), lambda i, j: A[i,j] , name='B')
+    B = te.compute((m, n // x), lambda i, j: A[i, j], name="B")
     s = te.create_schedule(B.op)
     mod = tvm.build(s, [A, B, x])
 
+
 def test_split_uneven_unique_likely():
-    a = te.placeholder((16, 16),)
-    b = te.placeholder((16, 16),)
+    a = te.placeholder(
+        (16, 16),
+    )
+    b = te.placeholder(
+        (16, 16),
+    )
     c = te.compute((16, 16), lambda x, y: a[x, y] + b[x, y])
 
     x, y = c.op.axis
index 0f1118d..e57040a 100644 (file)
@@ -18,6 +18,7 @@
 import tvm
 from tvm import te
 
+
 def test_scan_group():
     m = te.size_var("m")
     n = te.size_var("n")
@@ -25,7 +26,7 @@ def test_scan_group():
     s_state = te.placeholder((m, n))
     s_init = te.compute((1, n), lambda _, i: x[0, i])
 
-    s_update1 = te.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
+    s_update1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i])
     s_update2 = te.compute((m, n), lambda t, i: s_update1[t, i] + 1)
     s_update3 = te.compute((m, n), lambda t, i: s_update2[t, i] + 1)
     res = tvm.te.scan(s_init, s_update3, s_state, inputs=x)
@@ -50,6 +51,7 @@ def test_scan_group():
     except tvm.error.TVMError:
         pass
 
+
 def test_compute_group():
     m = te.size_var("m")
     n = te.size_var("n")
@@ -64,6 +66,7 @@ def test_compute_group():
     assert g.attach_stage == s[x2]
     assert g.num_child_stages == 2
 
+
 def test_nest_group():
     m = te.size_var("m")
     n = te.size_var("n")
@@ -80,6 +83,7 @@ def test_nest_group():
     assert g2.num_child_stages == 2
     assert g1.num_child_stages == 1
 
+
 if __name__ == "__main__":
     test_nest_group()
     test_compute_group()
index 94ec355..3afdb66 100644 (file)
@@ -23,8 +23,9 @@ from tvm.te.hybrid.runtime import HYBRID_GLOBALS
 
 import tvm.testing
 
+
 @pytest.mark.skip
-def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
+def run_and_check(func, args, var_dict={}, target="llvm", sch=None, outs=None):
     def tvm_val_2_py_val(val):
         val = tvm.tir.stmt_functor.substitute(val, var_dict)
         val = tvm.arith.Analyzer().simplify(val)
@@ -57,11 +58,10 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
             assert isinstance(i, list)
             emu_args.append(numpy.array(i))
 
-    compile_args = [i for i in args if isinstance(i, (te.tensor.Tensor, tvm.tir.Var))] + \
-                   (outs if isinstance(outs, list) else [outs])
-    module = tvm.build(sch,
-                       compile_args,
-                       target=target)
+    compile_args = [i for i in args if isinstance(i, (te.tensor.Tensor, tvm.tir.Var))] + (
+        outs if isinstance(outs, list) else [outs]
+    )
+    module = tvm.build(sch, compile_args, target=target)
     assert module
 
     out_tensors = []
@@ -86,6 +86,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
 
     return h_module, module_args, module_outs
 
+
 @script
 def outer_product(n, m, a, b):
     """This is a simple outer product.
@@ -99,33 +100,34 @@ def outer_product(n, m, a, b):
             c[i, j] = a[i] * b[j]
     return c
 
-#Test global function
-#Test bridge between frontend and backend
+
+# Test global function
+# Test bridge between frontend and backend
 def test_outer_product():
-    n = te.size_var('n')
-    m = te.size_var('m')
-    a = te.placeholder((n, ), name='a')
-    b = te.placeholder((m, ), name='b')
+    n = te.size_var("n")
+    m = te.size_var("m")
+    a = te.placeholder((n,), name="a")
+    b = te.placeholder((m,), name="b")
 
     try:
         c = outer_product(n, m, a, b)
         ir = c.op.body
     except IOError as err:
-        assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
+        assert sys.version_info[0] == 2 and str(err) == "could not get source code"
         return
 
-    #Check for i in (0, n)
+    # Check for i in (0, n)
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'i'
+    assert ir.loop_var.name == "i"
     assert ir.min.value == 0
-    assert ir.extent.name == 'n'
+    assert ir.extent.name == "n"
     ibody = ir.body
     assert isinstance(ibody, tvm.tir.For)
-    #Check for j in (0, m)
-    assert ibody.loop_var.name == 'j'
+    # Check for j in (0, m)
+    assert ibody.loop_var.name == "j"
     assert ibody.min.value == 0
-    assert ibody.extent.name == 'm'
-    #Check loop body
+    assert ibody.extent.name == "m"
+    # Check loop body
     jblock = ibody.body
     assert isinstance(jblock, tvm.tir.SeqStmt)
     jbody = jblock[0]
@@ -134,19 +136,19 @@ def test_outer_product():
     assert jbody.message.value == "index out of range!"
     jbody = jblock[1]
     assert isinstance(jbody, tvm.tir.ProducerStore)
-    assert jbody.producer.op.name == 'c'
+    assert jbody.producer.op.name == "c"
     assert len(jbody.indices) == 2
-    assert jbody.indices[0].name == 'i'
-    assert jbody.indices[1].name == 'j'
+    assert jbody.indices[0].name == "i"
+    assert jbody.indices[1].name == "j"
     assert isinstance(jbody.value, tvm.tir.Mul)
     mul = jbody.value
     assert isinstance(mul.a, tvm.tir.ProducerLoad)
-    assert mul.a.producer.name == 'a'
-    assert mul.b.producer.name == 'b'
+    assert mul.a.producer.name == "a"
+    assert mul.b.producer.name == "b"
 
     func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
     temp = util.tempdir()
-    path = temp.relpath('%s.py' % func.name)
+    path = temp.relpath("%s.py" % func.name)
     func.save(path)
     func_ = te.hybrid.HybridModule()
     func_.load(path)
@@ -156,13 +158,14 @@ def test_outer_product():
         assert key not in globals().keys()
         assert key not in outer_product.__globals__.keys()
 
-#Test local function
-#Test allocation of local variable
+
+# Test local function
+# Test allocation of local variable
 def test_fanout():
     @script
     def fanout(n, a):
         three = 3.0
-        b = output_tensor((a.shape[0] - 3, ), a.dtype)
+        b = output_tensor((a.shape[0] - 3,), a.dtype)
         for i in range(a.shape[0] - 3):
             sigma = 0.0
             for j in range(3):
@@ -171,67 +174,67 @@ def test_fanout():
             b[i] = sigma
         return b
 
-    n = te.size_var('n')
-    a = te.placeholder((n, ), 'float32', name='a')
+    n = te.size_var("n")
+    a = te.placeholder((n,), "float32", name="a")
     try:
         b = fanout(n, a)
         ir = b.op.body
     except IOError as err:
-        assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
+        assert sys.version_info[0] == 2 and str(err) == "could not get source code"
         return
 
-    #Check for i in (0, n-3)
+    # Check for i in (0, n-3)
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'i'
+    assert ir.loop_var.name == "i"
     assert ir.min.value == 0
     assert tvm.ir.structural_equal(ir.extent, n - 3)
-    #Check loopbody
+    # Check loopbody
     ibody = ir.body
     assert isinstance(ibody, tvm.tir.AttrStmt)
     abody = ibody.body
     assert isinstance(abody, tvm.tir.ProducerRealize)
     assert abody.bounds[0].min.value == 0
     assert abody.bounds[0].extent.value == 1
-    assert abody.producer.op.name == 'sigma'
-    #Check i loop body
+    assert abody.producer.op.name == "sigma"
+    # Check i loop body
     rbody = abody.body
     assert isinstance(rbody[0], tvm.tir.ProducerStore)
-    assert rbody[0].producer.op.name == 'sigma'
+    assert rbody[0].producer.op.name == "sigma"
     assert len(rbody[0].indices) == 1
     assert rbody[0].indices[0].value == 0
-    #Check fanout loop
+    # Check fanout loop
     jloop = rbody[1]
-    assert jloop.loop_var.name == 'j'
+    assert jloop.loop_var.name == "j"
     assert jloop.min.value == 0
     assert jloop.extent.value == 3
     jbody = jloop.body
     assert isinstance(jbody, tvm.tir.ProducerStore)
     assert len(jbody.indices) == 1
     assert jbody.indices[0].value == 0
-    assert jbody.producer.op.name == 'sigma'
+    assert jbody.producer.op.name == "sigma"
     assert isinstance(jbody.value, tvm.tir.Add)
     value = jbody.value
     assert isinstance(value.a, tvm.tir.ProducerLoad)
-    assert value.a.producer.name == 'sigma'
+    assert value.a.producer.name == "sigma"
     assert len(value.a.indices) == 1
     assert value.a.indices[0].value == 0
-    assert value.b.producer.name == 'a'
+    assert value.b.producer.name == "a"
     assert len(value.b.indices) == 1
     assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var)
-    divide= rbody[2]
+    divide = rbody[2]
     assert isinstance(divide, tvm.tir.ProducerStore)
     assert len(divide.indices) == 1
     assert divide.indices[0].value == 0
     value = divide.value
     assert isinstance(value, tvm.tir.Mul)
-    assert value.a.producer.name == 'sigma'
+    assert value.a.producer.name == "sigma"
     assert len(value.a.indices) == 1
     assert value.a.indices[0].value == 0
     assert abs(value.b.value - (1 / 3.0)) < 1e-5
     write = rbody[3]
     assert isinstance(write, tvm.tir.ProducerStore)
-    assert write.producer.op.name == 'b'
-    assert write.value.producer.name == 'sigma'
+    assert write.producer.op.name == "b"
+    assert write.value.producer.name == "sigma"
     assert len(write.value.indices) == 1
     assert write.value.indices[0].value == 0
 
@@ -242,9 +245,9 @@ def test_fanout():
 def test_looptype():
     @script
     def looptype(a, b, c):
-        d = output_tensor((16, ), 'int32')
-        e = output_tensor((16, ), 'int32')
-        f = output_tensor((16, ), 'int32')
+        d = output_tensor((16,), "int32")
+        e = output_tensor((16,), "int32")
+        f = output_tensor((16,), "int32")
         for i in parallel(16):
             d[i] = a[i]
         for j in vectorize(16):
@@ -253,9 +256,9 @@ def test_looptype():
             f[k] = c[k]
         return d, e, f
 
-    a = te.placeholder((16, ), name='a', dtype='int32')
-    b = te.placeholder((16, ), name='b', dtype='int32')
-    c = te.placeholder((16, ), name='c', dtype='int32')
+    a = te.placeholder((16,), name="a", dtype="int32")
+    b = te.placeholder((16,), name="b", dtype="int32")
+    c = te.placeholder((16,), name="c", dtype="int32")
     try:
         d, e, f = looptype(a, b, c)
         ir = d.op.body
@@ -275,8 +278,8 @@ def test_looptype():
 def test_if():
     @script
     def if_then_else(a):
-        b = output_tensor((10, ), 'int32')
-        c = output_tensor((10, ), 'int32')
+        b = output_tensor((10,), "int32")
+        c = output_tensor((10,), "int32")
         for i in range(10):
             if i % 2 == 0:
                 c[i] = a[i]
@@ -286,14 +289,14 @@ def test_if():
             b[i] = -1 if i % 2 == 0 else 1
         return b, c
 
-    a = te.placeholder((10, ), dtype='int32', name='a')
+    a = te.placeholder((10,), dtype="int32", name="a")
 
     func, ins, outs = run_and_check(if_then_else, [a])
     run_and_check(func, ins, outs=outs)
 
     @script
     def if_triple_condition(a):
-        b = output_tensor((10, ), 'int32')
+        b = output_tensor((10,), "int32")
         for i in range(10):
             if 0 <= i < 5:
                 b[i] = a[i]
@@ -306,7 +309,7 @@ def test_if():
 
     @script
     def if_and(a):
-        b = output_tensor((10, ), 'int32')
+        b = output_tensor((10,), "int32")
         for i in range(10):
             if i >= 0 and i < 5:
                 b[i] = a[i]
@@ -323,74 +326,73 @@ def test_if():
 def test_bind():
     @script
     def vec_add(a, b):
-        c = output_tensor((1000, ), 'float32')
-        for tx in bind('threadIdx.x', 1000):
+        c = output_tensor((1000,), "float32")
+        for tx in bind("threadIdx.x", 1000):
             c[tx] = a[tx] + b[tx]
         return c
 
-    a = te.placeholder((1000, ), dtype='float32', name='a')
-    b = te.placeholder((1000, ), dtype='float32', name='b')
-    func, ins, outs = run_and_check(vec_add, [a, b], target='cuda')
-    run_and_check(func, ins, outs=outs, target='cuda')
+    a = te.placeholder((1000,), dtype="float32", name="a")
+    b = te.placeholder((1000,), dtype="float32", name="b")
+    func, ins, outs = run_and_check(vec_add, [a, b], target="cuda")
+    run_and_check(func, ins, outs=outs, target="cuda")
 
     @script
     def raw(a, b):
-        c = output_tensor((1000, ), 'float32')
+        c = output_tensor((1000,), "float32")
         for i in range(1000):
             c[i] = a[i] + b[i]
         return c
 
     c = raw(a, b)
     sch = te.create_schedule(c.op)
-    x = te.thread_axis('threadIdx.x')
+    x = te.thread_axis("threadIdx.x")
     sch[c].bind(c.op.axis[0], x)
-    func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
-    run_and_check(func, ins, outs=outs, target='cuda')
-
+    func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target="cuda")
+    run_and_check(func, ins, outs=outs, target="cuda")
 
     @te.hybrid.script
     def foo(a):
         c = output_tensor((a.shape[0],), a.dtype)
-        total = allocate((1,), a.dtype, 'local')
+        total = allocate((1,), a.dtype, "local")
         len_i = a.shape[0]
         len_j = a.shape[1]
-        for i in bind('threadIdx.x', len_i):
-            total[0] = 0.
+        for i in bind("threadIdx.x", len_i):
+            total[0] = 0.0
             for k in const_range(len_j):
                 total[0] += a[i, k]
             c[i] = total[0]
 
         return c
 
-    a = te.placeholder((8, 4), 'float32')
+    a = te.placeholder((8, 4), "float32")
     c = foo(a)
     s = te.create_schedule(c.op)
     ir = tvm.lower(s, [a, c])
 
-    func, ins, outs = run_and_check(foo, [a], target='cuda')
-    run_and_check(func, ins, outs=outs, target='cuda')
+    func, ins, outs = run_and_check(foo, [a], target="cuda")
+    run_and_check(func, ins, outs=outs, target="cuda")
 
     @te.hybrid.script
     def max_threads(a):
         b = output_tensor(a.shape, a.dtype)
         n = a.shape[0]
         m = max_num_threads(True)
-        for i in bind('threadIdx.x', m):
-            for j in bind('blockIdx.x', ceil_div(n, m)):
+        for i in bind("threadIdx.x", m):
+            for j in bind("blockIdx.x", ceil_div(n, m)):
                 if i * m + j < n:
                     b[i * m + j] = a[i * m + j] + a[i * m + j]
         return b
 
-    a = te.placeholder((10000, ), 'float32')
-    with tvm.target.Target('cuda'):
-        func, ins, outs = run_and_check(max_threads, [a], target='cuda')
-        run_and_check(func, ins, outs=outs, target='cuda')
+    a = te.placeholder((10000,), "float32")
+    with tvm.target.Target("cuda"):
+        func, ins, outs = run_and_check(max_threads, [a], target="cuda")
+        run_and_check(func, ins, outs=outs, target="cuda")
 
 
 def test_math_intrin():
     @script
     def intrin_real(a):
-        b = output_tensor((8, ), 'float32')
+        b = output_tensor((8,), "float32")
         b[0] = sqrt(a[0])
         b[1] = log(a[1])
         b[2] = exp(a[2])
@@ -401,84 +403,86 @@ def test_math_intrin():
         b[7] = max(a[5], a[6])
         return b
 
-    a8 = te.placeholder((8, ), dtype='float32', name='a')
+    a8 = te.placeholder((8,), dtype="float32", name="a")
     b8 = intrin_real(a8)
     sch = te.create_schedule(b8.op)
     func = tvm.build(sch, [a8, b8])
     assert func
-    a = numpy.arange(2, 10).astype('float32')
+    a = numpy.arange(2, 10).astype("float32")
     tvm_a = tvm.nd.array(a)
-    tvm_b = tvm.nd.array(numpy.zeros((8, ), dtype='float32'))
+    tvm_b = tvm.nd.array(numpy.zeros((8,), dtype="float32"))
     b = intrin_real(a)
     func(tvm_a, tvm_b)
     tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
 
     @script
     def intrin_int(a):
-        b = output_tensor((1, ), 'int32')
+        b = output_tensor((1,), "int32")
         b[0] = popcount(a[0])
         return b
 
-    a1 = te.placeholder((1, ), dtype='int32')
+    a1 = te.placeholder((1,), dtype="int32")
     b1 = intrin_int(a1)
     sch = te.create_schedule(b1.op)
     func = tvm.build(sch, [a1, b1])
     assert func
-    a = numpy.array([114514]).astype('int32')
+    a = numpy.array([114514]).astype("int32")
     tvm_a = tvm.nd.array(a)
-    tvm_b = tvm.nd.array(numpy.array([0]).astype('int32'))
+    tvm_b = tvm.nd.array(numpy.array([0]).astype("int32"))
     b = intrin_int(a)
     func(tvm_a, tvm_b)
     assert tvm_b.asnumpy()[0] == b[0]
 
+
 # test non caconical loops
 def test_non_zero():
     @te.hybrid.script
     def blur(a):
-        b = output_tensor((30, 30), 'float32')
+        b = output_tensor((30, 30), "float32")
         for i in range(2, 32):
             for j in range(2, 32):
                 s = 0.0
                 for di in range(3):
                     for dj in range(3):
-                        s += a[i-di, j-dj]
-                b[i-2, j-2] = s / 9.0
+                        s += a[i - di, j - dj]
+                b[i - 2, j - 2] = s / 9.0
         return b
 
-    a = te.placeholder((32, 32), 'float32', 'a')
+    a = te.placeholder((32, 32), "float32", "a")
     func, ins, outs = run_and_check(blur, [a])
     run_and_check(func, ins, outs=outs)
 
     @te.hybrid.script
     def triangle(a, b):
-        c = output_tensor((10, 10), dtype='float32')
+        c = output_tensor((10, 10), dtype="float32")
         for i in range(10):
             for j in range(i, 10):
                 c[i, j] = a[i] * b[j]
         return c
 
-    a = te.placeholder((10, ), dtype='float32', name='a')
-    b = te.placeholder((10, ), dtype='float32', name='b')
+    a = te.placeholder((10,), dtype="float32", name="a")
+    b = te.placeholder((10,), dtype="float32", name="b")
 
     func, ins, outs = run_and_check(triangle, [a, b])
     run_and_check(func, ins, outs=outs)
 
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_allocate():
     @te.hybrid.script
     def blur2d(a):
-        b = output_tensor((30, 30), 'float32')
+        b = output_tensor((30, 30), "float32")
         for i in range(30):
-            ha = allocate((3, 30), 'float32')
+            ha = allocate((3, 30), "float32")
             for j in range(3):
                 for k in range(30):
-                    ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2]
+                    ha[j, k] = a[i + j, k] + a[i + j, k + 1] + a[i + j, k + 2]
             for j in range(30):
                 b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0
         return b
 
-    a = te.placeholder((32, 32), 'float32', 'a')
+    a = te.placeholder((32, 32), "float32", "a")
     b = blur2d(a)
     sch = te.create_schedule(b.op)
     func, ins, outs = run_and_check(blur2d, [a])
@@ -486,111 +490,114 @@ def test_allocate():
 
     @te.hybrid.script
     def share_vec_add(a, b):
-        c = output_tensor((256, ), 'float32')
-        shared = allocate((256, ), 'float32', 'shared')
+        c = output_tensor((256,), "float32")
+        shared = allocate((256,), "float32", "shared")
         for i in bind("threadIdx.x", 256):
             shared[i] = a[i]
-        local = allocate((256, ), 'float32', 'local')
+        local = allocate((256,), "float32", "local")
         for i in bind("threadIdx.x", 256):
             local[i] = b[i]
         for i in bind("threadIdx.x", 256):
             c[i] = shared[i] + local[i]
         return c
 
-    a = te.placeholder((256, ), dtype='float32', name='a')
-    b = te.placeholder((256, ), dtype='float32', name='b')
+    a = te.placeholder((256,), dtype="float32", name="a")
+    b = te.placeholder((256,), dtype="float32", name="b")
     c = share_vec_add(a, b)
-    func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda')
-    run_and_check(func, ins, outs=outs, target='cuda')
+    func, ins, outs = run_and_check(share_vec_add, [a, b], target="cuda")
+    run_and_check(func, ins, outs=outs, target="cuda")
+
 
 def test_upstream():
     @te.hybrid.script
     def upstream(a):
-        b = output_tensor((20, ), 'float32')
+        b = output_tensor((20,), "float32")
         for i in range(20):
             b[i] = a[i] * i
         return b
 
-    a = te.placeholder((20, ), 'float32')
-    b = te.placeholder((20, ), 'float32')
-    c = te.compute((20, ), lambda x: a[x] + b[x])
+    a = te.placeholder((20,), "float32")
+    b = te.placeholder((20,), "float32")
+    c = te.compute((20,), lambda x: a[x] + b[x])
     d = upstream(c)
     sch = te.create_schedule([c.op, d.op])
     ir = tvm.lower(sch, [a, b, d])
     func = tvm.build(sch, [a, b, d])
-    assert(func)
+    assert func
 
-    a = numpy.random.randn(20).astype('float32')
-    b = numpy.random.randn(20).astype('float32')
-    ref = numpy.zeros((20, ), 'float32')
+    a = numpy.random.randn(20).astype("float32")
+    b = numpy.random.randn(20).astype("float32")
+    ref = numpy.zeros((20,), "float32")
     for i in range(20):
         ref[i] = (a[i] + b[i]) * i
 
     tvm_a = tvm.nd.array(a)
     tvm_b = tvm.nd.array(b)
-    tvm_d = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
+    tvm_d = tvm.nd.array(numpy.zeros((20,)).astype("float32"))
 
     func(tvm_a, tvm_b, tvm_d)
     tvm.testing.assert_allclose(tvm_d.asnumpy(), ref, 1e-5, 1e-5)
 
+
 def test_downstream():
     @te.hybrid.script
     def downstream(a):
-        b = output_tensor((20, ), 'float32')
+        b = output_tensor((20,), "float32")
         for i in range(20):
             b[i] = a[i] * i
         return b
 
-
-    a = te.placeholder((20, ), 'float32')
+    a = te.placeholder((20,), "float32")
     b = downstream(a)
-    c = te.compute((20, ), lambda x: b[x] + 1.0)
+    c = te.compute((20,), lambda x: b[x] + 1.0)
 
     sch = te.create_schedule(c.op)
     module = tvm.build(sch, [a, c])
     assert module
 
-    a = numpy.random.randn(20).astype('float32')
-    ref = numpy.zeros((20, )).astype('float32')
+    a = numpy.random.randn(20).astype("float32")
+    ref = numpy.zeros((20,)).astype("float32")
     for i in range(20):
         ref[i] = (a[i] * i) + 1.0
 
     tvm_a = tvm.nd.array(a)
-    tvm_c = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
+    tvm_c = tvm.nd.array(numpy.zeros((20,)).astype("float32"))
     module(tvm_a, tvm_c)
     tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
 
+
 def test_const_param():
     @te.hybrid.script
     def add_something(a, b):
-        c = output_tensor((11, ), 'int32')
+        c = output_tensor((11,), "int32")
         for i in range(11):
             c[i] = a[i] + b
         return c
 
-    a = te.placeholder((11, ), dtype='int32', name='a')
-    b = tvm.tir.const(11, 'int32')
+    a = te.placeholder((11,), dtype="int32", name="a")
+    b = tvm.tir.const(11, "int32")
     c = add_something(a, b)
     sch = te.create_schedule(c.op)
-    module = tvm.build(sch, [a, c], 'llvm')
-    assert(module)
+    module = tvm.build(sch, [a, c], "llvm")
+    assert module
 
-    np_a = numpy.arange(11).astype('int32')
+    np_a = numpy.arange(11).astype("int32")
     np_b = 11
-    np_c = numpy.zeros((11, )).astype('int32')
+    np_c = numpy.zeros((11,)).astype("int32")
 
     nd_a = tvm.nd.array(np_a)
-    nd_c = tvm.nd.array(numpy.zeros((11, )).astype('int32'))
+    nd_c = tvm.nd.array(numpy.zeros((11,)).astype("int32"))
     module(nd_a, nd_c)
     ref = add_something(np_a, 11)
 
     tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
 
+
 def test_value_index():
     @te.hybrid.script
     def kernel_a(a):
-        b = output_tensor((16, ), 'int32')
-        c = output_tensor((4, 4), 'int32')
+        b = output_tensor((16,), "int32")
+        c = output_tensor((4, 4), "int32")
         for i in range(16):
             b[i] = a[i] + 2
             c[i // 4, i % 4] = a[i] + 1
@@ -598,27 +605,28 @@ def test_value_index():
 
     @te.hybrid.script
     def kernel_b(b, a):
-        c = output_tensor((4, 4), 'int32')
+        c = output_tensor((4, 4), "int32")
         for i in range(4):
             for j in range(4):
                 c[i, j] = a[i * 4 + j] * b[i, j]
         return c
 
-    a = te.placeholder((16, ), 'int32')
+    a = te.placeholder((16,), "int32")
     b, c = kernel_a(a)
     d = kernel_b(c, b)
     sch = te.create_schedule(d.op)
     module = tvm.build(sch, [a, d])
     assert module
 
-    np_a = numpy.arange(16).astype('int32')
+    np_a = numpy.arange(16).astype("int32")
     np_b, np_c = kernel_a(np_a)
     ref = kernel_b(np_c, np_b)
 
-    res = tvm.nd.array(numpy.zeros((4, 4)).astype('int32'))
+    res = tvm.nd.array(numpy.zeros((4, 4)).astype("int32"))
     module(tvm.nd.array(np_a), res)
     tvm.testing.assert_allclose(res.asnumpy(), ref)
 
+
 def test_func_call():
     @te.hybrid.script
     def foo(a, b):
@@ -633,11 +641,12 @@ def test_func_call():
                 d[i, j] = c[i, j] + i * j
         return d
 
-    a = te.placeholder((10, ), name='a')
-    b = te.placeholder((10, ), name='b')
+    a = te.placeholder((10,), name="a")
+    b = te.placeholder((10,), name="b")
     func, ins, outs = run_and_check(foo, [a, b])
     run_and_check(func, ins, outs=outs)
 
+
 def test_bool():
     @te.hybrid.script
     def foo(a):
@@ -649,15 +658,17 @@ def test_bool():
             else:
                 b[i] = 0.0
         return b
-    a = te.placeholder((10, ), name='a')
+
+    a = te.placeholder((10,), name="a")
     func, ins, outs = run_and_check(foo, [a])
     run_and_check(func, ins, outs=outs)
 
+
 def test_const_range():
     @te.hybrid.script
     def foo(a, b):
         c = output_tensor(a.shape, a.dtype)
-        d = output_tensor(a.shape, 'int32')
+        d = output_tensor(a.shape, "int32")
 
         for i in const_range(2):
             for j in const_range(5):
@@ -669,7 +680,7 @@ def test_const_range():
 
         return c, d
 
-    a = te.placeholder((2, 5), name='a', dtype='float32')
+    a = te.placeholder((2, 5), name="a", dtype="float32")
     b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
     func, ins, outs = run_and_check(foo, [a, b])
     run_and_check(func, ins, outs=outs)
@@ -684,7 +695,8 @@ def test_const_range():
             else:
                 c[i - len_b] = a[i - len_b] + b[i - len_b]
         return c
-    a = te.placeholder((5, ), name='a', dtype='int32')
+
+    a = te.placeholder((5,), name="a", dtype="int32")
     b = [1, 2, 3, 4, 5]
     c = goo(a, tvm.runtime.convert(b))
     sch = te.create_schedule(c.op)
@@ -701,11 +713,13 @@ def test_const_range():
                 d += a[i] + b[j]
                 c[i] = d
         return c
-    a = te.placeholder((5, ), name='a', dtype='int32')
+
+    a = te.placeholder((5,), name="a", dtype="int32")
     b = [1, 2, 3, 4, 5]
     func, ins, outs = run_and_check(hoo, [a, b])
     run_and_check(func, ins, outs=outs)
 
+
 def test_schedule():
     @script
     def outer_product(a, b):
@@ -714,8 +728,9 @@ def test_schedule():
             for j in range(64):
                 c[i, j] = a[i] * b[j]
         return c
-    a = te.placeholder((64,), name='a', dtype='float32')
-    b = te.placeholder((64,), name='b', dtype='float32')
+
+    a = te.placeholder((64,), name="a", dtype="float32")
+    b = te.placeholder((64,), name="b", dtype="float32")
     c = outer_product(a, b)
 
     # Test perfect loop split
@@ -733,16 +748,16 @@ def test_schedule():
     assert isinstance(ir, tvm.tir.AttrStmt)
     ir = ir.body
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'i.inner'
+    assert ir.loop_var.name == "i.inner"
     ir = ir.body
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'i.outer'
+    assert ir.loop_var.name == "i.outer"
     ir = ir.body
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'j.outer.outer'
+    assert ir.loop_var.name == "j.outer.outer"
     ir = ir.body
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'j.outer.inner'
+    assert ir.loop_var.name == "j.outer.inner"
     ir = ir.body
     func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
     run_and_check(func, ins, outs=outs)
@@ -754,7 +769,7 @@ def test_schedule():
     assert isinstance(ir, tvm.tir.AttrStmt)
     ir = ir.body
     assert isinstance(ir, tvm.tir.For)
-    assert ir.loop_var.name == 'i.j.fused'
+    assert ir.loop_var.name == "i.j.fused"
     func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
     run_and_check(func, ins, outs=outs)
 
@@ -767,6 +782,7 @@ def test_schedule():
 
     # Test loop binds
 
+
 def test_capture():
     n = 8
 
@@ -776,16 +792,17 @@ def test_capture():
 
     @te.hybrid.script
     def add_something(a):
-        c = output_tensor((constant_tuple[1],), 'int32')
+        c = output_tensor((constant_tuple[1],), "int32")
         for i in range(constant_tuple[1]):
             c[i] = a[i] + constant_list[1][const_value]
         return c
 
-    a = te.placeholder((n, ), dtype='int32', name='a')
+    a = te.placeholder((n,), dtype="int32", name="a")
 
     func, ins, outs = run_and_check(add_something, [a])
     run_and_check(func, ins, outs=outs)
 
+
 def test_array_inputs():
     @script
     def sum_array(inputs):
@@ -795,28 +812,30 @@ def test_array_inputs():
             for j in const_range(n):
                 out[i] += inputs[j][i]
         return out
+
     n = 5
     inputs = []
     for i in range(n):
-        inputs.append(te.placeholder((10,), name='t%s' % i, dtype='float32'))
+        inputs.append(te.placeholder((10,), name="t%s" % i, dtype="float32"))
 
     out = sum_array(tvm.runtime.convert(inputs))
     assert len(out.op.inputs) == n
 
     sch = te.create_schedule(out.op)
-    mod = tvm.build(sch, inputs + [out], target='llvm')
+    mod = tvm.build(sch, inputs + [out], target="llvm")
     assert mod
 
     input_nd = []
     out_ref = numpy.zeros((10,))
     for _ in range(n):
-        arr = numpy.random.uniform(size=(10,)).astype('float32')
+        arr = numpy.random.uniform(size=(10,)).astype("float32")
         input_nd.append(tvm.nd.array(arr))
         out_ref += arr
-    out_nd = tvm.nd.array(numpy.zeros((10,), 'float32'))
+    out_nd = tvm.nd.array(numpy.zeros((10,), "float32"))
     mod(*input_nd, out_nd)
     tvm.testing.assert_allclose(out_nd.asnumpy(), out_ref)
 
+
 if __name__ == "__main__":
     test_outer_product()
     test_fanout()
index c00ee70..316aa6f 100644 (file)
@@ -19,12 +19,13 @@ import tvm
 from tvm import te
 import pickle as pkl
 
+
 def test_schedule_create():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.placeholder((n, l), name='B')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.placeholder((n, l), name="B")
     AA = te.compute((m, l), lambda i, j: A[i, j])
     T = te.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k))
     s = te.create_schedule(T.op)
@@ -40,19 +41,19 @@ def test_schedule_create():
     json_str = tvm.ir.save_json(s)
     s_loaded = tvm.ir.load_json(json_str)
     assert isinstance(s_loaded, tvm.te.schedule.Schedule)
-    assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
+    assert str(s_loaded.outputs[0].body) == str(s.outputs[0].body)
 
     # pickle unpickle
     dump = pkl.dumps(s)
     s_loaded = pkl.loads(dump)
     assert isinstance(s_loaded, tvm.te.schedule.Schedule)
-    assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
+    assert str(s_loaded.outputs[0].body) == str(s.outputs[0].body)
 
 
 def test_reorder():
-    m = te.size_var('m')
-    A = te.placeholder((m,), name='A')
-    T = te.compute(m, lambda i: A[i+1])
+    m = te.size_var("m")
+    A = te.placeholder((m,), name="A")
+    T = te.compute(m, lambda i: A[i + 1])
 
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=10)
@@ -69,9 +70,10 @@ def test_reorder():
     except tvm.error.TVMError:
         pass
 
+
 def test_split():
-    m = te.size_var('m')
-    A = te.placeholder((m,), name='A')
+    m = te.size_var("m")
+    A = te.placeholder((m,), name="A")
     T = te.compute((m,), lambda i: A[i])
 
     s = te.create_schedule(T.op)
@@ -80,9 +82,9 @@ def test_split():
 
 
 def test_tile():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
@@ -91,9 +93,9 @@ def test_tile():
 
 
 def test_fuse():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
@@ -102,10 +104,11 @@ def test_fuse():
     assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
     assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
 
+
 def test_fuse_with_split():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
@@ -115,10 +118,11 @@ def test_fuse_with_split():
     assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
     assert tuple(s[T].leaf_iter_vars) == (xo, fused)
 
+
 def test_fuse_with_out_of_order_axis():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
@@ -126,19 +130,20 @@ def test_fuse_with_out_of_order_axis():
     xo, xi = s[T].split(T.op.axis[0], factor=10)
 
     with pytest.raises(RuntimeError):
-            fused = s[T].fuse(xo, y) # should throw here
+        fused = s[T].fuse(xo, y)  # should throw here
+
 
 def test_fuse_with_out_of_order_axis_with_reorder():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
     y = T.op.axis[1]
     xo, xi = s[T].split(T.op.axis[0], factor=10)
     s[T].reorder(y, xo, xi)
-    fused = s[T].fuse(y, xo) # should be ok
+    fused = s[T].fuse(y, xo)  # should be ok
 
     s = te.create_schedule(T.op)
     y = T.op.axis[1]
@@ -146,11 +151,12 @@ def test_fuse_with_out_of_order_axis_with_reorder():
     s[T].reorder(y, xo, xi)
 
     with pytest.raises(RuntimeError):
-        fused = s[T].fuse(y, xi) # should throw here
+        fused = s[T].fuse(y, xi)  # should throw here
+
 
 def test_singleton():
-    A = te.placeholder((), name='A')
-    T = te.compute((), lambda : A() + 1)
+    A = te.placeholder((), name="A")
+    T = te.compute((), lambda: A() + 1)
     s = te.create_schedule(T.op)
     fused = s[T].fuse()
     assert any(isinstance(x, tvm.te.schedule.Singleton) for x in s[T].relations)
@@ -161,9 +167,9 @@ def test_singleton():
 
 
 def test_vectorize():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     T = te.compute((m, n), lambda i, j: A[i, j])
 
     s = te.create_schedule(T.op)
@@ -177,16 +183,17 @@ def test_vectorize():
 
 
 def test_vectorize_commreduce():
-    V = te.placeholder((128,), name='V')
-    ax = te.reduce_axis((0, 128), name='ax')
+    V = te.placeholder((128,), name="V")
+    ax = te.reduce_axis((0, 128), name="ax")
     O = te.compute((1,), lambda _: te.sum(V[ax], axis=[ax]))
     s = te.create_schedule(O.op)
     with pytest.raises(RuntimeError):
-        s[O].vectorize(ax) # should throw here
+        s[O].vectorize(ax)  # should throw here
+
 
 def test_pragma():
     m = 100
-    A = te.placeholder((m,), name='A')
+    A = te.placeholder((m,), name="A")
     T = te.compute((m,), lambda i: A[i])
 
     s = te.create_schedule(T.op)
@@ -199,89 +206,93 @@ def test_pragma():
 
 
 def test_rfactor():
-    n = te.size_var('n')
+    n = te.size_var("n")
     k1 = te.reduce_axis((0, n), name="k1")
     k2 = te.reduce_axis((0, n), name="k2")
-    A = te.placeholder((n, n, n), name='A')
-    B = te.compute((n, ), lambda i: te.sum(A[i, k1, k2], axis=[k1, k2]))
+    A = te.placeholder((n, n, n), name="A")
+    B = te.compute((n,), lambda i: te.sum(A[i, k1, k2], axis=[k1, k2]))
     # normal schedule
     s = te.create_schedule(B.op)
     BF = s.rfactor(B, k1)
-    assert(tuple(BF.shape) == (n, n))
-    assert(set(BF.op.body[0].axis) == set([k2]))
-    assert(s[B].op.body[0].axis[0].dom.extent == n)
-    assert(len(s[B].all_iter_vars) == 2)
+    assert tuple(BF.shape) == (n, n)
+    assert set(BF.op.body[0].axis) == set([k2])
+    assert s[B].op.body[0].axis[0].dom.extent == n
+    assert len(s[B].all_iter_vars) == 2
     # schedule with splot
     s = te.create_schedule(B.op)
     ko, ki = s[B].split(k1, factor=4)
     xo, xi = s[B].split(B.op.axis[0], factor=8)
     BF = s.rfactor(B, ki)
-    assert(BF.shape[0].value == 4)
-    assert(BF.shape[1] == n)
-    assert(BF.op.body[0].axis[0] ==  k2)
-    assert(BF.op.body[0].axis[1].var ==  ko.var)
-    assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
+    assert BF.shape[0].value == 4
+    assert BF.shape[1] == n
+    assert BF.op.body[0].axis[0] == k2
+    assert BF.op.body[0].axis[1].var == ko.var
+    assert s[B].op.body[0].axis[0].dom.extent.value == 4
     # schedule with factor_axis
     s = te.create_schedule(B.op)
     ko, ki = s[B].split(k1, factor=4)
     xo, xi = s[B].split(B.op.axis[0], factor=8)
     BF = s.rfactor(B, ki, 1)
-    assert(n == BF.shape[0])
-    assert(BF.shape[1].value == 4)
-    assert(BF.op.body[0].axis[0] ==  k2)
-    assert(BF.op.body[0].axis[1].var ==  ko.var)
-    assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
+    assert n == BF.shape[0]
+    assert BF.shape[1].value == 4
+    assert BF.op.body[0].axis[0] == k2
+    assert BF.op.body[0].axis[1].var == ko.var
+    assert s[B].op.body[0].axis[0].dom.extent.value == 4
+
 
 def test_tensor_intrin():
     n = 16
-    x = te.placeholder((n,), name='x')
-    y = te.placeholder((n,), name='y')
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    x = te.placeholder((n,), name="x")
+    y = te.placeholder((n,), name="y")
+    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+
     def intrin_func(ins, outs):
-        assert(isinstance(ins[0], tvm.te.schedule.Buffer))
-        assert(ins[0].shape[0].value == n)
+        assert isinstance(ins[0], tvm.te.schedule.Buffer)
+        assert ins[0].shape[0].value == n
         return tvm.tir.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0])
+
     intrin = te.decl_tensor_intrin(z.op, intrin_func)
     assert intrin.op == z.op
     assert intrin.reduce_init is None
     assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
-    assert(intrin.buffers[0].shape[0].value == n)
+    assert intrin.buffers[0].shape[0].value == n
     m = 32
-    x = te.placeholder((m,), name='x')
-    y = te.placeholder((m,), name='y')
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    x = te.placeholder((m,), name="x")
+    y = te.placeholder((m,), name="y")
+    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
     s = te.create_schedule(z.op)
     xo, xi = s[z].split(z.op.axis[0], factor=n)
     s[z].tensorize(xi, intrin)
-    assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
-    assert(s[z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized)
+    assert s[z].iter_var_attrs[xi].tensor_intrin == intrin
+    assert s[z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized
+
 
 def test_tensor_intrin_scalar_params():
     n = te.size_var("n")
-    x = te.placeholder((n,), name='x')
+    x = te.placeholder((n,), name="x")
     v = te.size_var("v")
     w = te.size_var("w")
-    z = te.compute((n,), lambda i: x[i]*v + w, name='z')
+    z = te.compute((n,), lambda i: x[i] * v + w, name="z")
 
     def intrin_func(ins, outs, sp):
-        assert(isinstance(ins[0], tvm.te.schedule.Buffer))
-        assert(ins[0].shape[0] == n)
-        assert(sp[0] == v)
-        assert(sp[1] == w)
+        assert isinstance(ins[0], tvm.te.schedule.Buffer)
+        assert ins[0].shape[0] == n
+        assert sp[0] == v
+        assert sp[1] == w
         return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
 
-    intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w], default_buffer_params={
-        "offset_factor": 1
-    })
+    intrin = te.decl_tensor_intrin(
+        z.op, intrin_func, scalar_params=[v, w], default_buffer_params={"offset_factor": 1}
+    )
     assert intrin.op == z.op
     assert intrin.reduce_init is None
     assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
-    assert(intrin.buffers[0].shape[0] == n)
+    assert intrin.buffers[0].shape[0] == n
     assert tuple(intrin.scalar_params) == tuple((v, w))
 
-    A = te.placeholder((10,10), name='A')
+    A = te.placeholder((10, 10), name="A")
     # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
-    C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
+    C = te.compute((10, 10), lambda i, j: intrin(i * i, A[i, j], i + j), name="C")
     s = te.create_schedule(C.op)
     stmt = tvm.lower(s, [A, C])["main"].body
     assert isinstance(stmt.body.body, tvm.tir.Evaluate)
@@ -289,25 +300,27 @@ def test_tensor_intrin_scalar_params():
     assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
     assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"
 
+
 def test_legalize_invalid_attach():
-    A = te.compute((10, 10), lambda i, j: 1.0, name='A')
-    B = te.compute((10, 10), lambda i, j: A[i][j], name='B')
+    A = te.compute((10, 10), lambda i, j: 1.0, name="A")
+    B = te.compute((10, 10), lambda i, j: A[i][j], name="B")
 
     # Case 1: Split an axis which is the target of a compute_at
     s = te.create_schedule([B.op])
     s[A].compute_at(s[B], B.op.axis[1])
     s[B].split(B.op.axis[1], 2)
 
-    stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+    stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body
     assert isinstance(stmt.body.body, tvm.tir.stmt.For)
 
     # Case 2: Fuse an axis which is the target of a compute_at
     s = te.create_schedule([B.op])
     s[A].compute_at(s[B], B.op.axis[1])
     s[B].fuse(B.op.axis[0], B.op.axis[1])
-    stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+    stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body
     assert isinstance(stmt, tvm.tir.stmt.For)
 
+
 if __name__ == "__main__":
     test_singleton()
     test_pragma()
index e226b7a..de2178c 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_bound1():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule([A2.op])
     xo, xi = s[A2].split(s[A2].op.axis[0], 8)
     s[A1].compute_at(s[A2], xo)
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(bounds[A1.op.axis[0]].extent.value == 8)
+    assert bounds[A1.op.axis[0]].extent.value == 8
+
 
 def test_bound2():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
     s = te.create_schedule(A2.op)
     xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
     # test normalize not affecting schedule
@@ -44,15 +46,16 @@ def test_bound2():
     s[A1].compute_at(s[A2], yo)
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(bounds[A1.op.axis[0]].extent.value == 8)
-    assert(bounds[A1.op.axis[1]].extent.value == 8)
+    assert bounds[A1.op.axis[0]].extent.value == 8
+    assert bounds[A1.op.axis[1]].extent.value == 8
+
 
 def test_bound3():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     s[A1].set_scope("shared")
@@ -67,40 +70,43 @@ def test_bound3():
 
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(bounds[A1.op.axis[0]].extent.value==32)
-    assert(bounds[A1.op.axis[1]].extent.value==16)
+    assert bounds[A1.op.axis[0]].extent.value == 32
+    assert bounds[A1.op.axis[1]].extent.value == 16
+
 
 def test_bound_split_ext_less_than_factor():
     m = 8
-    I = te.placeholder((m,), name='I')
-    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
-    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    I = te.placeholder((m,), name="I")
+    EF = te.compute((m,), lambda i: I[i] * 2, name="EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name="E")
     s = te.create_schedule([E.op])
-    xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
+    xo, xi = s[E].split(s[E].op.axis[0], factor=32)
     s[EF].compute_at(s[E], xo)
 
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     assert bounds[xi].extent.value == m
 
+
 def test_bound_split_ext_less_than_naprts():
     m = 8
-    I = te.placeholder((m,), name='I')
-    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
-    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    I = te.placeholder((m,), name="I")
+    EF = te.compute((m,), lambda i: I[i] * 2, name="EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name="E")
     s = te.create_schedule([E.op])
-    xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
+    xo, xi = s[E].split(s[E].op.axis[0], nparts=32)
     s[EF].compute_at(s[E], xo)
 
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     assert bounds[xo].extent.value == m
 
+
 def test_bound_split_divisible():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((8 * m, l), name='A')
-    B = te.compute((8 * m, l), lambda i, j: A[i, j], name='B')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((8 * m, l), name="A")
+    B = te.compute((8 * m, l), lambda i, j: A[i, j], name="B")
     s = te.create_schedule(B.op)
     xo, xi = s[B].split(B.op.axis[0], 8)
     bounds = tvm.te.schedule.InferBound(s)
@@ -108,12 +114,13 @@ def test_bound_split_divisible():
     assert bounds[xo].extent == m
     assert bounds[xi].extent.value == 8
 
+
 def test_bound_tile_divisible():
-    m = te.var('m')
-    l = te.var('l')
+    m = te.var("m")
+    l = te.var("l")
     shape = (8 * m, 32 * l)
-    A = te.placeholder(shape, name='A')
-    B = te.compute(shape, lambda i, j: A[i, j], name='B')
+    A = te.placeholder(shape, name="A")
+    B = te.compute(shape, lambda i, j: A[i, j], name="B")
     s = te.create_schedule(B.op)
     xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
     bounds = tvm.te.schedule.InferBound(s)
@@ -123,13 +130,14 @@ def test_bound_tile_divisible():
     assert bounds[yo].extent == l
     assert bounds[yi].extent.value == 32
 
+
 def test_bound_fusesplit1():
-    m = te.var('m')
-    l = te.var('l')
-    split1 = te.var('s')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    split1 = te.var("s")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
@@ -139,28 +147,34 @@ def test_bound_fusesplit1():
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     idxdiv = tvm.tir.indexdiv
-    tvm.testing.assert_prim_expr_equal(
-        bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l))
+    tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l))
 
-    expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
+    expected_extent = idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1
     for i in range(1, 6):
         for j in range(1, 6):
             for k in range(1, 6):
-                vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")})
+                vars = tvm.runtime.convert(
+                    {
+                        split1: tvm.tir.const(i, "int32"),
+                        l: tvm.tir.const(j, "int32"),
+                        xo.var: tvm.tir.const(k, "int32"),
+                    }
+                )
                 tvm.testing.assert_prim_expr_equal(
                     tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars),
-                    tvm.tir.stmt_functor.substitute(expected_extent, vars)
+                    tvm.tir.stmt_functor.substitute(expected_extent, vars),
                 )
 
     tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l)
 
+
 def test_bound_fusesplit2():
     m = te.var("m")
     l = tvm.runtime.convert(6)
     split = tvm.runtime.convert(3)
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
@@ -170,18 +184,26 @@ def test_bound_fusesplit2():
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")})
-    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2)
-    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3)
-    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1)
-    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3)
+    tvm.testing.assert_prim_expr_equal(
+        tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2
+    )
+    tvm.testing.assert_prim_expr_equal(
+        tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3
+    )
+    tvm.testing.assert_prim_expr_equal(
+        tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1
+    )
+    tvm.testing.assert_prim_expr_equal(
+        tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3
+    )
 
 
 def test_bound_warp():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     s[A1].set_scope("warp")
@@ -196,7 +218,8 @@ def test_bound_warp():
     s[A1].bind(xi, tx)
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(bounds[A1.op.axis[0]].extent.value==16)
+    assert bounds[A1.op.axis[0]].extent.value == 16
+
 
 def test_bound_scan():
     m = te.var("m")
@@ -204,7 +227,7 @@ def test_bound_scan():
     X = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
     s_state = te.placeholder((m, n))
     s_init = te.compute((1, n), lambda _, i: X[0, i])
-    s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+    s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
     s_scan = tvm.te.scan(s_init, s_update, s_state)
 
     assert tuple(s_scan.shape) == (m, n)
@@ -217,40 +240,47 @@ def test_bound_scan():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
     assert bounds[XX.op.axis[1]].extent.value == 4
 
+
 def test_bound_conv1d():
-    n = te.var('n')
-    A = te.compute((n+2), lambda i: 1,  name='A')
+    n = te.var("n")
+    A = te.compute((n + 2), lambda i: 1, name="A")
+
     def computeB(ii):
         i = ii + 1
-        return A[i-1] + A[i] + A[i+1]
-    B = te.compute(n, computeB, name='B')
+        return A[i - 1] + A[i] + A[i + 1]
+
+    B = te.compute(n, computeB, name="B")
     s = te.create_schedule(B.op)
     s[A].compute_at(s[B], B.op.axis[0])
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
-    assert(bounds[A.op.axis[0]].extent.value == 3)
+    assert bounds[A.op.axis[0]].extent.value == 3
+
 
 def test_bound_blur():
     n = tvm.runtime.convert(12)
-    A = te.compute((n, n), lambda i, j: 1, name='A')
+    A = te.compute((n, n), lambda i, j: 1, name="A")
+
     def computeB(ii, jj):
         # set the correct center
         i = ii + 1
         j = jj + 1
-        return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1]
-    B = te.compute((n-2, n-2), computeB, name='B')
+        return A[i][j] + A[i - 1][j] + A[i + 1][j] + A[i][j + 1] + A[i][j - 1]
+
+    B = te.compute((n - 2, n - 2), computeB, name="B")
     s = te.create_schedule(B.op)
     s[A].compute_at(s[B], B.op.axis[1])
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
-    assert(bounds[A.op.axis[0]].extent.value == 3)
-    assert(bounds[A.op.axis[1]].extent.value == 3)
+    assert bounds[A.op.axis[0]].extent.value == 3
+    assert bounds[A.op.axis[1]].extent.value == 3
+
 
 def test_bound_rfactor():
-    n = te.var('n')
-    A = te.placeholder((n,), name='A')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
     k = te.reduce_axis((0, n))
-    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, where=(i>1)), name='B')
+    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, where=(i > 1)), name="B")
     # schedule
     s = te.create_schedule(B.op)
     kf, ki = s[B].split(k, nparts=4)
@@ -258,8 +288,9 @@ def test_bound_rfactor():
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
 
-    assert(bounds[BF.op.axis[0]].extent.value == 4)
-    assert(bounds[BF.op.axis[1]].extent.value == 1)
+    assert bounds[BF.op.axis[0]].extent.value == 4
+    assert bounds[BF.op.axis[1]].extent.value == 1
+
 
 def test_bound_group_schedule():
     m = te.var("m")
@@ -277,6 +308,7 @@ def test_bound_group_schedule():
     assert bounds[x.op.axis[0]].extent.value == 1
     assert bounds[x.op.axis[1]].extent == n
 
+
 def test_bound_nest_group():
     m = te.var("m")
     n = te.var("n")
@@ -299,11 +331,11 @@ def test_bound_nest_group():
 
 
 def test_bound_nest_thread():
-    m = te.var('m')
-    A = te.placeholder((m), name='A')
-    A1 = te.compute((m,), lambda i: A[i], name='A1')
-    A2 = te.compute((m,), lambda i: A1[i] + 2, name='A2')
-    A3 = te.compute((m,), lambda i: A2[i] + 3, name='A3')
+    m = te.var("m")
+    A = te.placeholder((m), name="A")
+    A1 = te.compute((m,), lambda i: A[i], name="A1")
+    A2 = te.compute((m,), lambda i: A1[i] + 2, name="A2")
+    A3 = te.compute((m,), lambda i: A2[i] + 3, name="A3")
 
     s = te.create_schedule(A3.op)
     s[A2].set_scope("shared")
@@ -320,20 +352,18 @@ def test_bound_nest_thread():
     s[A1].compute_at(s[A3], tx)
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
-    assert(bounds[A1.op.axis[0]].extent.value==1)
-    assert(bounds[A2.op.axis[0]].extent.value==32)
-    assert(bounds[A3.op.axis[0]].extent == m)
+    assert bounds[A1.op.axis[0]].extent.value == 1
+    assert bounds[A2.op.axis[0]].extent.value == 32
+    assert bounds[A3.op.axis[0]].extent == m
+
 
 def test_gemm_bound():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n, n), name='A')
-    B = te.placeholder((n, n), name='B')
-    k = te.reduce_axis((0, n), name='k')
-    C = te.compute(
-        (n, n),
-        lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k),
-        name='CC')
+    A = te.placeholder((n, n), name="A")
+    B = te.placeholder((n, n), name="B")
+    k = te.reduce_axis((0, n), name="k")
+    C = te.compute((n, n), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), name="CC")
     # schedule
     s = te.create_schedule(C.op)
     xtile, ytile = 32, 32
@@ -376,18 +406,18 @@ def test_gemm_bound():
     s[BB].bind(tx, thread_x)
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
-    assert(bounds[BB.op.axis[0]].extent.value==64)
-    assert(bounds[AA.op.axis[0]].extent.value==64)
-    assert(bounds[CC.op.axis[0]].extent.value == 8)
-    assert(bounds[CC.op.axis[1]].extent.value == 8)
+    assert bounds[BB.op.axis[0]].extent.value == 64
+    assert bounds[AA.op.axis[0]].extent.value == 64
+    assert bounds[CC.op.axis[0]].extent.value == 8
+    assert bounds[CC.op.axis[1]].extent.value == 8
 
 
 def test_bound_tensor_compute_op():
     def intrin_test():
         m1 = te.var("m1")
         n1 = te.var("n1")
-        a = te.placeholder((m1, n1), name='a')
-        c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c')
+        a = te.placeholder((m1, n1), name="a")
+        c = te.compute((1, n1), lambda i, j: a[0, j] + a[1, j] + a[2, j], name="c")
 
         Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1)
         Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1)
@@ -395,21 +425,27 @@ def test_bound_tensor_compute_op():
         def intrin_func(ins, outs):
             aa = ins[0]
             cc = outs[0]
+
             def _body():
                 ib = tvm.tir.ir_builder.create()
-                ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
+                ib.emit(
+                    tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))
+                )
                 return ib.get()
+
             return _body()
-        return te.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb})
+
+        return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, c: Cb})
 
     test_func = intrin_test()
-    A = te.placeholder((20,20), name='A')
-    B = te.compute(A.shape, lambda i,j : A[i,j], name='B')
-    C = te.compute((10, 20), lambda i : test_func(B[i:10, 0:20]), name='C')
+    A = te.placeholder((20, 20), name="A")
+    B = te.compute(A.shape, lambda i, j: A[i, j], name="B")
+    C = te.compute((10, 20), lambda i: test_func(B[i:10, 0:20]), name="C")
     s = te.create_schedule(C.op)
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(bounds[B.op.axis[0]].extent.value == 10)
+    assert bounds[B.op.axis[0]].extent.value == 10
+
 
 def test_bound_simplification_failure():
     # Check that the bounds are not expanded
@@ -423,15 +459,17 @@ def test_bound_simplification_failure():
         if not bounds[A.op.axis[0]].extent.value <= 2:
             print(stmt)
             assert bounds[A.op.axis[0]].extent.value <= 2
+
     tdiv = tvm.tir.truncdiv
     # These are hard to simplify, moreover we don't simplify them
-    _check(te.compute((10,), lambda i: A[tvm.te.min(3*i, 4*i) + tvm.te.min(-3*i, -2*i)]))
-    _check(te.compute((10,), lambda i: A[tvm.te.min(3*i, 4*i) + tvm.te.max(-3*i, -4*i)]))
-    _check(te.compute((10,), lambda i: A[-2*tdiv(i,2) - tvm.te.min(i, 0-i)]))
+    _check(te.compute((10,), lambda i: A[tvm.te.min(3 * i, 4 * i) + tvm.te.min(-3 * i, -2 * i)]))
+    _check(te.compute((10,), lambda i: A[tvm.te.min(3 * i, 4 * i) + tvm.te.max(-3 * i, -4 * i)]))
+    _check(te.compute((10,), lambda i: A[-2 * tdiv(i, 2) - tvm.te.min(i, 0 - i)]))
     _check(te.compute((10,), lambda i: A[i + (0 - i)]))
     # This would cause out of bounds, but we nevertheless include it
     _check(te.compute((10,), lambda i: A[i]))
 
+
 if __name__ == "__main__":
     test_bound_nest_thread()
     test_bound1()
index 3893bb6..039fe08 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_bound_tile_mod():
     def compute(M_tiles, N_tiles, factor, dtype):
         # Algo
         M = M_tiles * factor
         N = N_tiles * factor
 
-        A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
-        C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
+        A = tvm.te.placeholder((N, M), name="A", dtype=dtype)
+        C = tvm.te.compute((N, M), lambda n, m: A[n, m], name="C")
         s = tvm.te.create_schedule(C.op)
 
         return s, A, C
@@ -37,7 +38,7 @@ def test_bound_tile_mod():
         nio, nii = s[C].split(ni, 2)
         n = s[C].fuse(nii, mi)
         C_shared = s.cache_write(C, "shared")
-        bn, bm, ni, mi = C_shared.op.axis       
+        bn, bm, ni, mi = C_shared.op.axis
         s[C_shared].storage_align(ni, factor * 2, padding)
 
         n, m = s[C].op.axis
@@ -51,10 +52,11 @@ def test_bound_tile_mod():
     s, A, C = compute(2, 2, 128, "float16")
     s = schedule(s, 128, 8, A, C)
     bounds = tvm.te.schedule.InferBound(s)
-    check = (bounds[s.stages[2].op.axis[2]].extent == 16)
-    if(not check):
+    check = bounds[s.stages[2].op.axis[2]].extent == 16
+    if not check:
         print(tvm.lower(s, [A, C], simple_mode=True))
-    assert(check)
+    assert check
+
 
 if __name__ == "__main__":
     test_bound_tile_mod()
index 7d11020..05ca9fd 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_scan():
     m = te.var("m")
     n = te.var("n")
@@ -36,13 +37,14 @@ def test_scan():
         s = te.create_schedule(s_scan.op)
         s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
         apath = tvm.te.schedule.CreateAttachPath(s)
-        assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
-        assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis]))
+        assert tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])
+        assert tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis])
 
     def test_fix_pt():
         body = tvm.te.schedule.ScanGetBody(s_scan.op)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
-        assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
+        assert fxpt[s_scan.spatial_axis_[0]].value != 0
+
 
 def test_scan_fix_point():
     m = te.var("m")
@@ -53,42 +55,47 @@ def test_scan_fix_point():
     s_init = te.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init")
 
     def test_scan0():
-        s_update = te.compute((l, m, n),
-                               lambda t, i, j: x[t, j, i]  + s_state[t-1, i, j], name="update")
+        s_update = te.compute(
+            (l, m, n), lambda t, i, j: x[t, j, i] + s_state[t - 1, i, j], name="update"
+        )
         s_scan = tvm.te.scan(s_init, s_update, s_state)
         body = tvm.te.schedule.ScanGetBody(s_scan.op)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
-        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
-        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)
+        assert fxpt[s_scan.op.spatial_axis_[0]].value == 1
+        assert fxpt[s_scan.op.spatial_axis_[1]].value == 1
 
     def test_scan1():
-        s_update = te.compute((l, m, n),
-                               lambda t, i, j: x[t, j, i]  + s_state[t-1, j, i], name="update")
+        s_update = te.compute(
+            (l, m, n), lambda t, i, j: x[t, j, i] + s_state[t - 1, j, i], name="update"
+        )
         s_scan = tvm.te.scan(s_init, s_update, s_state)
         body = tvm.te.schedule.ScanGetBody(s_scan.op)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
-        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
-        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+        assert fxpt[s_scan.op.spatial_axis_[0]].value == 0
+        assert fxpt[s_scan.op.spatial_axis_[1]].value == 0
 
     def test_scan3_not_exact_reach():
-        s_h1 = te.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1")
-        s_h2 = te.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1")
-        s_update = te.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
+        s_h1 = te.compute((l, n, m), lambda t, j, i: s_state[t - 1, i, j], name="h1")
+        s_h2 = te.compute((l, m, n), lambda t, i, j: s_state[t - 1, i, 10] * 2, name="h1")
+        s_update = te.compute(
+            (l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update"
+        )
         s_scan = tvm.te.scan(s_init, s_update, s_state)
         body = tvm.te.schedule.ScanGetBody(s_scan.op)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
-        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
-        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+        assert fxpt[s_scan.op.spatial_axis_[0]].value == 1
+        assert fxpt[s_scan.op.spatial_axis_[1]].value == 0
 
     def test_scan4_reach_other():
-        s_h1 = te.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1")
-        s_h2 = te.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1")
-        s_update = te.compute((l, m, n),
-                               lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
+        s_h1 = te.compute((l, n, m), lambda t, j, i: s_state[t - 1, j, j], name="h1")
+        s_h2 = te.compute((l, m, n), lambda t, i, j: s_state[t - 1, i, j] * 2, name="h1")
+        s_update = te.compute(
+            (l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update"
+        )
         s_scan = tvm.te.scan(s_init, s_update, s_state)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
-        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
-        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+        assert fxpt[s_scan.op.spatial_axis_[0]].value == 0
+        assert fxpt[s_scan.op.spatial_axis_[1]].value == 0
 
     def test_scan5_multi_output():
         m = te.var("m")
@@ -99,14 +106,12 @@ def test_scan_fix_point():
         s2 = te.placeholder((m, n))
         s1_init = te.compute((1, n), lambda _, i: x1[0, i])
         s2_init = te.compute((1, n), lambda _, i: x2[0, i])
-        s1_update = te.compute((m, n), lambda t, i: s1[t-1, i] +  x1[t, i])
-        s2_update = te.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
-        r0, r1 = tvm.te.scan([s1_init, s2_init],
-                          [s1_update, s2_update],
-                          [s1, s2])
+        s1_update = te.compute((m, n), lambda t, i: s1[t - 1, i] + x1[t, i])
+        s2_update = te.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i])
+        r0, r1 = tvm.te.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2])
         body = tvm.te.schedule.ScanGetBody(r0.op)
         fxpt = tvm.te.schedule.ScanFixPointAnalysis(r0.op)
-        assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
+        assert fxpt[r1.op.spatial_axis_[0]].value == 1
 
     test_scan0()
     test_scan1()
@@ -114,10 +119,11 @@ def test_scan_fix_point():
     test_scan4_reach_other()
     test_scan5_multi_output()
 
+
 def test_create_read_graph():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
     A1 = te.compute((m, l), lambda i, j: A[i, j])
     A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3)
 
@@ -126,8 +132,8 @@ def test_create_read_graph():
     assert g[A2.op][0] == A1
     assert g[A1.op][0] == A
     post_order = tvm.te.schedule.PostDFSOrder([A2.op], g)
-    assert(post_order[0] == A.op)
-    assert(post_order[1] == A1.op)
+    assert post_order[0] == A.op
+    assert post_order[1] == A1.op
 
 
 if __name__ == "__main__":
index 23c7486..abdf81d 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_lstm_cell_inline():
     num_step = 128
     num_input = 256
@@ -29,35 +30,41 @@ def test_lstm_cell_inline():
     # h: output hidden state, c: cell state.
     s_state_h = te.placeholder((num_step, batch_size, num_hidden))
     s_state_c = te.placeholder((num_step, batch_size, num_hidden))
-    s_init_c = te.compute((1, batch_size, num_hidden),
-                           lambda *i: 0.0, name="init_c")
-    s_init_h = te.compute((1, batch_size, num_hidden),
-                           lambda *i: 0.0, name="init_h")
+    s_init_c = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c")
+    s_init_h = te.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h")
     # LSTM transition
     k = te.reduce_axis((0, num_input), name="ki2h")
     s_i2h = te.compute(
         (num_step, 4, batch_size, num_hidden),
         lambda t, x, i, j: te.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
-        name="s_i2h")
+        name="s_i2h",
+    )
     k = te.reduce_axis((0, num_hidden), name="ki2h")
     s_h2h = te.compute(
         (num_step, 4, batch_size, num_hidden),
         lambda t, x, i, j: te.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
-        name="s_h2h")
+        name="s_h2h",
+    )
     # Gate rules
-    gates = te.compute(s_i2h.shape, lambda *i:
-                        s_i2h(*i) + s_h2h(*i), name="gates")
+    gates = te.compute(s_i2h.shape, lambda *i: s_i2h(*i) + s_h2h(*i), name="gates")
     gshape = (num_step, batch_size, num_hidden)
     in_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, 0, i, j]), name="in_gate")
-    in_transform = te.compute(gshape, lambda t, i, j: te.tanh(gates[t, 1, i, j]), name="in_transform")
-    forget_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, 2, i, j]), name="forget_gate")
+    in_transform = te.compute(
+        gshape, lambda t, i, j: te.tanh(gates[t, 1, i, j]), name="in_transform"
+    )
+    forget_gate = te.compute(
+        gshape, lambda t, i, j: te.sigmoid(gates[t, 2, i, j]), name="forget_gate"
+    )
     out_gate = te.compute(gshape, lambda t, i, j: te.sigmoid(gates[t, 3, i, j]), name="out_gate")
-    next_c = te.compute(gshape,
-                         lambda t, i, j:
-                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
-                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
-    next_h = te.compute(gshape,
-                         lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h")
+    next_c = te.compute(
+        gshape,
+        lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j]
+        + in_gate[t, i, j] * in_transform[t, i, j],
+        name="next_c",
+    )
+    next_h = te.compute(
+        gshape, lambda t, i, j: out_gate[t, i, j] * te.tanh(next_c[t, i, j]), name="next_h"
+    )
     update_c = te.compute(gshape, lambda *i: next_c(*i), name="update_c")
     update_h = te.compute(gshape, lambda *i: next_h(*i), name="update_h")
     # schedule
@@ -66,7 +73,8 @@ def test_lstm_cell_inline():
         [update_h, update_c],
         [s_state_h, s_state_c],
         inputs=[X],
-        name="lstm_scan")
+        name="lstm_scan",
+    )
     # schedule
     s = te.create_schedule(scan_h.op)
     # Inline gate computations
@@ -78,5 +86,6 @@ def test_lstm_cell_inline():
     # verify we can lower correctly
     tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
 
+
 if __name__ == "__main__":
     test_lstm_cell_inline()
index 3f93c77..1555974 100644 (file)
@@ -18,26 +18,26 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 def test_schedule0():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
     s = te.create_schedule(A1.op)
 
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [A, A1], stmt, None)
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
     assert isinstance(func, tvm.tir.PrimFunc)
 
 
 def test_schedule1():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
 
     s = te.create_schedule(A1.op)
     xo, xi = s[A1].split(A1.op.axis[0], 8)
@@ -46,17 +46,16 @@ def test_schedule1():
     assert isinstance(bounds, tvm.container.Map)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [A, A1], stmt, None)
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
     assert isinstance(func, tvm.tir.PrimFunc)
 
 
 def test_schedule2():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     xo, xi = s[A2].split(A2.op.axis[0], 8)
@@ -64,8 +63,7 @@ def test_schedule2():
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [A, A2], stmt, None)
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
     assert isinstance(func, tvm.tir.PrimFunc)
 
 
@@ -75,7 +73,7 @@ def test_schedule_scan():
     x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
     s_state = te.placeholder((m, n))
     s_init = te.compute((1, n), lambda _, i: x[0, i])
-    s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
+    s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i])
     res = tvm.te.scan(s_init, s_update, s_state)
 
     assert tuple(res.shape) == (m, n)
@@ -83,27 +81,27 @@ def test_schedule_scan():
     s = s.normalize()
     ir = tvm.lower(s, [s_state], simple_mode=True)
     bounds = tvm.te.schedule.InferBound(s)
-    assert(bounds[res.op.scan_axis].min.value == 1)
+    assert bounds[res.op.scan_axis].min.value == 1
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
 
-
 def test_inline_multi_reduce():
     def argmax_comp(x, y):
         idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
         val = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
         return idx, val
+
     def argmax_init(idx_typ, val_typ):
         return tvm.tir.const(-1, idx_typ), tvm.te.min_value(val_typ)
 
-    argmax = te.comm_reducer(argmax_comp, argmax_init, name='argmax')
-    m = te.var('m')
-    n = te.var('n')
-    val = te.placeholder((m, n), name='val', dtype='float32')
-    val1 = te.compute((m, n), lambda i, j: val[i, j]+1, name='val1')
-    val2 = te.compute((m, n), lambda i, j: te.exp(val1[i, j]), name='val2')
-    k = te.reduce_axis((0, n), 'k')
-    T_idx, T_val = te.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
+    argmax = te.comm_reducer(argmax_comp, argmax_init, name="argmax")
+    m = te.var("m")
+    n = te.var("n")
+    val = te.placeholder((m, n), name="val", dtype="float32")
+    val1 = te.compute((m, n), lambda i, j: val[i, j] + 1, name="val1")
+    val2 = te.compute((m, n), lambda i, j: te.exp(val1[i, j]), name="val2")
+    k = te.reduce_axis((0, n), "k")
+    T_idx, T_val = te.compute((m,), lambda i: argmax((k.var, val2[i, k]), axis=k), name="T")
     s = te.create_schedule(T_idx.op)
     s[val1].compute_inline()
     s = s.normalize()
@@ -112,13 +110,13 @@ def test_inline_multi_reduce():
 
 
 def test_auto_inline():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m, n), name='A')
-    B = te.placeholder((m, n), name='B')
-    C = te.placeholder((m, n), name='C')
-    T1 = te.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='T1')
-    T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m, n), name="A")
+    B = te.placeholder((m, n), name="B")
+    C = te.placeholder((m, n), name="C")
+    T1 = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="T1")
+    T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name="T2")
 
     s = te.create_schedule(T2.op)
     tvm.te.schedule.AutoInlineElemWise(s)
@@ -129,8 +127,8 @@ def test_auto_inline():
 
 def test_schedule_const_bound():
     n = 128
-    A = te.placeholder((n,), name='A')
-    A1 = te.compute((n,), lambda i: A[i] + 1, name='A1')
+    A = te.placeholder((n,), name="A")
+    A1 = te.compute((n,), lambda i: A[i] + 1, name="A1")
     s = te.create_schedule(A1.op)
     xo, xi = s[A1].split(A1.op.axis[0], 8)
     bounds = tvm.te.schedule.InferBound(s)
@@ -139,11 +137,11 @@ def test_schedule_const_bound():
 
 
 def test_inline_mixed():
-    n = te.var('n')
-    A = te.placeholder((n, ), name='A')
-    A1 = te.compute(A.shape, lambda *i: A(*i) + 1, name='A1')
-    A2 = te.compute(A.shape, lambda *i: A1(*i) + 2, name='A2')
-    C = te.compute((n,), lambda i: A2[i] + A1[i], name='C')
+    n = te.var("n")
+    A = te.placeholder((n,), name="A")
+    A1 = te.compute(A.shape, lambda *i: A(*i) + 1, name="A1")
+    A2 = te.compute(A.shape, lambda *i: A1(*i) + 2, name="A2")
+    C = te.compute((n,), lambda i: A2[i] + A1[i], name="C")
 
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=8)
@@ -152,9 +150,11 @@ def test_inline_mixed():
     s = s.normalize()
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+
     def check(x):
         if isinstance(x, tvm.tir.Call):
             assert x.func != A2
+
     tvm.tir.stmt_functor.post_order_visit(s[C].op.body[0], check)
 
 
@@ -166,13 +166,11 @@ def test_scan_inline1():
     s_state2 = te.placeholder((m, n))
     s_init1 = te.compute((1, n), lambda _, i: x[0, i])
     s_init2 = te.compute((1, n), lambda _, i: x[0, i])
-    s_x1 = te.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1")
-    s_x2 = te.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2")
+    s_x1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + x[t, i], name="x1")
+    s_x2 = te.compute((m, n), lambda t, i: s_state2[t - 1, i] + 1, name="x2")
     s_update1 = te.compute((m, n), lambda t, i: s_x1[t, i], "u1")
     s_update2 = te.compute((m, n), lambda t, i: s_x2[t, i], "u2")
-    res1, res2 = tvm.te.scan([s_init1, s_init2],
-                          [s_update1, s_update2],
-                          [s_state1, s_state2])
+    res1, res2 = tvm.te.scan([s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2])
     s = te.create_schedule(res1.op)
     s[s_x1].compute_inline()
     stmt = tvm.lower(s, [x, res1, res2])
@@ -186,14 +184,12 @@ def test_scan_inline2():
     s_state2 = te.placeholder((m, n))
     s_init1 = te.compute((1, n), lambda _, i: x[0, i])
     s_init2 = te.compute((1, n), lambda _, i: x[0, i])
-    s_xx = te.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="xx")
+    s_xx = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + x[t, i], name="xx")
     s_x1 = te.compute((m, n), lambda t, i: s_xx[t, i] + 1, name="x1")
-    s_x2 = te.compute((m, n), lambda t, i: s_xx[t, i] + s_state2[t-1, 2], name="x2")
+    s_x2 = te.compute((m, n), lambda t, i: s_xx[t, i] + s_state2[t - 1, 2], name="x2")
     s_update1 = te.compute((m, n), lambda t, i: s_x1[t, i], "u1")
     s_update2 = te.compute((m, n), lambda t, i: s_x2[t, i], "u2")
-    res1, res2 = tvm.te.scan([s_init1, s_init2],
-                          [s_update1, s_update2],
-                          [s_state1, s_state2])
+    res1, res2 = tvm.te.scan([s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2])
     s = te.create_schedule(res1.op)
     s[s_xx].compute_inline()
     s[s_x1].compute_inline()
@@ -202,11 +198,11 @@ def test_scan_inline2():
 
 
 def test_schedule_cache():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m, n), name='A')
-    B = te.placeholder((m, n), name='B')
-    C = te.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='C')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m, n), name="A")
+    B = te.placeholder((m, n), name="B")
+    C = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="C")
 
     s = te.create_schedule(C.op)
     AA = s.cache_read(A, "shared", readers=[C])
@@ -217,30 +213,30 @@ def test_schedule_cache():
 
 
 def test_schedule_middle_cache():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m, n), name='A')
-    B = te.placeholder((m, n), name='B')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m, n), name="A")
+    B = te.placeholder((m, n), name="B")
 
-    C = te.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='C')
-    D = te.compute((m, n), lambda i, j:  C(i , j) , name='D')
+    C = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="C")
+    D = te.compute((m, n), lambda i, j: C(i, j), name="D")
 
     s = te.create_schedule(D.op)
     AA = s.cache_read(A, "local", readers=[C])
     BB = s.cache_read(B, "local", readers=[C])
     CC = s.cache_read(C, "local", readers=[D])
     DD = s.cache_write(D, "local")
-    #s[AA].compute_at(s[CC], CC.op.axis[0])
+    # s[AA].compute_at(s[CC], CC.op.axis[0])
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
 
 def test_schedule_cache_relayout1():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m, n), name='A')
-    B = te.placeholder((m, n), name='B')
-    C = te.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='C')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m, n), name="A")
+    B = te.placeholder((m, n), name="B")
+    C = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="C")
 
     s = te.create_schedule(C.op)
     s[C].reorder(C.op.axis[1], C.op.axis[0])
@@ -250,11 +246,11 @@ def test_schedule_cache_relayout1():
 
 
 def test_schedule_cache_relayout2():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m*4, n), name='A')
-    B = te.placeholder((m*4, n), name='B')
-    C = te.compute(A.shape, lambda i, j:  A(i, j) * B(i, j), name='C')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m * 4, n), name="A")
+    B = te.placeholder((m * 4, n), name="B")
+    C = te.compute(A.shape, lambda i, j: A(i, j) * B(i, j), name="C")
     s = te.create_schedule(C.op)
     x, y = C.op.axis
     xo, xi = s[C].split(x, factor=4)
@@ -266,13 +262,12 @@ def test_schedule_cache_relayout2():
 
 
 def test_schedule_cache_relayout3():
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m*4, n), name='A')
-    B = te.placeholder((m*4, n), name='B')
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m * 4, n), name="A")
+    B = te.placeholder((m * 4, n), name="B")
     k = te.reduce_axis((0, n), "k")
-    C = te.compute((A.shape[0],),
-                    lambda i: te.sum(A(i, k) * B(i, k), axis=k), name='C')
+    C = te.compute((A.shape[0],), lambda i: te.sum(A(i, k) * B(i, k), axis=k), name="C")
     s = te.create_schedule(C.op)
     x = C.op.axis[0]
     xo, xi = s[C].split(x, factor=4)
@@ -285,11 +280,12 @@ def test_schedule_cache_relayout3():
 def test_schedule_cache_relayout4():
     def _compute(*indice):
         return A(*indice) + 1, B(*indice) / 2
-    m = te.var('m')
-    n = te.var('n')
-    A = te.placeholder((m*4, n), name='A')
-    B = te.placeholder((m*4, n), name='B')
-    C1, C2 = te.compute(A.shape, _compute, name='C')
+
+    m = te.var("m")
+    n = te.var("n")
+    A = te.placeholder((m * 4, n), name="A")
+    B = te.placeholder((m * 4, n), name="B")
+    C1, C2 = te.compute(A.shape, _compute, name="C")
     s = te.create_schedule([C1.op, C2.op])
     C1_cache, C2_cache = s.cache_write([C1, C2], "local")
     s = s.normalize()
@@ -298,46 +294,45 @@ def test_schedule_cache_relayout4():
 
 
 def intrin_gemv(m, n):
-    w = te.placeholder((m, n), name='w')
-    x = te.placeholder((n,), name='x')
-    k = te.reduce_axis((0, n), name='k')
-    z = te.compute((m,), lambda i:
-                    te.sum(w[i, k] * x[k], axis=k), name='z')
-    Wb = tvm.tir.decl_buffer(w.shape, w.dtype,
-                         name="W",
-                         offset_factor=16,
-                         strides=[te.var('ldw'), 1])
+    w = te.placeholder((m, n), name="w")
+    x = te.placeholder((n,), name="x")
+    k = te.reduce_axis((0, n), name="k")
+    z = te.compute((m,), lambda i: te.sum(w[i, k] * x[k], axis=k), name="z")
+    Wb = tvm.tir.decl_buffer(
+        w.shape, w.dtype, name="W", offset_factor=16, strides=[te.var("ldw"), 1]
+    )
+
     def intrin_func(ins, outs):
         ww, xx = ins
         zz = outs[0]
         ww_ptr = ww.access_ptr("r")
         xx_ptr = xx.access_ptr("r")
         zz_ptr = zz.access_ptr("w")
-        body = tvm.tir.call_packed(
-            "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
-        reset = tvm.tir.call_packed(
-            "fill_zero", zz_ptr, n)
-        update = tvm.tir.call_packed(
-            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        body = tvm.tir.call_packed("gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        reset = tvm.tir.call_packed("fill_zero", zz_ptr, n)
+        update = tvm.tir.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, reset, update
 
     buffer_params = {"data_alignment": 16, "offset_factor": 16}
     return te.decl_tensor_intrin(
-        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params
+    )
 
 
 def test_schedule_tensor_compute1():
     # basic: split, reorder, tile
     M, N, L = 2048, 1024, 512
     factor, rfactor = 16, 16
-    A = te.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
-    B = te.placeholder((M, L//rfactor, rfactor), name='B')
-    k = te.reduce_axis((0, L//rfactor), name='k')
+    A = te.placeholder((N // factor, L // rfactor, factor, rfactor), name="A")
+    B = te.placeholder((M, L // rfactor, rfactor), name="B")
+    k = te.reduce_axis((0, L // rfactor), name="k")
 
     gemv = intrin_gemv(factor, rfactor)
-    C = te.compute((N, M//factor, factor),
+    C = te.compute(
+        (N, M // factor, factor),
         lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
-        name='C')
+        name="C",
+    )
 
     s = te.create_schedule(C.op)
     ai, aj, ax = s[C].op.axis
@@ -351,18 +346,17 @@ def test_schedule_tensor_compute1():
 
 
 def intrin_vadd(n, cache_read=False, cache_write=False):
-    scope_ubuf = 'local'
-    dtype = 'float32'
-    x = te.placeholder((n,), dtype=dtype, name='vx')
-    y = te.placeholder((n,), dtype=dtype, name='vy')
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    scope_ubuf = "local"
+    dtype = "float32"
+    x = te.placeholder((n,), dtype=dtype, name="vx")
+    y = te.placeholder((n,), dtype=dtype, name="vy")
+    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
     s = te.create_schedule(z.op)
 
     def create_buffer(t):
-        return tvm.tir.decl_buffer(t.shape, t.dtype,
-                               name='W'+t.name,
-                               scope=scope_ubuf,
-                               offset_factor=16)
+        return tvm.tir.decl_buffer(
+            t.shape, t.dtype, name="W" + t.name, scope=scope_ubuf, offset_factor=16
+        )
 
     binds = {}
     if cache_read:
@@ -373,27 +367,34 @@ def intrin_vadd(n, cache_read=False, cache_write=False):
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
-        ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
+        ib.emit(
+            tvm.tir.call_extern(
+                outs[0].dtype,
+                "vadd",
+                ins[0].access_ptr("r"),
+                ins[1].access_ptr("r"),
+                outs[0].access_ptr("wr"),
+            )
+        )
         return ib.get()
 
-    return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={
-        "offset_factor": 16
-    })
+    return te.decl_tensor_intrin(
+        z.op, intrin_func, binds=binds, default_buffer_params={"offset_factor": 16}
+    )
 
 
 def test_schedule_tensor_compute2():
     # cache_read, cache_write
     M = 1024
     factor = 16
-    dtype = 'float32'
-    scope_ubuf = 'local'
+    dtype = "float32"
+    scope_ubuf = "local"
 
-    A = te.placeholder((M//factor, factor), name="A", dtype=dtype)
-    B = te.placeholder((M//factor, factor), name="B", dtype=dtype)
+    A = te.placeholder((M // factor, factor), name="A", dtype=dtype)
+    B = te.placeholder((M // factor, factor), name="B", dtype=dtype)
 
     vadd = intrin_vadd(factor, True, True)
-    C = te.compute((M//factor, factor),
-        lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
+    C = te.compute((M // factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name="C")
 
     s = te.create_schedule(C.op)
     AL = s.cache_read(A, scope_ubuf, C)
@@ -408,14 +409,13 @@ def test_schedule_tensor_compute3():
     # compute_at
     M = 1024
     factor = 16
-    dtype = 'float32'
-    A = te.placeholder((M//factor, factor), name="A", dtype=dtype)
-    B = te.placeholder((M//factor, factor), name="B", dtype=dtype)
-    Bi = te.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
+    dtype = "float32"
+    A = te.placeholder((M // factor, factor), name="A", dtype=dtype)
+    B = te.placeholder((M // factor, factor), name="B", dtype=dtype)
+    Bi = te.compute((M // factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
 
     vadd = intrin_vadd(factor)
-    C = te.compute((M//factor, factor),
-        lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
+    C = te.compute((M // factor, factor), lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name="C")
     s = te.create_schedule(C.op)
     s[Bi].compute_at(s[C], C.op.axis[0])
     s = s.normalize()
@@ -425,9 +425,11 @@ def test_schedule_tensor_compute3():
 
 def test_loop_dep_reduce():
     X = te.placeholder(shape=(10,), name="x")
+
     def f(n):
         rv = te.reduce_axis((0, n))
         return te.sum(X[rv], axis=rv)
+
     Y = te.compute(X.shape, f, name="y")
     s = te.create_schedule([Y.op])
     f = tvm.build(s, [X, Y])
@@ -435,19 +437,22 @@ def test_loop_dep_reduce():
 
 def test_loop_dep_reduce_cache_write():
     X = te.placeholder(shape=(10,), name="x")
+
     def f(n):
         rv = te.reduce_axis((0, n))
         init = lambda dtype: tvm.tir.Select(n > 1, tvm.tir.const(0, dtype), n.astype(dtype))
-        sum = te.comm_reducer(lambda x, y: tvm.te.max(x + y, n.astype('float32')), init, name='sum')
+        sum = te.comm_reducer(lambda x, y: tvm.te.max(x + y, n.astype("float32")), init, name="sum")
         return sum(X[rv], axis=rv)
+
     Y = te.compute(X.shape, f, name="y")
     s = te.create_schedule([Y.op])
-    s.cache_write(Y, 'local')
+    s.cache_write(Y, "local")
     f = tvm.build(s, [X, Y])
 
+
 def test_reduction_and_dummy_fuse_split():
     n = 10
-    X = te.placeholder(shape=(n,), dtype='int32', name="X")
+    X = te.placeholder(shape=(n,), dtype="int32", name="X")
     k = te.reduce_axis((0, n))
     Y = te.compute((), lambda: te.sum(X[k], k), name="Y")
     s = te.create_schedule([Y.op])
@@ -455,35 +460,39 @@ def test_reduction_and_dummy_fuse_split():
     axo, axi = s[Y.op].split(ax, nparts=20)
     f = tvm.build(s, [Y, X])
 
-    args = [tvm.nd.empty((), 'int32')] + [tvm.nd.array(np.ones((n,), dtype='int32'))]
+    args = [tvm.nd.empty((), "int32")] + [tvm.nd.array(np.ones((n,), dtype="int32"))]
     f(*args)
     assert args[0].asnumpy() == n
 
     n = 10
-    X = te.placeholder(shape=(n,), dtype='int32', name="X")
+    X = te.placeholder(shape=(n,), dtype="int32", name="X")
     k = te.reduce_axis((0, n))
     Y = te.compute((n,), lambda i: te.sum(X[k], k), name="Y")
     s = te.create_schedule([Y.op])
     ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
     f = tvm.build(s, [Y, X])
 
-    args = [tvm.nd.array(np.ones((n,), dtype='int32'))] + \
-        [tvm.nd.array(np.ones((n,), dtype='int32'))]
+    args = [tvm.nd.array(np.ones((n,), dtype="int32"))] + [
+        tvm.nd.array(np.ones((n,), dtype="int32"))
+    ]
     f(*args)
     assert np.all(args[0].asnumpy() == n)
 
+
 def test_schedule_compute_inline():
     shape = [10, 1024]
     A = te.placeholder(shape, name="A")
     B = te.placeholder(shape, name="B")
-    C = te.compute(shape, lambda *index:A(*index)+ B(*index), name = "C")
-    def _compute(*index) :
-        return C(*index) , C(*index) * B(*index)
-    F,E = te.compute(shape, _compute, name = "F")
+    C = te.compute(shape, lambda *index: A(*index) + B(*index), name="C")
+
+    def _compute(*index):
+        return C(*index), C(*index) * B(*index)
+
+    F, E = te.compute(shape, _compute, name="F")
 
     s = te.create_schedule([F.op, E.op])
     AL = s.cache_read(A, "local", [C])
-    BL = s.cache_read(B, "local", [C,E])
+    BL = s.cache_read(B, "local", [C, E])
     CL = s.cache_write(C, "local")
     FL, EL = s.cache_write([F, E], "local")
     s[C].compute_inline()
@@ -497,14 +506,14 @@ def test_local_stage_predicate():
     m = 1
     n = 3
     p = 2
-    A = tvm.te.placeholder((m, n, p), name='A')
+    A = tvm.te.placeholder((m, n, p), name="A")
     B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
     C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
     by = tvm.te.thread_axis("blockIdx.y")
     tx = tvm.te.thread_axis("threadIdx.x")
     vx = tvm.te.thread_axis("vthread")
 
-    def schedule(thread_tag, mem_scope) :
+    def schedule(thread_tag, mem_scope):
         s = tvm.te.create_schedule(C.op)
         s[B].compute_at(s[C], s[C].op.axis[0])
         s[B].set_scope(mem_scope)
@@ -519,29 +528,25 @@ def test_local_stage_predicate():
         ret = []
         tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
         return ret
+
     # local vs. threadIdx
     s = schedule(tx, "local")
     lowered_body = tvm.lower(s, [A, C])["main"].body
-    assert (not any(
-        collect_visit(lowered_body,
-                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))
     # local vs. vthread
     s = schedule(vx, "local")
     lowered_body = tvm.lower(s, [A, C])["main"].body
-    assert (not any(
-        collect_visit(lowered_body,
-                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))
     # shared vs. blockIdx
     s = schedule(by, "shared")
     lowered_body = tvm.lower(s, [A, C])["main"].body
-    assert (not any(
-        collect_visit(lowered_body,
-                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))
+
 
 def test_local_stage_predicate2():
-    A = tvm.te.placeholder((128, ), name="A")
-    B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
-    C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
+    A = tvm.te.placeholder((128,), name="A")
+    B = tvm.te.compute((128,), lambda bi: A[bi] + 1, name="B")
+    C = tvm.te.compute((128,), lambda ci: B[ci] + 2, name="C")
     s = tvm.te.create_schedule(C.op)
     AA = s.cache_read(A, "local", [B])
     s[B].set_scope("shared")
@@ -567,14 +572,12 @@ def test_local_stage_predicate2():
         return ret
 
     def visit_stmt(op):
-        if (isinstance(op, tvm.tir.Allocate)):
+        if isinstance(op, tvm.tir.Allocate):
             return op.extents[0].value == 97
         return False
 
-    assert (not any(
-        collect_visit(lowered_body,
-                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
-    assert (any(collect_visit(lowered_body, visit_stmt)))
+    assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))
+    assert any(collect_visit(lowered_body, visit_stmt))
 
 
 if __name__ == "__main__":
index a57a340..88cf66e 100644 (file)
@@ -20,11 +20,14 @@ from tvm import topi
 import numpy as np
 import tvm.testing
 
+
 def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
-    A = te.placeholder((n, l), name='A', dtype='float16')
-    B = te.placeholder((l, m), name='B', dtype='float16')
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute((n, m), lambda i, j: te.sum(A[i, k].astype('float32') * B[k, j].astype('float32'), axis=k))
+    A = te.placeholder((n, l), name="A", dtype="float16")
+    B = te.placeholder((l, m), name="B", dtype="float16")
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute(
+        (n, m), lambda i, j: te.sum(A[i, k].astype("float32") * B[k, j].astype("float32"), axis=k)
+    )
     s = te.create_schedule(C.op)
     y, x = s[C].op.axis
     k = s[C].op.reduce_axis[0]
@@ -47,7 +50,7 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
     tile_k = 16
     vthread = 1
 
-    yo, ty = s[C].split(y, tile_y*vthread)
+    yo, ty = s[C].split(y, tile_y * vthread)
     vy, ty = s[C].split(ty, tile_y)
     ty, yi = s[C].split(ty, TY)
 
@@ -69,8 +72,8 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
     s[CL].reorder(ko, kl, ki, yo, xo)
 
     s[AA].compute_at(s[CL], ko)
-    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
-    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v)
+    tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
     tx, vec = s[AA].split(tx, factor=v)
     fused = s[AA].fuse(s[AA].op.axis[0], xo)
     _, ty = s[AA].split(fused, factor=by)
@@ -80,8 +83,8 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
     s[AA].vectorize(vec)
 
     s[BB].compute_at(s[CL], ko)
-    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
-    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v)
+    tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
     tx, vec = s[BB].split(tx, factor=v)
     fused = s[BB].fuse(s[BB].op.axis[0], xo)
     _, ty = s[BB].split(fused, factor=by)
@@ -93,9 +96,9 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
     s[AL].compute_at(s[CL], kl)
     s[BL].compute_at(s[CL], kl)
 
-    s[CL].pragma(ko, 'tensor_core')
+    s[CL].pragma(ko, "tensor_core")
 
-    func = tvm.build(s, [A, B, C], 'cuda')
+    func = tvm.build(s, [A, B, C], "cuda")
 
     ctx = tvm.gpu(0)
     a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
@@ -106,16 +109,19 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
     c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
     func(a, b, c)
     evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
-    print('gemm m=%d n=%d k=%d: %f ms' % (m, n, l, evaluator(a, b, c).mean * 1e3))
+    print("gemm m=%d n=%d k=%d: %f ms" % (m, n, l, evaluator(a, b, c).mean * 1e3))
 
     c_np = np.dot(a_np, b_np)
     np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
 
+
 def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
-    A = te.placeholder((batch, n, l), name='A', dtype='float16')
-    B = te.placeholder((batch, l, m), name='B', dtype='float16')
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute((batch, n, m), lambda b, i, j: te.sum((A[b, i, k] * B[b, k, j]).astype('float32'), axis=k))
+    A = te.placeholder((batch, n, l), name="A", dtype="float16")
+    B = te.placeholder((batch, l, m), name="B", dtype="float16")
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute(
+        (batch, n, m), lambda b, i, j: te.sum((A[b, i, k] * B[b, k, j]).astype("float32"), axis=k)
+    )
     s = te.create_schedule(C.op)
     z, y, x = s[C].op.axis
     k = s[C].op.reduce_axis[0]
@@ -138,7 +144,7 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
     tile_k = 16
     vthread = 1
 
-    yo, ty = s[C].split(y, tile_y*vthread)
+    yo, ty = s[C].split(y, tile_y * vthread)
     vy, ty = s[C].split(ty, tile_y)
     ty, yi = s[C].split(ty, TY)
 
@@ -161,8 +167,8 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
     s[CL].reorder(ko, kl, ki, zo, yo, xo)
 
     s[AA].compute_at(s[CL], ko)
-    xo, xi = s[AA].split(s[AA].op.axis[2], factor=bx*v)
-    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[AA].split(s[AA].op.axis[2], factor=bx * v)
+    tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
     tx, vec = s[AA].split(tx, factor=v)
     fused = s[AA].fuse(s[AA].op.axis[1], xo)
     _, ty = s[AA].split(fused, factor=by)
@@ -172,8 +178,8 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
     s[AA].vectorize(vec)
 
     s[BB].compute_at(s[CL], ko)
-    xo, xi = s[BB].split(s[BB].op.axis[2], factor=bx*v)
-    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[BB].split(s[BB].op.axis[2], factor=bx * v)
+    tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
     tx, vec = s[BB].split(tx, factor=v)
     fused = s[BB].fuse(s[BB].op.axis[1], xo)
     _, ty = s[BB].split(fused, factor=by)
@@ -185,9 +191,9 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
     s[AL].compute_at(s[CL], kl)
     s[BL].compute_at(s[CL], kl)
 
-    s[CL].pragma(ko, 'tensor_core')
+    s[CL].pragma(ko, "tensor_core")
 
-    func = tvm.build(s, [A, B, C], 'cuda')
+    func = tvm.build(s, [A, B, C], "cuda")
 
     ctx = tvm.gpu(0)
     a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype)
@@ -198,22 +204,28 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
     c = tvm.nd.array(np.zeros((batch, n, m), dtype=C.dtype), ctx)
     func(a, b, c)
     evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
-    print('batch gemm m=%d n=%d k=%d batch=%d: %f ms' % (m, n, l, batch, evaluator(a, b, c).mean * 1e3))
+    print(
+        "batch gemm m=%d n=%d k=%d batch=%d: %f ms"
+        % (m, n, l, batch, evaluator(a, b, c).mean * 1e3)
+    )
 
     for bs in range(batch):
-      c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :])
+        c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :])
     np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
 
+
 @tvm.testing.requires_tensorcore
 def test_tensor_core_matmul():
-    tensor_core_matmul(16) #test with warp_tile 16x16x16
-    tensor_core_matmul(8) #test with warp_tile 8x32x16
-    tensor_core_matmul(32) #test with warp_tile 32x8x16
+    tensor_core_matmul(16)  # test with warp_tile 16x16x16
+    tensor_core_matmul(8)  # test with warp_tile 8x32x16
+    tensor_core_matmul(32)  # test with warp_tile 32x8x16
+
 
 @tvm.testing.requires_tensorcore
 def test_tensor_core_batch_matmul():
     tensor_core_batch_matmul()
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     test_tensor_core_matmul()
     test_tensor_core_batch_matmul()
index 8b70c82..01da1a1 100644 (file)
@@ -29,19 +29,34 @@ def intrin_wmma_load_matrix(shape, scope):
         row, col = n, l
     elif scope == "wmma.matrix_b":
         row, col = l, m
-    A = te.placeholder((row, col), name='A', dtype='float16')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col)
-    C = te.compute((row, col), lambda i, j: A[i, j], name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col)
+    A = te.placeholder((row, col), name="A", dtype="float16")
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=row * col
+    )
+    C = te.compute((row, col), lambda i, j: A[i, j], name="C")
+    BC = tvm.tir.decl_buffer(
+        C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col
+    )
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
 
         BA = ins[0]
         BC = outs[0]
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
-                                BC.data, n, m, l, BC.elem_offset // (row * col),
-                                BA.access_ptr('r'), col, 'row_major'))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_load_matrix_sync",
+                BC.data,
+                n,
+                m,
+                l,
+                BC.elem_offset // (row * col),
+                BA.access_ptr("r"),
+                col,
+                "row_major",
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
@@ -49,33 +64,65 @@ def intrin_wmma_load_matrix(shape, scope):
 
 def intrin_wmma_gemm(shape):
     n, m, l = shape
-    A = te.placeholder((n, l), name='A', dtype='float16')
-    B = te.placeholder((l, m), name='B', dtype='float16')
+    A = te.placeholder((n, l), name="A", dtype="float16")
+    B = te.placeholder((l, m), name="B", dtype="float16")
     k = te.reduce_axis((0, l), name="k")
-    C = te.compute((n, m),
-                    lambda ii, jj:
-                    te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
-                    name='C')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l)
-    BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m)
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
+    C = te.compute(
+        (n, m),
+        lambda ii, jj: te.sum(A[ii, k].astype("float") * B[k, jj].astype("float"), axis=k),
+        name="C",
+    )
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, name="BA", scope="wmma.matrix_a", data_alignment=32, offset_factor=n * l
+    )
+    BB = tvm.tir.decl_buffer(
+        B.shape, B.dtype, name="BB", scope="wmma.matrix_b", data_alignment=32, offset_factor=l * m
+    )
+    BC = tvm.tir.decl_buffer(
+        C.shape,
+        C.dtype,
+        name="BC",
+        scope="wmma.accumulator",
+        data_alignment=32,
+        offset_factor=n * m,
+    )
 
     def intrin_func(ins, outs):
         BA, BB = ins
-        BC, = outs
+        (BC,) = outs
 
         def init():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
+            ib.emit(
+                tvm.tir.call_intrin(
+                    "handle",
+                    "tir.tvm_fill_fragment",
+                    BC.data,
+                    n,
+                    m,
+                    l,
+                    BC.elem_offset // (n * m),
+                    0.0,
+                )
+            )
             return ib.get()
 
         def update():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
-                                    BC.data, BC.elem_offset // (n * m),
-                                    BA.data, BA.elem_offset // (n * l),
-                                    BB.data, BB.elem_offset // (l * m),
-                                    BC.data, BC.elem_offset // (n * m)))
+            ib.emit(
+                tvm.tir.call_intrin(
+                    "handle",
+                    "tir.tvm_mma_sync",
+                    BC.data,
+                    BC.elem_offset // (n * m),
+                    BA.data,
+                    BA.elem_offset // (n * l),
+                    BB.data,
+                    BB.elem_offset // (l * m),
+                    BC.data,
+                    BC.elem_offset // (n * m),
+                )
+            )
             return ib.get()
 
         return update(), init(), update()
@@ -85,19 +132,34 @@ def intrin_wmma_gemm(shape):
 
 def intrin_wmma_store_matrix(shape):
     n, m, l = shape
-    A = te.placeholder((n, m), name='A', dtype='float32')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
-    C = te.compute((n, m), lambda i, j: A[i, j], name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m)
+    A = te.placeholder((n, m), name="A", dtype="float32")
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, scope="wmma.accumulator", data_alignment=32, offset_factor=n * m
+    )
+    C = te.compute((n, m), lambda i, j: A[i, j], name="C")
+    BC = tvm.tir.decl_buffer(
+        C.shape, C.dtype, scope="global", data_alignment=32, offset_factor=n * m
+    )
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
 
         BA = ins[0]
         BC = outs[0]
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
-                                BA.data, n, m, l, BA.elem_offset // (n * m),
-                                BC.access_ptr('w'), m, 'row_major'))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_store_matrix_sync",
+                BA.data,
+                n,
+                m,
+                l,
+                BA.elem_offset // (n * m),
+                BC.access_ptr("w"),
+                m,
+                "row_major",
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
@@ -108,18 +170,21 @@ def test_tensor_core_batch_matmal():
     batch_size = 4
     n = 512
     m, l = n, n
-    assert (n % 32 == 0)
-    assert (m % 8 == 0)
-    assert (l % 16 == 0)
+    assert n % 32 == 0
+    assert m % 8 == 0
+    assert l % 16 == 0
     nn, mm, ll = n // 32, m // 8, l // 16
-    A = te.placeholder((batch_size, nn, ll, 32, 16), name='A', dtype='float16')
-    B = te.placeholder((batch_size, ll, mm, 16, 8), name='B', dtype='float16')
-    k1 = te.reduce_axis((0, ll), name='k1')
-    k2 = te.reduce_axis((0, 16), name='k2')
-    C = te.compute((batch_size, nn, mm, 32, 8),
-                    lambda b, i, j, ii, jj:
-                    te.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]),
-                    name='Fragment_C')
+    A = te.placeholder((batch_size, nn, ll, 32, 16), name="A", dtype="float16")
+    B = te.placeholder((batch_size, ll, mm, 16, 8), name="B", dtype="float16")
+    k1 = te.reduce_axis((0, ll), name="k1")
+    k2 = te.reduce_axis((0, 16), name="k2")
+    C = te.compute(
+        (batch_size, nn, mm, 32, 8),
+        lambda b, i, j, ii, jj: te.sum(
+            A[b, i, k1, ii, k2].astype("float") * B[b, k1, j, k2, jj].astype("float"), axis=[k1, k2]
+        ),
+        name="Fragment_C",
+    )
     s = te.create_schedule(C.op)
 
     warp_size = 32
@@ -130,18 +195,18 @@ def test_tensor_core_batch_matmal():
     warp_col_tiles = 2
     chunk = 4
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
-    AS = s.cache_read(A, 'shared', [C])
-    BS = s.cache_read(B, 'shared', [C])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
-    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
-    CF = s.cache_write(C, 'wmma.accumulator')
+    AS = s.cache_read(A, "shared", [C])
+    BS = s.cache_read(B, "shared", [C])
+    AF = s.cache_read(AS, "wmma.matrix_a", [C])
+    BF = s.cache_read(BS, "wmma.matrix_b", [C])
+    CF = s.cache_write(C, "wmma.accumulator")
 
     b, i, j, kernel_i, kernel_j = s[C].op.axis
     i, ii = s[C].split(i, factor=warp_row_tiles)
@@ -184,12 +249,12 @@ def test_tensor_core_batch_matmal():
     s[BS].bind(ty, thread_z)
     s[BS].bind(to, thread_x)
 
-    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_a'))
-    s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_b'))
+    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), "wmma.matrix_a"))
+    s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), "wmma.matrix_b"))
     s[C].tensorize(kernel_i, intrin_wmma_store_matrix((32, 8, 16)))
     s[CF].tensorize(_i, intrin_wmma_gemm((32, 8, 16)))
 
-    func = tvm.build(s, [A, B, C], 'cuda')
+    func = tvm.build(s, [A, B, C], "cuda")
 
     ctx = tvm.gpu(0)
     a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype)
@@ -199,15 +264,16 @@ def test_tensor_core_batch_matmal():
     c = tvm.nd.array(np.zeros((batch_size, nn, mm, 32, 8), dtype=C.dtype), ctx)
     func(a, b, c)
     evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
-    print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3))
+    print("gemm with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))
 
     if VERIFY:
         func(a, b, c)
         a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
         b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
         c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
-        np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4)
-
+        np.testing.assert_allclose(
+            c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4
+        )
 
 
 @tvm.testing.requires_tensorcore
@@ -234,70 +300,87 @@ def test_tensor_core_batch_conv():
     chunk = 2
 
     # Input feature map: (N, H, W, IC, n, ic)
-    data_shape = (batch_size // block_size,
-                  height,
-                  width,
-                  in_channels // block_size,
-                  block_size,
-                  block_size)
+    data_shape = (
+        batch_size // block_size,
+        height,
+        width,
+        in_channels // block_size,
+        block_size,
+        block_size,
+    )
     # Kernel: (H, W, IC, OC, ic, oc)
-    kernel_shape = (kernel_h,
-                    kernel_w,
-                    in_channels // block_size,
-                    out_channels // block_size,
-                    block_size,
-                    block_size)
+    kernel_shape = (
+        kernel_h,
+        kernel_w,
+        in_channels // block_size,
+        out_channels // block_size,
+        block_size,
+        block_size,
+    )
 
     # Output feature map: (N, H, W, OC, n, oc)
-    output_shape = (batch_size // block_size,
-                    height,
-                    width,
-                    out_channels // block_size,
-                    block_size,
-                    block_size)
-
-    assert (batch_size % block_size == 0)
-    assert (in_channels % block_size == 0)
-    assert (out_channels % block_size == 0)
-
-    kh = te.reduce_axis((0, kernel_h), name='kh')
-    kw = te.reduce_axis((0, kernel_w), name='kw')
-    ic = te.reduce_axis((0, in_channels // block_size), name='ic')
-    ii = te.reduce_axis((0, block_size), name='ii')
+    output_shape = (
+        batch_size // block_size,
+        height,
+        width,
+        out_channels // block_size,
+        block_size,
+        block_size,
+    )
+
+    assert batch_size % block_size == 0
+    assert in_channels % block_size == 0
+    assert out_channels % block_size == 0
+
+    kh = te.reduce_axis((0, kernel_h), name="kh")
+    kw = te.reduce_axis((0, kernel_w), name="kw")
+    ic = te.reduce_axis((0, in_channels // block_size), name="ic")
+    ii = te.reduce_axis((0, block_size), name="ii")
 
     # Algorithm
-    A = te.placeholder(data_shape, name='A', dtype="float16")
-    W = te.placeholder(kernel_shape, name='W', dtype="float16")
+    A = te.placeholder(data_shape, name="A", dtype="float16")
+    W = te.placeholder(kernel_shape, name="W", dtype="float16")
     Apad = te.compute(
-        (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
-         block_size),
+        (
+            batch_size // block_size,
+            height + 2 * pad_h,
+            width + 2 * pad_w,
+            in_channels // block_size,
+            block_size,
+            block_size,
+        ),
         lambda n, h, w, i, nn, ii: tvm.tir.if_then_else(
-            tvm.tir.all(h >= pad_h, h - pad_h < height,
-                    w >= pad_w, w - pad_w < width),
-            A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.tir.const(0., "float16")),
-        name='Apad')
-    Conv = te.compute(output_shape,
-                       lambda n, h, w, o, nn, oo: te.sum(
-                           Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
-                           W[kh, kw, ic, o, ii, oo].astype("float32"),
-                           axis=[ic, kh, kw, ii]),
-                       name="Conv")
+            tvm.tir.all(h >= pad_h, h - pad_h < height, w >= pad_w, w - pad_w < width),
+            A[n, h - pad_h, w - pad_w, i, nn, ii],
+            tvm.tir.const(0.0, "float16"),
+        ),
+        name="Apad",
+    )
+    Conv = te.compute(
+        output_shape,
+        lambda n, h, w, o, nn, oo: te.sum(
+            Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32")
+            * W[kh, kw, ic, o, ii, oo].astype("float32"),
+            axis=[ic, kh, kw, ii],
+        ),
+        name="Conv",
+    )
 
     s = te.create_schedule(Conv.op)
     s[Apad].compute_inline()
 
-    AS = s.cache_read(Apad, 'shared', [Conv])
-    WS = s.cache_read(W, 'shared', [Conv])
-    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
-    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
-    ConvF = s.cache_write(Conv, 'wmma.accumulator')
+    AS = s.cache_read(Apad, "shared", [Conv])
+    WS = s.cache_read(W, "shared", [Conv])
+    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
+    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
+    ConvF = s.cache_write(Conv, "wmma.accumulator")
 
-    block_x = te.thread_axis('blockIdx.x')
-    block_y = te.thread_axis('blockIdx.y')
-    block_z = te.thread_axis('blockIdx.z')
-    thread_x = te.thread_axis('threadIdx.x')
-    thread_y = te.thread_axis('threadIdx.y')
-    thread_z = te.thread_axis('threadIdx.z')
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
 
     nc, hc, wc, oc, nnc, ooc = Conv.op.axis
     block_k = s[Conv].fuse(hc, wc)
@@ -342,12 +425,12 @@ def test_tensor_core_batch_conv():
     s[WS].bind(to, thread_x)
     s[WS].vectorize(ti)
 
-    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a'))
-    s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b'))
+    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), "wmma.matrix_a"))
+    s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), "wmma.matrix_b"))
     s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16)))
     s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16)))
 
-    func = tvm.build(s, [A, W, Conv], 'cuda')
+    func = tvm.build(s, [A, W, Conv], "cuda")
 
     ctx = tvm.gpu(0)
     a_np = np.random.uniform(size=data_shape).astype(A.dtype)
@@ -356,20 +439,25 @@ def test_tensor_core_batch_conv():
     w = tvm.nd.array(w_np, ctx)
     c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
     evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
-    print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
+    print("conv2d with tensor core: %f ms" % (evaluator(a, w, c).mean * 1e3))
 
     if VERIFY:
         func(a, w, c)
         a_np = a_np.transpose(0, 4, 1, 2, 3, 5).reshape(batch_size, height, width, in_channels)
-        w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(kernel_h, kernel_w, in_channels, out_channels)
-        c_np = c.asnumpy().transpose((0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width, out_channels)
-        c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype),
-                                   w_np.astype(Conv.dtype),
-                                   (stride_h, stride_w),
-                                   (pad_h, pad_w)).astype(Conv.dtype)
+        w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(
+            kernel_h, kernel_w, in_channels, out_channels
+        )
+        c_np = (
+            c.asnumpy()
+            .transpose((0, 4, 1, 2, 3, 5))
+            .reshape(batch_size, height, width, out_channels)
+        )
+        c_std = conv2d_nhwc_python(
+            a_np.astype(Conv.dtype), w_np.astype(Conv.dtype), (stride_h, stride_w), (pad_h, pad_w)
+        ).astype(Conv.dtype)
         np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_tensor_core_batch_matmal()
     test_tensor_core_batch_conv()
index 5152235..83a5d30 100644 (file)
 import tvm
 from tvm import te
 
+
 def intrin_vadd(n):
-    x = te.placeholder((n,), name='vx')
-    y = te.placeholder((n,), name='vy')
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    x = te.placeholder((n,), name="vx")
+    y = te.placeholder((n,), name="vy")
+    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+
     def intrin_func(ins, outs):
         xx, yy = ins
         zz = outs[0]
         return tvm.tir.call_packed("vadd", xx, yy, zz)
+
     buffer_params = {"offset_factor": 16}
     return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)
 
+
 def intrin_gemv(m, n):
-    w = te.placeholder((m, n), name='w')
-    x = te.placeholder((n,), name='x')
-    k = te.reduce_axis((0, n), name='k')
-    z = te.compute((m,), lambda i:
-                    te.sum(w[i, k] * x[k], axis=k), name='z')
-    Wb = tvm.tir.decl_buffer(w.shape, w.dtype,
-                         name="W",
-                         offset_factor=16,
-                         strides=[te.var('ldw'), 1])
+    w = te.placeholder((m, n), name="w")
+    x = te.placeholder((n,), name="x")
+    k = te.reduce_axis((0, n), name="k")
+    z = te.compute((m,), lambda i: te.sum(w[i, k] * x[k], axis=k), name="z")
+    Wb = tvm.tir.decl_buffer(
+        w.shape, w.dtype, name="W", offset_factor=16, strides=[te.var("ldw"), 1]
+    )
+
     def intrin_func(ins, outs):
         ww, xx = ins
         zz = outs[0]
         ww_ptr = ww.access_ptr("r")
         xx_ptr = xx.access_ptr("r")
         zz_ptr = zz.access_ptr("w")
-        body = tvm.tir.call_packed(
-            "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
-        reset = tvm.tir.call_packed(
-            "fill_zero", zz_ptr, n)
-        update = tvm.tir.call_packed(
-            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        body = tvm.tir.call_packed("gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        reset = tvm.tir.call_packed("fill_zero", zz_ptr, n)
+        update = tvm.tir.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, reset, update
 
     buffer_params = {"offset_factor": 16, "data_alignment": 16}
     return te.decl_tensor_intrin(
-        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params
+    )
+
 
 def intrin_gemv_no_reset(m, n):
-    w = te.placeholder((m, n), name='w')
-    x = te.placeholder((n,), name='x')
-    k = te.reduce_axis((0, n), name='k')
-    z = te.compute((m,), lambda i:
-                    te.sum(w[i, k] * x[k], axis=k), name='z')
-    Wb = tvm.tir.decl_buffer(w.shape, w.dtype,
-                         name="W",
-                         offset_factor=16,
-                         strides=[te.var('ldw'), 1])
+    w = te.placeholder((m, n), name="w")
+    x = te.placeholder((n,), name="x")
+    k = te.reduce_axis((0, n), name="k")
+    z = te.compute((m,), lambda i: te.sum(w[i, k] * x[k], axis=k), name="z")
+    Wb = tvm.tir.decl_buffer(
+        w.shape, w.dtype, name="W", offset_factor=16, strides=[te.var("ldw"), 1]
+    )
+
     def intrin_func(ins, outs):
         ww, xx = ins
         zz = outs[0]
         ww_ptr = ww.access_ptr("r")
         xx_ptr = xx.access_ptr("r")
         zz_ptr = zz.access_ptr("w")
-        body = tvm.tir.call_packed(
-            "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
-        update = tvm.tir.call_packed(
-            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        body = tvm.tir.call_packed("gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        update = tvm.tir.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, None, update
 
-
     buffer_params = {"offset_factor": 16, "data_alignment": 16}
     return te.decl_tensor_intrin(
-        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params
+    )
 
 
 def test_tensorize_vadd():
     m = 128
-    x = te.placeholder((m,), name='x')
-    y = te.placeholder((m,), name='y')
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    x = te.placeholder((m,), name="x")
+    y = te.placeholder((m,), name="y")
+    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
 
     def check(factor):
         s = te.create_schedule(z.op)
@@ -105,9 +104,7 @@ def test_tensorize_vadd():
         fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
         body = fmatch(s[z], out_dom, in_dom, vadd)
         ana = tvm.arith.Analyzer()
-        assert tvm.ir.structural_equal(
-            ana.simplify(body[0]),
-            ana.simplify(vadd.op.body[0]))
+        assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(vadd.op.body[0]))
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [x, y, z])
 
@@ -118,11 +115,10 @@ def test_tensorize_matmul():
     n = 1024
     m = n
     l = n
-    A = te.placeholder((n, l), name='A')
-    B = te.placeholder((m, l), name='B')
-    k = te.reduce_axis((0, l), name='k')
-    C = te.compute((n, m), lambda i, j:
-                    te.sum(B[j, k] * A[i, k], axis=k), name='C')
+    A = te.placeholder((n, l), name="A")
+    B = te.placeholder((m, l), name="B")
+    k = te.reduce_axis((0, l), name="k")
+    C = te.compute((n, m), lambda i, j: te.sum(B[j, k] * A[i, k], axis=k), name="C")
 
     def check(factor):
         s = te.create_schedule(C.op)
@@ -141,13 +137,10 @@ def test_tensorize_matmul():
         body = fmatch(s[C], out_dom, in_dom, gemv)
         ana = tvm.arith.Analyzer()
 
-        assert tvm.ir.structural_equal(
-            ana.simplify(body[0]),
-            ana.simplify(gemv.op.body[0]))
+        assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0]))
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [A, B, C])
 
-
     def check_rfactor(factor, rfactor):
         s = te.create_schedule(C.op)
         x, y = C.op.axis
@@ -167,9 +160,7 @@ def test_tensorize_matmul():
         fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
         body = fmatch(s[C], out_dom, in_dom, gemv)
         ana = tvm.arith.Analyzer()
-        assert tvm.ir.structural_equal(
-            ana.simplify(body[0]),
-            ana.simplify(gemv.op.body[0]))
+        assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0]))
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [A, B, C])
 
@@ -192,9 +183,7 @@ def test_tensorize_matmul():
         fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
         body = fmatch(s[C], out_dom, in_dom, gemv)
         ana = tvm.arith.Analyzer()
-        assert tvm.ir.structural_equal(
-            ana.simplify(body[0]),
-            ana.simplify(gemv.op.body[0]))
+        assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0]))
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [A, B, C])
 
@@ -218,9 +207,7 @@ def test_tensorize_matmul():
         fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
         body = fmatch(s[C], out_dom, in_dom, gemv)
         ana = tvm.arith.Analyzer()
-        assert tvm.ir.structural_equal(
-            ana.simplify(body[0]),
-            ana.simplify(gemv.op.body[0]))
+        assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0]))
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [A, B, C])
 
@@ -229,6 +216,7 @@ def test_tensorize_matmul():
     check_rfactor_no_reset(16, 16)
     check_rfactor_no_reset_multi_reduction(16, 16)
 
+
 # This tests whether algorithm and intrinsics expressions are simplified
 # as much as possible first and then checked for equality. See Issue #696
 def test_tensorize_op():
@@ -238,29 +226,27 @@ def test_tensorize_op():
     def op_intrin():
         bh = 9
         bw = 9
-        x = te.placeholder((5, 5), name='A')
-        y = te.compute((bh, bw),
-                        lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])
+        x = te.placeholder((5, 5), name="A")
+        y = te.compute((bh, bw), lambda i, j: x[idxd(j, 3) + idxm(i, 3), idxm(j, 3) + idxd(i, 3)])
 
         def intrin_func(ins, outs):
-            xx, = ins
+            (xx,) = ins
             zz = outs[0]
             return tvm.tir.call_packed("op", xx, zz)
 
-        return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={
-            "offset_factor": 2
-        })
+        return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={"offset_factor": 2})
 
-    A = te.placeholder((5, 5), name='A')
-    B = te.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
+    A = te.placeholder((5, 5), name="A")
+    B = te.compute((9, 9), lambda i, j: A[idxd(j, 3) + idxm(i, 3), idxm(j, 3) + idxd(i, 3)])
     bt = op_intrin()
     s = te.create_schedule(B.op)
 
-    x,y = B.op.axis
+    x, y = B.op.axis
     s[B].tensorize(x, bt)
     s = s.normalize()
     tvm.lower(s, [A, B])
 
+
 # This test asserts that tensorize does not have any effect on
 # TensorComputeOp operations
 def test_tensorize_tensor_compute_op():
@@ -268,18 +254,24 @@ def test_tensorize_tensor_compute_op():
     # is a loop of another intrinsic called "vadd"
     def intrin_multivadd(n):
         n_a = te.var("n_a")
-        Ab = tvm.tir.decl_buffer((n, ), "float32", strides=[n_a])
+        Ab = tvm.tir.decl_buffer((n,), "float32", strides=[n_a])
 
         n_b = te.var("n_b")
-        Bb = tvm.tir.decl_buffer((n, ), "float32", strides=[n_b])
+        Bb = tvm.tir.decl_buffer((n,), "float32", strides=[n_b])
 
         n_c = te.var("n_c")
-        Cb = tvm.tir.decl_buffer((n, ), "float32", strides=[n_c])
-
-        z = te.compute((n,), lambda i: tvm.tir.call_extern("float32", 'vadd',
-                                                        Ab.access_ptr("w", offset=n_a*i),
-                                                        Bb.access_ptr("r", offset=n_b*i),
-                                                        Cb.access_ptr("r", offset=n_c*i)))
+        Cb = tvm.tir.decl_buffer((n,), "float32", strides=[n_c])
+
+        z = te.compute(
+            (n,),
+            lambda i: tvm.tir.call_extern(
+                "float32",
+                "vadd",
+                Ab.access_ptr("w", offset=n_a * i),
+                Bb.access_ptr("r", offset=n_b * i),
+                Cb.access_ptr("r", offset=n_c * i),
+            ),
+        )
 
         # replace the pattern with the multivadd call. I need to figure out
         # how to pass it the right parameters.
@@ -289,36 +281,42 @@ def test_tensorize_tensor_compute_op():
         return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
 
     def intrin_vadd(n):
-        dtype = 'float32'
-        x = te.placeholder((n,), dtype=dtype, name='vx')
-        y = te.placeholder((n,), dtype=dtype, name='vy')
-        z = te.compute(x.shape, lambda i: x[i] + y[i], name='z')
+        dtype = "float32"
+        x = te.placeholder((n,), dtype=dtype, name="vx")
+        y = te.placeholder((n,), dtype=dtype, name="vy")
+        z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
         s = te.create_schedule(z.op)
 
         def create_buffer(t):
-            return tvm.tir.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16)
+            return tvm.tir.decl_buffer(t.shape, t.dtype, name="W" + t.name, offset_factor=16)
 
         def intrin_func(ins, outs):
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern("float32", 'vadd',
-                                    ins[0].access_ptr("r"), ins[1].access_ptr('r'),
-                                    outs[0].access_ptr('wr')))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "float32",
+                    "vadd",
+                    ins[0].access_ptr("r"),
+                    ins[1].access_ptr("r"),
+                    outs[0].access_ptr("wr"),
+                )
+            )
             return ib.get()
-        return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
-                                                                y: create_buffer(y),
-                                                                z: create_buffer(z)})
+
+        return te.decl_tensor_intrin(
+            z.op, intrin_func, binds={x: create_buffer(x), y: create_buffer(y), z: create_buffer(z)}
+        )
 
     # cache_read, cache_write
     M = 1024
     factor = 16
-    dtype = 'float32'
+    dtype = "float32"
 
-    A = te.placeholder((M//factor, factor), name="A", dtype=dtype)
-    B = te.placeholder((M//factor, factor), name="B", dtype=dtype)
+    A = te.placeholder((M // factor, factor), name="A", dtype=dtype)
+    B = te.placeholder((M // factor, factor), name="B", dtype=dtype)
 
     vadd = intrin_vadd(factor)
-    C = te.compute((M//factor, factor),
-                    lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
+    C = te.compute((M // factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name="C")
 
     s = te.create_schedule(C.op)
     multivadd = intrin_multivadd(64)
@@ -332,7 +330,6 @@ def test_tensorize_tensor_compute_op():
     assert stmt.body.body.loop_var.name == C.op.axis[0].var.name
 
 
-
 if __name__ == "__main__":
     test_tensorize_vadd()
     test_tensorize_matmul()
index 6cfc0b1..6e88a12 100644 (file)
@@ -19,6 +19,7 @@ import tvm
 from tvm import te
 from tvm import te
 
+
 @tvm.te.tag_scope(tag="conv")
 def compute_conv(data, weight):
     N, IC, H, W = data.shape
@@ -26,27 +27,34 @@ def compute_conv(data, weight):
     OH = H - KH + 1
     OW = W - KW + 1
 
-    ic = te.reduce_axis((0, IC), name='ic')
-    dh = te.reduce_axis((0, KH), name='dh')
-    dw = te.reduce_axis((0, KW), name='dw')
+    ic = te.reduce_axis((0, IC), name="ic")
+    dh = te.reduce_axis((0, KH), name="dh")
+    dw = te.reduce_axis((0, KW), name="dw")
+
+    return te.compute(
+        (N, OC, OH, OW),
+        lambda i, oc, h, w: te.sum(
+            data[i, ic, h + dh, w + dw] * weight[oc, ic, dh, dw], axis=[ic, dh, dw]
+        ),
+    )
 
-    return te.compute((N, OC, OH, OW), lambda i, oc, h, w: \
-        te.sum(data[i, ic, h+dh, w+dw] * weight[oc, ic, dh, dw],
-                axis=[ic, dh, dw]))
 
 def test_with():
-    n = te.size_var('n')
-    m = te.size_var('m')
-    l = te.size_var('l')
+    n = te.size_var("n")
+    m = te.size_var("m")
+    l = te.size_var("l")
 
-    A = te.placeholder((n, l), name='A')
-    B = te.placeholder((m, l), name='B')
+    A = te.placeholder((n, l), name="A")
+    B = te.placeholder((m, l), name="B")
     with tvm.te.tag_scope(tag="gemm"):
-        k = te.reduce_axis((0, l), name='k')
-        C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k),
-                        attrs={"hello" : 1, "arr": [10, 12]})
+        k = te.reduce_axis((0, l), name="k")
+        C = te.compute(
+            (n, m),
+            lambda i, j: te.sum(A[i, k] * B[j, k], axis=k),
+            attrs={"hello": 1, "arr": [10, 12]},
+        )
 
-    assert C.op.tag == 'gemm'
+    assert C.op.tag == "gemm"
     assert "hello" in C.op.attrs
     assert "xx" not in C.op.attrs
     assert C.op.attrs["hello"].value == 1
@@ -58,31 +66,32 @@ def test_with():
 
 
 def test_decorator():
-    n = te.size_var('n')
-    c = te.size_var('c')
-    h = te.size_var('h')
-    w = te.size_var('w')
-    kh = te.size_var('kh')
-    kw = te.size_var('kw')
-
-    A = te.placeholder((n, c, h, w), name='A')
-    B = te.placeholder((c, c, kh, kw), name='B')
+    n = te.size_var("n")
+    c = te.size_var("c")
+    h = te.size_var("h")
+    w = te.size_var("w")
+    kh = te.size_var("kh")
+    kw = te.size_var("kw")
+
+    A = te.placeholder((n, c, h, w), name="A")
+    B = te.placeholder((c, c, kh, kw), name="B")
     C = compute_conv(A, B)
-    assert C.op.tag == 'conv'
+    assert C.op.tag == "conv"
     assert len(C.op.attrs) == 0
 
+
 def test_nested():
-    n = te.size_var('n')
-    c = te.size_var('c')
-    h = te.size_var('h')
-    w = te.size_var('w')
-    kh = te.size_var('kh')
-    kw = te.size_var('kw')
-
-    A = te.placeholder((n, c, h, w), name='A')
-    B = te.placeholder((c, c, kh, kw), name='B')
+    n = te.size_var("n")
+    c = te.size_var("c")
+    h = te.size_var("h")
+    w = te.size_var("w")
+    kh = te.size_var("kh")
+    kw = te.size_var("kw")
+
+    A = te.placeholder((n, c, h, w), name="A")
+    B = te.placeholder((c, c, kh, kw), name="B")
     try:
-        with te.tag_scope(tag='conv'):
+        with te.tag_scope(tag="conv"):
             C = compute_conv(A, B)
         assert False
     except ValueError:
index 3d22c0f..a3936fa 100644 (file)
@@ -19,55 +19,58 @@ import numpy as np
 from tvm import te
 from tvm.topi.nn.pooling import pool
 
+
 def test_tensor():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.placeholder((n, l), name='B')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.placeholder((n, l), name="B")
     T = te.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
     print(T)
     print(T.op.body)
-    assert(tuple(T.shape) == (m, n, l))
-    assert(isinstance(A.op, tvm.te.PlaceholderOp))
-    assert(A == A)
-    assert(T.op.output(0) == T)
-    assert(T.op.output(0).__hash__() == T.__hash__())
-    d = {T.op.output(0) : 1}
-    assert(d[T] == 1)
-    assert(T[0][0][0].astype('float16').dtype == 'float16')
+    assert tuple(T.shape) == (m, n, l)
+    assert isinstance(A.op, tvm.te.PlaceholderOp)
+    assert A == A
+    assert T.op.output(0) == T
+    assert T.op.output(0).__hash__() == T.__hash__()
+    d = {T.op.output(0): 1}
+    assert d[T] == 1
+    assert T[0][0][0].astype("float16").dtype == "float16"
 
 
 def test_rank_zero():
-    m = te.size_var('m')
-    A = te.placeholder((m,), name='A')
-    scale = te.placeholder((), name='s')
+    m = te.size_var("m")
+    A = te.placeholder((m,), name="A")
+    scale = te.placeholder((), name="s")
     k = te.reduce_axis((0, m), name="k")
-    T = te.compute((), lambda : te.sum(A[k] * scale(), axis=k))
+    T = te.compute((), lambda: te.sum(A[k] * scale(), axis=k))
     print(T)
     print(T.op.body)
-    assert(tuple(T.shape) == ())
+    assert tuple(T.shape) == ()
 
 
 def test_conv1d():
-    n = te.size_var('n')
-    A = te.placeholder((n+2), name='A')
+    n = te.size_var("n")
+    A = te.placeholder((n + 2), name="A")
+
     def computeB(ii):
         i = ii + 1
-        return A[i-1] + A[i] + A[i+1]
+        return A[i - 1] + A[i] + A[i + 1]
+
     B = te.compute(n, computeB)
 
 
 def test_tensor_slice():
-    n = te.size_var('n')
+    n = te.size_var("n")
     A = te.compute((n, n), lambda i, j: 1)
     B = te.compute((n,), lambda i: A[0][i] + A[0][i])
 
 
 def test_tensor_reduce_multi_axis():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     k1 = te.reduce_axis((0, n), "k")
     k2 = te.reduce_axis((0, m), "k")
     C = te.compute((1,), lambda _: te.sum(A[k1, k2], axis=(k1, k2)))
@@ -75,38 +78,41 @@ def test_tensor_reduce_multi_axis():
 
 
 def test_tensor_comm_reducer():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A = te.placeholder((m, n), name='A')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = te.placeholder((m, n), name="A")
     k = te.reduce_axis((0, n), "k")
-    mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t))
+    mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t))
     C = te.compute((m,), lambda i: mysum(A[i, k], axis=k))
 
+
 def test_tensor_comm_reducer_overload():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t))
+    m = te.size_var("m")
+    n = te.size_var("n")
+    mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t))
     sum_res = mysum(m, n)
 
+
 def test_tensor_reduce():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.placeholder((n, l), name='B')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.placeholder((n, l), name="B")
     T = te.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
     rv = te.reduce_axis((0, A.shape[1]), "k")
-    C = te.compute((m, n), lambda i, j: te.sum(T(i, j, rv+1), axis=rv))
+    C = te.compute((m, n), lambda i, j: te.sum(T(i, j, rv + 1), axis=rv))
     # json load save
     C_json = tvm.ir.save_json(C)
     C_loaded = tvm.ir.load_json(C_json)
-    assert(isinstance(C_loaded, te.tensor.Tensor))
-    assert(str(C_loaded) == str(C))
+    assert isinstance(C_loaded, te.tensor.Tensor)
+    assert str(C_loaded) == str(C)
+
 
 def test_tensor_compute1():
     m = 1024
     factor = 16
-    dtype = 'float32'
+    dtype = "float32"
 
     def intrin_vadd(n):
         x = te.placeholder((n,))
@@ -115,24 +121,30 @@ def test_tensor_compute1():
 
         def intrin_func(ins, outs):
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
+            ib.emit(
+                tvm.tir.call_extern(
+                    outs[0].dtype,
+                    "vadd",
+                    ins[0].access_ptr("r"),
+                    ins[1].access_ptr("r"),
+                    outs[0].access_ptr("wr"),
+                )
+            )
             return ib.get()
 
-        return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={
-            "offset_factor": n
-        })
+        return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})
 
     vadd = intrin_vadd(factor)
 
-    A = te.placeholder((m//factor, factor), name="A", dtype=dtype)
-    B = te.placeholder((m//factor, factor), name="B", dtype=dtype)
-    C = te.compute((m//factor, factor),
-          lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
+    A = te.placeholder((m // factor, factor), name="A", dtype=dtype)
+    B = te.placeholder((m // factor, factor), name="B", dtype=dtype)
+    C = te.compute((m // factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
 
     s = te.create_schedule(C.op)
     stmt = tvm.lower(s, [A, B, C])["main"].body
     assert isinstance(stmt.body, tvm.tir.Evaluate)
 
+
 def test_tensor_compute2():
     M = 2048
     N = 1024
@@ -140,7 +152,7 @@ def test_tensor_compute2():
     factor = 16
     factor1 = 32
     factor2 = 32
-    dtype = 'float32'
+    dtype = "float32"
 
     def intrin_gemm(m, n, l):
         k = te.reduce_axis((0, l))
@@ -153,40 +165,44 @@ def test_tensor_compute2():
             x_ptr = ins[0].access_ptr("r")
             y_ptr = ins[1].access_ptr("r")
             z_ptr = outs[0].access_ptr("w")
-            body = tvm.tir.call_packed(
-                "gemv", x_ptr, y_ptr, z_ptr, m, n, l)
-            reset = tvm.tir.call_packed(
-                "fill_zero", z_ptr, m, n)
-            update = tvm.tir.call_packed(
-                "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
+            body = tvm.tir.call_packed("gemv", x_ptr, y_ptr, z_ptr, m, n, l)
+            reset = tvm.tir.call_packed("fill_zero", z_ptr, m, n)
+            update = tvm.tir.call_packed("gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
             return body, reset, update
 
-        return te.decl_tensor_intrin(z.op, intrin_func,
-                                     default_buffer_params={"offset_factor": n})
+        return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})
 
     vgemm = intrin_gemm(factor1, factor2, factor)
 
-    A = te.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
-    B = te.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
-    k = te.reduce_axis((0, L//factor), name='k')
-    C = te.compute((M//factor1, N//factor2, factor1, factor2),
-          lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
+    A = te.placeholder((M // factor1, L // factor, factor1, factor), name="A", dtype=dtype)
+    B = te.placeholder((N // factor2, L // factor, factor2, factor), name="B", dtype=dtype)
+    k = te.reduce_axis((0, L // factor), name="k")
+    C = te.compute(
+        (M // factor1, N // factor2, factor1, factor2),
+        lambda i, j: vgemm(
+            A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k
+        ),
+    )
 
     s = te.create_schedule(C.op)
     stmt = tvm.lower(s, [A, B, C])["main"].body
     assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
     assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
 
+
 def test_tensor_scan():
     m = te.size_var("m")
     n = te.size_var("n")
     x = te.placeholder((m, n))
     s = te.placeholder((m, n))
-    res = tvm.te.scan(te.compute((1, n), lambda _, i: x[0, i]),
-                   te.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]),
-                   s)
+    res = tvm.te.scan(
+        te.compute((1, n), lambda _, i: x[0, i]),
+        te.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]),
+        s,
+    )
     assert tuple(res.shape) == (m, n)
 
+
 def test_scan_multi_out():
     m = te.size_var("m")
     n = te.size_var("n")
@@ -196,63 +212,65 @@ def test_scan_multi_out():
     s2 = te.placeholder((m, n))
     s1_init = te.compute((1, n), lambda _, i: x1[0, i])
     s2_init = te.compute((1, n), lambda _, i: x2[0, i])
-    s1_update = te.compute((m, n), lambda t, i: s1[t-1, i] + s2[t-1, i] + x1[t, i])
-    s2_update = te.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
-
-    r0, r1 = tvm.te.scan([s1_init, s2_init],
-                      [s1_update, s2_update],
-                      [s1, s2])
-    assert(r0.value_index == 0)
-    assert(r1.value_index == 1)
+    s1_update = te.compute((m, n), lambda t, i: s1[t - 1, i] + s2[t - 1, i] + x1[t, i])
+    s2_update = te.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i])
+
+    r0, r1 = tvm.te.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2])
+    assert r0.value_index == 0
+    assert r1.value_index == 1
     json_str = tvm.ir.save_json(r0.op)
     zz = tvm.ir.load_json(json_str)
     assert isinstance(zz, tvm.te.ScanOp)
 
+
 def test_extern():
-    m = te.size_var('m')
-    A = te.placeholder((m,), name='A')
+    m = te.size_var("m")
+    A = te.placeholder((m,), name="A")
 
     def extern_func(ins, outs):
-        assert(isinstance(ins[0], tvm.te.schedule.Buffer))
+        assert isinstance(ins[0], tvm.te.schedule.Buffer)
         return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, m)
+
     B = te.extern((m,), [A], extern_func)
-    assert(tuple(B.shape) == (m,))
+    assert tuple(B.shape) == (m,)
 
 
 def test_extern_multi_out():
-    m = te.size_var('m')
-    A = te.placeholder((m,), name='A')
+    m = te.size_var("m")
+    A = te.placeholder((m,), name="A")
     B = te.compute((m,), lambda i: A[i] * 10)
 
     def extern_func(ins, outs):
-        assert(isinstance(ins[0], tvm.te.schedule.Buffer))
-        return tvm.tir.call_packed(
-            "myadd", ins[0].data, outs[0].data, outs[1].data, m)
+        assert isinstance(ins[0], tvm.te.schedule.Buffer)
+        return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, outs[1].data, m)
+
     res = te.extern([A.shape, A.shape], [A, B], extern_func)
-    assert(len(res) == 2)
-    assert(res[1].value_index == 1)
+    assert len(res) == 2
+    assert res[1].value_index == 1
+
 
 def test_tuple_inputs():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A0 = te.placeholder((m, n), name='A0')
-    A1 = te.placeholder((m, n), name='A1')
-    T0, T1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A0 = te.placeholder((m, n), name="A0")
+    A1 = te.placeholder((m, n), name="A1")
+    T0, T1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name="T")
     s = te.create_schedule(T0.op)
 
     for i in range(len(T0.shape)):
-      assert(T0.shape[i] == T1.shape[i])
-    assert(T0.op == T1.op)
-    assert(T0.value_index == 0)
-    assert(T1.value_index == 1)
+        assert T0.shape[i] == T1.shape[i]
+    assert T0.op == T1.op
+    assert T0.value_index == 0
+    assert T1.value_index == 1
+
 
 def test_tuple_with_different_deps():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    A0 = te.placeholder((m, n), name='A1')
-    A1 = te.placeholder((m, n), name='A2')
-    B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
-    C = te.compute((m, n), lambda i, j: B0[i, j] + 4, name='C')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A0 = te.placeholder((m, n), name="A1")
+    A1 = te.placeholder((m, n), name="A2")
+    B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name="B")
+    C = te.compute((m, n), lambda i, j: B0[i, j] + 4, name="C")
 
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=10)
@@ -262,9 +280,13 @@ def test_tuple_with_different_deps():
     stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
 
     def get_B1_realize(x):
-        if isinstance(x, tvm.tir.ProducerRealize) and \
-           x.producer.op == B1.op and x.producer.value_index == 1:
+        if (
+            isinstance(x, tvm.tir.ProducerRealize)
+            and x.producer.op == B1.op
+            and x.producer.value_index == 1
+        ):
             ret.append(x)
+
     ret = []
     tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize)
 
@@ -272,56 +294,61 @@ def test_tuple_with_different_deps():
 
 
 def test_tensor_inputs():
-    x = te.placeholder((1,), name='x')
+    x = te.placeholder((1,), name="x")
     y = te.compute(x.shape, lambda i: x[i] + x[i])
     assert tuple(y.op.input_tensors) == (x,)
 
 
 def test_tensor_pool():
     def intrin_pool():
-        A = te.placeholder((64, 16, 16), name='A')
-        kh = te.reduce_axis((0, 3), name='kh')
-        kw = te.reduce_axis((0, 3), name='kw')
-        P = te.compute((64, 14, 14),
-                        lambda c, oh, ow: tvm.te.max(A[c, oh + kh, ow + kw],
-                                                  axis=[kh, kw]),
-                        name='p')
+        A = te.placeholder((64, 16, 16), name="A")
+        kh = te.reduce_axis((0, 3), name="kh")
+        kw = te.reduce_axis((0, 3), name="kw")
+        P = te.compute(
+            (64, 14, 14),
+            lambda c, oh, ow: tvm.te.max(A[c, oh + kh, ow + kw], axis=[kh, kw]),
+            name="p",
+        )
 
         def intrin_func(ins, outs):
             dinp = ins[0]
             dout = outs[0]
             return tvm.tir.call_packed("op", dinp, dout)
 
-        return te.decl_tensor_intrin(P.op, intrin_func,
-                                     default_buffer_params={"offset_factor": 1})
+        return te.decl_tensor_intrin(P.op, intrin_func, default_buffer_params={"offset_factor": 1})
 
-    A = te.placeholder((1, 64, 16, 16), name='A')
-    P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0),
-             pool_type='max')
+    A = te.placeholder((1, 64, 16, 16), name="A")
+    P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0), pool_type="max")
     s = te.create_schedule(P.op)
     _, oh, _, _ = P.op.axis
     intrin = intrin_pool()
     s[P].tensorize(oh, intrin)
     tvm.lower(s, [A, P])
 
+
 def test_tensor_scalar_mixed():
     # test te with tensor and scalar
-    a = np.array(np.random.uniform(size=(10,)), 'float32')
-    b = np.array(np.random.uniform(size=(1))[0], 'float32')
-    c = np.array(np.random.uniform(size=(10,)), 'float32')
+    a = np.array(np.random.uniform(size=(10,)), "float32")
+    b = np.array(np.random.uniform(size=(1))[0], "float32")
+    c = np.array(np.random.uniform(size=(10,)), "float32")
 
     @tvm.register_func("tvm.test_tensor_scalar_scale")
     def my_scale(tensor, scalar, out):
         out_np = tensor.asnumpy() * scalar.asnumpy()
         tvm.nd.array(out_np).copyto(out)
 
-    A = te.placeholder(a.shape, name='A')
-    B = te.placeholder(b.shape, name='B')
-    C = te.extern(a.shape, [A, B],
-                  lambda ins, outs: tvm.tir.call_packed(
-                  "tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]), name="C")
+    A = te.placeholder(a.shape, name="A")
+    B = te.placeholder(b.shape, name="B")
+    C = te.extern(
+        a.shape,
+        [A, B],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]
+        ),
+        name="C",
+    )
     s = te.create_schedule(C.op)
-    f = tvm.build(s, [A, B, C], 'llvm')
+    f = tvm.build(s, [A, B, C], "llvm")
 
     ta = tvm.nd.array(a)
     tb = tvm.nd.array(b)
@@ -332,25 +359,29 @@ def test_tensor_scalar_mixed():
 
 def test_tensor_scalar():
     # test te with scalar shape
-    a = np.array(np.random.uniform(size=(1))[0], 'float32')
-    b = np.array(0.0, 'float32')
+    a = np.array(np.random.uniform(size=(1))[0], "float32")
+    b = np.array(0.0, "float32")
 
     @tvm.register_func("tvm.test_tensor_scalar_copy")
     def mycopy(x, y):
         x.copyto(y)
 
-    A = te.placeholder(a.shape, name='A')
-    B = te.extern(a.shape, [A],
-                  lambda ins, outs: tvm.tir.call_packed(
-                  "tvm.test_tensor_scalar_copy", ins[0], outs[0]), name="B")
+    A = te.placeholder(a.shape, name="A")
+    B = te.extern(
+        a.shape,
+        [A],
+        lambda ins, outs: tvm.tir.call_packed("tvm.test_tensor_scalar_copy", ins[0], outs[0]),
+        name="B",
+    )
     s = te.create_schedule(B.op)
-    f = tvm.build(s, [A, B], 'llvm')
+    f = tvm.build(s, [A, B], "llvm")
 
     ta = tvm.nd.array(a)
     tb = tvm.nd.array(b)
     f(ta, tb)
     tvm.testing.assert_allclose(ta.asnumpy(), tb.asnumpy())
 
+
 if __name__ == "__main__":
     test_rank_zero()
     test_tensor_inputs()
index 577cdfb..a833915 100644 (file)
@@ -25,9 +25,9 @@ import tvm.testing
 
 def test_operator_type_and_tags():
     k = 1
-    n = te.var('n')
-    A = te.placeholder((), name='A')
-    B = te.placeholder((10, 5), name='B')
+    n = te.var("n")
+    A = te.placeholder((), name="A")
+    B = te.placeholder((10, 5), name="B")
     B1 = B[0]
     B2 = B[0, 0]
 
@@ -70,10 +70,10 @@ def test_combination():
     k = 3
     n = 5
     m = 10
-    x = te.var('x')
-    A = te.placeholder((n, m), name='A')
-    B = te.placeholder((n, m), name='B')
-    C = te.placeholder((n, m), name='C')
+    x = te.var("x")
+    A = te.placeholder((n, m), name="A")
+    B = te.placeholder((n, m), name="B")
+    C = te.placeholder((n, m), name="C")
     D = k + A - B * C + x
     s = te.create_schedule(D.op)
     foo = tvm.build(s, [x, A, B, C, D], "llvm")
@@ -84,15 +84,14 @@ def test_combination():
     c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx)
     d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
     foo(x, a, b, c, d)
-    tvm.testing.assert_allclose(
-        d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x)
+    tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x)
 
 
 def verify_tensor_scalar_bop(shape, typ="add"):
     """Verify non-constant Tensor and scalar binary operations."""
-    sh = [te.size_var('n%d' % i) for i in range(0, len(shape))]
-    k = te.var('k')
-    A = te.placeholder(sh, name='A')
+    sh = [te.size_var("n%d" % i) for i in range(0, len(shape))]
+    k = te.var("k")
+    A = te.placeholder(sh, name="A")
     if typ == "add":
         B = A + k
     elif typ == "sub":
@@ -132,7 +131,7 @@ def verify_tensor_scalar_bop(shape, typ="add"):
         foo(a_nd, b_nd, k_, *shape)
         tvm.testing.assert_allclose(b_nd.asnumpy(), b_npy, rtol=1e-5)
 
-    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
+    for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan"]:
         check_device(device)
 
 
@@ -159,8 +158,7 @@ def verify_broadcast_bop(lhs_shape, rhs_shape, typ="add"):
         with tvm.target.Target(device):
             s = tvm.topi.testing.get_broadcast_schedule(device)(C)
 
-        foo = tvm.build(s, [A, B, C], device,
-                        name="broadcast_binary" + "_" + typ)
+        foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
         lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
         rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
         if typ == "add":
@@ -180,15 +178,16 @@ def verify_broadcast_bop(lhs_shape, rhs_shape, typ="add"):
         out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
         for _ in range(1):
             foo(lhs_nd, rhs_nd, out_nd)
-        tvm.testing.assert_allclose(
-            out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
+        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1e-4, atol=1e-4)
 
-    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
+    for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan"]:
         check_device(device)
 
 
 @tvm.testing.uses_gpu
-def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, stride, padding, typ="add"):
+def verify_conv2d_scalar_bop(
+    batch, in_size, in_channel, num_filter, kernel, stride, padding, typ="add"
+):
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not tvm.testing.device_enabled(device):
@@ -196,15 +195,13 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
             return
         print("Running on target: %s" % device)
 
-        conv2d_nchw, schedule_conv2d_nchw = tvm.topi.testing.get_conv2d_nchw_implement(
-            device)
+        conv2d_nchw, schedule_conv2d_nchw = tvm.topi.testing.get_conv2d_nchw_implement(device)
 
         k = 10.0
         dilation = (1, 1)
         with tvm.target.Target(device):
-            A = te.placeholder((batch, in_channel, in_size, in_size), name='A')
-            W = te.placeholder(
-                (num_filter, in_channel, kernel, kernel), name='W')
+            A = te.placeholder((batch, in_channel, in_size, in_size), name="A")
+            W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W")
             B = conv2d_nchw(A, W, stride, padding, dilation, A.dtype)
             if typ == "add":
                 C = B + k
@@ -220,14 +217,10 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
 
         foo = tvm.build(s, [A, W, B, C], device, name="conv2d_scalar_" + typ)
 
-        a_npy = np.random.uniform(
-            size=get_const_tuple(A.shape)).astype(A.dtype)
-        w_npy = np.random.uniform(
-            size=get_const_tuple(W.shape)).astype(W.dtype)
-        b_npy = tvm.topi.testing.conv2d_nchw_python(
-            a_npy, w_npy, stride, padding)
-        c_npy = np.random.uniform(
-            size=get_const_tuple(B.shape)).astype(B.dtype)
+        a_npy = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
+        w_npy = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
+        b_npy = tvm.topi.testing.conv2d_nchw_python(a_npy, w_npy, stride, padding)
+        c_npy = np.random.uniform(size=get_const_tuple(B.shape)).astype(B.dtype)
         if typ == "add":
             c_npy = b_npy + k
         elif typ == "sub":
@@ -244,10 +237,9 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
         b_nd = tvm.nd.array(np.empty(b_npy.shape).astype(B.dtype), ctx)
         c_nd = tvm.nd.array(np.empty(c_npy.shape).astype(C.dtype), ctx)
         foo(a_nd, w_nd, b_nd, c_nd)
-        tvm.testing.assert_allclose(
-            c_nd.asnumpy(), c_npy, rtol=1E-4, atol=1E-4)
+        tvm.testing.assert_allclose(c_nd.asnumpy(), c_npy, rtol=1e-4, atol=1e-4)
 
-    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
+    for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan"]:
         check_device(device)
 
 
index 4231f48..7ea9321 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_verify_compute():
-  n = te.size_var("n")
-  m = te.size_var("m")
-  A = te.placeholder((n, m), name='A')
-  k = te.reduce_axis((0, m), "k")
-  k_ = te.reduce_axis((0, m-1), "k_")
-  f1 = lambda i: te.sum(A[i, k], axis=k)
-  f2 = lambda i: A[i,0] + 1
-  f3 = lambda i: te.sum(A[i, k], axis=k) + 1
-  f4 = lambda i: A[i,0] * (te.sum(A[i, k], axis=k) + 1)
-  f5 = lambda i: (te.sum(A[i, k], axis=k), A[i,0] + 1)
-  f6 = lambda i: (te.sum(A[i, k], axis=k), te.sum(A[i, k_], axis=k_))
+    n = te.size_var("n")
+    m = te.size_var("m")
+    A = te.placeholder((n, m), name="A")
+    k = te.reduce_axis((0, m), "k")
+    k_ = te.reduce_axis((0, m - 1), "k_")
+    f1 = lambda i: te.sum(A[i, k], axis=k)
+    f2 = lambda i: A[i, 0] + 1
+    f3 = lambda i: te.sum(A[i, k], axis=k) + 1
+    f4 = lambda i: A[i, 0] * (te.sum(A[i, k], axis=k) + 1)
+    f5 = lambda i: (te.sum(A[i, k], axis=k), A[i, 0] + 1)
+    f6 = lambda i: (te.sum(A[i, k], axis=k), te.sum(A[i, k_], axis=k_))
 
-  #
-  # Valid compute
-  try:
-    B = te.compute((n,), f1, name="B")
-  except tvm._ffi.base.TVMError as ex:
-    assert False
+    #
+    # Valid compute
+    try:
+        B = te.compute((n,), f1, name="B")
+    except tvm._ffi.base.TVMError as ex:
+        assert False
 
-  #
-  # Valid compute
-  try:
-    B = te.compute((n,), f2, name="B")
-  except tvm._ffi.base.TVMError as ex:
-    assert False
+    #
+    # Valid compute
+    try:
+        B = te.compute((n,), f2, name="B")
+    except tvm._ffi.base.TVMError as ex:
+        assert False
 
-  #
-  # Invalid compute with non top level reduction
-  try:
-    B = te.compute((n,), f3, name="B")
-    assert False
-  except tvm._ffi.base.TVMError as ex:
-    pass
+    #
+    # Invalid compute with non top level reduction
+    try:
+        B = te.compute((n,), f3, name="B")
+        assert False
+    except tvm._ffi.base.TVMError as ex:
+        pass
 
-  #
-  # Invalid compute with non top level reduction
-  try:
-    B = te.compute((n,), f4, name="B")
-    assert False
-  except tvm._ffi.base.TVMError as ex:
-    pass
+    #
+    # Invalid compute with non top level reduction
+    try:
+        B = te.compute((n,), f4, name="B")
+        assert False
+    except tvm._ffi.base.TVMError as ex:
+        pass
 
-  #
-  # Invalid compute with reduction and non-reduction batch ops
-  try:
-    B0, B1 = te.compute((n,), f5, name="B")
-    assert False
-  except tvm._ffi.base.TVMError as ex:
-    pass
+    #
+    # Invalid compute with reduction and non-reduction batch ops
+    try:
+        B0, B1 = te.compute((n,), f5, name="B")
+        assert False
+    except tvm._ffi.base.TVMError as ex:
+        pass
 
-  #
-  # Invalid compute with unequal batch reduction ops
-  try:
-    B0, B1 = te.compute((n,), f6, name="B")
-    assert False
-  except tvm._ffi.base.TVMError as ex:
-    pass
+    #
+    # Invalid compute with unequal batch reduction ops
+    try:
+        B0, B1 = te.compute((n,), f6, name="B")
+        assert False
+    except tvm._ffi.base.TVMError as ex:
+        pass
 
 
 if __name__ == "__main__":
-  test_verify_compute()
+    test_verify_compute()
index c7be325..10e3914 100644 (file)
@@ -19,18 +19,19 @@ import tvm
 from tvm import te
 import tvm.testing
 
+
 def test_check_numerical_grads():
     # Functions and their derivatives
     functions = [
-        lambda x: (x*x*x, 3*x*x),
-        lambda x: (x*x, 2*x),
+        lambda x: (x * x * x, 3 * x * x),
+        lambda x: (x * x, 2 * x),
         lambda x: (np.abs(x), np.sign(x)),
-        lambda x: (np.log(np.abs(x)), 1/x),
-        lambda x: (np.sqrt(np.abs(x)), np.sign(x)/(2*np.sqrt(np.abs(x)))),
-        lambda x: (1/x, -1/(x*x)),
-        lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
-        lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
-        lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
+        lambda x: (np.log(np.abs(x)), 1 / x),
+        lambda x: (np.sqrt(np.abs(x)), np.sign(x) / (2 * np.sqrt(np.abs(x)))),
+        lambda x: (1 / x, -1 / (x * x)),
+        lambda x: (np.sign(np.sin(1 / x)), np.zeros_like(x)),
+        lambda x: (x * np.sin(1 / x), np.sin(1 / x) - np.cos(1 / x) / x),
+        lambda x: (np.sin(1 / x), -np.cos(1 / x) / (x * x)),
         lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))),
     ]
 
@@ -61,15 +62,15 @@ def test_check_numerical_grads():
 
             # Same thing but with keyword arguments
             func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
-            grads = {'x': f1(x_input)[1], 'y': f2(y_input)[1]}
+            grads = {"x": f1(x_input)[1], "y": f2(y_input)[1]}
 
-            tvm.testing.check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)
+            tvm.testing.check_numerical_grads(func_forw, {"x": x_input, "y": y_input}, grads)
 
     def _noise1(x, atol=1e-2, rtol=0.1):
         # We go in random direction using twice the original tolerance to be sure this
         # results in an error
         sqrt_n = np.sqrt(float(np.prod(x.shape)))
-        tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
+        tol = 2 * (np.linalg.norm(x) * rtol + atol * sqrt_n)
         noise = np.random.normal(size=x.shape)
         noise = tol * noise / np.linalg.norm(noise)
         return x + noise
@@ -77,7 +78,7 @@ def test_check_numerical_grads():
     def _noise2(x, atol=1e-2, rtol=0.1):
         # This noise affects just a single component
         sqrt_n = np.sqrt(float(np.prod(x.shape)))
-        tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
+        tol = 2 * (np.linalg.norm(x) * rtol + atol * sqrt_n)
         n = np.random.randint(np.prod(x.shape))
         noise = np.zeros_like(x)
         noise.reshape(-1)[n] = tol
@@ -100,10 +101,10 @@ def test_check_numerical_grads():
                 raise AssertionError("tvm.testing.check_numerical_grads didn't raise an exception")
 
             func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
-            grads = {'x': _noise2(f1(x_input)[1]), 'y': _noise2(f2(y_input)[1])}
+            grads = {"x": _noise2(f1(x_input)[1]), "y": _noise2(f2(y_input)[1])}
 
             try:
-                tvm.testing.check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)
+                tvm.testing.check_numerical_grads(func_forw, {"x": x_input, "y": y_input}, grads)
             except AssertionError as e:
                 pass
             else:
@@ -112,4 +113,3 @@ def test_check_numerical_grads():
 
 if __name__ == "__main__":
     test_tvm.testing.check_numerical_grads()
-
index 86a1ed7..c3ae417 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_equal_expr():
-    x = te.var('x')
-    y = te.var('y')
+    x = te.var("x")
+    y = te.var("y")
 
     def func1():
         return x + y + 1
index 449a462..940355e 100644 (file)
@@ -18,11 +18,12 @@ import pytest
 import tvm
 from tvm import te
 
+
 @pytest.mark.xfail
 def test_loop_dependent_allocate():
     N = te.size_var("N")
-    A = te.placeholder((2*N,), "float32", "A")
-    C = te.compute((N, ), lambda i: A[2*i] + A[i+1], name='C')
+    A = te.placeholder((2 * N,), "float32", "A")
+    C = te.compute((N,), lambda i: A[2 * i] + A[i + 1], name="C")
     s = te.create_schedule(C.op)
     AA = s.cache_read(A, "local", [C])
     s[AA].compute_at(s[C], s[C].op.axis[0])
@@ -30,5 +31,6 @@ def test_loop_dependent_allocate():
     # referencing undefined variable
     tvm.lower(s, [A, C])
 
+
 if __name__ == "__main__":
     test_loop_dependent_allocate()
index ec3c762..7a23c8a 100644 (file)
@@ -19,10 +19,12 @@ import tvm
 from tvm import te
 import tvm.testing
 
+
 def get_verify_pass(valid, **kwargs):
     def _fverify(f, *_):
         valid[0] = tvm.tir.analysis.verify_gpu_code(f, kwargs)
         return f
+
     return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
 
 
@@ -35,8 +37,8 @@ def test_shared_memory():
         tvm_type = tvm.runtime.DataType(dtype)
         type_size = tvm_type.bits // 8 * tvm_type.lanes
 
-        A = te.placeholder((N,), name='A', dtype=dtype)
-        B = te.compute((N, ), lambda i: A[i], name='B')
+        A = te.placeholder((N,), name="A", dtype=dtype)
+        B = te.compute((N,), lambda i: A[i], name="B")
 
         s = te.create_schedule([B.op])
         AA = s.cache_read(A, "shared", [B])
@@ -48,33 +50,55 @@ def test_shared_memory():
         # shared memory usage: M * sizeof(dtype) Bytes
         # thread usage: M
 
-        for target in ['opencl', 'cuda']:
+        for target in ["opencl", "cuda"]:
             if not tvm.testing.device_enabled(target):
                 continue
             valid = [None]
-            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (2, get_verify_pass(valid,
-                                    max_shared_memory_per_block=type_size * M - 1,
-                                    max_threads_per_block=M))]}):
+            with tvm.transform.PassContext(
+                config={
+                    "tir.add_lower_pass": [
+                        (
+                            2,
+                            get_verify_pass(
+                                valid,
+                                max_shared_memory_per_block=type_size * M - 1,
+                                max_threads_per_block=M,
+                            ),
+                        )
+                    ]
+                }
+            ):
                 tvm.build(s, [A, B], target)
             assert not valid[0]
 
-            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (2, get_verify_pass(valid,
-                                    max_shared_memory_per_block=type_size * M,
-                                    max_threads_per_block=M))]}):
+            with tvm.transform.PassContext(
+                config={
+                    "tir.add_lower_pass": [
+                        (
+                            2,
+                            get_verify_pass(
+                                valid,
+                                max_shared_memory_per_block=type_size * M,
+                                max_threads_per_block=M,
+                            ),
+                        )
+                    ]
+                }
+            ):
                 tvm.build(s, [A, B], target)
             assert valid[0]
-    check_shared_memory('float32')
-    check_shared_memory('int8x4')
+
+    check_shared_memory("float32")
+    check_shared_memory("int8x4")
+
 
 @tvm.testing.requires_gpu
 def test_local_memory():
     N = 1024
     M = 128
 
-    A = te.placeholder((N,), name='A', dtype='float32')
-    B = te.compute((N, ), lambda i: A[i], name='B')
+    A = te.placeholder((N,), name="A", dtype="float32")
+    B = te.compute((N,), lambda i: A[i], name="B")
 
     s = te.create_schedule([B.op])
     AA = s.cache_read(A, "local", [B])
@@ -85,82 +109,136 @@ def test_local_memory():
     # local memory usage: M * 4B
     # thread usage: M
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_local_memory_per_block=4 * M - 1,
-                                max_threads_per_block=1))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_local_memory_per_block=4 * M - 1, max_threads_per_block=1
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_local_memory_per_block=4 * M,
-                                max_threads_per_block=1))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_local_memory_per_block=4 * M, max_threads_per_block=1
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert valid[0]
 
+
 @tvm.testing.requires_gpu
 def test_num_thread():
     N = 1024
     M = 128
 
-    A = te.placeholder((N,), name='A', dtype='float32')
-    B = te.compute((N, ), lambda i: A[i], name='B')
+    A = te.placeholder((N,), name="A", dtype="float32")
+    B = te.compute((N,), lambda i: A[i], name="B")
 
     s = te.create_schedule([B.op])
     o, i = s[B].split(s[B].op.axis[0], M)
 
-    s[B].bind(o, te.thread_axis('threadIdx.x'))
+    s[B].bind(o, te.thread_axis("threadIdx.x"))
     s[B].bind(i, te.thread_axis("threadIdx.y"))
 
     # shared memory usage: 0
     # thread usage: N
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N - 1))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_shared_memory_per_block=0, max_threads_per_block=N - 1
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_shared_memory_per_block=0, max_threads_per_block=N
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert valid[0]
 
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N,
-                                max_thread_y=M-1))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid,
+                            max_shared_memory_per_block=0,
+                            max_threads_per_block=N,
+                            max_thread_y=M - 1,
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N,
-                                max_thread_y=M))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid,
+                            max_shared_memory_per_block=0,
+                            max_threads_per_block=N,
+                            max_thread_y=M,
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert valid[0]
 
+
 @tvm.testing.requires_gpu
 def test_multiple_kernels():
     N = 1024
 
-    A = te.placeholder((N, N), name='A')
+    A = te.placeholder((N, N), name="A")
     B = te.compute((N, N), lambda i, j: A[i, j])
     C = te.compute((N, N), lambda i, j: B[i, j])
 
@@ -172,31 +250,48 @@ def test_multiple_kernels():
     # shared memory usage: 0
     # thread usage: N
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N - 1))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_shared_memory_per_block=0, max_threads_per_block=N - 1
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, C], target)
         assert not valid[0]
 
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-            (2, get_verify_pass(valid,
-                                max_shared_memory_per_block=0,
-                                max_threads_per_block=N))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [
+                    (
+                        2,
+                        get_verify_pass(
+                            valid, max_shared_memory_per_block=0, max_threads_per_block=N
+                        ),
+                    )
+                ]
+            }
+        ):
             tvm.build(s, [A, C], target)
         assert valid[0]
 
+
 @tvm.testing.requires_gpu
 def test_wrong_bind():
     N = 1024
 
-    A = te.placeholder((N, N-1), name='A')
-    B = te.compute((N, N-1), lambda i, j: A[i, j])
+    A = te.placeholder((N, N - 1), name="A")
+    B = te.compute((N, N - 1), lambda i, j: A[i, j])
 
     s = te.create_schedule([B.op])
 
@@ -204,21 +299,25 @@ def test_wrong_bind():
     s[B].bind(s[B].op.axis[0], te.thread_axis("threadIdx.x"))
     s[B].bind(s[B].op.axis[1], te.thread_axis("threadIdx.x"))
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
+        with tvm.transform.PassContext(
+            config={
+                "tir.add_lower_pass": [(2, get_verify_pass(valid, max_threads_per_block=N * N))]
+            }
+        ):
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
+
 @tvm.testing.requires_gpu
 def test_vectorize():
     N = 1024
 
-    A = te.placeholder((N, N), name='A')
+    A = te.placeholder((N, N), name="A")
     B = te.compute((N, N), lambda i, j: A[i, j])
 
     s = te.create_schedule([B.op])
@@ -230,21 +329,23 @@ def test_vectorize():
     s[B].bind(jo, te.thread_axis("threadIdx.x"))
     s[B].vectorize(ji)
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
-        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (2, get_verify_pass(valid, max_vector_bytes=16))]}):
+        with tvm.transform.PassContext(
+            config={"tir.add_lower_pass": [(2, get_verify_pass(valid, max_vector_bytes=16))]}
+        ):
             tvm.lower(s, [A, B])
         assert not valid[0]
 
+
 @tvm.testing.requires_gpu
 def test_vthread():
     N = 1024
 
-    A = te.placeholder((N, 16), name='A')
+    A = te.placeholder((N, 16), name="A")
     B = te.compute((N, 16), lambda i, j: A[i, j])
 
     s = te.create_schedule([B.op])
@@ -252,20 +353,22 @@ def test_vthread():
     s[B].bind(s[B].op.axis[0], te.thread_axis("blockIdx.x"))
     s[B].bind(s[B].op.axis[1], te.thread_axis("vthread"))
 
-    for target in ['opencl', 'cuda']:
+    for target in ["opencl", "cuda"]:
         if not tvm.testing.device_enabled(target):
             continue
 
         valid = [None]
 
         for phase in [1, 2]:
-            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (phase, get_verify_pass(valid, max_vthread=16))]}):
+            with tvm.transform.PassContext(
+                config={"tir.add_lower_pass": [(phase, get_verify_pass(valid, max_vthread=16))]}
+            ):
                 tvm.build(s, [A, B], target)
             assert valid[0]
 
-            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
-                (phase, get_verify_pass(valid, max_vthread=15))]}):
+            with tvm.transform.PassContext(
+                config={"tir.add_lower_pass": [(phase, get_verify_pass(valid, max_vthread=15))]}
+            ):
                 tvm.build(s, [A, B], target)
             assert not valid[0]
 
index 7ec3fde..4c89ff1 100644 (file)
@@ -31,7 +31,7 @@ other_devices = ["llvm", "ext_dev"]
 @tvm.testing.uses_gpu
 def test_verify_memory_all_bind():
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     B = te.compute(A.shape, lambda i: A[i] + 1.0, name="B")
 
     # B is bound to threads.
@@ -45,7 +45,8 @@ def test_verify_memory_all_bind():
     for dev_type in gpu_devices + other_devices:
         if tvm.testing.device_enabled(dev_type):
             binded_mod = tvm.tir.transform.Apply(
-                lambda f: f.with_attr("target", tvm.target.Target(dev_type)))(mod)
+                lambda f: f.with_attr("target", tvm.target.Target(dev_type))
+            )(mod)
             tvm.tir.transform.VerifyMemory()(binded_mod)
 
 
@@ -55,7 +56,7 @@ def test_verify_memory_all_bind():
 @tvm.testing.uses_gpu
 def test_verify_memory_not_bind():
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     B = te.compute(A.shape, lambda i: A[i] + 1.0, name="B")
 
     # B is not bound to threads.
@@ -66,14 +67,16 @@ def test_verify_memory_not_bind():
     for dev_type in gpu_devices:
         if tvm.testing.device_enabled(dev_type):
             binded_mod = tvm.tir.transform.Apply(
-                lambda f: f.with_attr("target", tvm.target.Target(dev_type)))(mod)
+                lambda f: f.with_attr("target", tvm.target.Target(dev_type))
+            )(mod)
             with pytest.raises(RuntimeError):
                 tvm.tir.transform.VerifyMemory()(binded_mod)
 
     for dev_type in other_devices:
         if tvm.testing.device_enabled(dev_type):
             binded_mod = tvm.tir.transform.Apply(
-                lambda f: f.with_attr("target", tvm.target.Target(dev_type)))(mod)
+                lambda f: f.with_attr("target", tvm.target.Target(dev_type))
+            )(mod)
             tvm.tir.transform.VerifyMemory()(binded_mod)
 
 
@@ -83,7 +86,7 @@ def test_verify_memory_not_bind():
 @tvm.testing.uses_gpu
 def test_verify_memory_partially_bind():
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
+    A = te.placeholder((n,), name="A")
     B = te.compute(A.shape, lambda i: A[i] + 1.0, name="B")
     C = te.compute(B.shape, lambda i: B[i] + 2.0, name="C")
     D = te.compute(C.shape, lambda i: C[i] + 2.0, name="D")
@@ -94,19 +97,21 @@ def test_verify_memory_partially_bind():
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
-    mod = tvm. lower(s, [A, B, C, D])
+    mod = tvm.lower(s, [A, B, C, D])
 
     for dev_type in gpu_devices:
         if tvm.testing.device_enabled(dev_type):
             binded_mod = tvm.tir.transform.Apply(
-                lambda f: f.with_attr("target", tvm.target.Target(dev_type)))(mod)
+                lambda f: f.with_attr("target", tvm.target.Target(dev_type))
+            )(mod)
             with pytest.raises(RuntimeError):
                 tvm.tir.transform.VerifyMemory()(binded_mod)
 
     for dev_type in other_devices:
         if tvm.testing.device_enabled(dev_type):
             binded_mod = tvm.tir.transform.Apply(
-                lambda f: f.with_attr("target", tvm.target.Target(dev_type)))(mod)
+                lambda f: f.with_attr("target", tvm.target.Target(dev_type))
+            )(mod)
             tvm.tir.transform.VerifyMemory()(binded_mod)
 
 
index 57dd826..a6db37f 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_verify_ssa():
-    x = te.var('x')
+    x = te.var("x")
     y = te.var()
     z = tvm.tir.Evaluate(x + y)
-    assert(tvm.tir.analysis.verify_ssa(
-        tvm.tir.PrimFunc([x, y],z)))
+    assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z))
+
+    assert not tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z)))
 
-    assert(not tvm.tir.analysis.verify_ssa(
-        tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z))))
 
 def test_verify_weak_let_ssa():
-    x = te.var('x')
+    x = te.var("x")
     z1 = tvm.tir.Let(x, 1, x + 1)
     z2 = tvm.tir.Let(x, 2, x + 2)
 
-    assert(tvm.tir.analysis.verify_ssa(
-        tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1))))
-    assert(not tvm.tir.analysis.verify_ssa(
-        tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2))))
+    assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1)))
+    assert not tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2)))
+
 
 if __name__ == "__main__":
     test_verify_ssa()
index f7e8f2f..3007509 100644 (file)
@@ -19,10 +19,11 @@ from tvm import te
 from tvm.tir import Buffer
 import numpy as np
 
+
 def test_buffer():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    l = te.size_var('l')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    l = te.size_var("l")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     Bb = tvm.tir.decl_buffer((n, l), "float32")
 
@@ -32,9 +33,9 @@ def test_buffer():
 
 
 def test_buffer_access_ptr():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
+    m = te.size_var("m")
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1])
     aptr = Ab.access_ptr("rw")
     assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m)
     assert aptr.args[0].dtype == Ab.dtype
@@ -44,78 +45,89 @@ def test_buffer_access_ptr():
 
 
 def test_buffer_access_ptr_offset():
-    m = te.size_var('m')
-    n = te.size_var('n')
+    m = te.size_var("m")
+    n = te.size_var("n")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     aptr = Ab.access_ptr("rw", offset=100)
     tvm.testing.assert_prim_expr_equal(aptr.args[2], 100)
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
-    v = te.size_var('int32')
+    v = te.size_var("int32")
     aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
     tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v)
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
-    aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v))
-    tvm.testing.assert_prim_expr_equal(aptr.args[2], tvm.tir.call_extern('int32', "test_call", 200 + v))
+    aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern("int32", "test_call", 100 + 100 + v))
+    tvm.testing.assert_prim_expr_equal(
+        aptr.args[2], tvm.tir.call_extern("int32", "test_call", 200 + v)
+    )
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
 
 
 def test_buffer_access_ptr_extent():
-    m = te.size_var('m')
-    n = te.size_var('n')
+    m = te.size_var("m")
+    n = te.size_var("n")
     Ab = tvm.tir.decl_buffer((m, n), "float32")
     aptr = Ab.access_ptr("rw")
     assert tvm.ir.structural_equal(aptr.args[3], m * n)
     aptr = Ab.access_ptr("rw", offset=100)
     assert tvm.ir.structural_equal(aptr.args[3], m * n - 100)
-    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
+    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1])
     aptr = Ab.access_ptr("rw", offset=100)
     assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
 
 
 def test_buffer_vload():
-    m = te.size_var('m')
-    n = te.size_var('n')
+    m = te.size_var("m")
+    n = te.size_var("n")
     Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
     load = Ab.vload([2, 3])
     tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103)
 
 
 def test_buffer_index_merge_mult_mod():
-    m = te.size_var('m')
-    n = te.size_var('n')
-    s = te.size_var('s')
-    k0 = te.size_var('k0')
-    k1 = te.size_var('k1')
+    m = te.size_var("m")
+    n = te.size_var("n")
+    s = te.size_var("s")
+    k0 = te.size_var("k0")
+    k1 = te.size_var("k1")
     A = tvm.tir.decl_buffer((m, n), "float32")
     A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
+
     def assert_simplified_equal(index_simplified, index_direct):
-        assert tvm.ir.structural_equal(index_simplified, index_direct),\
-        "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
+        assert tvm.ir.structural_equal(
+            index_simplified, index_direct
+        ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct)
+
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
     # Test Case1
     index_simplified = A_stride.vload(
-        (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
+        (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1)
+    )
     index_direct = A_stride.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
 
     # Test Case2
-    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
-                                idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
+    index_simplified = A.vload(
+        (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1))
+    )
     index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case3
-    index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
-                                idxd(idxm(k0, idxd(k1, s)), n),
-                                idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
-                                idxm(idxm(k0, idxd(k1, s)), n)))
+    index_simplified = A.vload(
+        (
+            idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxd(idxm(k0, idxd(k1, s)), n),
+            idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxm(idxm(k0, idxd(k1, s)), n),
+        )
+    )
     index_direct = A.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case4 (not able to simplify)
-    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
-                                idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
-    index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
-                            (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
+    index_simplified = A.vload(
+        (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))
+    )
+    index_direct = A.vload(
+        (0, idxd(idxm(k0, idxd(k1, s)), n) * n + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
+    )
     assert_simplified_equal(index_simplified, index_direct)
 
 
@@ -125,17 +137,17 @@ def test_buffer_broadcast():
     n0, n1, n2 = te.size_var("n0"), te.size_var("n1"), te.size_var("n2")
     o0, o1, o2 = te.size_var("o0"), te.size_var("o1"), te.size_var("o2")
 
-    A = te.placeholder((m0, m1, m2), name='A')
-    B = te.placeholder((n0, n1, n2), name='B')
+    A = te.placeholder((m0, m1, m2), name="A")
+    B = te.placeholder((n0, n1, n2), name="B")
 
-    C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
+    C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name="C")
 
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
     s = te.create_schedule(C.op)
 
     def check():
-        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
+        fadd = tvm.build(s, [A, B, C], target="llvm", name="bcast_add", binds={A: Ab, B: Bb})
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx)
@@ -148,13 +160,13 @@ def test_buffer_broadcast():
 
 @tvm.testing.requires_llvm
 def test_buffer_broadcast_expr():
-    n0, m0, x = te.size_var('n0'), te.size_var('m0'), te.size_var('x')
-    n1, m1 = te.size_var('n1'), te.size_var('m1')
-    o0, o1 = te.size_var('o0'), te.size_var('o1')
+    n0, m0, x = te.size_var("n0"), te.size_var("m0"), te.size_var("x")
+    n1, m1 = te.size_var("n1"), te.size_var("m1")
+    o0, o1 = te.size_var("o0"), te.size_var("o1")
 
-    A = te.placeholder((m0, n0), name='A')
-    B = te.placeholder((m1, n1), name='B')
-    C = te.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C')
+    A = te.placeholder((m0, n0), name="A")
+    B = te.placeholder((m1, n1), name="B")
+    C = te.compute((o0, o1 // x), lambda i, j: A[i, j] + B[i, j], name="C")
 
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
     Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
@@ -162,8 +174,9 @@ def test_buffer_broadcast_expr():
     s = te.create_schedule(C.op)
 
     def check_stride():
-        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
-                         binds={A:Ab, B:Bb, C:Cc})
+        fadd = tvm.build(
+            s, [A, B, C, o1, x], target="llvm", name="bcast_add", binds={A: Ab, B: Bb, C: Cc}
+        )
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
@@ -172,8 +185,9 @@ def test_buffer_broadcast_expr():
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
     def check_no_stride():
-        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
-                         binds={A: Ab, B: Bb, C: Cc})
+        fadd = tvm.build(
+            s, [A, B, C, o1, x], target="llvm", name="bcast_add", binds={A: Ab, B: Bb, C: Cc}
+        )
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
@@ -183,7 +197,7 @@ def test_buffer_broadcast_expr():
 
     def check_auto_bind():
         # Let build bind buffers
-        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add')
+        fadd = tvm.build(s, [A, B, C, o1, x], target="llvm", name="bcast_add")
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
index 578e32f..3cde5d7 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_expr_constructor():
     x = tvm.tir.Var("xx", "float32")
     assert isinstance(x, tvm.tir.Var)
     assert x.name == "xx"
 
-    x = tvm.tir.Reduce(None, [1],
-                       [tvm.tir.IterVar((0, 1), "x", 2)],
-                       None, 0)
+    x = tvm.tir.Reduce(None, [1], [tvm.tir.IterVar((0, 1), "x", 2)], None, 0)
     assert isinstance(x, tvm.tir.Reduce)
     assert x.combiner == None
     assert x.value_index == 0
@@ -51,28 +50,28 @@ def test_expr_constructor():
     a = tvm.tir.const(1.0, dtype="float32")
     b = te.var("x", dtype="float32")
 
-    for cls in [tvm.tir.Add,
-                tvm.tir.Sub,
-                tvm.tir.Mul,
-                tvm.tir.Div,
-                tvm.tir.Mod,
-                tvm.tir.Min,
-                tvm.tir.Max,
-                tvm.tir.LT,
-                tvm.tir.LE,
-                tvm.tir.GT,
-                tvm.tir.GE]:
+    for cls in [
+        tvm.tir.Add,
+        tvm.tir.Sub,
+        tvm.tir.Mul,
+        tvm.tir.Div,
+        tvm.tir.Mod,
+        tvm.tir.Min,
+        tvm.tir.Max,
+        tvm.tir.LT,
+        tvm.tir.LE,
+        tvm.tir.GT,
+        tvm.tir.GE,
+    ]:
         x = cls(a, b)
         assert isinstance(x, cls)
         assert x.a == a
         assert x.b.same_as(b)
 
-
     a = tvm.runtime.convert(te.var("x") > 1)
     b = tvm.runtime.convert(te.var("x") == 1)
 
-    for cls in [tvm.tir.And,
-                tvm.tir.Or]:
+    for cls in [tvm.tir.And, tvm.tir.Or]:
         x = cls(a, b)
         assert isinstance(x, cls)
         assert x.a == a
@@ -139,9 +138,7 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"),
-                            tvm.runtime.convert("hellow"),
-                            nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), tvm.runtime.convert("hellow"), nop)
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
@@ -157,8 +154,7 @@ def test_stmt_constructor():
     assert x.index.value == 10
     assert x.value.value == 1
 
-    x = tvm.tir.Allocate(buffer_var, "float32", [10],
-                          tvm.tir.const(1, "uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -170,9 +166,7 @@ def test_stmt_constructor():
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"),
-                            tvm.tir.Evaluate(11),
-                            nop)
+    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop)
     assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
index c3a6661..22c24fa 100644 (file)
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 from tvm.topi.util import get_const_tuple
 
+
 def test_layout():
     layout = tvm.tir.layout("NCHW16c")
     assert layout is not None
@@ -50,6 +51,7 @@ def test_layout():
     assert layout[4] == "c"
     assert layout[-1] == "c"
 
+
 def test_bilayout_convertible():
     # not convertible
     assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
@@ -62,6 +64,7 @@ def test_bilayout_convertible():
     # convertible
     assert tvm.tir.bijective_layout("NCHW", "NCHW16c") is not None
 
+
 def test_bilayout_shape():
     bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
     assert isinstance(bilayout, tvm.tir.BijectiveLayout)
@@ -72,6 +75,7 @@ def test_bilayout_shape():
     src_shape = bilayout.backward_shape(dst_shape)
     assert get_const_tuple(src_shape) == (1, 32, 7, 7)
 
+
 def test_bilayout_index():
     bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
 
@@ -81,6 +85,7 @@ def test_bilayout_index():
     src_index = bilayout.backward_index([0, 1, 6, 6, 2])
     assert get_const_tuple(src_index) == (0, 18, 6, 6)
 
+
 if __name__ == "__main__":
     test_layout()
     test_bilayout_convertible()
index 0920603..0a1fe56 100644 (file)
@@ -24,16 +24,17 @@ import math
 
 
 def test_nearbyint():
-    m = te.var("m",)
-    A = te.placeholder((m,), name='A')
-    A_rounded = te.compute((m,), lambda *i: tvm.tir.nearbyint(A(*i)), name='A')
+    m = te.var(
+        "m",
+    )
+    A = te.placeholder((m,), name="A")
+    A_rounded = te.compute((m,), lambda *i: tvm.tir.nearbyint(A(*i)), name="A")
     s = te.create_schedule(A_rounded.op)
     f = tvm.build(s, [A, A_rounded], "llvm")
     ctx = tvm.cpu(0)
     n = 10
     a = tvm.nd.array(np.random.uniform(high=100, size=n).astype(A.dtype), ctx)
-    a_rounded = tvm.nd.array( \
-            np.random.uniform(size=n).astype(A_rounded.dtype), ctx)
+    a_rounded = tvm.nd.array(np.random.uniform(size=n).astype(A_rounded.dtype), ctx)
     f(a, a_rounded)
     # Note that numpys rint rounds to nearest integer with
     # ties to halfway is broken by rounding to even.
@@ -41,49 +42,49 @@ def test_nearbyint():
     # This is the default rounding mode with libc as well.
     # However one can set a different rounding mode and in that
     # case numpy result might differ.
-    tvm.testing.assert_allclose(
-        a_rounded.asnumpy(), np.rint(a.asnumpy()))
+    tvm.testing.assert_allclose(a_rounded.asnumpy(), np.rint(a.asnumpy()))
+
 
 def test_round_intrinsics_on_int():
-    i = tvm.te.var("i", 'int32')
-    for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
-                            tvm.tir.floor, tvm.tir.nearbyint]:
-        assert op(tvm.tir.const(10,'int32')).value == 10
-        assert op(tvm.tir.const(True,'bool')).value == True
+    i = tvm.te.var("i", "int32")
+    for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil, tvm.tir.floor, tvm.tir.nearbyint]:
+        assert op(tvm.tir.const(10, "int32")).value == 10
+        assert op(tvm.tir.const(True, "bool")).value == True
         assert op(i).same_as(i)
 
-    assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False
+    assert tvm.tir.isnan(tvm.tir.const(10, "int32")).value == False
 
 
 def test_unary_intrin():
     test_funcs = [
-        (tvm.tir.exp10, lambda x : np.power(10, x)),
-        (tvm.tir.log2, lambda x : np.log2(x)),
-        (tvm.tir.log10, lambda x : np.log10(x)),
-        (tvm.tir.sinh, lambda x : np.sinh(x)),
-        (tvm.tir.cosh, lambda x : np.cosh(x)),
-        (tvm.tir.log1p, lambda x : np.log1p(x)),
-        (tvm.tir.asin, lambda x : np.arcsin(x)),
-        (tvm.tir.acos, lambda x : np.arccos(x)),
-        (tvm.tir.atan, lambda x : np.arctan(x)),
-        (tvm.tir.asinh, lambda x : np.arcsinh(x)),
-        (tvm.tir.acosh, lambda x : np.arccosh(x)),
-        (tvm.tir.atanh, lambda x : np.arctanh(x)),
+        (tvm.tir.exp10, lambda x: np.power(10, x)),
+        (tvm.tir.log2, lambda x: np.log2(x)),
+        (tvm.tir.log10, lambda x: np.log10(x)),
+        (tvm.tir.sinh, lambda x: np.sinh(x)),
+        (tvm.tir.cosh, lambda x: np.cosh(x)),
+        (tvm.tir.log1p, lambda x: np.log1p(x)),
+        (tvm.tir.asin, lambda x: np.arcsin(x)),
+        (tvm.tir.acos, lambda x: np.arccos(x)),
+        (tvm.tir.atan, lambda x: np.arctan(x)),
+        (tvm.tir.asinh, lambda x: np.arcsinh(x)),
+        (tvm.tir.acosh, lambda x: np.arccosh(x)),
+        (tvm.tir.atanh, lambda x: np.arctanh(x)),
     ]
+
     def run_test(tvm_intrin, np_func):
-        m = te.var("m",)
-        A = te.placeholder((m,), name='A')
-        B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name='B')
+        m = te.var(
+            "m",
+        )
+        A = te.placeholder((m,), name="A")
+        B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B")
         s = te.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm")
         ctx = tvm.cpu(0)
         n = 10
         a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx)
-        b = tvm.nd.array( \
-            np.random.uniform(size=n).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         f(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)
+        tvm.testing.assert_allclose(b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)
 
     for func in test_funcs:
         run_test(*func)
@@ -91,37 +92,42 @@ def test_unary_intrin():
 
 def test_binary_intrin():
     test_funcs = [
-        (tvm.tir.atan2, lambda x1, x2 : np.arctan2(x1, x2)),
-        (tvm.tir.nextafter, lambda x1, x2 : np.nextafter(x1, x2)),
-        (tvm.tir.copysign, lambda x1, x2 : np.copysign(x1, x2)),
-        (tvm.tir.hypot, lambda x1, x2 : np.hypot(x1, x2)),
+        (tvm.tir.atan2, lambda x1, x2: np.arctan2(x1, x2)),
+        (tvm.tir.nextafter, lambda x1, x2: np.nextafter(x1, x2)),
+        (tvm.tir.copysign, lambda x1, x2: np.copysign(x1, x2)),
+        (tvm.tir.hypot, lambda x1, x2: np.hypot(x1, x2)),
     ]
+
     def run_test(tvm_intrin, np_func):
-        m = te.var("m",)
-        A = te.placeholder((m,), name='A')
-        B = te.placeholder((m,), name='B')
-        C = te.compute((m,), lambda *i: tvm_intrin(A(*i), B(*i)), name='C')
+        m = te.var(
+            "m",
+        )
+        A = te.placeholder((m,), name="A")
+        B = te.placeholder((m,), name="B")
+        C = te.compute((m,), lambda *i: tvm_intrin(A(*i), B(*i)), name="C")
         s = te.create_schedule(C.op)
         f = tvm.build(s, [A, B, C], "llvm")
         ctx = tvm.cpu(0)
         n = 10
         a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(B.dtype), ctx)
-        c = tvm.nd.array( \
-            np.random.uniform(size=n).astype(A.dtype), ctx)
+        c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
         f(a, b, c)
         tvm.testing.assert_allclose(
-            c.asnumpy(), np_func(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5)
+            c.asnumpy(), np_func(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5
+        )
 
     for func in test_funcs:
         run_test(*func)
 
 
 def test_ldexp():
-    m = te.var("m",)
-    A = te.placeholder((m,), name='A')
-    B = te.placeholder((m,), name='B', dtype="int32")
-    C = te.compute((m,), lambda *i: tvm.tir.ldexp(A(*i), B(*i)), name='C')
+    m = te.var(
+        "m",
+    )
+    A = te.placeholder((m,), name="A")
+    B = te.placeholder((m,), name="B", dtype="int32")
+    C = te.compute((m,), lambda *i: tvm.tir.ldexp(A(*i), B(*i)), name="C")
     s = te.create_schedule(C.op)
     f = tvm.build(s, [A, B, C], "llvm")
     ctx = tvm.cpu(0)
@@ -131,7 +137,8 @@ def test_ldexp():
     c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
     f(a, b, c)
     tvm.testing.assert_allclose(
-        c.asnumpy(), np.ldexp(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5)
+        c.asnumpy(), np.ldexp(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5
+    )
 
 
 if __name__ == "__main__":
index 7664806..b84ee09 100644 (file)
@@ -19,6 +19,7 @@ from tvm import te
 import numpy as np
 import tvm.testing
 
+
 def test_for():
     ib = tvm.tir.ir_builder.create()
     n = te.size_var("n")
@@ -38,6 +39,7 @@ def test_for():
     assert isinstance(body, tvm.tir.SeqStmt)
     assert isinstance(body[1], tvm.tir.For)
 
+
 def test_if():
     ib = tvm.tir.ir_builder.create()
     n = te.size_var("n")
@@ -58,6 +60,7 @@ def test_if():
     assert isinstance(body.then_case.index, tvm.tir.Var)
     assert body.else_case.index.value == 0
 
+
 def test_prefetch():
     A = tvm.tir.decl_buffer((10, 20), name="A")
     ib = tvm.tir.ir_builder.create()
@@ -65,17 +68,20 @@ def test_prefetch():
 
     with ib.for_range(0, n, name="i") as i:
         ib.emit(
-            tvm.tir.Prefetch(A,
-                [tvm.ir.Range.from_min_extent(i+1, 2),
-                 tvm.ir.Range.from_min_extent(0, 20)]))
+            tvm.tir.Prefetch(
+                A, [tvm.ir.Range.from_min_extent(i + 1, 2), tvm.ir.Range.from_min_extent(0, 20)]
+            )
+        )
     body = ib.get()
     assert body.body.bounds[0].extent.value == 2
 
+
 def test_cpu():
     n = 1024
     dtype = "float32"
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+
     def test_device_ir(A, B, C):
         n = A.shape[0]
         max_threads = 8
@@ -87,9 +93,16 @@ def test_cpu():
             Cptr[i] = Aptr[i] + Bptr[i]
         body = ib.get()
         return body
-    C = te.extern(A.shape, [A, B], lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
-                   name="vector_add", dtype=dtype)
+
+    C = te.extern(
+        A.shape,
+        [A, B],
+        lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
+        name="vector_add",
+        dtype=dtype,
+    )
     s = te.create_schedule(C.op)
+
     def check_target(target):
         if not tvm.testing.device_enabled(target):
             return
@@ -102,14 +115,16 @@ def test_cpu():
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         fadd(a, b, c)
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
     check_target("llvm")
 
+
 @tvm.testing.requires_gpu
 def test_gpu():
-    n = te.size_var('n')
+    n = te.size_var("n")
     dtype = "float32"
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
     idxd = tvm.tir.indexdiv
 
     def test_device_ir(A, B, C):
@@ -118,21 +133,28 @@ def test_gpu():
         ib = tvm.tir.ir_builder.create()
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
+        ib.scope_attr(bx, "thread_extent", idxd(n + max_threads - 1, max_threads))
         ib.scope_attr(tx, "thread_extent", max_threads)
         idx = bx.var * max_threads + tx.var
         Aptr = ib.buffer_ptr(A)
         Bptr = ib.buffer_ptr(B)
         Cptr = ib.buffer_ptr(C)
-        with ib.if_scope(ib.likely(idx<n)):
+        with ib.if_scope(ib.likely(idx < n)):
             Cptr[idx] = Aptr[idx] + Bptr[idx]
         body = ib.get()
         return body
-    C = te.extern(A.shape, [A, B], lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
-                   name="vector_add", dtype=dtype)
+
+    C = te.extern(
+        A.shape,
+        [A, B],
+        lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
+        name="vector_add",
+        dtype=dtype,
+    )
     s = te.create_schedule(C.op)
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+
     def check_target(target):
         n = 1024
         if not tvm.testing.device_enabled(target):
@@ -146,9 +168,11 @@ def test_gpu():
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         fadd(a, b, c)
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
     check_target("opencl")
     check_target("cuda")
 
+
 if __name__ == "__main__":
     test_prefetch()
     test_if()
index c182d9e..4d57ed8 100644 (file)
@@ -19,7 +19,6 @@ from tvm import te
 import numpy as np
 
 
-
 def test_const():
     x = tvm.tir.const(1, "int32")
     print(x.dtype)
@@ -28,19 +27,43 @@ def test_const():
 
 
 def test_scalar_dtype_inference():
-    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
-                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
-                 np.float16(1), np.float32(1), np.float64(1)]:
+    for data in [
+        True,
+        np.bool(1),
+        np.uint8(1),
+        np.uint16(1),
+        np.uint32(1),
+        np.uint64(1),
+        np.int8(1),
+        np.int16(1),
+        np.int32(1),
+        np.int64(1),
+        np.float16(1),
+        np.float32(1),
+        np.float64(1),
+    ]:
         assert tvm.tir.const(data).dtype == str(np.array(data).dtype)
-    assert tvm.tir.const(1).dtype == 'int32'
-    assert tvm.tir.const(1.0).dtype == 'float32'
-
-    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
-                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
-                 np.float16(1), np.float32(1), np.float64(1)]:
+    assert tvm.tir.const(1).dtype == "int32"
+    assert tvm.tir.const(1.0).dtype == "float32"
+
+    for data in [
+        True,
+        np.bool(1),
+        np.uint8(1),
+        np.uint16(1),
+        np.uint32(1),
+        np.uint64(1),
+        np.int8(1),
+        np.int16(1),
+        np.int32(1),
+        np.int64(1),
+        np.float16(1),
+        np.float32(1),
+        np.float64(1),
+    ]:
         assert tvm.runtime.convert(data).dtype == str(np.array(data).dtype)
-    assert tvm.runtime.convert(1).dtype == 'int32'
-    assert tvm.runtime.convert(1.0).dtype == 'float32'
+    assert tvm.runtime.convert(1).dtype == "int32"
+    assert tvm.runtime.convert(1.0).dtype == "float32"
 
 
 def test_make():
@@ -53,7 +76,7 @@ def test_make():
 
 def test_ir():
     x = tvm.tir.const(1, "int32")
-    y = tvm.tir.IntImm('int32', 1)
+    y = tvm.tir.IntImm("int32", 1)
     z = x + y
     stmt = tvm.tir.Evaluate(z)
     assert isinstance(stmt, tvm.tir.Evaluate)
@@ -64,18 +87,17 @@ def test_ir2():
     a = te.var("array", "handle")
     st = tvm.tir.Store(a, x + 1, 1)
     assert isinstance(st, tvm.tir.Store)
-    assert(st.buffer_var == a)
+    assert st.buffer_var == a
 
 
 def test_let():
-    x = te.var('x')
-    y = te.var('y')
-    stmt = tvm.tir.LetStmt(
-        x, 10, tvm.tir.Evaluate(x + 1));
+    x = te.var("x")
+    y = te.var("y")
+    stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
 
 def test_cast():
-    x = te.var('x', dtype="float32")
+    x = te.var("x", dtype="float32")
     y = x.astype("int32")
     z = x.astype("float32x4")
     assert isinstance(y, tvm.tir.Cast)
@@ -84,10 +106,9 @@ def test_cast():
 
 
 def test_attr():
-    x = te.var('x')
-    y = te.var('y')
-    stmt = tvm.tir.AttrStmt(
-        y, "stride", 10, tvm.tir.Evaluate(x + 1));
+    x = te.var("x")
+    y = te.var("y")
+    stmt = tvm.tir.AttrStmt(y, "stride", 10, tvm.tir.Evaluate(x + 1))
     assert stmt.node == y
 
     a = tvm.runtime.convert(1)
@@ -100,34 +121,33 @@ def test_attr():
 
 
 def test_basic():
-    a = te.var('a')
-    b = te.var('b')
-    c =  a + b
-    assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name)
+    a = te.var("a")
+    b = te.var("b")
+    c = a + b
+    assert str(c) == "(%s: int32 + %s: int32)" % (a.name, b.name)
 
 
 def test_stmt():
     x = tvm.tir.Evaluate(0)
-    tvm.tir.For(te.var('i'), 0, 1,
-                 tvm.tir.For.Serial, 0,
-                 x)
+    tvm.tir.For(te.var("i"), 0, 1, tvm.tir.For.Serial, 0, x)
+
 
 def test_dir():
-    x = te.var('x')
+    x = te.var("x")
     dir(x)
 
 
 def test_dtype():
-    x = te.var('x')
-    assert x.dtype == 'int32'
-    y = te.var('y')
-    assert (x > y).dtype == 'bool'
+    x = te.var("x")
+    assert x.dtype == "int32"
+    y = te.var("y")
+    assert (x > y).dtype == "bool"
 
 
 def test_any():
-    x = te.var('x')
-    y = te.var('y')
-    z = te.var('z')
+    x = te.var("x")
+    y = te.var("y")
+    z = te.var("z")
     try:
         t = x or x
         assert False
@@ -138,18 +158,29 @@ def test_any():
         assert False
     except ValueError:
         pass
-    assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
-    assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % (
-        x.name, y.name, x.name, z.name)
-    assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \
-        '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % (
-            x.name, y.name, y.name, z.name, x.name, z.name)
+    assert str(tvm.tir.any(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name)
+    assert str(tvm.tir.any(x < y, x > z)) == "((%s: int32 < %s: int32) || (%s > %s: int32))" % (
+        x.name,
+        y.name,
+        x.name,
+        z.name,
+    )
+    assert str(
+        tvm.tir.any(x < y, y > z + 1, x < z * 2)
+    ) == "(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))" % (
+        x.name,
+        y.name,
+        y.name,
+        z.name,
+        x.name,
+        z.name,
+    )
 
 
 def test_all():
-    x = te.var('x')
-    y = te.var('y')
-    z = te.var('z')
+    x = te.var("x")
+    y = te.var("y")
+    z = te.var("z")
     try:
         t = x and x
         assert False
@@ -160,44 +191,57 @@ def test_all():
         assert False
     except ValueError:
         pass
-    assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
-    assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % (
-        x.name, y.name, x.name, z.name)
-    assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \
-        '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % (
-            x.name, y.name, y.name, z.name, x.name, z.name)
+    assert str(tvm.tir.all(x < y)) == "(%s: int32 < %s: int32)" % (x.name, y.name)
+    assert str(tvm.tir.all(x < y, x > z)) == "((%s: int32 < %s: int32) && (%s > %s: int32))" % (
+        x.name,
+        y.name,
+        x.name,
+        z.name,
+    )
+    assert str(
+        tvm.tir.all(x < y, y > z + 1, x < z * 2)
+    ) == "(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))" % (
+        x.name,
+        y.name,
+        y.name,
+        z.name,
+        x.name,
+        z.name,
+    )
 
 
 def test_bitwise():
-    x = te.var('x')
-    y = te.var('y')
-    assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32)'
-    assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32)'
-    assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32)'
-    assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32)'
-    assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32)'
-    assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32)'
-    assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32)'
-    assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32)'
-    assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32)'
-    assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32)'
-    assert str(10 % x) == 'floormod(10, x: int32)'
-
-    assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32)'
-    assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
-    assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
-    assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
+    x = te.var("x")
+    y = te.var("y")
+    assert str(x << y) == "@tir.shift_left(x: int32, y: int32, dtype=int32)"
+    assert str(x >> y) == "@tir.shift_right(x: int32, y: int32, dtype=int32)"
+    assert str(x & y) == "@tir.bitwise_and(x: int32, y: int32, dtype=int32)"
+    assert str(x | y) == "@tir.bitwise_or(x: int32, y: int32, dtype=int32)"
+    assert str(x ^ y) == "@tir.bitwise_xor(x: int32, y: int32, dtype=int32)"
+    assert str(10 & x) == "@tir.bitwise_and(10, x: int32, dtype=int32)"
+    assert str(10 | x) == "@tir.bitwise_or(10, x: int32, dtype=int32)"
+    assert str(10 ^ x) == "@tir.bitwise_xor(10, x: int32, dtype=int32)"
+    assert str(10 >> x) == "@tir.shift_right(10, x: int32, dtype=int32)"
+    assert str(10 << x) == "@tir.shift_left(10, x: int32, dtype=int32)"
+    assert str(10 % x) == "floormod(10, x: int32)"
+
+    assert str(~x) == "@tir.bitwise_not(x: int32, dtype=int32)"
+    assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
+    assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
+    assert (te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
 
 
 def test_float_bitwise():
-    t = tvm.tir.const(1.5,dtype='float32')
-    for test in [lambda lhs, rhs : lhs << rhs,
-                    lambda lhs, rhs : lhs >> rhs,
-                    lambda lhs, rhs : lhs | rhs,
-                    lambda lhs, rhs : lhs ^ rhs,
-                    lambda lhs, rhs : lhs & rhs]:
+    t = tvm.tir.const(1.5, dtype="float32")
+    for test in [
+        lambda lhs, rhs: lhs << rhs,
+        lambda lhs, rhs: lhs >> rhs,
+        lambda lhs, rhs: lhs | rhs,
+        lambda lhs, rhs: lhs ^ rhs,
+        lambda lhs, rhs: lhs & rhs,
+    ]:
         try:
-            test(t,10.0)
+            test(t, 10.0)
             assert False
         except tvm.TVMError:
             pass
@@ -209,71 +253,71 @@ def test_float_bitwise():
 
 
 def test_shift_bounds():
-    x = te.var('x')
-    for test in [lambda lhs, rhs : lhs << rhs,
-                    lambda lhs, rhs : lhs >> rhs]:
-        #negative case
-        for testcase in [(x,-1), (x,32)]:
+    x = te.var("x")
+    for test in [lambda lhs, rhs: lhs << rhs, lambda lhs, rhs: lhs >> rhs]:
+        # negative case
+        for testcase in [(x, -1), (x, 32)]:
             try:
                 test(*testcase)
                 assert False
             except tvm.TVMError:
                 pass
 
-        #positive case
-        for testcase in [(x,0), (x,16), (x,31)]:
+        # positive case
+        for testcase in [(x, 0), (x, 16), (x, 31)]:
             test(*testcase)
 
 
 def test_divide_by_zero():
-    for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
-                    lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
-                    lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
-                    lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
-                    lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
+    for test in [
+        lambda lhs, rhs: tvm.tir.floormod(lhs, rhs),
+        lambda lhs, rhs: tvm.tir.floordiv(lhs, rhs),
+        lambda lhs, rhs: tvm.tir.truncmod(lhs, rhs),
+        lambda lhs, rhs: tvm.tir.truncdiv(lhs, rhs),
+        lambda lhs, rhs: tvm.tir.div(lhs, rhs),
+    ]:
         try:
-            test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
+            test(tvm.tir.const(5, "int32"), tvm.tir.const(0, "int32"))
             assert False
         except tvm.TVMError:
             pass
 
 
 def test_isnan():
-    x = te.var('x', 'float32')
-    assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool)'
-    assert str(tvm.tir.isnan(x).dtype) == 'bool'
-    y = te.var('y', 'float16')
-    assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool)'
-    z = te.var('z', 'int32')
-    assert str(tvm.tir.isnan(z)) == 'False'
-    k = te.var('k', 'int8x2')
-    assert str(tvm.tir.isnan(k).dtype) == 'uint1x2'
+    x = te.var("x", "float32")
+    assert str(tvm.tir.isnan(x)) == "@tir.isnan(x: float32, dtype=bool)"
+    assert str(tvm.tir.isnan(x).dtype) == "bool"
+    y = te.var("y", "float16")
+    assert str(tvm.tir.isnan(y)) == "@tir.isnan(cast(float32, y: float16), dtype=bool)"
+    z = te.var("z", "int32")
+    assert str(tvm.tir.isnan(z)) == "False"
+    k = te.var("k", "int8x2")
+    assert str(tvm.tir.isnan(k).dtype) == "uint1x2"
 
 
 def test_equality():
-    a = te.var('a')
-    b = te.var('b')
-    c = (a == b)
+    a = te.var("a")
+    b = te.var("b")
+    c = a == b
     assert not c
-    d = (c != c)
+    d = c != c
     assert not d
 
 
 def test_equality_string_imm():
-    x = 'a'
+    x = "a"
     y = tvm.tir.StringImm(x)
     x == y.value
     x == y
 
+
 def test_prim_func():
-    x = te.var('x')
-    y = te.var('y')
+    x = te.var("x")
+    y = te.var("y")
     b = tvm.tir.decl_buffer((x,), "float32")
-    stmt = tvm.tir.LetStmt(
-        x, 10, tvm.tir.Evaluate(x + 1));
+    stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
-    func = tvm.tir.PrimFunc(
-        [x, y, b], stmt)
+    func = tvm.tir.PrimFunc([x, y, b], stmt)
     # make sure we can print
     func.astext()
     assert func.buffer_map[func.params[2]].same_as(b)
@@ -303,8 +347,7 @@ def test_buffer_load_store():
     s = tvm.tir.BufferStore(b, 0.1, [0])
     assert isinstance(s, tvm.tir.BufferStore)
 
-    s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)],
-                              True, tvm.tir.Evaluate(0))
+    s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)], True, tvm.tir.Evaluate(0))
     assert isinstance(s, tvm.tir.BufferRealize)
 
 
index 65d87be..f1f8cf7 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def check_throws(f):
     try:
         f()
@@ -59,24 +60,30 @@ def test_const_fold2():
     assert (1 * x).same_as(x)
     assert isinstance(tdiv(1, x), tvm.tir.Div)
 
+
 def test_const_fold3():
     # Test that using ints with logic operations is forbidden
     x = te.var("x")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
-            check_throws(lambda: func(tvm.tir.const(val, 'uint1'), x))
-            check_throws(lambda: func(x, tvm.tir.const(val, 'uint1')))
+            check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
+            check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
 
     # Test const folding when both arguments are const
-    for tvm_func, py_func in [(tvm.tir.all, lambda a, b: a and b), (tvm.tir.any, lambda a, b: a or b)]:
+    for tvm_func, py_func in [
+        (tvm.tir.all, lambda a, b: a and b),
+        (tvm.tir.any, lambda a, b: a or b),
+    ]:
         for v1 in [0, 1]:
             for v2 in [0, 1]:
-                assert tvm.ir.structural_equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')),
-                                         tvm.tir.const(py_func(v1, v2), 'uint1'))
+                assert tvm.ir.structural_equal(
+                    tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, "uint1")),
+                    tvm.tir.const(py_func(v1, v2), "uint1"),
+                )
 
-    x = te.var("x", 'uint1')
-    true = tvm.tir.const(1, 'uint1')
-    false = tvm.tir.const(0, 'uint1')
+    x = te.var("x", "uint1")
+    true = tvm.tir.const(1, "uint1")
+    false = tvm.tir.const(0, "uint1")
 
     assert tvm.tir.all(x, true).same_as(x)
     assert tvm.tir.all(true, x).same_as(x)
@@ -100,46 +107,48 @@ def test_const_fold4():
     assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
     x5 = te.ceil(x4)
     assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
-    x6 = x5.astype('int')
+    x6 = x5.astype("int")
     assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
-    y = (te.round((tvm.tir.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
+    y = (te.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) + 2).astype("int")
     assert isinstance(y, tvm.tir.IntImm) and y.value == 6
 
 
 def test_binary_dtype_match():
     def verify_general_dtype_support(f, is_conditional=False):
-        rules = [[('bool', 'int32'), 'int32'],
-                 [('int32', 'float32'), 'float32'],
-                 [('int32', 'int64'), 'int64'],
-                 [('uint32', 'int32'), 'int32']]
+        rules = [
+            [("bool", "int32"), "int32"],
+            [("int32", "float32"), "float32"],
+            [("int32", "int64"), "int64"],
+            [("uint32", "int32"), "int32"],
+        ]
         for (lhs_dtype, rhs_dtype), out_dtype in rules:
-            lhs = te.var('lhs', dtype=lhs_dtype)
-            rhs = te.var('rhs', dtype=rhs_dtype)
+            lhs = te.var("lhs", dtype=lhs_dtype)
+            rhs = te.var("rhs", dtype=rhs_dtype)
             out = f(lhs, rhs)
             if not is_conditional:
                 assert out.dtype == out_dtype
             else:
-                assert out.dtype == 'bool'
-            if hasattr(out, 'a'):
+                assert out.dtype == "bool"
+            if hasattr(out, "a"):
                 assert out.a.dtype == out_dtype
                 assert out.b.dtype == out_dtype
-            elif hasattr(out, 'args'):
+            elif hasattr(out, "args"):
                 # CallOp
                 assert out.args[0].dtype == out_dtype
                 assert out.args[1].dtype == out_dtype
             else:
-                raise ValueError('Unknown binary op format!')
+                raise ValueError("Unknown binary op format!")
 
     def verify_callop_float_only(f):
-        for lhs_dtype in ['int32', 'float32', 'float64']:
-            for rhs_dtype in ['int32', 'float32', 'float64']:
-                lhs = te.var('lhs', dtype=lhs_dtype)
-                rhs = te.var('rhs', dtype=rhs_dtype)
-                if 'float' not in lhs_dtype and 'float' not in rhs_dtype:
+        for lhs_dtype in ["int32", "float32", "float64"]:
+            for rhs_dtype in ["int32", "float32", "float64"]:
+                lhs = te.var("lhs", dtype=lhs_dtype)
+                rhs = te.var("rhs", dtype=rhs_dtype)
+                if "float" not in lhs_dtype and "float" not in rhs_dtype:
                     check_throws(lambda: f(lhs, rhs))
-                elif 'float' in lhs_dtype and 'float' in rhs_dtype and lhs_dtype != rhs_dtype:
+                elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype:
                     check_throws(lambda: f(lhs, rhs))
-                elif 'float' in lhs_dtype:
+                elif "float" in lhs_dtype:
                     out = f(lhs, rhs)
                     assert out.dtype == lhs_dtype
                     assert out.args[0].dtype == lhs_dtype
@@ -158,14 +167,16 @@ def test_binary_dtype_match():
 
 
 def test_if_then_else():
-    cases = [[(te.var('cond', dtype='bool'), 'bool', 'int32'), 'int32'],
-             [(True, 'int32', 'float32'), 'float32'],
-             [(False, 'int32', 'int64'), 'int64'],
-             [(te.var('cond', dtype='bool'), 'uint32', 'int32'), 'int32'],
-             [(te.var('cond', dtype='int32'), 'uint32', 'int32'), 'int32']]
+    cases = [
+        [(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
+        [(True, "int32", "float32"), "float32"],
+        [(False, "int32", "int64"), "int64"],
+        [(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"],
+        [(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"],
+    ]
     for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
-        lhs = te.var('lhs', dtype=lhs_dtype)
-        rhs = te.var('rhs', dtype=rhs_dtype)
+        lhs = te.var("lhs", dtype=lhs_dtype)
+        rhs = te.var("rhs", dtype=rhs_dtype)
         if cond is True or cond is False:
             out = tvm.tir.if_then_else(cond, lhs, rhs)
             out2 = tvm.tir.if_then_else(not cond, rhs, lhs)
@@ -177,15 +188,15 @@ def test_if_then_else():
             else:
                 assert tvm.ir.structural_equal(out, rhs.astype(out_dtype)) == 1
                 assert tvm.ir.structural_equal(out3, lhs.astype(out_dtype)) == 1
-        elif cond.dtype == 'bool':
+        elif cond.dtype == "bool":
             out = tvm.tir.if_then_else(cond, lhs, rhs)
             assert out.dtype == out_dtype
             assert out.args[1].dtype == out_dtype
             assert out.args[2].dtype == out_dtype
-        elif cond.dtype != 'bool':
+        elif cond.dtype != "bool":
             check_throws(lambda: tvm.tir.if_then_else(cond, lhs, rhs))
         else:
-            raise ValueError('Unknown combinations')
+            raise ValueError("Unknown combinations")
 
 
 if __name__ == "__main__":
index 61accf2..c79ba68 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_ir_transform():
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
@@ -38,10 +39,12 @@ def test_ir_transform():
         if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA":
             return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1)
         return op
+
     body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"])
     stmt_list = tvm.tir.stmt_list(body.body.body)
     assert stmt_list[0].value.args[1].args[0].value == "TestB"
     assert stmt_list[1].value.value == 0
 
+
 if __name__ == "__main__":
     test_ir_transform()
index 593b845..1e2f846 100644 (file)
@@ -30,14 +30,18 @@ def consistent_equal(x, y, map_free_vars=False):
     if struct_equal0 != struct_equal1:
         raise ValueError(
             "Non-communicative {} vs {}, sequal0={}, sequal1={}".format(
-                x, y, struct_equal0, struct_equal1))
+                x, y, struct_equal0, struct_equal1
+            )
+        )
 
     # NOTE: hash colision can happen but should be rare.
     # we can confirm that hash colison doesn't happen for our testcases
     if struct_equal0 != (xhash == yhash):
         raise ValueError(
             "Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
-                x, y, struct_equal0, xhash, yhash))
+                x, y, struct_equal0, xhash, yhash
+            )
+        )
     return struct_equal0
 
 
@@ -51,8 +55,7 @@ def test_exprs():
     zx = vx + vx
     zy = vy + vy
 
-    assert consistent_equal(zx * zx, (vx + vx) * (vx + vx),
-                            map_free_vars=False)
+    assert consistent_equal(zx * zx, (vx + vx) * (vx + vx), map_free_vars=False)
 
     # test assert trigger.
     with pytest.raises(ValueError):
@@ -68,11 +71,9 @@ def test_exprs():
     assert consistent_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
     assert not consistent_equal(vx + 1, vy + 1, map_free_vars=False)
     # Defintition remap
-    assert consistent_equal(tvm.tir.Let(vx, 1, vx - 1),
-                            tvm.tir.Let(vy, 1, vy - 1))
+    assert consistent_equal(tvm.tir.Let(vx, 1, vx - 1), tvm.tir.Let(vy, 1, vy - 1))
     # Default same address free var remap
-    assert consistent_equal(tvm.tir.Let(vx, 1, vx // vz),
-                            tvm.tir.Let(vy, 1, vy // vz))
+    assert consistent_equal(tvm.tir.Let(vx, 1, vx // vz), tvm.tir.Let(vy, 1, vy // vz))
 
     assert consistent_equal(zx * zx, zx * zx)
     assert consistent_equal(zx * zx, zy * zy, map_free_vars=True)
@@ -80,21 +81,17 @@ def test_exprs():
 
 
 def test_prim_func():
-    x = te.var('x')
-    y = te.var('y')
+    x = te.var("x")
+    y = te.var("y")
     # counter example of same equality
-    func0 = tvm.tir.PrimFunc(
-        [x, y], tvm.tir.Evaluate(x + y))
-    func1 = tvm.tir.PrimFunc(
-        [x, y], tvm.tir.Evaluate(y + x))
+    func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x + y))
+    func1 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(y + x))
     assert not consistent_equal(func0, func1)
 
     # new cases
     b = tvm.tir.decl_buffer((x,), "float32")
-    stmt = tvm.tir.LetStmt(
-        x, 10, tvm.tir.Evaluate(x + 1))
-    func0 = tvm.tir.PrimFunc(
-        [x, y, b], stmt)
+    stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
+    func0 = tvm.tir.PrimFunc([x, y, b], stmt)
     # easiest way to deep copy is via save/load
     func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
     tvm.ir.assert_structural_equal(func0, func1)
@@ -109,6 +106,7 @@ def test_prim_func():
     mod1 = tvm.IRModule.from_expr(func1)
     tvm.ir.assert_structural_equal(mod0, mod1)
 
+
 def test_array():
     x = np.arange(10)
     nx = tvm.nd.array(x)
@@ -117,6 +115,7 @@ def test_array():
     assert consistent_equal(nx, ny)
     assert not consistent_equal(nx, nz)
 
+
 def test_env_func():
     @tvm.register_func("test.sequal.env_func")
     def test(x):
@@ -127,7 +126,6 @@ def test_env_func():
     assert consistent_equal(y, x)
 
 
-
 def test_attrs():
     x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
     y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
@@ -143,16 +141,17 @@ def test_attrs():
 
 
 def test_stmt():
-    x = te.var('x')
-    y = te.var('y')
+    x = te.var("x")
+    y = te.var("y")
     n = 128
-    A = te.placeholder((n, n), name='A')
-    B = te.placeholder((n, n), name='B')
-    ii = te.var('i')
-    jj = te.var('j')
+    A = te.placeholder((n, n), name="A")
+    B = te.placeholder((n, n), name="B")
+    ii = te.var("i")
+    jj = te.var("j")
 
-    Ab = tvm.tir.decl_buffer((n,), name='A')
+    Ab = tvm.tir.decl_buffer((n,), name="A")
     n = te.var("n")
+
     def func2():
         ib = tvm.tir.ir_builder.create()
         A = ib.buffer_ptr(Ab)
index 1d57db6..c5163a8 100644 (file)
@@ -20,37 +20,35 @@ from tvm import te
 
 
 def lower_stmt(sche, params, passfunc):
-    func = tvm.driver.build_module.form_irmodule(
-        sche, params, "main", None)["main"]
-    func = passfunc()(
-        tvm.IRModule.from_expr(func))["main"]
+    func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"]
+    func = passfunc()(tvm.IRModule.from_expr(func))["main"]
     stmt = func.body
     return stmt
 
 
 def test_promote():
     def runpass(op, passfunc):
-        a = te.placeholder((100,), dtype='bfloat16')
-        b = te.placeholder((100,), dtype='bfloat16')
+        a = te.placeholder((100,), dtype="bfloat16")
+        b = te.placeholder((100,), dtype="bfloat16")
         c = te.compute((100,), lambda i: op(a[i], b[i]))
         s = te.create_schedule(c.op)
         return lower_stmt(s, [a, b, c], passfunc)
 
     def get_promoted(op):
-        a = te.placeholder((100,), dtype='bfloat16')
-        b = te.placeholder((100,), dtype='bfloat16')
-        c = te.compute((100,), lambda i:
-                       topi.cast(op(topi.cast(a[i], 'float'),
-                                    topi.cast(b[i], 'float')), 'bfloat16')
-                       )
+        a = te.placeholder((100,), dtype="bfloat16")
+        b = te.placeholder((100,), dtype="bfloat16")
+        c = te.compute(
+            (100,),
+            lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"),
+        )
         s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.form_irmodule(
-            s, [a, b, c], "main", None)["main"]
+        func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"]
         return func.body
 
     def test_promoted(op):
         stmt = runpass(op, tvm.tir.transform.BF16Promote)
         tvm.ir.assert_structural_equal(stmt, get_promoted(op))
+
     test_promoted(topi.add)
     test_promoted(topi.subtract)
     test_promoted(topi.multiply)
@@ -59,56 +57,63 @@ def test_promote():
 
 def test_eliminate():
     def to32(v):
-        return topi.cast(v, 'float')
+        return topi.cast(v, "float")
 
     def to16(v):
-        return topi.cast(v, 'bfloat16')
+        return topi.cast(v, "bfloat16")
 
     def get_eliminated():
-        a = te.placeholder((100,), dtype='bfloat16')
-        b = te.placeholder((100,), dtype='bfloat16')
-        c = te.compute((100,), lambda i: to16(
-            topi.add(
-                to32(
-                    to16(
-                        topi.add(
-                            to32(a[i]),
-                            to32(b[i]),
+        a = te.placeholder((100,), dtype="bfloat16")
+        b = te.placeholder((100,), dtype="bfloat16")
+        c = te.compute(
+            (100,),
+            lambda i: to16(
+                topi.add(
+                    to32(
+                        to16(
+                            topi.add(
+                                to32(a[i]),
+                                to32(b[i]),
+                            )
                         )
-                    )
-                ),
-                to32(
-                    to16(
-                        topi.add(
-                            to32(a[i]),
-                            to32(b[i]),
+                    ),
+                    to32(
+                        to16(
+                            topi.add(
+                                to32(a[i]),
+                                to32(b[i]),
+                            )
                         )
-                    )
+                    ),
                 )
-            )
-        ))
+            ),
+        )
         s = te.create_schedule(c.op)
         stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
         return stmt
 
     def get_target():
-        a = te.placeholder((100,), dtype='bfloat16')
-        b = te.placeholder((100,), dtype='bfloat16')
-        c = te.compute((100,), lambda i: to16(
-            topi.add(topi.add(
-                to32(a[i]),
-                to32(b[i]),
+        a = te.placeholder((100,), dtype="bfloat16")
+        b = te.placeholder((100,), dtype="bfloat16")
+        c = te.compute(
+            (100,),
+            lambda i: to16(
+                topi.add(
+                    topi.add(
+                        to32(a[i]),
+                        to32(b[i]),
+                    ),
+                    topi.add(
+                        to32(a[i]),
+                        to32(b[i]),
+                    ),
+                )
             ),
-                     topi.add(
-                         to32(a[i]),
-                         to32(b[i]),
-                     )
-                    )
-        ))
+        )
         s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.form_irmodule(
-            s, [a, b, c], "main", None)["main"]
+        func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"]
         return func.body
+
     tvm.ir.assert_structural_equal(get_eliminated(), get_target())
 
 
@@ -116,47 +121,52 @@ def test_legalize():
     def to32(v):
         uint32_v = topi.cast(v, "uint32")
         uint32_v = tvm.tir.call_intrin(
-            "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32"))
+            "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")
+        )
         return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v)
 
     def to16(v):
         uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v)
         rounding_bias = tvm.tir.call_intrin(
-            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
+            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
+        )
         rounding_bias = tvm.tir.call_intrin(
-            "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
+            "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")
+        )
         rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
         uint32_v = uint32_v + rounding_bias
         uint32_v = tvm.tir.call_intrin(
-            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
-        return topi.cast(uint32_v, 'uint16')
+            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
+        )
+        return topi.cast(uint32_v, "uint16")
 
     def check(fcompute_before, fcompute_after):
-        a = te.placeholder((100,), dtype='bfloat16', name='A')
-        b = te.placeholder((100,), dtype='bfloat16', name='B')
-        c = te.compute((100,), fcompute_before(a, b), name='C')
+        a = te.placeholder((100,), dtype="bfloat16", name="A")
+        b = te.placeholder((100,), dtype="bfloat16", name="B")
+        c = te.compute((100,), fcompute_before(a, b), name="C")
         s = te.create_schedule(c.op)
         stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize)
 
-        a = te.placeholder((100,), dtype='uint16', name='A')
-        b = te.placeholder((100,), dtype='uint16', name='B')
-        c = te.compute((100,), fcompute_after(a, b), name='C')
+        a = te.placeholder((100,), dtype="uint16", name="A")
+        b = te.placeholder((100,), dtype="uint16", name="B")
+        c = te.compute((100,), fcompute_after(a, b), name="C")
         s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.form_irmodule(
-            s, [a, b, c], "main", None)["main"]
+        func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"]
         tvm.ir.assert_structural_equal(stmt, func.body)
 
     def orig1(a, b):
-        return lambda i: a[i] + b[i] + a[99-i] + b[99-i]
+        return lambda i: a[i] + b[i] + a[99 - i] + b[99 - i]
 
     def after1(a, b):
-        return lambda i: to16(to32(a[i]) + to32(b[i] ) + to32(a[99 - i]) + to32(b[99 - i]))
+        return lambda i: to16(to32(a[i]) + to32(b[i]) + to32(a[99 - i]) + to32(b[99 - i]))
 
     def orig2(a, b):
         return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i]
 
     def after2(a, b):
-        return lambda i: to16(to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + to32(a[i]))
+        return lambda i: to16(
+            to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + to32(a[i])
+        )
 
     check(orig1, after1)
     check(orig2, after2)
index 2886958..191aec4 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_for():
     dev_type = te.var("dev_type")
+
     def device_context(dev_id):
         ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
-        return tvm.tir.Call(
-            "handle", "tir.tvm_thread_context", [ctx])
+        return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
 
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
     A = ib.allocate("float32", n, name="A", scope="global")
     with ib.for_range(0, n, name="i") as i:
-        ib.emit(tvm.tir.call_extern
-                ("int32", "fadd", device_context(0), A))
+        ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A))
         with ib.for_range(0, 10, name="j") as j:
-            ib.emit(tvm.tir.call_extern
-                    ("int32", "fadd", device_context(1), A))
-            ib.emit(tvm.tir.call_extern
-                    ("int32", "fadd", device_context(0), A))
+            ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A))
+            ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A))
     body = ib.get()
-    mod = tvm.IRModule({
-        "func" : tvm.tir.PrimFunc([dev_type, n], body)
-    })
+    mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)})
 
     mod = tvm.tir.transform.CombineContextCall()(mod)
 
index 8469bc9..2d45118 100644 (file)
@@ -33,7 +33,8 @@ def test_coproc_sync():
             unit_bits=8,
             max_simd_bits=32,
             max_num_bits=128,
-            head_address=tvm.tir.call_extern("handle", "global_cache"))
+            head_address=tvm.tir.call_extern("handle", "global_cache"),
+        )
 
     ib = tvm.tir.ir_builder.create()
     n = te.size_var("n")
@@ -53,11 +54,11 @@ def test_coproc_sync():
     body = stmt.body.body.body
     blist = tvm.tir.stmt_list(body)
 
-    assert(blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")))
-    assert(blist[1].value.args[3].value == 80)
-    assert(blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync")))
-    assert(blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier")))
-    assert(blist[-1].value.args[3].value == 10)
+    assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier"))
+    assert blist[1].value.args[3].value == 80
+    assert blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync"))
+    assert blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier"))
+    assert blist[-1].value.args[3].value == 10
 
 
 def test_coproc_sync2():
@@ -116,10 +117,10 @@ def test_coproc_sync3():
     slist = tvm.tir.stmt_list(slist[-1])
     pop_st = slist[0].body[0]
 
-    assert(push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push")))
-    assert(__check_list(push_st.value.args, [2,3]))
-    assert(pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop")))
-    assert(__check_list(pop_st.value.args, [2,3]))
+    assert push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push"))
+    assert __check_list(push_st.value.args, [2, 3])
+    assert pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop"))
+    assert __check_list(pop_st.value.args, [2, 3])
 
 
 if __name__ == "__main__":
index cf9ea9e..224905b 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_decorate_device():
     x = te.var("x")
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x)))
@@ -24,5 +25,6 @@ def test_decorate_device():
     stmt = tvm.tir.transform.DecorateDeviceScope()(mod)["main"].body
     assert stmt.attr_key == "device_scope"
 
+
 if __name__ == "__main__":
     test_decorate_device()
index 7c93b4e..2cbb665 100644 (file)
@@ -23,9 +23,11 @@ from tvm.testing import enabled_targets
 
 var_list = []
 
+
 def verify_structure(stmt, expected_struct):
     node_dict = {}
     struct = {}
+
     def _extract_vars(op):
         global var_list
         if isinstance(op, tvm.tir.Var):
@@ -48,18 +50,22 @@ def verify_structure(stmt, expected_struct):
 
     tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
     for key, val in node_dict.items():
-        struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
-                               else None for child in val[0])
-
-    assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \
-                                      % (expected_struct, struct)
+        struct[val[1]] = tuple(
+            node_dict[child][1] if child in node_dict else None for child in val[0]
+        )
+
+    assert struct == expected_struct, "Structure mismatch: expect %s but got %s" % (
+        expected_struct,
+        struct,
+    )
     var_list.clear()
 
+
 def test_hoist_top_for():
     ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
     data = ib.pointer("float32", name="data")
 
     with ib.for_range(0, l, "i") as i:
@@ -73,16 +79,20 @@ def test_hoist_top_for():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.For", "j"): (("tir.For", "k"),),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "j"), ("tir.For", "j")),
+        ("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_hoist_multi_var_if():
     ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
     data = ib.pointer("float32", name="data")
 
     with ib.for_range(0, l, "i") as i:
@@ -96,17 +106,20 @@ def test_hoist_multi_var_if():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,),
-                       ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
-                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),),
-                       ('tir.For', 'i'): (('tir.For', 'j'),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
+        ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),),
+        ("tir.For", "i"): (("tir.For", "j"),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_hoist_no_match_for():
     ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
     data = ib.pointer("float32", name="data")
 
     with ib.for_range(0, l, "i") as i:
@@ -121,17 +134,20 @@ def test_hoist_no_match_for():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,),
-                       ('tir.IfThenElse', ('i', )): (('tir.For', 'k'), ('tir.For', 'k')),
-                       ('tir.For', 'j'): (None,),
-                       ('tir.For', 'i'): (('tir.For', 'j'),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")),
+        ("tir.For", "j"): (None,),
+        ("tir.For", "i"): (("tir.For", "j"),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_no_else():
     ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     with ib.for_range(0, l, "i") as i:
         with ib.for_range(0, m, "j") as j:
@@ -142,18 +158,22 @@ def test_no_else():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.For", "j"): (("tir.For", "k"),),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
+        ("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_attr_stmt():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -163,24 +183,28 @@ def test_attr_stmt():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.5
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.0
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
-                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),),
-                       ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),),
-                       ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
+        ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),),
+        ("tir.For", "i"): (("tir.For", "j"),),
+        ("tir.AttrStmt", "thread_extent", 64): (("tir.For", "i"),),
+        ("tir.AttrStmt", "thread_extent", 32): (("tir.AttrStmt", "thread_extent", 64),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_nested_for():
     ib = tvm.tir.ir_builder.create()
     data = ib.pointer("float32", name="data")
 
-
     with ib.for_range(0, 5, "i") as i:
         with ib.for_range(0, 10, "j") as j:
             with ib.if_scope(i >= 3):
@@ -195,19 +219,22 @@ def test_nested_for():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'l'): (None,), ('tir.For', 'k'): (('tir.For', 'l'),),
-                       ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
-                       ('tir.For', 'j'): (None,),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    expected_struct = {
+        ("tir.For", "l"): (None,),
+        ("tir.For", "k"): (("tir.For", "l"),),
+        ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
+        ("tir.For", "j"): (None,),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
+        ("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_if_block():
     ib = tvm.tir.ir_builder.create()
     data = ib.pointer("float32", name="data")
     n = te.var("n")
 
-
     with ib.for_range(0, 5, "i") as i:
         with ib.for_range(0, 10, "j") as j:
             with ib.if_scope(i >= 3):
@@ -218,23 +245,28 @@ def test_if_block():
                             data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
                         with ib.else_scope():
                             data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
-                        with ib.if_scope(j <5):
+                        with ib.if_scope(j < 5):
                             data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1
 
-
     with ib.for_range(0, 5, "i") as i:
         with ib.for_range(0, 10, "j") as j:
-                with ib.for_range(0, 15, "k") as k:
-                    with ib.if_scope(n >= 3):
-                        data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6
+            with ib.for_range(0, 15, "k") as k:
+                with ib.if_scope(n >= 3):
+                    data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None),
-                       ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),),
-                       ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
+    expected_struct = {
+        ("tir.IfThenElse", ("i", "j")): (None, None),
+        ("tir.IfThenElse", ("j",)): (None, None),
+        ("tir.For", "l"): (None,),
+        ("tir.For", "k"): (None,),
+        ("tir.For", "j"): (("tir.For", "j"),),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
+        ("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
+        ("tir.IfThenElse", ("n",)): (("tir.For", "j"), None),
+    }
     verify_structure(new_stmt, expected_struct)
 
 
@@ -252,13 +284,16 @@ def test_multi_if():
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    expected_struct = {('tir.For', 'k'): (None,),
-                       ('tir.IfThenElse', ('j',)): (('tir.For', 'k'), None),
-                       ('tir.For', 'j'): (('tir.IfThenElse', ('j',)),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
+    expected_struct = {
+        ("tir.For", "k"): (None,),
+        ("tir.IfThenElse", ("j",)): (("tir.For", "k"), None),
+        ("tir.For", "j"): (("tir.IfThenElse", ("j",)),),
+        ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None),
+        ("tir.For", "i"): (("tir.IfThenElse", ("i",)),),
+    }
     verify_structure(new_stmt, expected_struct)
 
+
 def test_no_hoisting_1():
     ib = tvm.tir.ir_builder.create()
     data = ib.pointer("float32", name="data")
@@ -275,12 +310,13 @@ def test_no_hoisting_1():
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_2():
     ib = tvm.tir.ir_builder.create()
     data = ib.pointer("float32", name="data")
@@ -299,20 +335,21 @@ def test_no_hoisting_2():
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_3():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     dshape_inner = (33, 63)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -324,29 +361,30 @@ def test_no_hoisting_3():
                 ib.scope_attr(tx, "thread_extent", dshape_inner[0])
                 ib.scope_attr(bx, "thread_extent", dshape_inner[1])
                 with ib.if_scope(tx < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_4():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     dshape_inner = (33, 63)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -356,29 +394,30 @@ def test_no_hoisting_4():
             with ib.for_range(0, n, "k") as k:
                 ib.scope_attr(tx, "thread_extent", dshape_inner[0])
                 with ib.if_scope(tx < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_5():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     dshape_inner = (33, 63)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -390,28 +429,29 @@ def test_no_hoisting_5():
             with ib.for_range(0, n, "k") as k:
                 ib.scope_attr(tx, "thread_extent", dshape_inner[0])
                 with ib.if_scope(tx < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_6():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -421,28 +461,29 @@ def test_no_hoisting_6():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope((tx + k) < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_no_hoisting_7():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -453,23 +494,24 @@ def test_no_hoisting_7():
             with ib.if_scope((tx + j) < 9):
                 with ib.for_range(0, n, "k") as k:
                     with ib.if_scope((tx + k) < 3):
-                        data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                        data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_hoisting_block_scope_1():
     n = te.size_var("n")
     m = te.size_var("m")
-    A = te.placeholder((n, m), name='A')
+    A = te.placeholder((n, m), name="A")
     k = te.reduce_axis((0, m), "k")
     B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
     s = te.create_schedule(B.op)
@@ -480,60 +522,61 @@ def test_hoisting_block_scope_1():
     s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
     s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
     s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
-    func = tvm.driver.build_module.form_irmodule(
-            s, [A, B], "main", None)["main"]
+    func = tvm.driver.build_module.form_irmodule(s, [A, B], "main", None)["main"]
     stmt = func.body
     new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 def test_hoisting_block_scope_2():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     dshape_inner = (33, 63)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", dshape[0])
-    #ib.scope_attr(bx, "thread_extent", dshape[1])
+    # ib.scope_attr(bx, "thread_extent", dshape[1])
     with ib.for_range(0, l, "i") as i:
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 ib.scope_attr(bx, "thread_extent", dshape[1])
                 with ib.if_scope(tx < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    #tvm.ir.assert_structural_equal(new_stmt, stmt)
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    # tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 def test_hoisting_block_scope_3():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     dshape_inner = (33, 63)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -545,31 +588,32 @@ def test_hoisting_block_scope_3():
             ib.scope_attr(bx, "thread_extent", dshape_inner[1])
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope(tx < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    #tvm.ir.assert_structural_equal(new_stmt, stmt)
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    # tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 def test_hoisting_block_scope_4():
     nn = 1024
     n = tvm.runtime.convert(nn)
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    AA = te.compute((n,), lambda *i: A(*i), name='A')
-    BB = te.compute((n,), lambda *i: B(*i), name='B')
-    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
-    C = te.compute(A.shape, lambda *i: T(*i), name='C')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    AA = te.compute((n,), lambda *i: A(*i), name="A")
+    BB = te.compute((n,), lambda *i: B(*i), name="B")
+    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name="T")
+    C = te.compute(A.shape, lambda *i: T(*i), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], factor=4)
     xo1, xo2 = s[C].split(xo, factor=13)
@@ -578,56 +622,57 @@ def test_hoisting_block_scope_4():
     s[C].pragma(xo2, "parallel_stride_pattern")
     s[C].pragma(xo2, "parallel_barrier_when_finish")
     s[C].vectorize(xi)
-    func = tvm.driver.build_module.form_irmodule(
-            s, [A, B, C], "main", None)["main"]
+    func = tvm.driver.build_module.form_irmodule(s, [A, B, C], "main", None)["main"]
     stmt = func.body
     new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 def test_hoisting_block_scope_5():
     ib = tvm.tir.ir_builder.create()
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
-    g = te.var('g')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
+    g = te.var("g")
 
     ib.scope_attr(data, "storage_scope", "global")
     with ib.for_range(0, l, "i") as i:
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope(data[g] < 3):
-                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 0.3
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3
                 with ib.else_scope():
-                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 1.3
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
 
     stmt = new_stmt
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
+
 def test_hoisting_block_scope_6():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -637,28 +682,29 @@ def test_hoisting_block_scope_6():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope((tx + n) < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 def test_hoisting_block_scope_7():
     ib = tvm.tir.ir_builder.create()
     dshape = (32, 64)
     data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
+    l = te.var("l")
+    m = te.var("m")
+    n = te.var("n")
 
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
@@ -668,79 +714,85 @@ def test_hoisting_block_scope_7():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope((tx + i) < 3):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3
                 with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.3
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
     new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
-    with tvm.transform.PassContext(config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+    ):
         new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
-    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+    assert not tvm.ir.structural_equal(new_stmt, stmt)
+
 
 @pytest.mark.skip()
 def test_hoisting_op_conv():
     dtype = "float32"
     dshape = (1, 80, 73, 73)
     kshape = (192, 80, 3, 3)
-    padding=(1, 1)
-    groups=1
-    dilation=(1, 1)
-    kernel_size=(3, 3)
-    channels=192
-    scale=1
+    padding = (1, 1)
+    groups = 1
+    dilation = (1, 1)
+    kernel_size = (3, 3)
+    channels = 192
+    scale = 1
     x = relay.var("x", shape=dshape, dtype=dtype)
     w = relay.var("w", shape=kshape, dtype=dtype)
-    y = relay.nn.conv2d(x, w, padding=padding,
-                                dilation=dilation,
-                                groups=groups,
-                                channels=channels,
-                                kernel_size=kernel_size)
+    y = relay.nn.conv2d(
+        x,
+        w,
+        padding=padding,
+        dilation=dilation,
+        groups=groups,
+        channels=channels,
+        kernel_size=kernel_size,
+    )
 
     func = relay.Function([x, w], y)
     mod = tvm.IRModule()
-    mod['main'] = func
+    mod["main"] = func
     mod = relay.transform.InferType()(mod)
 
     data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
     kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
 
-    params = {'w': tvm.nd.array(kernel)}
+    params = {"w": tvm.nd.array(kernel)}
     for target, ctx in enabled_targets():
         with tvm.transform.PassContext(opt_level=3):
             graph, lib, params = relay.build_module.build(mod, target=target, params=params)
             m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
             x = np.random.uniform(size=dshape)
             data_tvm = tvm.nd.array(data)
-            m.set_input('x', data_tvm)
+            m.set_input("x", data_tvm)
             m.set_input(**params)
             m.run()
             e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
             t1 = e(data_tvm).results
             t1 = np.array(t1) * 1000
-            print('{} ms'.format(t1.mean()))
+            print("{} ms".format(t1.mean()))
 
-        with tvm.transform.PassContext(opt_level=3, config={
-        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
-        }):
+        with tvm.transform.PassContext(
+            opt_level=3, config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
+        ):
             graph, lib, params = relay.build_module.build(mod, target=target, params=params)
             m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
             x = np.random.uniform(size=dshape)
             data_tvm = tvm.nd.array(data)
-            m.set_input('x', data_tvm)
+            m.set_input("x", data_tvm)
             m.set_input(**params)
             m.run()
             e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
             t2 = e(data_tvm).results
             t2 = np.array(t2) * 1000
 
-            print('{} ms'.format(t2.mean()))
+            print("{} ms".format(t2.mean()))
         tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)
 
+
 if __name__ == "__main__":
     pytest.main([__file__])
index 887b8b0..49a643a 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_copy2d():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.compute((m, l), lambda i, j: A[i, j], name='B')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.compute((m, l), lambda i, j: A[i, j], name="B")
     s = te.create_schedule(B.op)
     s[B].pragma(B.op.axis[0], "memcpy")
     bounds = tvm.te.schedule.InferBound(s)
@@ -41,12 +42,14 @@ def test_copy2d():
 
 
 def test_copy_pad():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.compute((m + 2, l), lambda i, j:
-                    tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1),
-                                     A[i - 1, j], 1.0), name='B')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.compute(
+        (m + 2, l),
+        lambda i, j: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1), A[i - 1, j], 1.0),
+        name="B",
+    )
     s = te.create_schedule(B.op)
     s[B].pragma(B.op.axis[0], "memcpy")
     bounds = tvm.te.schedule.InferBound(s)
@@ -69,9 +72,8 @@ def test_copy_pad():
 
 
 def test_single_point_test():
-    A = te.placeholder((1,), name='A')
-    B = te.compute((1,), lambda i:
-                    A[i], name='B')
+    A = te.placeholder((1,), name="A")
+    B = te.compute((1,), lambda i: A[i], name="B")
     s = te.create_schedule(B.op)
     s[B].pragma(B.op.axis[0], "memcpy")
     bounds = tvm.te.schedule.InferBound(s)
@@ -93,10 +95,10 @@ def test_single_point_test():
 
 def test_copy_pad_split():
     m = 4 * 3
-    A = te.placeholder((m, ), name="A")
-    Apad = te.compute((m + 2,), lambda i:
-                       tvm.tir.if_then_else(tvm.tir.all(i >= 1, i <= m),
-                                        A[i - 1], 0.0), "Apad")
+    A = te.placeholder((m,), name="A")
+    Apad = te.compute(
+        (m + 2,), lambda i: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad"
+    )
     B = te.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
     s = te.create_schedule(B.op)
     xo, xi = s[B].split(B.op.axis[0], factor=4)
@@ -111,7 +113,7 @@ def test_copy_pad_split():
     mod = tvm.tir.transform.Simplify()(mod._move())
 
     def cb(src, dst, pad_before, pad_after, pad_value):
-        assert(dst.elem_offset.value == 0)
+        assert dst.elem_offset.value == 0
         tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
 
         rpad_before = tvm.te.max(1 - xo * 4, 0)
@@ -124,7 +126,6 @@ def test_copy_pad_split():
     stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
 
 
-
 if __name__ == "__main__":
     test_copy2d()
     test_copy_pad()
index cf58632..ceb32c4 100644 (file)
@@ -17,8 +17,9 @@
 import tvm
 from tvm import te
 
+
 def test_double_buffer():
-    dtype = 'int64'
+    dtype = "int64"
     n = 100
     m = 4
     tx = te.thread_axis("threadIdx.x")
@@ -36,17 +37,13 @@ def test_double_buffer():
             C[j] = B[j] + 1
 
     stmt = ib.get()
-    mod = tvm.IRModule({
-        "db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
-    })
+    mod = tvm.IRModule({"db": tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)})
 
     opt = tvm.transform.Sequential(
-        [tvm.tir.transform.InjectDoubleBuffer(),
-         tvm.tir.transform.Simplify()])
+        [tvm.tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify()]
+    )
 
-    with tvm.transform.PassContext(config={
-        "tir.InjectDoubleBuffer" : {"split_loop" : 2}
-    }):
+    with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}):
         mod = opt(mod)
     stmt = mod["db"].body
 
@@ -55,9 +52,11 @@ def test_double_buffer():
 
     f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
     count = [0]
+
     def count_sync(op):
         if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")):
             count[0] += 1
+
     tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
     assert count[0] == 4
 
index be725d6..3e7a5a0 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_vthread():
-    dtype = 'int64'
+    dtype = "int64"
     n = 100
     m = 4
     nthread = 2
+
     def get_vthread(name):
         tx = te.thread_axis(name)
         ty = te.thread_axis(name)
@@ -34,28 +36,36 @@ def test_vthread():
             B = ib.allocate("float32", m, name="B", scope="shared")
             B[i] = A[i * nthread + tx]
             bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
-            ib.emit(tvm.tir.call_extern("int32", "Run",
-                                    bbuffer.access_ptr("r"),
-                                    tvm.tir.call_intrin("int32", "tir.tvm_context_id")))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    "Run",
+                    bbuffer.access_ptr("r"),
+                    tvm.tir.call_intrin("int32", "tir.tvm_context_id"),
+                )
+            )
             C[i * nthread + tx] = B[i] + 1
         return ib.get()
 
-    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body
+    stmt = tvm.tir.transform.InjectVirtualThread()(
+        tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread")))
+    )["main"].body
 
     assert stmt.body.body.extents[0].value == 2
 
-    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+    stmt = tvm.tir.transform.InjectVirtualThread()(
+        tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread")))
+    )["main"].body
 
     assert len(stmt.body.body.extents) == 3
 
 
 def test_vthread_extern():
-    dtype = 'int64'
+    dtype = "int64"
     n = 100
     m = 4
     nthread = 2
+
     def get_vthread(name):
         tx = te.thread_axis(name)
         ty = te.thread_axis(name)
@@ -71,15 +81,20 @@ def test_vthread_extern():
             bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
             A[tx] = tx + 1.0
             B[ty] = ty + 1.0
-            ib.emit(tvm.tir.call_extern("int32", "Run",
-                                        abuffer.access_ptr("r"),
-                                        bbuffer.access_ptr("r"),
-                                        cbuffer.access_ptr("rw")))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    "Run",
+                    abuffer.access_ptr("r"),
+                    bbuffer.access_ptr("r"),
+                    cbuffer.access_ptr("rw"),
+                )
+            )
         return ib.get()
 
-
-    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+    stmt = tvm.tir.transform.InjectVirtualThread()(
+        tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread")))
+    )["main"].body
 
     assert stmt.body.body.extents[0].value == 2
     assert stmt.body.body.body.body.body.body.extents[0].value == 2
@@ -102,12 +117,14 @@ def test_vthread_if_then_else():
             B[i] = A[i * nthread + tx] + 2
     stmt = ib.get()
 
-    stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([], stmt)))["main"].body
+    stmt = tvm.tir.transform.InjectVirtualThread()(
+        tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    )["main"].body
 
     assert stmt.body.body.body[0].else_case != None
     assert stmt.body.body.body[1].else_case == None
 
+
 if __name__ == "__main__":
     test_vthread_extern()
     test_vthread()
index bb35f32..bb9e5ee 100644 (file)
@@ -19,6 +19,7 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 def collect_visit(stmt, f):
     ret = []
     tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
@@ -29,52 +30,54 @@ def collect_visit(stmt, f):
 @pytest.mark.xfail
 def test_out_of_bounds_llvm(index_a, index_b):
     n = te.size_var("n")
-    A = te.placeholder ((n,), name='A')
-    B = te.placeholder ((n,), name='B')
-    C = te.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name='C')
-    s = te.create_schedule (C.op)
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name="C")
+    s = te.create_schedule(C.op)
     tgt = "llvm"
     tgt_host = "llvm"
-    stmt = tvm.lower (s, [A, B, C], simple_mode=True)
-    print (stmt)
-    fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+    stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+    print(stmt)
+    fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
     ctx = tvm.context(tgt, 0)
     a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx)
     c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx)
-    fadd (a, b, c)
+    fadd(a, b, c)
+
 
 @tvm.testing.requires_llvm
 def test_in_bounds_llvm():
     n = te.size_var("n")
-    A = te.placeholder ((n,), name='A')
-    B = te.placeholder ((n,), name='B')
-    C = te.compute(A.shape, lambda i: A[i] + B[i], name='C')
-    s = te.create_schedule (C.op)
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+    s = te.create_schedule(C.op)
     tgt = "llvm"
     tgt_host = "llvm"
-    stmt = tvm.lower (s, [A, B, C], simple_mode=True)
-    fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+    stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+    fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
     ctx = tvm.context(tgt, 0)
     a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx)
     c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx)
-    fadd (a, b, c)
+    fadd(a, b, c)
+
 
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
     n = tvm.runtime.convert(nn)
-    a = te.placeholder((n), name='a')
-    b = te.placeholder((n), name='b')
-    c = te.compute((n,), lambda i: a[i + index_a] + b[i + index_b], name='c')
+    a = te.placeholder((n), name="a")
+    b = te.placeholder((n), name="b")
+    c = te.compute((n,), lambda i: a[i + index_a] + b[i + index_b], name="c")
     s = te.create_schedule(c.op)
     xo, xi = s[c].split(c.op.axis[0], factor=8)
     s[c].parallel(xo)
     s[c].vectorize(xi)
     tgt = "llvm"
     tgt_host = "llvm"
-    stmt = tvm.lower (s, [a, b, c], simple_mode=True)
+    stmt = tvm.lower(s, [a, b, c], simple_mode=True)
     f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
     ctx = tvm.cpu(0)
     n = nn
@@ -83,13 +86,14 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
     c = tvm.nd.array(np.zeros(n, dtype=c.dtype), ctx)
     f(a, b, c)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_vectorize_llvm():
     n = 512
     lanes = 2
-    A = te.placeholder((n,), name='A', dtype="float32x%d" % lanes)
-    B = te.compute((n,), lambda i: A[i], name='B')
-    C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name='C')
+    A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes)
+    B = te.compute((n,), lambda i: A[i], name="B")
+    C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C")
     s = te.create_schedule(C.op)
     xo, xi = s[C].split(C.op.axis[0], nparts=2)
     _, xi = s[C].split(xi, factor=2)
@@ -99,26 +103,26 @@ def test_in_bounds_vectorize_llvm():
     xo, xi = s[B].split(B.op.axis[0], factor=2)
     s[B].vectorize(xi)
     # build and invoke the kernel.
-    lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, C], "llvm", simple_mode=False)
     f = tvm.build(s, [A, C], "llvm")
     ctx = tvm.cpu(0)
     # launch the kernel.
-    a = tvm.nd.empty((n,), A.dtype).copyfrom(
-        np.random.uniform(size=(n, lanes)))
+    a = tvm.nd.empty((n,), A.dtype).copyfrom(np.random.uniform(size=(n, lanes)))
     c = tvm.nd.empty((n,), C.dtype, ctx)
     f(a, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_loop_partition_basic_llvm():
-    n = te.size_var('n')
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i]+B[i])
+    T = te.compute((n,), lambda i: A[i] + B[i])
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -127,17 +131,18 @@ def test_in_bounds_loop_partition_basic_llvm():
     t = tvm.nd.empty((32,), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
-    n = te.size_var('n')
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i + index_a]+B[i + index_b])
+    T = te.compute((n,), lambda i: A[i + index_a] + B[i + index_b])
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -146,13 +151,18 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
     t = tvm.nd.empty((32,), T.dtype, ctx)
     f(a, b, t)
 
+
 def test_in_bounds_const_loop_partition_ir():
-    def check_attr_stmt (x):
-        if isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
+    def check_attr_stmt(x):
+        if (
+            isinstance(x, tvm.tir.AttrStmt)
+            and x.attr_key == "buffer_bound"
+            and str(x.value) == str(n)
+        ):
             return True
         return False
 
-    def check_branch_stmt (x):
+    def check_branch_stmt(x):
         if isinstance(x, tvm.tir.IfThenElse):
             return True
         return False
@@ -161,25 +171,27 @@ def test_in_bounds_const_loop_partition_ir():
         count = 0
         for i in collect_visit(stmt, f):
             if i is True:
-              count = count + 1
-        assert (count == nums)
+                count = count + 1
+        assert count == nums
 
-    def collect_branch_stmt (x):
+    def collect_branch_stmt(x):
         if isinstance(x, tvm.tir.IfThenElse):
             branch_collector.append(x)
 
     n = 21
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i]+B[i])
+    T = te.compute((n,), lambda i: A[i] + B[i])
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
 
-    with tvm.transform.PassContext(config={
-        "tir.instrument_bound_checkers": True,
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(
+        config={
+            "tir.instrument_bound_checkers": True,
+            "tir.LoopPartition": {"partition_const_loop": True},
+        }
+    ):
         mod = tvm.driver.lower(s, [A, B, T], name="main")
 
     stmt = mod["main"].body
@@ -189,23 +201,25 @@ def test_in_bounds_const_loop_partition_ir():
 
     branch_collector = list()
     collect_visit(stmt, collect_branch_stmt)
-    assert(len(branch_collector) ==  2)
+    assert len(branch_collector) == 2
 
 
 @tvm.testing.requires_llvm
 def test_in_bounds_const_loop_partition_llvm():
-    with tvm.transform.PassContext(config={
-        "tir.instrument_bound_checkers": True,
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(
+        config={
+            "tir.instrument_bound_checkers": True,
+            "tir.LoopPartition": {"partition_const_loop": True},
+        }
+    ):
         n = 21
-        A = te.placeholder((n, ), name='A')
-        B = te.placeholder((n, ), name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.placeholder((n,), name="B")
 
-        T = te.compute((n, ), lambda i: A[i]+B[i])
+        T = te.compute((n,), lambda i: A[i] + B[i])
         s = te.create_schedule(T.op)
         xo, xi = s[T].split(T.op.axis[0], factor=4)
-        lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+        lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
         ctx = tvm.cpu(0)
 
         f = tvm.build(s, [A, B, T], "llvm")
@@ -214,21 +228,24 @@ def test_in_bounds_const_loop_partition_llvm():
         t = tvm.nd.empty((n,), T.dtype, ctx)
         f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
-    with tvm.transform.PassContext(config={
-        "tir.instrument_bound_checkers": True,
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(
+        config={
+            "tir.instrument_bound_checkers": True,
+            "tir.LoopPartition": {"partition_const_loop": True},
+        }
+    ):
         n = 21
-        A = te.placeholder((n, ), name='A')
-        B = te.placeholder((n, ), name='B')
+        A = te.placeholder((n,), name="A")
+        B = te.placeholder((n,), name="B")
 
-        T = te.compute((n, ), lambda i: A[i + index_a]+B[i + index_b])
+        T = te.compute((n,), lambda i: A[i + index_a] + B[i + index_b])
         s = te.create_schedule(T.op)
         xo, xi = s[T].split(T.op.axis[0], factor=4)
-        lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+        lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
         ctx = tvm.cpu(0)
 
         f = tvm.build(s, [A, B, T], "llvm")
@@ -237,6 +254,7 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
         t = tvm.nd.empty((n,), T.dtype, ctx)
         f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_conv_llvm(loop_tiling=False):
     HSTR = WSTR = 1
@@ -246,33 +264,40 @@ def test_in_bounds_conv_llvm(loop_tiling=False):
     batch_size = 1
     in_height = in_width = 64
     out_height = out_width = in_height - kernel_height + 1
-    data = te.placeholder((batch_size, in_channel, in_height, in_width), name='data')
-    kernel = te.placeholder((kernel_height, kernel_width, in_channel,
-        out_channel), name='kernel')
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
-    conv = te.compute((batch_size, out_channel, out_height, out_width),
-                       lambda n, oc, oh, ow: te.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
-                                                     kernel[kh, kw, ic, oc],
-                                                     axis=[ic, kh, kw]),
-                       name="conv2d")
+    data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data")
+    kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel")
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
+    conv = te.compute(
+        (batch_size, out_channel, out_height, out_width),
+        lambda n, oc, oh, ow: te.sum(
+            data[n, ic, oh * HSTR + kh, ow * WSTR + kw] * kernel[kh, kw, ic, oc], axis=[ic, kh, kw]
+        ),
+        name="conv2d",
+    )
     s = te.create_schedule(conv.op)
 
     n, oc, oh, ow = conv.op.axis
     if loop_tiling:
         oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
     lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
-    ctx = tvm.cpu (0)
+    ctx = tvm.cpu(0)
 
     f = tvm.build(s, [data, kernel, conv], "llvm")
-    data_input = tvm.nd.array(np.random.uniform(
-          size=(batch_size, in_channel, in_height, in_width)).astype("float32"), ctx)
-    kernel_input = tvm.nd.array(np.random.uniform(
-          size=(kernel_height, kernel_width, in_channel, out_channel)).astype("float32"), ctx)
-    conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), "float32", ctx)
+    data_input = tvm.nd.array(
+        np.random.uniform(size=(batch_size, in_channel, in_height, in_width)).astype("float32"), ctx
+    )
+    kernel_input = tvm.nd.array(
+        np.random.uniform(size=(kernel_height, kernel_width, in_channel, out_channel)).astype(
+            "float32"
+        ),
+        ctx,
+    )
+    conv_out = tvm.nd.empty((batch_size, out_channel, out_height, out_width), "float32", ctx)
     f(data_input, kernel_input, conv_out)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False):
@@ -283,71 +308,84 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False
     batch_size = 1
     in_height = in_width = 64
     out_height = out_width = in_height - kernel_height + 1
-    data = te.placeholder((batch_size, in_channel, in_height, in_width), name='data')
-    kernel = te.placeholder((kernel_height, kernel_width, in_channel,
-        out_channel), name='kernel')
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
-    conv = te.compute((batch_size, out_channel, out_height, out_width),
-                       lambda n, oc, oh, ow: te.sum(data[n + data_offsets[0],
-                                                          ic + data_offsets[1],
-                                                          oh*HSTR + kh + data_offsets[2],
-                                                          ow*WSTR + kw + data_offsets[3]]
-                                                          *
-                                                     kernel[kh + kernel_offsets[0],
-                                                     kw + kernel_offsets[1],
-                                                     ic + kernel_offsets[2],
-                                                     oc + kernel_offsets[3]],
-                                                     axis=[ic, kh, kw]),
-                       name="conv2d")
+    data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data")
+    kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel")
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
+    conv = te.compute(
+        (batch_size, out_channel, out_height, out_width),
+        lambda n, oc, oh, ow: te.sum(
+            data[
+                n + data_offsets[0],
+                ic + data_offsets[1],
+                oh * HSTR + kh + data_offsets[2],
+                ow * WSTR + kw + data_offsets[3],
+            ]
+            * kernel[
+                kh + kernel_offsets[0],
+                kw + kernel_offsets[1],
+                ic + kernel_offsets[2],
+                oc + kernel_offsets[3],
+            ],
+            axis=[ic, kh, kw],
+        ),
+        name="conv2d",
+    )
     s = te.create_schedule(conv.op)
 
     n, oc, oh, ow = conv.op.axis
     if loop_tiling:
         oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
     lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
-    ctx = tvm.cpu (0)
+    ctx = tvm.cpu(0)
 
     f = tvm.build(s, [data, kernel, conv], "llvm")
-    data_input = tvm.nd.array(np.random.uniform(
-          size=(batch_size, in_channel, in_height, in_width)).astype("float32"), ctx)
-    kernel_input = tvm.nd.array(np.random.uniform(
-          size=(kernel_height, kernel_width, in_channel, out_channel)).astype("float32"), ctx)
-    conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), "float32", ctx)
+    data_input = tvm.nd.array(
+        np.random.uniform(size=(batch_size, in_channel, in_height, in_width)).astype("float32"), ctx
+    )
+    kernel_input = tvm.nd.array(
+        np.random.uniform(size=(kernel_height, kernel_width, in_channel, out_channel)).astype(
+            "float32"
+        ),
+        ctx,
+    )
+    conv_out = tvm.nd.empty((batch_size, out_channel, out_height, out_width), "float32", ctx)
     f(data_input, kernel_input, conv_out)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_tensors_with_same_shapes1D_llvm():
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((k, ), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((k,), name="B")
 
-    T = te.compute((m, ), lambda i: A[i]*B[i])
+    T = te.compute((m,), lambda i: A[i] * B[i])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
-    a = tvm.nd.array(np.random.uniform(size=(32, )).astype(A.dtype), ctx)
+    a = tvm.nd.array(np.random.uniform(size=(32,)).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx)
     t = tvm.nd.empty((32,), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape):
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((k, ), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((k,), name="B")
 
-    T = te.compute((m, ), lambda i: A[i]*B[i])
+    T = te.compute((m,), lambda i: A[i] * B[i])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -356,17 +394,18 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape
     t = tvm.nd.empty((c_shape,), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_tensors_with_same_shapes2D_llvm():
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, n), name='A')
-    B = te.placeholder((k, k), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n, n), name="A")
+    B = te.placeholder((k, k), name="B")
 
-    T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
+    T = te.compute((m, m), lambda i, j: A[i][j] * B[i][j])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -375,95 +414,105 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm():
     t = tvm.nd.empty((32, 32), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape):
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, n), name='A')
-    B = te.placeholder((k, k), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n, n), name="A")
+    B = te.placeholder((k, k), name="B")
 
-    T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
+    T = te.compute((m, m), lambda i, j: A[i][j] * B[i][j])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
-    a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1])).astype(A.dtype), ctx)
-    b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1])).astype(B.dtype), ctx)
-    t = tvm.nd.empty((c_shape[0],c_shape[1]), T.dtype, ctx)
+    a = tvm.nd.array(np.random.uniform(size=(a_shape[0], a_shape[1])).astype(A.dtype), ctx)
+    b = tvm.nd.array(np.random.uniform(size=(b_shape[0], b_shape[1])).astype(B.dtype), ctx)
+    t = tvm.nd.empty((c_shape[0], c_shape[1]), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 def test_in_bounds_tensors_with_same_shapes3D_llvm():
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, n, n), name='A')
-    B = te.placeholder((k, k, k), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n, n, n), name="A")
+    B = te.placeholder((k, k, k), name="B")
 
-    T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
+    T = te.compute((m, m, m), lambda i, j, p: A[i][j][p] * B[i][j][p])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
 
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
-    a = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(A.dtype), ctx)
-    b = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(B.dtype), ctx)
+    a = tvm.nd.array(np.random.uniform(size=(32, 32, 32)).astype(A.dtype), ctx)
+    b = tvm.nd.array(np.random.uniform(size=(32, 32, 32)).astype(B.dtype), ctx)
     t = tvm.nd.empty((32, 32, 32), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape):
-    n = te.size_var('n')
-    k = te.size_var('k')
-    m = te.size_var('m')
-    A = te.placeholder((n, n, n), name='A')
-    B = te.placeholder((k, k, k), name='B')
+    n = te.size_var("n")
+    k = te.size_var("k")
+    m = te.size_var("m")
+    A = te.placeholder((n, n, n), name="A")
+    B = te.placeholder((k, k, k), name="B")
 
-    T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
+    T = te.compute((m, m, m), lambda i, j, p: A[i][j][p] * B[i][j][p])
     s = te.create_schedule(T.op)
-    lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
+    lowered_func = tvm.lower(s, [A, B, T], "llvm", simple_mode=False)
 
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
-    a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1], c_shape[2])).astype(A.dtype), ctx)
-    b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1], b_shape[2])).astype(B.dtype), ctx)
-    t = tvm.nd.empty((c_shape[0],c_shape[1],c_shape[2]), T.dtype, ctx)
+    a = tvm.nd.array(
+        np.random.uniform(size=(a_shape[0], a_shape[1], c_shape[2])).astype(A.dtype), ctx
+    )
+    b = tvm.nd.array(
+        np.random.uniform(size=(b_shape[0], b_shape[1], b_shape[2])).astype(B.dtype), ctx
+    )
+    t = tvm.nd.empty((c_shape[0], c_shape[1], c_shape[2]), T.dtype, ctx)
     f(a, b, t)
 
+
 @tvm.testing.requires_llvm
 @pytest.mark.xfail
 def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
     n = 64
-    A = te.placeholder((n, ), name='A')
-    scale = te.placeholder((), name='scale')
+    A = te.placeholder((n,), name="A")
+    scale = te.placeholder((), name="scale")
     k = te.reduce_axis((0, n), name="k")
-    C = te.compute((), lambda : te.sum(A[k + k + k] * scale, axis=k), name="C")
-    D = te.compute((), lambda : C + 1)
+    C = te.compute((), lambda: te.sum(A[k + k + k] * scale, axis=k), name="C")
+    D = te.compute((), lambda: C + 1)
     s = te.create_schedule(D.op)
-    stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
+    stmt = tvm.lower(s, [A, scale, D], simple_mode=True)
 
     # build and invoke the kernel.
     f = tvm.build(s, [A, scale, D], "llvm")
     ctx = tvm.cpu(0)
     # launch the kernel.
     a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
-    sc = tvm.nd.array(
-        np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
+    sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
     d = tvm.nd.empty((), D.dtype, ctx)
     f(a, sc, d)
     d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
     tvm.testing.assert_allclose(d.asnumpy(), d_np)
 
+
 if __name__ == "__main__":
-    with tvm.transform.PassContext(config={
-        "tir.instrument_bound_checkers": True,
-    }):
+    with tvm.transform.PassContext(
+        config={
+            "tir.instrument_bound_checkers": True,
+        }
+    ):
         # zero scale
         test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm()
         # in bound
@@ -537,8 +586,8 @@ if __name__ == "__main__":
         test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, -1, 0], True)
         test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, -1], True)
         # tensors with diff shapes basic operation such as mul
-        test_out_of_bounds_tensors_with_diff_shapes1D_llvm (32, 64, 64)
-        test_out_of_bounds_tensors_with_diff_shapes1D_llvm (64, 32, 64)
+        test_out_of_bounds_tensors_with_diff_shapes1D_llvm(32, 64, 64)
+        test_out_of_bounds_tensors_with_diff_shapes1D_llvm(64, 32, 64)
         test_out_of_bounds_tensors_with_diff_shapes2D_llvm([64, 64], [32, 32], [64, 64])
         test_out_of_bounds_tensors_with_diff_shapes2D_llvm([32, 32], [64, 64], [64, 64])
         test_out_of_bounds_tensors_with_diff_shapes3D_llvm([64, 64, 64], [32, 32, 32], [64, 64, 64])
index f5f4030..12ad16d 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_coproc_lift():
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
@@ -62,5 +63,6 @@ def test_coproc_lift():
     assert body.body.body.body[1].node == cp
     assert len(body.body.body.body) == 2
 
+
 if __name__ == "__main__":
     test_coproc_lift()
index 73642e0..f9beac7 100644 (file)
@@ -18,18 +18,19 @@ import tvm
 from tvm import te
 import numpy
 
+
 def collect_visit(stmt, f):
     ret = []
-    tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x)))
+    tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
     return ret
 
 
 def test_basic():
-    n = te.size_var('n')
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i]+B[i])
+    T = te.compute((n,), lambda i: A[i] + B[i])
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
 
@@ -40,18 +41,16 @@ def test_basic():
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(
-        collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
-    assert(any(
-        collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
+    assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_const_loop():
     n = 21
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i]+B[i])
+    T = te.compute((n,), lambda i: A[i] + B[i])
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
 
@@ -59,22 +58,21 @@ def test_const_loop():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
+
 
 def test_multi_loop():
     ib = tvm.tir.ir_builder.create()
-    m = te.size_var('m')
-    n = te.size_var('n')
+    m = te.size_var("m")
+    n = te.size_var("n")
     with ib.for_range(0, 4, "i") as i:
         with ib.for_range(0, n, "j") as j:
             with ib.for_range(0, m, "k") as k:
-                with ib.if_scope(ib.likely(i*m+j+k < n)):
+                with ib.if_scope(ib.likely(i * m + j + k < n)):
                     ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
                     ib.emit(tvm.tir.Evaluate(n))
@@ -84,20 +82,21 @@ def test_multi_loop():
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
+
 
 def test_multi_if():
     ib = tvm.tir.ir_builder.create()
-    m = te.size_var('m')
-    n = te.size_var('n')
-    with ib.for_range(0, 4, 'i') as i:
-        with ib.for_range(0, n, 'j') as j:
-            with ib.for_range(0, m, 'k') as k:
-                with ib.if_scope(ib.likely(i*m+j+k < n)):
+    m = te.size_var("m")
+    n = te.size_var("n")
+    with ib.for_range(0, 4, "i") as i:
+        with ib.for_range(0, n, "j") as j:
+            with ib.for_range(0, m, "k") as k:
+                with ib.if_scope(ib.likely(i * m + j + k < n)):
                     ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
                     ib.emit(tvm.tir.Evaluate(n))
-                with ib.if_scope(ib.likely(i*m+j-k < n)):
+                with ib.if_scope(ib.likely(i * m + j - k < n)):
                     ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
                     ib.emit(tvm.tir.Evaluate(n))
@@ -107,15 +106,14 @@ def test_multi_if():
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(
-        collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_thread_axis():
-    m = te.size_var('m')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
-    B = te.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
+    m = te.size_var("m")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
+    B = te.compute((m, l), lambda i, j: A[i, j] + 3, name="B")
     s = te.create_schedule(B.op)
 
     s[B].set_scope("shared")
@@ -131,22 +129,21 @@ def test_thread_axis():
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(
-        collect_visit(stmt.body. body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_vectorize():
-    n = te.size_var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
     bias = te.size_var("bias", dtype="float32")
     scale = te.size_var("scale", dtype="float32")
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name="C")
     # schedule
     s = te.create_schedule(C.op)
     # create iter var and assign them tags.
     num_thread = 32
-    bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
+    bx, x = s[C].split(C.op.axis[0], factor=num_thread * 4)
     tx, x = s[C].split(x, nparts=num_thread)
     _, x = s[C].split(x, factor=4)
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
@@ -154,70 +151,67 @@ def test_vectorize():
     s[C].vectorize(x)
     stmt = tvm.lower(s, [A, B], name="main")["main"].body
     body = stmt.body.body.body.body
-    assert(x.var.name not in str(body.condition))
-    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
+    assert x.var.name not in str(body.condition)
+    assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))
 
 
 def test_condition():
     ib = tvm.tir.ir_builder.create()
-    m = te.size_var('m')
-    n = te.size_var('n')
-    with ib.for_range(0, tvm.tir.truncdiv(n+3,4), 'i') as i:
-      with ib.for_range(0, 4, 'j') as j:
-        ib.emit(tvm.tir.Evaluate(
-          tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
+    m = te.size_var("m")
+    n = te.size_var("n")
+    with ib.for_range(0, tvm.tir.truncdiv(n + 3, 4), "i") as i:
+        with ib.for_range(0, 4, "j") as j:
+            ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(i * 4 + j < n), m, n)))
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
+    assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select)))
 
 
 def test_condition_EQ():
     ib = tvm.tir.ir_builder.create()
-    m = te.size_var('m')
-    n = te.size_var('n')
-    with ib.for_range(0, 10, 'i') as i:
-            ib.emit(tvm.tir.Evaluate(
-                tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
+    m = te.size_var("m")
+    n = te.size_var("n")
+    with ib.for_range(0, 10, "i") as i:
+        ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
+    assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select)))
 
 
 def test_thread_axis2():
     n = tvm.runtime.convert(4096)
-    m = te.size_var('m')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda i: A[i] + B[i], name='C')
+    m = te.size_var("m")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
     s = te.create_schedule(C.op)
     num_thread = 32
     bx, x = s[C].split(C.op.axis[0], factor=32)
     tx, x = s[C].split(x, nparts=num_thread)
-    _,  x = s[C].split(x, factor=m)
+    _, x = s[C].split(x, factor=m)
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
     stmt = tvm.lower(s, [A, B], name="main")["main"].body
     for_body = stmt.body.body.body.body[0]
-    assert('threadIdx' not in str(for_body.extent))
+    assert "threadIdx" not in str(for_body.extent)
+
 
 def test_everything_during_deduction():
-    m = te.size_var('m')
-    n = te.size_var('n')
+    m = te.size_var("m")
+    n = te.size_var("n")
     ib = tvm.tir.ir_builder.create()
-    with ib.for_range(0, n, 'i') as i:
-        with ib.for_range(0, 32, 'j') as j:
-            with ib.if_scope(ib.likely(tvm.tir.truncdiv(i,j) < m)):
+    with ib.for_range(0, n, "i") as i:
+        with ib.for_range(0, 32, "j") as j:
+            with ib.if_scope(ib.likely(tvm.tir.truncdiv(i, j) < m)):
                 # this guard will produce everything during deduction
                 ib.emit(tvm.tir.Evaluate(m))
     stmt = ib.get()
@@ -226,15 +220,15 @@ def test_everything_during_deduction():
     mod = tvm.tir.transform.LoopPartition()(mod)
     stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
+    assert isinstance(stmt.body.body, tvm.tir.IfThenElse)
 
-    assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))
 
 def test_single_likely():
     n = 60
-    A = te.placeholder((n, ), name='A')
-    B = te.placeholder((n, ), name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
 
-    T = te.compute((n, ), lambda i: A[i]+B[i])
+    T = te.compute((n,), lambda i: A[i] + B[i])
     s = te.create_schedule(T.op)
     x = T.op.axis[0]
     xo, xi = s[T].split(x, factor=16)
@@ -244,21 +238,20 @@ def test_single_likely():
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
+
 
 def test_multi_likely():
     n = 94
     m = 62
-    A = te.placeholder((n, m), name='A')
-    B = te.placeholder((n, m), name='B')
+    A = te.placeholder((n, m), name="A")
+    B = te.placeholder((n, m), name="B")
 
-    T = te.compute((n, m), lambda i, j: A[i, j]+B[i, j])
+    T = te.compute((n, m), lambda i, j: A[i, j] + B[i, j])
     s = te.create_schedule(T.op)
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
@@ -272,33 +265,31 @@ def test_multi_likely():
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_oneD_pool():
-    m = te.size_var('m')
+    m = te.size_var("m")
     ib = tvm.tir.ir_builder.create()
-    #data = te.placeholder((16,), name = 'data')
+    # data = te.placeholder((16,), name = 'data')
     data = ib.pointer("float32", name="A")
     out = ib.pointer("float32", name="A")
-    with ib.for_range(0, 16, 'ow') as ow:
-        with ib.for_range(0, 3, 'kw') as kw:
+    with ib.for_range(0, 16, "ow") as ow:
+        with ib.for_range(0, 3, "kw") as kw:
             with ib.if_scope(ib.likely(ow > 0)):
                 with ib.if_scope(ib.likely(ow < 15)):
                     out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
-    with ib.for_range(0, 16, 'ow') as ow:
-        with ib.for_range(0, 3, 'kw') as kw:
+    with ib.for_range(0, 16, "ow") as ow:
+        with ib.for_range(0, 3, "kw") as kw:
             with ib.if_scope(ib.likely(ow < 1)):
                 with ib.if_scope(ib.likely(kw > 0)):
                     out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
-    with ib.for_range(0, 16, 'ow') as ow:
-        with ib.for_range(0, 3, 'kw') as kw:
+    with ib.for_range(0, 16, "ow") as ow:
+        with ib.for_range(0, 3, "kw") as kw:
             with ib.if_scope(ib.likely(ow > 14)):
                 with ib.if_scope(ib.likely(kw < 2)):
                     out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
@@ -307,66 +298,63 @@ def test_oneD_pool():
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_cce_loop_1():
-  ib = tvm.tir.ir_builder.create()
-  dtype = 'float16'
-  n = 514
-  m = 514
-  _A = te.placeholder((n*m,), name = 'A')
-  Ab = tvm.tir.decl_buffer((n*m,), dtype, name="A")
-  A = ib.buffer_ptr(Ab)
-  _B = te.placeholder((n*m,), name = 'B')
-  Bb = tvm.tir.decl_buffer((n*m,), dtype, name="B")
-  B = ib.buffer_ptr(Bb)
-  #for i in 0 to n-1:
-  with ib.for_range(0, 11, name="i") as i:
-      with ib.for_range(0, 160, name="j") as j:
-          with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
-               A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
-  stmt = ib.get()
-
-  mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
-  with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-  }):
-    mod = tvm.tir.transform.LoopPartition()(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+    ib = tvm.tir.ir_builder.create()
+    dtype = "float16"
+    n = 514
+    m = 514
+    _A = te.placeholder((n * m,), name="A")
+    Ab = tvm.tir.decl_buffer((n * m,), dtype, name="A")
+    A = ib.buffer_ptr(Ab)
+    _B = te.placeholder((n * m,), name="B")
+    Bb = tvm.tir.decl_buffer((n * m,), dtype, name="B")
+    B = ib.buffer_ptr(Bb)
+    # for i in 0 to n-1:
+    with ib.for_range(0, 11, name="i") as i:
+        with ib.for_range(0, 160, name="j") as j:
+            with ib.if_scope(ib.likely(((i * 160) + j) < 1600)):
+                A[(i + 1) * m + j + 1] = (
+                    B[(i) * m + j + 1] + B[(i + 1) * m + j + 1] + B[(i + 2) * m + j + 1]
+                )
+    stmt = ib.get()
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
-  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_cce_loop_2():
-  ib = tvm.tir.ir_builder.create()
-  len = 112
-  tile = 32
-  loop = (len + tile - 1) // tile
-  with ib.for_range(0, loop, 'i') as i:
-    head = i * tile
-    with ib.if_scope(ib.likely(head + tile > len)):
-      tail = len
-      ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
-    with ib.else_scope():
-      tail = head + tile
-      ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
-
-  stmt = ib.get()
-
-  mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-  with tvm.transform.PassContext(config={
-      "tir.LoopPartition": {"partition_const_loop": True}
-  }):
-    mod = tvm.tir.transform.LoopPartition()(mod)
-    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+    ib = tvm.tir.ir_builder.create()
+    len = 112
+    tile = 32
+    loop = (len + tile - 1) // tile
+    with ib.for_range(0, loop, "i") as i:
+        head = i * tile
+        with ib.if_scope(ib.likely(head + tile > len)):
+            tail = len
+            ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail))
+        with ib.else_scope():
+            tail = head + tile
+            ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail))
 
-  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    stmt = ib.get()
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_cce_loop_3():
@@ -374,23 +362,21 @@ def test_cce_loop_3():
     loop1 = 4
     loop2 = 9998
     tile = 39991
-    with ib.for_range(0,loop2,'i') as i:
-        with ib.for_range(0,loop1,'j') as j:
+    with ib.for_range(0, loop2, "i") as i:
+        with ib.for_range(0, loop1, "j") as j:
             head1 = i
             head2 = j
-            with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
-                ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))
+            with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)):
+                ib.emit(tvm.tir.call_extern("float16", "cce_intrisic", head1))
 
     stmt = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_conv_tiling():
@@ -401,17 +387,18 @@ def test_conv_tiling():
     batch_size = 1
     in_height = in_width = 64
     out_height = out_width = in_height - kernel_height + 1
-    data = te.placeholder((batch_size, in_channel, in_height, in_width), name='data')
-    kernel = te.placeholder((kernel_height, kernel_width, in_channel,
-        out_channel), name='kernel')
-    ic = te.reduce_axis((0, in_channel), name='ic')
-    kh = te.reduce_axis((0, kernel_height), name='kh')
-    kw = te.reduce_axis((0, kernel_width), name='kw')
-    conv = te.compute((batch_size, out_channel, out_height, out_width),
-                       lambda n, oc, oh, ow: te.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
-                                                     kernel[kh, kw, ic, oc],
-                                                     axis=[ic, kh, kw]),
-                       name="conv2d")
+    data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data")
+    kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel")
+    ic = te.reduce_axis((0, in_channel), name="ic")
+    kh = te.reduce_axis((0, kernel_height), name="kh")
+    kw = te.reduce_axis((0, kernel_width), name="kw")
+    conv = te.compute(
+        (batch_size, out_channel, out_height, out_width),
+        lambda n, oc, oh, ow: te.sum(
+            data[n, ic, oh * HSTR + kh, ow * WSTR + kw] * kernel[kh, kw, ic, oc], axis=[ic, kh, kw]
+        ),
+        name="conv2d",
+    )
     s = te.create_schedule(conv.op)
 
     n, oc, oh, ow = conv.op.axis
@@ -420,17 +407,16 @@ def test_conv_tiling():
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod = tvm.tir.transform.LoopPartition()(mod)
         stmt = tvm.tir.transform.Simplify()(mod)["main"].body
 
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
 
 def test_multilevel_splitting_with_indivisble_factors():
     from tvm import topi
+
     A = te.placeholder((130,), dtype="float32")
     B = topi.nn.relu(A)
     s = te.create_schedule(B.op)
@@ -441,57 +427,68 @@ def test_multilevel_splitting_with_indivisble_factors():
     s[B].unroll(yi)
 
     ## But this does the right thing.
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
+
         def visit_stmt(op):
-            return(isinstance(op, tvm.tir.Max))
+            return isinstance(op, tvm.tir.Max)
+
         num_max = collect_visit(lowered_body, visit_stmt)
         assert num_max.count(True) == 10
 
 
 def test_double_splitting_with_indivisible_factors():
     m = 48
-    dtype="float32"
-    A = te.placeholder((m,), name='A', dtype=dtype)
-    C = te.compute((m,), lambda i: A[i], name='C')
-    D = te.compute((m,), lambda i: C[i], name='D')
+    dtype = "float32"
+    A = te.placeholder((m,), name="A", dtype=dtype)
+    C = te.compute((m,), lambda i: A[i], name="C")
+    D = te.compute((m,), lambda i: C[i], name="D")
 
     s = te.create_schedule(D.op)
     co, ci = s[C].split(C.op.axis[0], factor=10)
     do, di = s[D].split(D.op.axis[0], 32)
     s[C].compute_at(s[D], do)
 
-    target = 'llvm'
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    target = "llvm"
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
         func = tvm.build(f, target=target)
 
     top_produce = f["fadd1"].body
-    assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))
 
     # check functional correctness of generated code
     ctx = tvm.context(target, 0)
-    a = tvm.nd.array(numpy.ones(m,).astype(dtype), ctx)
-    c = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
-    d = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
+    a = tvm.nd.array(
+        numpy.ones(
+            m,
+        ).astype(dtype),
+        ctx,
+    )
+    c = tvm.nd.array(
+        numpy.zeros(
+            m,
+        ).astype(dtype),
+        ctx,
+    )
+    d = tvm.nd.array(
+        numpy.zeros(
+            m,
+        ).astype(dtype),
+        ctx,
+    )
     func(a, c, d)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5)
     tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5)
 
+
 def test_simple_rfactor():
-    K = 16*4+4
-    k = te.reduce_axis((0, K), 'k')
+    K = 16 * 4 + 4
+    k = te.reduce_axis((0, K), "k")
 
-    A = te.placeholder((1, K), name='A')
+    A = te.placeholder((1, K), name="A")
 
-    B = te.compute( (1,), lambda b:
-            te.sum(A[b, k], axis=k),
-            name='B'
-    )
+    B = te.compute((1,), lambda b: te.sum(A[b, k], axis=k), name="B")
 
     s = te.create_schedule(B.op)
     ko, _ = s[B].split(s[B].op.reduce_axis[0], 16)
@@ -504,9 +501,7 @@ def test_simple_rfactor():
     mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
     stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
 
-    with tvm.transform.PassContext(config={
-        "tir.LoopPartition": {"partition_const_loop": True}
-    }):
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
         mod2 = tvm.tir.transform.LoopPartition()(mod1)
         stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
 
index 3042f9e..fb3790e 100644 (file)
@@ -24,12 +24,12 @@ def lower_intrin(params, stmt):
     """wrapper to call transformation in stmt"""
     lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
     stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(params, stmt).with_attr(
-        "target", tvm.target.Target("llvm")))
-    mod = tvm.transform.Sequential([
-        tvm.tir.transform.Simplify(),
-        tvm.tir.transform.LowerIntrin()
-    ])(mod)
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc(params, stmt).with_attr("target", tvm.target.Target("llvm"))
+    )
+    mod = tvm.transform.Sequential([tvm.tir.transform.Simplify(), tvm.tir.transform.LowerIntrin()])(
+        mod
+    )
     func = mod["main"]
     stmt = func.body
     return stmt.value if lower_expr else stmt.body
@@ -61,6 +61,7 @@ def check_value(expr, vx, vy, data, fref):
 def get_ref_data():
     """Get reference data for every pairs"""
     import itertools
+
     x = range(-10, 10)
     y = list(range(-10, 10))
     y.remove(0)
@@ -78,24 +79,21 @@ def test_lower_floordiv():
         res = lower_intrin([x, y], tvm.te.floordiv(x, y))
         check_value(res, x, y, data, lambda a, b: a // b)
         # rhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(
-            y >= 0, tvm.te.floordiv(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
         # involves max
-        res = lower_intrin([x, y], tvm.tir.Select(
-            y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
-        check_value(res, x, y, data, lambda a,
-                    b: max(a // b, 0) if b > 0 else 0)
+        res = lower_intrin(
+            [x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero)
+        )
+        check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(
-            tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
-        check_value(res, x, y, data, lambda a, b: a //
-                    b if b > 0 and a >= 0 else 0)
+        res = lower_intrin(
+            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero)
+        )
+        check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
         # const power of two
-        res = lower_intrin([x, y], tvm.te.floordiv(
-            x, tvm.tir.const(8, dtype=dtype)))
-        check_value(res, x, y, [(a, b)
-                                for a, b in data if b == 8], lambda a, b: a // b)
+        res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
+        check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
 
 
 @tvm.testing.requires_llvm
@@ -109,19 +107,16 @@ def test_lower_floormod():
         res = lower_intrin([x, y], tvm.te.floormod(x, y))
         check_value(res, x, y, data, lambda a, b: a % b)
         # rhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(
-            y >= 0, tvm.te.floormod(x, y), zero))
+        res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin([x, y], tvm.tir.Select(
-            tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
-        check_value(res, x, y, data, lambda a, b: a %
-                    b if b > 0 and a >= 0 else 0)
+        res = lower_intrin(
+            [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero)
+        )
+        check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
         # const power of two
-        res = lower_intrin([x, y], tvm.te.floormod(
-            x, tvm.tir.const(8, dtype=dtype)))
-        check_value(res, x, y, [(a, b)
-                                for a, b in data if b == 8], lambda a, b: a % b)
+        res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
+        check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
 
 
 if __name__ == "__main__":
index 38bf89c..28179c2 100644 (file)
@@ -25,8 +25,8 @@ import tvm.testing
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_local_scope():
     m = 128
-    A = te.placeholder((m,), name='A')
-    B = te.compute((m,), lambda i: A[i] + 3, name='B')
+    A = te.placeholder((m,), name="A")
+    B = te.compute((m,), lambda i: A[i] + 3, name="B")
 
     s = te.create_schedule(B.op)
     AA = s.cache_read(A, "warp", [B])
@@ -43,20 +43,19 @@ def test_lower_warp_memory_local_scope():
     assert cuda_target.thread_warp_size == 32
     mod = tvm.lower(s, [A, B], name="f")
 
-    mod = tvm.tir.transform.Apply(
-        lambda f: f.with_attr("target", cuda_target))(mod)
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
     fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
     mod = tvm.IRModule.from_expr(fdevice)
     fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
-    assert(fdevice.body.body.value.value == "local")
-    assert(fdevice.body.body.body.extents[0].value == 2)
+    assert fdevice.body.body.value.value == "local"
+    assert fdevice.body.body.body.extents[0].value == 2
 
 
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_correct_indices():
     n = 32
-    A = te.placeholder((2, n, n), name='A', dtype="float32")
-    C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')
+    A = te.placeholder((2, n, n), name="A", dtype="float32")
+    C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name="C")
 
     s = te.create_schedule(C.op)
     bk_x = te.thread_axis("blockIdx.x")
@@ -84,8 +83,7 @@ def test_lower_warp_memory_correct_indices():
     # 2. If we are accessing from different warps (different threadIdx.y), we are actually
     #    assessing different buffers, so there is no need to distinguish from elements,
     #    and therefore threadIdx.y is NOT a index.
-    idx_names = map(lambda x: x.name,
-                    filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
+    idx_names = map(lambda x: x.name, filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
     assert "threadIdx.x" in idx_names
     assert "threadIdx.y" not in idx_names
 
@@ -99,9 +97,8 @@ def test_lower_warp_memory_cuda_end_to_end():
             return
 
         m = 128
-        A = te.placeholder((m,), name='A', dtype=dtype)
-        B = te.compute(
-            (m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')
+        A = te.placeholder((m,), name="A", dtype=dtype)
+        B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name="B")
 
         cuda_target = tvm.target.Target("cuda")
         assert cuda_target.thread_warp_size == 32
@@ -121,11 +118,16 @@ def test_lower_warp_memory_cuda_end_to_end():
             func = tvm.build(s, [A, B], "cuda")
             A_np = np.array(list(range(m)), dtype=dtype)
             B_np = np.array(
-                list(range(1, 32)) + [0] +
-                list(range(33, 64)) + [32] +
-                list(range(65, 96)) + [64] +
-                list(range(97, 128)) + [96],
-                dtype=dtype)
+                list(range(1, 32))
+                + [0]
+                + list(range(33, 64))
+                + [32]
+                + list(range(65, 96))
+                + [64]
+                + list(range(97, 128))
+                + [96],
+                dtype=dtype,
+            )
             A_nd = tvm.nd.array(A_np, ctx)
             B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
             func(A_nd, B_nd)
@@ -144,8 +146,22 @@ def test_lower_warp_memory_cuda_half_a_warp():
             return
 
         n, m = 16, 16
-        A = te.placeholder((n, m,), name='A', dtype=dtype)
-        B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B')
+        A = te.placeholder(
+            (
+                n,
+                m,
+            ),
+            name="A",
+            dtype=dtype,
+        )
+        B = te.compute(
+            (
+                n,
+                m,
+            ),
+            lambda j, i: A[j, (i + 1) % m],
+            name="B",
+        )
 
         cuda_target = tvm.target.Target("cuda")
         assert cuda_target.thread_warp_size == 2 * m
@@ -167,10 +183,8 @@ def test_lower_warp_memory_cuda_half_a_warp():
 
             ctx = tvm.gpu(0)
             func = tvm.build(s, [A, B], "cuda")
-            A_np = np.array([list(range(i, m + i))
-                             for i in range(n)], dtype=dtype)
-            B_np = np.array([list(range(1 + i, m + i)) + [i]
-                             for i in range(n)], dtype=dtype)
+            A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
+            B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
             A_nd = tvm.nd.array(A_np, ctx)
             B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
             func(A_nd, B_nd)
@@ -189,10 +203,9 @@ def test_lower_warp_memory_cuda_2_buffers():
             return
 
         m = 32
-        A = te.placeholder((m,), name='A', dtype=dtype)
-        B = te.placeholder((m,), name='B', dtype=dtype)
-        C = te.compute((m,), lambda i: A[(i + 1) %
-                                         m] + B[(i + 1) % m], name='C')
+        A = te.placeholder((m,), name="A", dtype=dtype)
+        B = te.placeholder((m,), name="B", dtype=dtype)
+        C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name="C")
 
         cuda_target = tvm.target.Target("cuda")
         assert m <= cuda_target.thread_warp_size
@@ -232,8 +245,8 @@ def test_lower_warp_memory_cuda_2_buffers():
 @tvm.testing.requires_gpu
 def test_lower_warp_memory_roundup():
     def check(device, m):
-        A = te.placeholder((m,), name='A')
-        B = te.compute((m,), lambda i: A[i] + 1, name='B')
+        A = te.placeholder((m,), name="A")
+        B = te.compute((m,), lambda i: A[i] + 1, name="B")
 
         with tvm.target.Target(device):
             s = te.create_schedule(B.op)
@@ -257,7 +270,7 @@ def test_lower_warp_memory_roundup():
             B_np = A_np + 1
             tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)
 
-    for device in ['cuda', 'rocm']:
+    for device in ["cuda", "rocm"]:
         if not tvm.testing.device_enabled(device):
             print("skip because", device, "is not enabled..")
             continue
index 4797eea..15f9940 100644 (file)
@@ -18,12 +18,13 @@ import tvm
 from tvm import te
 import numpy
 
+
 def test_makeapi():
     """Not yet working, mock design"""
-    n = te.size_var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.placeholder((n,), name='B')
-    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.placeholder((n,), name="B")
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
 
     bounds = tvm.te.schedule.InferBound(s)
@@ -32,14 +33,17 @@ def test_makeapi():
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
     mod = tvm.tir.transform.Apply(
-        lambda f: f.with_attr({
-            "target": tvm.target.Target("llvm"),
-            "global_symbol": "main",
-        }))(mod)
+        lambda f: f.with_attr(
+            {
+                "target": tvm.target.Target("llvm"),
+                "global_symbol": "main",
+            }
+        )
+    )(mod)
 
     num_unpacked_args = 2
     f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
-    assert(len(f.params) == 8)
+    assert len(f.params) == 8
 
 
 if __name__ == "__main__":
index 6571499..e1e5adb 100644 (file)
@@ -22,8 +22,7 @@ from tvm.tir import const
 
 def lower_stmt(params, stmt, target_bits):
     func = tvm.tir.PrimFunc(params, stmt)
-    func = tvm.tir.transform.NarrowDataType(target_bits)(
-        tvm.IRModule.from_expr(func))["main"]
+    func = tvm.tir.transform.NarrowDataType(target_bits)(tvm.IRModule.from_expr(func))["main"]
     stmt = func.body
     return stmt
 
@@ -52,12 +51,12 @@ def lower_sch(sch, args, target_bits):
 def test_basic():
     def check(m, n, target_bits, target_dtype):
         ib = tvm.tir.ir_builder.create()
-        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        Ab = tvm.tir.decl_buffer((m, n), name="A")
         A = ib.buffer_ptr(Ab)
-        Bb = tvm.tir.decl_buffer((m, n), name='B')
+        Bb = tvm.tir.decl_buffer((m, n), name="B")
         B = ib.buffer_ptr(Bb)
-        with ib.for_range(0, m, name='i') as i:
-            with ib.for_range(0, n, name='j') as j:
+        with ib.for_range(0, m, name="i") as i:
+            with ib.for_range(0, n, name="j") as j:
                 B[i * n + j] = A[i * n + j] + 1
         stmt = ib.get()
         stmt = lower_stmt([Ab, Bb], stmt, target_bits)
@@ -67,25 +66,25 @@ def test_basic():
     # const shape
     # i32 -> i32
     check(2, 2, 32, "int32")
-    check(2**16, 2**16, 32, "int32")  # i32 + i32 is not promoted to i64 even if overflow
+    check(2 ** 16, 2 ** 16, 32, "int32")  # i32 + i32 is not promoted to i64 even if overflow
     # i64 -> i32
-    check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
-    check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
+    check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32")
+    check(const(2 ** 16, dtype="int64"), const(2 ** 16, dtype="int64"), 32, "int64")
     # i32 -> i16
     check(2, 2, 16, "int16")
-    check(2**10, 2**10, 16, "int32")
+    check(2 ** 10, 2 ** 10, 16, "int32")
 
     # symbolic shape
-    check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
-    check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
+    check(te.size_var(name="m", dtype="int32"), te.size_var(name="n", dtype="int32"), 32, "int32")
+    check(te.size_var(name="m", dtype="int64"), te.size_var(name="n", dtype="int64"), 32, "int64")
 
 
 def test_thread_axis():
     def check(m, n, target_bits, target_dtype):
         ib = tvm.tir.ir_builder.create()
-        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        Ab = tvm.tir.decl_buffer((m, n), name="A")
         A = ib.buffer_ptr(Ab)
-        Bb = tvm.tir.decl_buffer((m, n), name='B')
+        Bb = tvm.tir.decl_buffer((m, n), name="B")
         B = ib.buffer_ptr(Bb)
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
@@ -98,85 +97,81 @@ def test_thread_axis():
         assert stmt.body.node.var.dtype == target_dtype
 
     # i32 -> i32
-    check(2, 32,
-          target_bits=32, target_dtype='int32')
-    check(2**30, 32,  # i32 + i32 is not promoted to i64 even in the case of overflow
-          target_bits=32, target_dtype='int32')
+    check(2, 32, target_bits=32, target_dtype="int32")
+    check(
+        2 ** 30,
+        32,  # i32 + i32 is not promoted to i64 even in the case of overflow
+        target_bits=32,
+        target_dtype="int32",
+    )
     # i64 -> i32
-    check(const(2, dtype='int64'),
-          const(32, dtype='int64'),
-          target_bits=32, target_dtype='int32')
-    check(const(2**30, dtype='int64'),
-          const(32, dtype='int64'),
-          target_bits=32, target_dtype='int64')
+    check(const(2, dtype="int64"), const(32, dtype="int64"), target_bits=32, target_dtype="int32")
+    check(
+        const(2 ** 30, dtype="int64"),
+        const(32, dtype="int64"),
+        target_bits=32,
+        target_dtype="int64",
+    )
     # i32 -> i16
-    check(2, 32,
-          target_bits=16, target_dtype='int16')
-    check(2**14, 32,
-          target_bits=16, target_dtype='int32')
+    check(2, 32, target_bits=16, target_dtype="int16")
+    check(2 ** 14, 32, target_bits=16, target_dtype="int32")
 
 
 def test_multilanes():
     def check(m, lanes, target_bits, target_dtype):
         ib = tvm.tir.ir_builder.create()
-        Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A')
+        Ab = tvm.tir.decl_buffer((m,), dtype="float32x{}".format(lanes), name="A")
         A = ib.buffer_ptr(Ab)
-        Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B')
+        Bb = tvm.tir.decl_buffer((m,), dtype="float32x{}".format(lanes), name="B")
         B = ib.buffer_ptr(Bb)
-        with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
+        with ib.for_range(0, m, name="i", dtype=m.dtype) as i:
             B[i] = A[i] + 1
         stmt = ib.get()
         stmt = lower_stmt([Ab, Bb], stmt, target_bits)
         assert stmt.loop_var.dtype == target_dtype
 
     # i32 -> i32
-    check(const(2 ** 10, dtype='int32'), 2,
-          target_bits=32, target_dtype='int32')
-    check(const(2 ** 32, dtype='int32'), 2,
-          target_bits=32, target_dtype='int32')
+    check(const(2 ** 10, dtype="int32"), 2, target_bits=32, target_dtype="int32")
+    check(const(2 ** 32, dtype="int32"), 2, target_bits=32, target_dtype="int32")
     # i64 -> i32
-    check(const(2 ** 10, dtype='int64'), 2,
-          target_bits=32, target_dtype='int32')
-    check(const(2 ** 32, dtype='int64'), 2,
-          target_bits=32, target_dtype='int64')
+    check(const(2 ** 10, dtype="int64"), 2, target_bits=32, target_dtype="int32")
+    check(const(2 ** 32, dtype="int64"), 2, target_bits=32, target_dtype="int64")
     # i32 -> i16
-    check(const(2 ** 10, dtype='int32'), 2,
-          target_bits=16, target_dtype='int16')
-    check(const(2 ** 16, dtype='int32'), 2,
-          target_bits=16, target_dtype='int32')
+    check(const(2 ** 10, dtype="int32"), 2, target_bits=16, target_dtype="int16")
+    check(const(2 ** 16, dtype="int32"), 2, target_bits=16, target_dtype="int32")
 
 
 def test_reduce():
     def check(m, target_bits, target_dtype):
-        A = te.placeholder((m,), name='A', dtype='float32')
+        A = te.placeholder((m,), name="A", dtype="float32")
         k = te.reduce_axis((0, m), "k")
-        B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
+        B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B")
         s = te.create_schedule(B.op)
         stmt = lower_sch(s, [A, B], target_bits)
         assert stmt[1].loop_var.dtype == target_dtype
 
     # i32 -> i32
-    check(const(64, dtype='int32'), 32, 'int32')
+    check(const(64, dtype="int32"), 32, "int32")
     # i64 -> i32
-    check(const(64, dtype='int64'), 32, 'int32')
+    check(const(64, dtype="int64"), 32, "int32")
     # i32 -> i16
-    check(const(64, dtype='int32'), 16, 'int16')
-    check(const(2**16, dtype='int32'), 16, 'int32')
+    check(const(64, dtype="int32"), 16, "int16")
+    check(const(2 ** 16, dtype="int32"), 16, "int32")
     # symbolic
-    check(te.var('n', dtype='int32'), 32, 'int32')
-    check(te.var('n', dtype='int64'), 32, 'int64')
+    check(te.var("n", dtype="int32"), 32, "int32")
+    check(te.var("n", dtype="int64"), 32, "int64")
 
 
 def test_slice():
     def check(m, n, target_bits, target_dtype):
         # The index may overflow in B, while not in A
         ib = tvm.tir.ir_builder.create()
-        Ab = tvm.tir.decl_buffer((m, n), name='A')
+        Ab = tvm.tir.decl_buffer((m, n), name="A")
         A = ib.buffer_ptr(Ab)
-        Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
+        Bb = tvm.tir.decl_buffer((m, n * 2), name="B")
         B = ib.buffer_ptr(Bb)
-        with ib.for_range(0, m, name='i') as i:
-            with ib.for_range(0, n, name='j') as j:
+        with ib.for_range(0, m, name="i") as i:
+            with ib.for_range(0, n, name="j") as j:
                 A[i * n + j] = B[i * 2 * n + 2 * j] + 1
         stmt = ib.get()
         stmt = lower_stmt([Ab, Bb], stmt, target_bits)
@@ -184,18 +179,19 @@ def test_slice():
         assert stmt.body.loop_var.dtype == target_dtype
 
     # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
-    check(const(2**15, 'int64'), const(2**15, 'int64'),
-          target_bits=32, target_dtype='int32')
+    check(const(2 ** 15, "int64"), const(2 ** 15, "int64"), target_bits=32, target_dtype="int32")
     # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
-    check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
-          target_bits=32, target_dtype='int64')
+    check(
+        const(2 ** 15, "int64"), const((2 ** 15 + 1), "int64"), target_bits=32, target_dtype="int64"
+    )
 
 
 def test_relay_basic():
     engine = relay.backend.compile_engine.get()
+
     def check(shapex, shapey, target_bits, target_dtype):
-        x = relay.var('x', shape=shapex)
-        y = relay.var('y', shape=shapey)
+        x = relay.var("x", shape=shapex)
+        y = relay.var("y", shape=shapey)
         z = relay.add(x, y)
         func = relay.Function([x, y], z)
         mod = tvm.IRModule.from_expr(func)
@@ -208,18 +204,32 @@ def test_relay_basic():
         if len(shapex) > 1 or len(shapey) > 1:
             assert stmt.body.loop_var.dtype == target_dtype
 
-    check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), (1, const(2**15 + 1, 'int64')),
-          target_bits=32, target_dtype="int64")
-    check((const(2**16, 'int64'), const(2**15, 'int64')), (1, const(2**15, 'int64')),
-          target_bits=32, target_dtype="int32")
-    check((const(2**31, 'int64'),), (const(2**31, 'int64'),),
-          target_bits=32, target_dtype="int32")
-    check((const(2**31 + 1, 'int64'),), (const(2**31 + 1, 'int64'),),
-          target_bits=32, target_dtype="int64")
+    check(
+        (const(2 ** 16, "int64"), const(2 ** 15 + 1, "int64")),
+        (1, const(2 ** 15 + 1, "int64")),
+        target_bits=32,
+        target_dtype="int64",
+    )
+    check(
+        (const(2 ** 16, "int64"), const(2 ** 15, "int64")),
+        (1, const(2 ** 15, "int64")),
+        target_bits=32,
+        target_dtype="int32",
+    )
+    check(
+        (const(2 ** 31, "int64"),), (const(2 ** 31, "int64"),), target_bits=32, target_dtype="int32"
+    )
+    check(
+        (const(2 ** 31 + 1, "int64"),),
+        (const(2 ** 31 + 1, "int64"),),
+        target_bits=32,
+        target_dtype="int64",
+    )
 
 
 def test_relay_take():
     engine = relay.backend.compile_engine.get()
+
     def check(shape, index, target_bits, target_dtype):
         x = relay.var("x", shape=shape)
         y = relay.op.take(x, indices=index)
@@ -230,10 +240,18 @@ def test_relay_take():
         stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32)
         assert stmt.value.index.dtype == target_dtype
 
-    check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(0, dtype="int64"),
-          target_bits=32, target_dtype="int32")
-    check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(2**31, dtype="int64"),
-          target_bits=32, target_dtype="int64")
+    check(
+        (const(2 ** 16, "int64"), const(2 ** 15 + 1, "int64")),
+        relay.const(0, dtype="int64"),
+        target_bits=32,
+        target_dtype="int32",
+    )
+    check(
+        (const(2 ** 16, "int64"), const(2 ** 15 + 1, "int64")),
+        relay.const(2 ** 31, dtype="int64"),
+        target_bits=32,
+        target_dtype="int64",
+    )
 
 
 if __name__ == "__main__":
index 977f50e..c513640 100644 (file)
@@ -22,23 +22,21 @@ def test_prim_func_pass():
     @tvm.tir.transform.prim_func_pass(opt_level=1)
     class TestReplaceFunc:
         """Simple test function to replace one argument to another."""
+
         def __init__(self, new_func):
             self.new_func = new_func
 
         def transform_function(self, func, mod, ctx):
             return self.new_func
 
-    x = te.var('x')
-    y = te.var('y')
+    x = te.var("x")
+    y = te.var("y")
     b = tvm.tir.decl_buffer((x,), "float32")
-    stmt = tvm.tir.LetStmt(
-        x, 10, tvm.tir.Evaluate(x + 1));
+    stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
 
-    func = tvm.tir.PrimFunc(
-        [x, y, b], stmt)
+    func = tvm.tir.PrimFunc([x, y, b], stmt)
 
-    new_func = tvm.tir.PrimFunc(
-        [x, y, b], tvm.tir.Evaluate(0))
+    new_func = tvm.tir.PrimFunc([x, y, b], tvm.tir.Evaluate(0))
 
     mod = tvm.IRModule({"main": func})
     mod = TestReplaceFunc(new_func)(mod)
@@ -52,19 +50,18 @@ def test_cow_pass():
         return f
 
     pidentity = tvm.tir.transform.Apply(fapply)
-    x = te.var('x')
-    func = tvm.tir.PrimFunc(
-        [x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32)
+    x = te.var("x")
+    func = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32)
     func_hash = func.__hash__()
     mod = tvm.IRModule({"main": func})
     del func
     # copy on write
     mod_hash = mod.__hash__()
-    mod = tvm.transform.Sequential(
-        [pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
+    mod = tvm.transform.Sequential([pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
     assert mod_hash == mod.__hash__()
     assert func_hash == mod["main"].__hash__()
 
+
 if __name__ == "__main__":
     test_cow_pass()
     test_prim_func_pass()
index c58b8b4..2edb8cf 100644 (file)
 import tvm
 from tvm import te
 
+
 def nop():
     return tvm.tir.Evaluate(0)
 
+
 def test_remove_no_op():
-    i = te.var('i')
-    j = te.var('j')
-    k = te.var('k')
-    m = te.var('m')
-    n = te.var('n')
-    dtype = 'int64'
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
+    i = te.var("i")
+    j = te.var("j")
+    k = te.var("k")
+    m = te.var("m")
+    n = te.var("n")
+    dtype = "int64"
+    Ab = tvm.tir.decl_buffer((n,), dtype)
     stmt = tvm.tir.For(
-        i, 0, 4, 0, 0,
+        i,
+        0,
+        4,
+        0,
+        0,
         tvm.tir.For(
-            j, 0, n, 0, 0,
+            j,
+            0,
+            n,
+            0,
+            0,
             tvm.tir.For(
-                k, 0, m, 0, 0,
-                tvm.tir.IfThenElse(
-                    (i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
+                k,
+                0,
+                m,
+                0,
+                0,
+                tvm.tir.IfThenElse((i * m + j + k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)),
+            ),
+        ),
+    )
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
     ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
 
-    assert(isinstance(ret, tvm.tir.Evaluate))
-    store = tvm.tir.Store(Ab.data,
-                           tvm.tir.Load(dtype, Ab.data, i) + 1,
-                           i + 1)
+    assert isinstance(ret, tvm.tir.Evaluate)
+    store = tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)
     stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2))
     ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
-    assert(ret == store)
+    assert ret == store
 
     # remove zero extent loop
     stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
     ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
-    assert(isinstance(ret, tvm.tir.Evaluate))
+    assert isinstance(ret, tvm.tir.Evaluate)
 
 
 if __name__ == "__main__":
index 9f1104d..1c5f2f4 100644 (file)
@@ -22,22 +22,18 @@ def test_rewrite_Select():
     ib = tvm.tir.ir_builder.create()
     A = ib.allocate("float32", 100, name="A", scope="global")
     i = te.var("i")
-    y = tvm.tir.Select(i > 1, A[i-1], 1.0)
+    y = tvm.tir.Select(i > 1, A[i - 1], 1.0)
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
     yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
 
-    z = tvm.tir.Select(
-        tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
+    z = tvm.tir.Select(tvm.tir.Select(i > 1, A[i - 1], 1.0) > 0.0, A[i], 0.1)
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
     zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
 
     a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z)
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
     aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
     builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else")
 
index 48d0849..f298288 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 
+
 def test_stmt_simplify():
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
@@ -27,8 +28,7 @@ def test_stmt_simplify():
             A[i] = C[i]
 
     body = tvm.tir.LetStmt(n, 10, ib.get())
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A, C, n], body))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body))
     body = tvm.tir.transform.Simplify()(mod)["main"].body
     assert isinstance(body.body, tvm.tir.Store)
 
@@ -46,8 +46,7 @@ def test_thread_extent_simplify():
     with ib.if_scope(tx + ty < 12):
         A[tx] = C[tx + ty]
     body = tvm.tir.LetStmt(n, 10, ib.get())
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A, C, n], body))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body))
     body = tvm.tir.transform.Simplify()(mod)["main"].body
     assert isinstance(body.body.body.body, tvm.tir.Store)
 
@@ -65,37 +64,38 @@ def test_if_likely():
         with ib.if_scope(ib.likely(tx * 32 + ty < n)):
             A[tx] = C[tx * 32 + ty]
     body = ib.get()
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A, C, n], body))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body))
     body = tvm.tir.transform.Simplify()(mod)["main"].body
     assert isinstance(body.body.body, tvm.tir.IfThenElse)
     assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse)
 
 
 def test_basic_likely_elimination():
-    n = te.size_var('n')
+    n = te.size_var("n")
     X = te.placeholder(shape=(n,), name="x")
     W = te.placeholder(shape=(n + 1,), dtype="int32", name="w")
 
     def f(i):
         start = W[i]
-        extent = W[i+1] - W[i]
+        extent = W[i + 1] - W[i]
         rv = te.reduce_axis((0, extent))
         return te.sum(X[rv + start], axis=rv)
+
     Y = te.compute(X.shape, f, name="y")
     s = te.create_schedule([Y.op])
     stmt = tvm.lower(s, [X, W, Y], simple_mode=True)
-    assert('if' not in str(stmt))
+    assert "if" not in str(stmt)
+
 
 def test_complex_likely_elimination():
     def cumsum(X):
         """
         Y[i] = sum(X[:i])
         """
-        (m, ) = X.shape
-        s_state = te.placeholder((m + 1, ), dtype="int32", name="state")
-        s_init = te.compute((1, ), lambda _: tvm.tir.const(0, "int32"))
-        s_update = te.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1])
+        (m,) = X.shape
+        s_state = te.placeholder((m + 1,), dtype="int32", name="state")
+        s_init = te.compute((1,), lambda _: tvm.tir.const(0, "int32"))
+        s_update = te.compute((m + 1,), lambda l: s_state[l - 1] + X[l - 1])
         return tvm.te.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")
 
     def sparse_lengths_sum(data, indices, lengths):
@@ -112,8 +112,13 @@ def test_complex_likely_elimination():
 
         return te.compute(oshape, sls)
 
-    m, n, d, i, l = te.size_var('m'), te.size_var('n'), te.size_var('d'),\
-                    te.size_var('i'), te.size_var('l')
+    m, n, d, i, l = (
+        te.size_var("m"),
+        te.size_var("n"),
+        te.size_var("d"),
+        te.size_var("i"),
+        te.size_var("l"),
+    )
     data_ph = te.placeholder((m, d * 32), name="data")
     indices_ph = te.placeholder((i,), name="indices", dtype="int32")
     lengths_ph = te.placeholder((n,), name="lengths", dtype="int32")
@@ -125,7 +130,8 @@ def test_complex_likely_elimination():
     s[Y].reorder(n, do, gg, di)
     s[Y].vectorize(di)
     stmt = tvm.lower(s, [data_ph, indices_ph, lengths_ph, Y], simple_mode=True)
-    assert('if' not in str(stmt))
+    assert "if" not in str(stmt)
+
 
 if __name__ == "__main__":
     test_stmt_simplify()
index b0acc6c..2d1fea0 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_flatten2():
-    m = te.size_var('m')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    m = te.size_var("m")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     xo, xi = s[A2].split(A2.op.axis[0], 8)
@@ -30,30 +31,28 @@ def test_flatten2():
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
-    A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
+    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A")
+    A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2")
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [Ab, A2b], stmt, {A: Ab, A2: A2b})
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, A2: A2b})
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
 
 def test_flatten_prefetch():
-    A = te.placeholder((25, 100, 4), name = 'A')
-    _A= tvm.tir.decl_buffer(A.shape, A.dtype, name = 'A');
-    i = te.size_var('i')
-    j = te.size_var('j')
+    A = te.placeholder((25, 100, 4), name="A")
+    _A = tvm.tir.decl_buffer(A.shape, A.dtype, name="A")
+    i = te.size_var("i")
+    j = te.size_var("j")
     region = [tvm.ir.Range.from_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
     stmt = tvm.tir.Prefetch(_A, region)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [_A], stmt, {A: _A})
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([_A], stmt, {A: _A})
 
     mod = tvm.IRModule.from_expr(func)
-    mod = tvm.transform.Sequential([
-        tvm.tir.transform.StorageFlatten(64),
-        tvm.tir.transform.Simplify()])(mod)
+    mod = tvm.transform.Sequential(
+        [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
+    )(mod)
     stmt = mod["main"].body
     assert stmt.extent.value == 2
     assert isinstance(stmt.body, tvm.tir.For)
@@ -63,9 +62,9 @@ def test_flatten_prefetch():
 def test_flatten_storage_align():
     m = 8
     l = 16
-    A = te.placeholder((m, l), name='A')
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    A = te.placeholder((m, l), name="A")
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     s[A1].storage_align(A1.op.axis[0], 2, 1)
@@ -75,16 +74,16 @@ def test_flatten_storage_align():
 
     func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
     mod = tvm.IRModule.from_expr(func)
-    mod = tvm.transform.Sequential([
-        tvm.tir.transform.StorageFlatten(64),
-        tvm.tir.transform.Simplify()])(mod)
+    mod = tvm.transform.Sequential(
+        [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
+    )(mod)
 
     stmt = mod["main"].body
-    assert(stmt.body.extents[0].value == 17 * 8)
+    assert stmt.body.extents[0].value == 17 * 8
 
 
 def test_flatten_double_buffer():
-    dtype = 'int64'
+    dtype = "int64"
     n = 100
     m = 4
     tx = te.thread_axis("threadIdx.x")
@@ -103,33 +102,34 @@ def test_flatten_double_buffer():
 
     stmt = ib.get()
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A, C], stmt))
-
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt))
 
-    with tvm.transform.PassContext(config={
-        "tir.InjectDoubleBuffer" : {"split_loop" : 2}
-    }):
-        mod = tvm.transform.Sequential([
-            tvm.tir.transform.StorageFlatten(64),
-            tvm.tir.transform.InjectDoubleBuffer(),
-            tvm.tir.transform.Simplify()])(mod)
+    with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}):
+        mod = tvm.transform.Sequential(
+            [
+                tvm.tir.transform.StorageFlatten(64),
+                tvm.tir.transform.InjectDoubleBuffer(),
+                tvm.tir.transform.Simplify(),
+            ]
+        )(mod)
 
     stmt = mod["main"].body
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
 
-    mod = tvm.IRModule.from_expr(
-        tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db"))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db"))
     f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
 
     count = [0]
+
     def count_sync(op):
-        if isinstance(op, tvm.tir.Call) and  op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")):
+        if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")):
             count[0] += 1
+
     tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
     assert count[0] == 4
 
+
 if __name__ == "__main__":
     test_flatten2()
     test_flatten_storage_align()
index 46ba687..cc2b427 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_storage_share():
-    m = te.var('m')
-    l = te.var('l')
-    A = te.placeholder((m, l), name='A')
+    m = te.var("m")
+    l = te.var("l")
+    A = te.placeholder((m, l), name="A")
     num_stage = 5
     B = A
     for t in range(num_stage):
-        B = te.compute((m, l), lambda i, j: B[i, j] + (t+1), name='A%d' % t)
+        B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t)
 
     s = te.create_schedule(B.op)
     bounds = tvm.te.schedule.InferBound(s)
@@ -42,21 +43,23 @@ def test_storage_share():
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
     assert num_alloc[0] == 1
 
+
 def register_mem(scope_tb, max_bits):
-    #Register mem
+    # Register mem
     @tvm.register_func("tvm.info.mem.%s" % scope_tb)
     def mem_info_inp_buffer():
-        return tvm.ir.make_node("MemoryInfo",
-                        unit_bits= 16,
-                        max_simd_bits=32,
-                        max_num_bits=max_bits,
-                        head_address=None)
+        return tvm.ir.make_node(
+            "MemoryInfo", unit_bits=16, max_simd_bits=32, max_num_bits=max_bits, head_address=None
+        )
+
 
 def test_alloc_seq():
     scope_tb = "local.L0A"
@@ -80,35 +83,38 @@ def test_alloc_seq():
     body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 200
+
     tvm.tir.stmt_functor.post_order_visit(body, verify)
     assert num_alloc[0] == 1
 
+
 def test_alloc_different_dtypes():
     def stmt_generater(dtype_list, length):
         ib = tvm.tir.ir_builder.create()
         base_dtype = dtype_list[0]
-        global_a = te.placeholder((length,), name = "global_a", dtype = base_dtype)
+        global_a = te.placeholder((length,), name="global_a", dtype=base_dtype)
         assert len(dtype_list) == 4
         with ib.for_range(0, length, name="j") as j:
             dtype = dtype_list[0]
             A = ib.allocate(dtype, length, name="A", scope="local.L0A")
-            A[j] = tvm.tir.const(1, dtype = dtype)
+            A[j] = tvm.tir.const(1, dtype=dtype)
         with ib.for_range(0, length, name="j") as j:
             dtype = dtype_list[1]
             B = ib.allocate(dtype, length, name="B", scope="local.L0A")
-            B[j] = tvm.tir.const(1, dtype = dtype)
+            B[j] = tvm.tir.const(1, dtype=dtype)
         with ib.for_range(0, length, name="j") as j:
             dtype = dtype_list[2]
             C = ib.allocate(dtype, length, name="C", scope="local.L0A")
-            C[j] = tvm.tir.const(1, dtype = dtype)
+            C[j] = tvm.tir.const(1, dtype=dtype)
         with ib.for_range(0, length, name="j") as j:
             dtype = dtype_list[3]
             D = ib.allocate(dtype, length, name="D", scope="local.L0A")
-            D[j] = tvm.tir.const(1, dtype = dtype)
+            D[j] = tvm.tir.const(1, dtype=dtype)
         with ib.for_range(0, length, name="j") as j:
             dtype = "int8"
             E = ib.allocate(dtype, length, name="E", scope="local.L0A")
@@ -157,11 +163,11 @@ def test_alloc_different_dtypes():
 
 def test_inplace_rule():
     m = 10
-    A = te.placeholder((m,), name='A')
-    A0 = te.compute((m,), lambda i: A[i], name='A0')
-    A1 = te.compute((m,), lambda i: A[i] + 1, name='A1')
-    AA =  te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name='AA')
-    B = te.compute((m,), lambda i: AA[i] + 1, name='B')
+    A = te.placeholder((m,), name="A")
+    A0 = te.compute((m,), lambda i: A[i], name="A0")
+    A1 = te.compute((m,), lambda i: A[i] + 1, name="A1")
+    AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA")
+    B = te.compute((m,), lambda i: AA[i] + 1, name="B")
     s = te.create_schedule(B.op)
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
@@ -178,21 +184,23 @@ def test_inplace_rule():
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
     assert num_alloc[0] == 2
 
 
 def test_storage_combine():
     n = 8
-    A = te.placeholder((4,), name='A')
+    A = te.placeholder((4,), name="A")
     num_stage = 5
     B = A
     stages = []
     for t in range(num_stage):
-        B = te.compute((n, ), lambda i: B[i] + B[0] + (t+1), name='A%d' % t)
+        B = te.compute((n,), lambda i: B[i] + B[0] + (t + 1), name="A%d" % t)
         stages.append(B)
 
     s = te.create_schedule(B.op)
@@ -210,29 +218,31 @@ def test_storage_combine():
     stmt = mod["main"].body
 
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
-            assert (n.extents[0].value == 16)
+            assert n.extents[0].value == 16
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
     assert num_alloc[0] == 1
 
 
 def test_storage_share_gpu():
-    m = te.var('m')
-    A = [te.placeholder((m), name='A')]
+    m = te.var("m")
+    A = [te.placeholder((m), name="A")]
     num_stage = 5
     for t in range(num_stage):
-        A.append(te.compute((m,), lambda i: A[-1][i] + (t+1), name='A%d_s' % t))
-        A.append(te.compute((m,), lambda i: A[-1][i], name='A%d' % t))
+        A.append(te.compute((m,), lambda i: A[-1][i] + (t + 1), name="A%d_s" % t))
+        A.append(te.compute((m,), lambda i: A[-1][i], name="A%d" % t))
     s = te.create_schedule(A[-1].op)
     for t in range(num_stage):
-        x = A[2*t+2].op.axis[0]
-        bx, tx = s[A[2*t+2]].split(x, factor=32)
-        s[A[2*t+2]].bind(bx, te.thread_axis("blockIdx.x"))
-        s[A[2*t+2]].bind(tx, te.thread_axis("threadIdx.x"))
-        s[A[2*t+1]].compute_at(s[A[2*t+2]], tx)
-        s[A[2*t+1]].set_scope("shared")
+        x = A[2 * t + 2].op.axis[0]
+        bx, tx = s[A[2 * t + 2]].split(x, factor=32)
+        s[A[2 * t + 2]].bind(bx, te.thread_axis("blockIdx.x"))
+        s[A[2 * t + 2]].bind(tx, te.thread_axis("threadIdx.x"))
+        s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx)
+        s[A[2 * t + 1]].set_scope("shared")
 
     bounds = tvm.te.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
@@ -250,10 +260,12 @@ def test_storage_share_gpu():
         if isinstance(n, tvm.tir.AttrStmt):
             if n.attr_key == "storage_scope":
                 alloc_stats[n.value.value] += 1
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
     assert alloc_stats["global"] == 2
     assert alloc_stats["shared"] == num_stage
 
+
 def test_parallel_alloc():
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
@@ -266,14 +278,14 @@ def test_parallel_alloc():
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
     body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
-    assert (isinstance(body.body.body, tvm.tir.Allocate))
+    assert isinstance(body.body.body, tvm.tir.Allocate)
 
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
     with ib.for_range(0, n, name="t") as i:
         ib.scope_attr(
-            tvm.tir.const(1, "int32") , "pragma_scope",
-            tvm.tir.StringImm("parallel_launch_point"))
+            tvm.tir.const(1, "int32"), "pragma_scope", tvm.tir.StringImm("parallel_launch_point")
+        )
         with ib.for_range(0, n, name="i", for_type="parallel") as i:
             with ib.for_range(0, 10, name="j") as j:
                 A = ib.allocate("float32", n, name="A", scope="global")
@@ -283,19 +295,20 @@ def test_parallel_alloc():
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
     body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
-    assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
+    assert isinstance(body.body.body.body.body, tvm.tir.Allocate)
 
-def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
-    #Test Buffer
+
+def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024):
+    # Test Buffer
     register_mem(scope_tb, max_bits)
     m = 10
-    A = te.placeholder((m,), name='A')
-    C = te.placeholder((m,), name='C')
-    D = te.placeholder((m,), name='D')
-    A0 = te.compute((m,), lambda i: A[i] + C[i], name='A0')
-    A1 = te.compute((m,), lambda i: D[i] * D[i], name='A1')
-    A2 = te.compute((m,), lambda i: A0[i] + A1[i], name='A2')
-    B = te.compute((m,), lambda i: A2[i], name='B')
+    A = te.placeholder((m,), name="A")
+    C = te.placeholder((m,), name="C")
+    D = te.placeholder((m,), name="D")
+    A0 = te.compute((m,), lambda i: A[i] + C[i], name="A0")
+    A1 = te.compute((m,), lambda i: D[i] * D[i], name="A1")
+    A2 = te.compute((m,), lambda i: A0[i] + A1[i], name="A2")
+    B = te.compute((m,), lambda i: A2[i], name="B")
     s = te.create_schedule(B.op)
     A0L = s.cache_read(A0, scope_tb, [A2])
     A1L = s.cache_read(A1, scope_tb, [A2])
@@ -315,12 +328,15 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
     # verify only have one allocations.
     # verify inplace folding works
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
     assert num_alloc[0] == 2
 
+
 def test_exceed_mem():
     max_bits = 639
     # The critical max_num_bits is between 639 and 640
@@ -329,36 +345,37 @@ def test_exceed_mem():
         test_inplace_rule2("local_TEM", max_bits)
     except Exception as e:
         estr = str(e)
-        loc = estr.find('Allocation exceed bound of memory')
+        loc = estr.find("Allocation exceed bound of memory")
         assert loc != -1
 
+
 def test_inplace_rule3():
-    #Test Buffer
+    # Test Buffer
     scope_tb = "local_TB3"
-    max_bits=1024 * 1024 * 1024
+    max_bits = 1024 * 1024 * 1024
 
     register_mem(scope_tb, max_bits)
     m = 10
-    B0 = te.placeholder((m,), name='B0')
-    B1 = te.placeholder((m,), name='B1')
-    B2 = te.placeholder((m,), name='B2')
-    B3 = te.placeholder((m,), name='B3')
-    B4 = te.placeholder((m,), name='B4')
-    B5 = te.placeholder((m,), name='B5')
-
-    B6 = te.compute((m,), lambda i: B1[i] * B5[i], name='B6')
-    B7 = te.compute((m,), lambda i: B2[i] * B4[i], name='B7')
-    B8 = te.compute((m,), lambda i: B6[i] - B7[i], name='B8')
-
-    B9 = te.compute((m,), lambda i: B2[i] * B3[i], name='B9')
-    B10 = te.compute((m,), lambda i: B0[i] * B5[i], name='B10')
-    B11 = te.compute((m,), lambda i: B9[i] - B10[i], name='B11')
-
-    B12 = te.compute((m,), lambda i: B0[i] * B4[i], name='B12')
-    B13 = te.compute((m,), lambda i: B1[i] * B3[i], name='B13')
-    B14 = te.compute((m,), lambda i: B12[i] - B13[i], name='B14')
-
-    B = te.compute((m,), lambda i: B8[i] * B11[i] + B14[i], name='B')
+    B0 = te.placeholder((m,), name="B0")
+    B1 = te.placeholder((m,), name="B1")
+    B2 = te.placeholder((m,), name="B2")
+    B3 = te.placeholder((m,), name="B3")
+    B4 = te.placeholder((m,), name="B4")
+    B5 = te.placeholder((m,), name="B5")
+
+    B6 = te.compute((m,), lambda i: B1[i] * B5[i], name="B6")
+    B7 = te.compute((m,), lambda i: B2[i] * B4[i], name="B7")
+    B8 = te.compute((m,), lambda i: B6[i] - B7[i], name="B8")
+
+    B9 = te.compute((m,), lambda i: B2[i] * B3[i], name="B9")
+    B10 = te.compute((m,), lambda i: B0[i] * B5[i], name="B10")
+    B11 = te.compute((m,), lambda i: B9[i] - B10[i], name="B11")
+
+    B12 = te.compute((m,), lambda i: B0[i] * B4[i], name="B12")
+    B13 = te.compute((m,), lambda i: B1[i] * B3[i], name="B13")
+    B14 = te.compute((m,), lambda i: B12[i] - B13[i], name="B14")
+
+    B = te.compute((m,), lambda i: B8[i] * B11[i] + B14[i], name="B")
     s = te.create_schedule(B.op)
 
     B1L = s.cache_read(B1, scope_tb, [B6, B13])
@@ -393,8 +410,7 @@ def test_inplace_rule3():
     assert isinstance(bounds, tvm.container.Map)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(
-        [B0, B1, B2, B3, B4, B5, B], stmt, None)
+    func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, B], stmt, None)
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
@@ -407,8 +423,10 @@ def test_inplace_rule3():
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             assert n.extents[0].value == 70
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
 
+
 def test_alloc_seq_type():
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
@@ -433,16 +451,19 @@ def test_alloc_seq_type():
     body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 500
+
     tvm.tir.stmt_functor.post_order_visit(body, verify)
     assert num_alloc[0] == 1
 
+
 def test_alloc_seq_type2():
     scope_tb = "local.L0A2"
-    max_bits=1024 * 1024 * 1024
+    max_bits = 1024 * 1024 * 1024
 
     register_mem(scope_tb, max_bits)
 
@@ -465,10 +486,12 @@ def test_alloc_seq_type2():
     body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
 
     num_alloc = [0]
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 200
+
     tvm.tir.stmt_functor.post_order_visit(body, verify)
     assert num_alloc[0] == 1
 
@@ -502,16 +525,18 @@ def test_reuse_small_buffer():
         if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 800
+
     tvm.tir.stmt_functor.post_order_visit(body, verify)
     assert num_alloc[0] == 1
 
+
 def test_replace_dataflow():
     shape = (255,)
-    A = te.placeholder(shape, name = "A")
-    B = te.compute(shape, lambda i: A[i] + A[i], name = "B")
-    C = te.compute(shape, lambda i: A[i] + B[i], name = "C")
-    D = te.compute(shape, lambda i: A[i] + C[i], name = "D")
-    E = te.compute(shape, lambda i: A[i] + D[i], name = "E")
+    A = te.placeholder(shape, name="A")
+    B = te.compute(shape, lambda i: A[i] + A[i], name="B")
+    C = te.compute(shape, lambda i: A[i] + B[i], name="C")
+    D = te.compute(shape, lambda i: A[i] + C[i], name="D")
+    E = te.compute(shape, lambda i: A[i] + D[i], name="E")
 
     s = te.create_schedule(E.op)
     s.cache_read(A, "local", [B, C, D, E])
@@ -523,7 +548,7 @@ def test_large_input():
     @te.hybrid.script
     def compute(a, b):
         n = 16384
-        c = output_tensor((n, n), 'int32')
+        c = output_tensor((n, n), "int32")
         for i in range(n):
             for j in range(n):
                 c[i, j] = a[i, j] - b[i, j]
@@ -531,15 +556,17 @@ def test_large_input():
 
     n = 16384
     shape = (n, n)
-    a = te.placeholder(shape, name='a', dtype='int32')
-    b = te.placeholder(shape, name='b', dtype='int32')
+    a = te.placeholder(shape, name="a", dtype="int32")
+    b = te.placeholder(shape, name="b", dtype="int32")
     c = te.compute(shape, lambda i, j: compute(a, b)[i, j])
     c = te.compute(shape, lambda i, j: 1 + c[i, j])
     s = te.create_schedule(c.op)
     stmt = tvm.lower(s, [a, b, c])["main"].body
+
     def verify(n):
         if isinstance(n, tvm.tir.Allocate):
             assert n.extents[0].value == 268435456
+
     tvm.tir.stmt_functor.post_order_visit(stmt, verify)
 
 
index e87302c..030c017 100644 (file)
@@ -21,12 +21,12 @@ import tvm.testing
 
 @tvm.testing.requires_cuda
 def test_thread_storage_sync():
-    m = te.size_var('m')
-    l = te.size_var('l')
-    A = te.placeholder((m, l), name='A')
+    m = te.size_var("m")
+    l = te.size_var("l")
+    A = te.placeholder((m, l), name="A")
 
-    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
-    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
+    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
 
     s = te.create_schedule(A2.op)
     xo, xi = s[A2].split(A2.op.axis[0], factor=8)
@@ -44,16 +44,16 @@ def test_thread_storage_sync():
 
     cuda_target = tvm.target.Target("cuda")
 
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
-        "global_symbol": "test", "target": cuda_target}))(mod._move())
+    mod = tvm.tir.transform.Apply(
+        lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
+    )(mod._move())
 
     fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
     mod = tvm.IRModule.from_expr(fdevice)
     cuda_target = tvm.target.Target("cuda")
     f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"]
     body_list = tvm.tir.stmt_list(f.body.body.body.body)
-    assert(body_list[1].value.op.same_as(
-        tvm.ir.Op.get("tir.tvm_storage_sync")))
+    assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
 
 
 if __name__ == "__main__":
index 6863994..57b7810 100644 (file)
@@ -21,9 +21,9 @@ import os
 
 def test_unroll_loop():
     ib = tvm.tir.ir_builder.create()
-    dtype = 'int64'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
+    dtype = "int64"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
     Aptr = ib.buffer_ptr(Ab)
     # for i in 0 to n-1:
     with ib.for_range(n, n + 2, name="i") as i:
@@ -43,9 +43,9 @@ def test_unroll_loop():
         ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
         assert isinstance(ret, tvm.tir.For)
 
-    with tvm.transform.PassContext(config={
-            "tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}}
+    ):
         ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
         assert isinstance(ret, tvm.tir.For)
         assert ret.for_type == tvm.tir.For.Unrolled
@@ -58,24 +58,25 @@ def test_unroll_loop():
     assert isinstance(ret, tvm.tir.For)
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
 
-    with tvm.transform.PassContext(config={
-            "tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}
-    }):
+    with tvm.transform.PassContext(
+        config={"tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}}
+    ):
         ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
         assert isinstance(ret[0], tvm.tir.For)
         assert ret[0].for_type == tvm.tir.For.Unrolled
         assert isinstance(ret[1], tvm.tir.For)
         assert ret[1].for_type != tvm.tir.For.Unrolled
 
+
 def test_unroll_fake_loop():
     ib = tvm.tir.ir_builder.create()
-    dtype = 'int32'
-    n = te.size_var('n')
-    Ab = tvm.tir.decl_buffer((n, ), dtype)
+    dtype = "int32"
+    n = te.size_var("n")
+    Ab = tvm.tir.decl_buffer((n,), dtype)
     Aptr = ib.buffer_ptr(Ab)
     # for i in 0 to n-1:
     with ib.for_range(0, 1, name="i") as i:
-        Aptr[i*2] = 3
+        Aptr[i * 2] = 3
         with ib.for_range(0, 10, name="j") as j:
             Aptr[j + 1] = Aptr[i] + 1
 
@@ -83,19 +84,19 @@ def test_unroll_fake_loop():
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
 
-    with tvm.transform.PassContext(config={
-            "tir.UnrollLoop": {
-                "auto_max_depth": 8,
-                "auto_max_extent": 1,
-                "explicit_unroll": False
-            }}):
+    with tvm.transform.PassContext(
+        config={
+            "tir.UnrollLoop": {"auto_max_depth": 8, "auto_max_extent": 1, "explicit_unroll": False}
+        }
+    ):
         ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
         assert isinstance(ret[0], tvm.tir.Store)
 
+
 def test_unroll_single_count_loops():
-    n = te.size_var('n')
-    A = te.placeholder((n,), name='A')
-    B = te.compute((n,), lambda *i: A(*i), name='B')
+    n = te.size_var("n")
+    A = te.placeholder((n,), name="A")
+    B = te.compute((n,), lambda *i: A(*i), name="B")
     s = te.create_schedule(B.op)
     s = s.normalize()
     dom_map = tvm.te.schedule.InferBound(s)
@@ -104,12 +105,11 @@ def test_unroll_single_count_loops():
     # auto_unroll_max_extent which has been set to 1 (default:0)
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
 
-    with tvm.transform.PassContext(config={
-            "tir.UnrollLoop": {"auto_max_step": 1}
-    }):
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 1}}):
         ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
         assert ret == stmt
 
+
 if __name__ == "__main__":
     test_unroll_loop()
     test_unroll_fake_loop()
index 0516b4a..204e26f 100644 (file)
 import tvm
 from tvm import te
 
+
 def test_vectorize_loop():
-    dtype = 'int64'
-    n = te.var('n')
+    dtype = "int64"
+    n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, n) as i:
@@ -39,8 +40,8 @@ def test_vectorize_loop():
 
 
 def test_vectorize_vector():
-    dtype = 'int64'
-    n = te.var('n')
+    dtype = "int64"
+    n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32x4", name="A")
     with ib.for_range(0, n) as i:
@@ -59,8 +60,8 @@ def test_vectorize_vector():
 
 
 def test_vectorize_with_if():
-    n = te.var('n')
-    x = te.var('x')
+    n = te.var("n")
+    x = te.var("x")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, 4, for_type="vectorize") as i:
@@ -96,7 +97,7 @@ def test_vectorize_let():
 
 
 def test_vectorize_with_le_cond():
-    n = te.var('n')
+    n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, 4, for_type="vectorize") as i:
@@ -111,7 +112,7 @@ def test_vectorize_with_le_cond():
 
 
 def test_vectorize_with_ge_cond():
-    n = te.var('n')
+    n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, 4, for_type="vectorize") as i:
@@ -126,14 +127,12 @@ def test_vectorize_with_ge_cond():
 
 
 def test_vectorize_if_then_else():
-    n = te.var('n')
-    x = te.var('x')
+    n = te.var("n")
+    x = te.var("x")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, 4, for_type="vectorize") as i:
-        A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else",
-                               i > 0,
-                               A[i] + 1, A[i])
+        A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i])
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
@@ -141,14 +140,13 @@ def test_vectorize_if_then_else():
 
     assert isinstance(stmt, tvm.tir.For)
 
-
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
     with ib.for_range(0, n) as k:
         with ib.for_range(0, 4, for_type="vectorize") as i:
-            A[k * 4 + i] = tvm.tir.call_intrin("float32", "tir.if_then_else",
-                                           k > 0,
-                                           A[k * 4 + i], 0)
+            A[k * 4 + i] = tvm.tir.call_intrin(
+                "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0
+            )
     stmt = ib.get()
 
     assert isinstance(stmt.body, tvm.tir.For)
index 07f7f9a..ce9c198 100644 (file)
@@ -79,13 +79,14 @@ from tvm import autotvm
 # can be very large (at the level of 10^9 for some input shapes)
 #
 
+
 @autotvm.template("tutorial/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     assert N == 1, "Only consider batch_size = 1 in this template"
 
-    data = te.placeholder((N, CI, H, W), name='data')
-    kernel = te.placeholder((CO, CI, KH, KW), name='kernel')
-    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32')
+    data = te.placeholder((N, CI, H, W), name="data")
+    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
     s = te.create_schedule([conv.op])
 
     ##### space definition begin #####
@@ -109,13 +110,13 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     data, raw_data = pad_data, data
 
     output = conv
-    OL = s.cache_write(conv, 'local')
+    OL = s.cache_write(conv, "local")
 
     # create cache stage
-    AA = s.cache_read(data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
-    AL = s.cache_read(AA, 'local', [OL])
-    WL = s.cache_read(WW, 'local', [OL])
+    AA = s.cache_read(data, "shared", [OL])
+    WW = s.cache_read(kernel, "shared", [OL])
+    AL = s.cache_read(AA, "local", [OL])
+    WL = s.cache_read(WW, "local", [OL])
 
     # tile and bind spatial axes
     n, f, y, x = s[output].op.axis
@@ -139,9 +140,9 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     # tile reduction axes
     n, f, y, x = s[OL].op.axis
     rc, ry, rx = s[OL].op.reduce_axis
-    rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
-    rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
+    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+    ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
+    rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
     s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
 
     s[AA].compute_at(s[OL], rxo)
@@ -161,11 +162,12 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # tune unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
 
     return s, [raw_data, kernel, conv]
 
+
 ######################################################################
 # Step 2:  Search through the space
 # ---------------------------------
@@ -176,30 +178,32 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
 # for this template
 
 # logging config (for printing tuning log to screen)
-logging.getLogger('autotvm').setLevel(logging.DEBUG)
-logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
+logging.getLogger("autotvm").setLevel(logging.DEBUG)
+logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
 
 # the last layer in resnet
 N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
-task = autotvm.task.create("tutorial/conv2d_no_batching",
-                           args=(N, H, W, CO, CI, KH, KW, strides, padding),
-                           target='cuda')
+task = autotvm.task.create(
+    "tutorial/conv2d_no_batching", args=(N, H, W, CO, CI, KH, KW, strides, padding), target="cuda"
+)
 print(task.config_space)
 
 # Use local gpu, measure 10 times for every config to reduce variance
 # The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
 measure_option = autotvm.measure_option(
     builder=autotvm.LocalBuilder(),
-    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
+    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4),
 )
 
 # Begin tuning, log records to file `conv2d.log`
 # During tuning we will also try many invalid configs, so you are expected to
 # see many error reports. As long as you can see non-zero GFLOPS, it is okay.
 tuner = autotvm.tuner.XGBTuner(task)
-tuner.tune(n_trial=20,
-           measure_option=measure_option,
-           callbacks=[autotvm.callback.log_to_file('conv2d.log')])
+tuner.tune(
+    n_trial=20,
+    measure_option=measure_option,
+    callbacks=[autotvm.callback.log_to_file("conv2d.log")],
+)
 
 #########################################################################
 # Finally we can inspect the best config from log file, check correctness,
@@ -212,7 +216,7 @@ print("\nBest config:")
 print(best_config)
 
 # apply history best from log file
-with autotvm.apply_history_best('conv2d.log'):
+with autotvm.apply_history_best("conv2d.log"):
     with tvm.target.Target("cuda"):
         s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
         func = tvm.build(s, arg_bufs)
@@ -233,5 +237,4 @@ tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
 # Evaluate running time. Here we choose a large repeat number (400) to reduce the noise
 # and the overhead of kernel launch. You can also use nvprof to validate the result.
 evaluator = func.time_evaluator(func.entry_name, ctx, number=400)
-print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
-
+print("Time cost of this operator: %f" % evaluator(a_tvm, w_tvm, c_tvm).mean)
index 71b1d3e..4cfde72 100644 (file)
@@ -77,31 +77,41 @@ import tvm.contrib.graph_runtime as runtime
 # We can load some pre-defined network from :code:`relay.testing`.
 # We can also load models from MXNet, ONNX and TensorFlow.
 
+
 def get_network(name, batch_size):
     """Get the symbol definition and random weight of a network"""
     input_shape = (batch_size, 3, 224, 224)
     output_shape = (batch_size, 1000)
 
     if "resnet" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "vgg" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
-    elif name == 'mobilenet':
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif name == "mobilenet":
         mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
-    elif name == 'squeezenet_v1.1':
-        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
-    elif name == 'inception_v3':
+    elif name == "squeezenet_v1.1":
+        mod, params = relay.testing.squeezenet.get_workload(
+            batch_size=batch_size, version="1.1", dtype=dtype
+        )
+    elif name == "inception_v3":
         input_shape = (1, 3, 299, 299)
         mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'mxnet':
+    elif name == "mxnet":
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
-        block = get_model('resnet18_v1', pretrained=True)
-        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+
+        block = get_model("resnet18_v1", pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
         net = mod["main"]
-        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
         mod = tvm.IRModule.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
@@ -190,31 +200,30 @@ def get_network(name, batch_size):
 
 # Replace "aarch64-linux-gnu" with the correct target of your board.
 # This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
-target = tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu')
+target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu")
 
 # Also replace this with the device key in your tracker
-device_key = 'rk3399'
+device_key = "rk3399"
 
 # Set this to True if you use android phone
 use_android = False
 
 #### TUNING OPTION ####
-network = 'resnet-18'
+network = "resnet-18"
 log_file = "%s.%s.log" % (device_key, network)
-dtype = 'float32'
+dtype = "float32"
 
 tuning_option = {
-    'log_filename': log_file,
-
-    'tuner': 'xgb',
-    'n_trial': 1500,
-    'early_stopping': 800,
-
-    'measure_option': autotvm.measure_option(
-        builder=autotvm.LocalBuilder(
-            build_func='ndk' if use_android else 'default'),
+    "log_filename": log_file,
+    "tuner": "xgb",
+    "n_trial": 1500,
+    "early_stopping": 800,
+    "measure_option": autotvm.measure_option(
+        builder=autotvm.LocalBuilder(build_func="ndk" if use_android else "default"),
         runner=autotvm.RPCRunner(
-            device_key, host='0.0.0.0', port=9190,
+            device_key,
+            host="0.0.0.0",
+            port=9190,
             number=5,
             timeout=10,
         ),
@@ -245,31 +254,33 @@ tuning_option = {
 # We will introduce a more sophisticated tuning scheduler in the future.
 
 # You can skip the implementation of this function for this tutorial.
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True):
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
         os.remove(tmp_log_file)
 
     for i, tsk in enumerate(reversed(tasks)):
-        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'xgb_knob':
-            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "xgb_knob":
+            tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob")
+        elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=50)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
@@ -280,13 +291,15 @@ def tune_tasks(tasks,
 
         # do tuning
         tsk_trial = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial=tsk_trial,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)
-                       ])
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
@@ -296,13 +309,14 @@ def tune_tasks(tasks,
 ########################################################################
 # Finally, we launch tuning jobs and evaluate the end-to-end performance.
 
+
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
     mod, params, input_shape, _ = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
+    )
 
     # run tuning tasks
     print("Tuning...")
@@ -312,13 +326,13 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(
-                mod, target=target, params=params)
+            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
         if use_android:
             from tvm.contrib import ndk
+
             filename = "net.so"
             lib.export_library(tmp.relpath(filename), ndk.create_shared)
         else:
@@ -327,8 +341,7 @@ def tune_and_evaluate(tuning_opt):
 
         # upload module to device
         print("Upload...")
-        remote = autotvm.measure.request_remote(device_key, '0.0.0.0', 9190,
-                                                timeout=10000)
+        remote = autotvm.measure.request_remote(device_key, "0.0.0.0", 9190, timeout=10000)
         remote.upload(tmp.relpath(filename))
         rlib = remote.load_module(filename)
 
@@ -336,15 +349,18 @@ def tune_and_evaluate(tuning_opt):
         ctx = remote.context(str(target), 0)
         module = runtime.create(graph, rlib, ctx)
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-        module.set_input('data', data_tvm)
+        module.set_input("data", data_tvm)
         module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
         ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
         prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
+
 
 # We do not run the tuning in our webpage server since it takes too long.
 # Uncomment the following line to run it by yourself.
index 6a995af..64be5eb 100644 (file)
@@ -75,37 +75,48 @@ import tvm.contrib.graph_runtime as runtime
 # We can load some pre-defined network from :code:`tvm.relay.testing`.
 # We can also load models from MXNet, ONNX and TensorFlow.
 
+
 def get_network(name, batch_size):
     """Get the symbol definition and random weight of a network"""
     input_shape = (batch_size, 3, 224, 224)
     output_shape = (batch_size, 1000)
 
     if "resnet" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "vgg" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
-    elif name == 'mobilenet':
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif name == "mobilenet":
         mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'squeezenet_v1.1':
-        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
-    elif name == 'inception_v3':
+    elif name == "squeezenet_v1.1":
+        mod, params = relay.testing.squeezenet.get_workload(
+            batch_size=batch_size, version="1.1", dtype=dtype
+        )
+    elif name == "inception_v3":
         input_shape = (1, 3, 299, 299)
         mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'mxnet':
+    elif name == "mxnet":
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
-        block = get_model('resnet18_v1', pretrained=True)
-        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+
+        block = get_model("resnet18_v1", pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
         net = mod["main"]
-        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
         mod = tvm.IRModule.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
     return mod, params, input_shape, output_shape
 
+
 ###########################################
 # Set Tuning Options
 # ------------------
@@ -115,18 +126,16 @@ def get_network(name, batch_size):
 target = tvm.target.cuda()
 
 #### TUNING OPTION ####
-network = 'resnet-18'
+network = "resnet-18"
 log_file = "%s.log" % network
-dtype = 'float32'
+dtype = "float32"
 
 tuning_option = {
-    'log_filename': log_file,
-
-    'tuner': 'xgb',
-    'n_trial': 2000,
-    'early_stopping': 600,
-
-    'measure_option': autotvm.measure_option(
+    "log_filename": log_file,
+    "tuner": "xgb",
+    "n_trial": 2000,
+    "early_stopping": 600,
+    "measure_option": autotvm.measure_option(
         builder=autotvm.LocalBuilder(timeout=10),
         runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150),
     ),
@@ -154,29 +163,31 @@ tuning_option = {
 # We will introduce a more sophisticated tuning scheduler in the future.
 
 # You can skip the implementation of this function for this tutorial.
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True):
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
         os.remove(tmp_log_file)
 
     for i, tsk in enumerate(reversed(tasks)):
-        prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=100)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
@@ -187,13 +198,15 @@ def tune_tasks(tasks,
 
         # do tuning
         tsk_trial = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial=tsk_trial,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)
-                       ])
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
@@ -203,13 +216,14 @@ def tune_tasks(tasks,
 ########################################################################
 # Finally, we launch tuning jobs and evaluate the end-to-end performance.
 
+
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
     mod, params, input_shape, out_shape = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
+    )
 
     # run tuning tasks
     print("Tuning...")
@@ -219,8 +233,7 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(
-                mod, target=target, params=params)
+            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
@@ -231,15 +244,18 @@ def tune_and_evaluate(tuning_opt):
         ctx = tvm.context(str(target), 0)
         module = runtime.create(graph, lib, ctx)
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-        module.set_input('data', data_tvm)
+        module.set_input("data", data_tvm)
         module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
         ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
         prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
+
 
 # We do not run the tuning in our webpage server since it takes too long.
 # Uncomment the following line to run it by yourself.
@@ -358,17 +374,20 @@ def tune_and_evaluate(tuning_opt):
 # to replace the corresponding part above.
 
 tuning_option = {
-    'log_filename': log_file,
-
-    'tuner': 'xgb',
-    'n_trial': 2000,
-    'early_stopping': 600,
-
-    'measure_option': autotvm.measure_option(
+    "log_filename": log_file,
+    "tuner": "xgb",
+    "n_trial": 2000,
+    "early_stopping": 600,
+    "measure_option": autotvm.measure_option(
         builder=autotvm.LocalBuilder(timeout=10),
         runner=autotvm.RPCRunner(
-            '1080ti',  # change the device key to your key
-            '0.0.0.0', 9190,
-            number=20, repeat=3, timeout=4, min_repeat_ms=150),
+            "1080ti",  # change the device key to your key
+            "0.0.0.0",
+            9190,
+            number=20,
+            repeat=3,
+            timeout=4,
+            min_repeat_ms=150,
+        ),
     ),
 }
index 27e4bd6..1fa2326 100644 (file)
@@ -76,31 +76,41 @@ import tvm.contrib.graph_runtime as runtime
 # We can load some pre-defined network from :code:`relay.testing`.
 # We can also load models from MXNet, ONNX and TensorFlow.
 
+
 def get_network(name, batch_size):
     """Get the symbol definition and random weight of a network"""
     input_shape = (batch_size, 3, 224, 224)
     output_shape = (batch_size, 1000)
 
     if "resnet" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "vgg" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
-    elif name == 'mobilenet':
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif name == "mobilenet":
         mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'squeezenet_v1.1':
-        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
-    elif name == 'inception_v3':
+    elif name == "squeezenet_v1.1":
+        mod, params = relay.testing.squeezenet.get_workload(
+            batch_size=batch_size, version="1.1", dtype=dtype
+        )
+    elif name == "inception_v3":
         input_shape = (1, 3, 299, 299)
         mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'mxnet':
+    elif name == "mxnet":
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
-        block = get_model('resnet18_v1', pretrained=True)
-        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+
+        block = get_model("resnet18_v1", pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
         net = mod["main"]
-        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
         mod = tvm.IRModule.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
@@ -187,35 +197,34 @@ def get_network(name, batch_size):
 
 #### DEVICE CONFIG ####
 
-target = tvm.target.Target('opencl -device=mali')
+target = tvm.target.Target("opencl -device=mali")
 
 # Replace "aarch64-linux-gnu" with the correct target of your board.
 # This target host is used for cross compilation. You can query it by :code:`gcc -v` on your device.
-target_host = 'llvm -mtriple=aarch64-linux-gnu'
+target_host = "llvm -mtriple=aarch64-linux-gnu"
 
 # Also replace this with the device key in your tracker
-device_key = 'rk3399'
+device_key = "rk3399"
 
 # Set this to True if you use android phone
 use_android = False
 
 #### TUNING OPTION ####
-network = 'resnet-18'
+network = "resnet-18"
 log_file = "%s.%s.log" % (device_key, network)
-dtype = 'float32'
+dtype = "float32"
 
 tuning_option = {
-    'log_filename': log_file,
-
-    'tuner': 'xgb',
-    'n_trial': 1000,
-    'early_stopping': 450,
-
-    'measure_option': autotvm.measure_option(
-        builder=autotvm.LocalBuilder(
-            build_func='ndk' if use_android else 'default'),
+    "log_filename": log_file,
+    "tuner": "xgb",
+    "n_trial": 1000,
+    "early_stopping": 450,
+    "measure_option": autotvm.measure_option(
+        builder=autotvm.LocalBuilder(build_func="ndk" if use_android else "default"),
         runner=autotvm.RPCRunner(
-            device_key, host='0.0.0.0', port=9190,
+            device_key,
+            host="0.0.0.0",
+            port=9190,
             number=10,
             timeout=5,
         ),
@@ -242,29 +251,31 @@ tuning_option = {
 # We will introduce a more sophisticated tuning scheduler in the future.
 
 # You can skip the implementation of this function for this tutorial.
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True):
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
         os.remove(tmp_log_file)
 
     for i, tsk in enumerate(reversed(tasks)):
-        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=50)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
@@ -275,13 +286,15 @@ def tune_tasks(tasks,
 
         # do tuning
         tsk_trial = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial=tsk_trial,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)
-                       ])
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
@@ -291,15 +304,18 @@ def tune_tasks(tasks,
 ########################################################################
 # Finally, we launch tuning jobs and evaluate the end-to-end performance.
 
+
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
     mod, params, input_shape, _ = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod["main"],
-                                              target=target,
-                                              target_host=target_host,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"],
+        target=target,
+        target_host=target_host,
+        params=params,
+        ops=(relay.op.get("nn.conv2d"),),
+    )
 
     # run tuning tasks
     print("Tuning...")
@@ -310,11 +326,13 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
             graph, lib, params = relay.build_module.build(
-                mod, target=target, params=params, target_host=target_host)
+                mod, target=target, params=params, target_host=target_host
+            )
         # export library
         tmp = tempdir()
         if use_android:
             from tvm.contrib import ndk
+
             filename = "net.so"
             lib.export_library(tmp.relpath(filename), ndk.create_shared)
         else:
@@ -323,8 +341,7 @@ def tune_and_evaluate(tuning_opt):
 
         # upload module to device
         print("Upload...")
-        remote = autotvm.measure.request_remote(device_key, '0.0.0.0', 9190,
-                                                timeout=10000)
+        remote = autotvm.measure.request_remote(device_key, "0.0.0.0", 9190, timeout=10000)
         remote.upload(tmp.relpath(filename))
         rlib = remote.load_module(filename)
 
@@ -332,15 +349,18 @@ def tune_and_evaluate(tuning_opt):
         ctx = remote.context(str(target), 0)
         module = runtime.create(graph, rlib, ctx)
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
-        module.set_input('data', data_tvm)
+        module.set_input("data", data_tvm)
         module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
         ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=30)
         prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
+
 
 # We do not run the tuning in our webpage server since it takes too long.
 # Uncomment the following line to run it by yourself.
index 92fdafb..8816824 100644 (file)
@@ -53,25 +53,34 @@ def get_network(name, batch_size):
     output_shape = (batch_size, 1000)
 
     if "resnet" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
     elif "vgg" in name:
-        n_layer = int(name.split('-')[1])
-        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
-    elif name == 'mobilenet':
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif name == "mobilenet":
         mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'squeezenet_v1.1':
-        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
-    elif name == 'inception_v3':
+    elif name == "squeezenet_v1.1":
+        mod, params = relay.testing.squeezenet.get_workload(
+            batch_size=batch_size, version="1.1", dtype=dtype
+        )
+    elif name == "inception_v3":
         input_shape = (1, 3, 299, 299)
         mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == 'mxnet':
+    elif name == "mxnet":
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
-        block = get_model('resnet18_v1', pretrained=True)
+
+        block = get_model("resnet18_v1", pretrained=True)
         mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
         net = mod["main"]
-        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
         mod = tvm.IRModule.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
@@ -121,55 +130,57 @@ os.environ["TVM_NUM_THREADS"] = str(num_threads)
 # latency of one operator closer to its actual latency during end-to-end inference.
 
 tuning_option = {
-    'log_filename': log_file,
-    'tuner': 'random',
-    'early_stopping': None,
-
-    'measure_option': autotvm.measure_option(
+    "log_filename": log_file,
+    "tuner": "random",
+    "early_stopping": None,
+    "measure_option": autotvm.measure_option(
         builder=autotvm.LocalBuilder(),
-        runner=autotvm.LocalRunner(number=1, repeat=10,
-                                   min_repeat_ms=0,
-                                   enable_cpu_cache_flush=True),
+        runner=autotvm.LocalRunner(
+            number=1, repeat=10, min_repeat_ms=0, enable_cpu_cache_flush=True
+        ),
     ),
 }
 
 
 # You can skip the implementation of this function for this tutorial.
-def tune_kernels(tasks,
-                 measure_option,
-                 tuner='gridsearch',
-                 early_stopping=None,
-                 log_filename='tuning.log'):
+def tune_kernels(
+    tasks, measure_option, tuner="gridsearch", early_stopping=None, log_filename="tuning.log"
+):
 
     for i, task in enumerate(tasks):
-        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(task, loss_type='rank')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(task, loss_type="rank")
+        elif tuner == "ga":
             tuner_obj = GATuner(task, pop_size=50)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(task)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(task)
         else:
             raise ValueError("Invalid tuner: " + tuner)
 
         # do tuning
-        n_trial=len(task.config_space)
-        tuner_obj.tune(n_trial=n_trial,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
-                           autotvm.callback.log_to_file(log_filename)])
+        n_trial = len(task.config_space)
+        tuner_obj.tune(
+            n_trial=n_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(n_trial, prefix=prefix),
+                autotvm.callback.log_to_file(log_filename),
+            ],
+        )
 
 
 # Use graph tuner to achieve graph level optimal schedules
 # Set use_DP=False if it takes too long to finish.
 def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
-    target_op = [relay.op.get("nn.conv2d"),]
+    target_op = [
+        relay.op.get("nn.conv2d"),
+    ]
     Tuner = DPTuner if use_DP else PBQPTuner
     executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
     executor.benchmark_layout_transform(min_exec_num=2000)
@@ -180,13 +191,14 @@ def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
 ########################################################################
 # Finally, we launch tuning jobs and evaluate the end-to-end performance.
 
+
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
     mod, params, data_shape, out_shape = get_network(model_name, batch_size)
-    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),))
+    tasks = autotvm.task.extract_from_program(
+        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
+    )
 
     # run tuning tasks
     tune_kernels(tasks, **tuning_opt)
@@ -196,8 +208,7 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_graph_best(graph_opt_sch_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(
-                mod, target=target, params=params)
+            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
 
         # upload parameters to device
         ctx = tvm.cpu()
@@ -210,8 +221,11 @@ def tune_and_evaluate(tuning_opt):
         print("Evaluate inference time cost...")
         ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3)
         prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
+
 
 # We do not run the tuning in our webpage server since it takes too long.
 # Uncomment the following line to run it by yourself.
index fd22eec..357abf1 100644 (file)
@@ -71,11 +71,11 @@ from tvm import autotvm
 
 # Matmul V0: Constant tiling factor
 def matmul_v0(N, L, M, dtype):
-    A = te.placeholder((N, L), name='A', dtype=dtype)
-    B = te.placeholder((L, M), name='B', dtype=dtype)
+    A = te.placeholder((N, L), name="A", dtype=dtype)
+    B = te.placeholder((L, M), name="B", dtype=dtype)
 
-    k = te.reduce_axis((0, L), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')
+    k = te.reduce_axis((0, L), name="k")
+    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
     s = te.create_schedule(C.op)
 
     # schedule
@@ -89,6 +89,7 @@ def matmul_v0(N, L, M, dtype):
 
     return s, [A, B, C]
 
+
 #####################################################################
 # Parametrize the schedule
 # ^^^^^^^^^^^^^^^^^^^^^^^^
@@ -105,11 +106,11 @@ def matmul_v0(N, L, M, dtype):
 # Matmul V1: List candidate values
 @autotvm.template("tutorial/matmul_v1")  # 1. use a decorator
 def matmul_v1(N, L, M, dtype):
-    A = te.placeholder((N, L), name='A', dtype=dtype)
-    B = te.placeholder((L, M), name='B', dtype=dtype)
+    A = te.placeholder((N, L), name="A", dtype=dtype)
+    B = te.placeholder((L, M), name="B", dtype=dtype)
 
-    k = te.reduce_axis((0, L), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')
+    k = te.reduce_axis((0, L), name="k")
+    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
     s = te.create_schedule(C.op)
 
     # schedule
@@ -124,13 +125,14 @@ def matmul_v1(N, L, M, dtype):
     cfg.define_knob("tile_x", [1, 2, 4, 8, 16])
 
     # 4. schedule according to config
-    yo, yi = s[C].split(y, cfg['tile_y'].val)
-    xo, xi = s[C].split(x, cfg['tile_x'].val)
+    yo, yi = s[C].split(y, cfg["tile_y"].val)
+    xo, xi = s[C].split(x, cfg["tile_x"].val)
 
     s[C].reorder(yo, xo, k, yi, xi)
 
     return s, [A, B, C]
 
+
 ###############################################################################
 # Here we make four modifications to the previous schedule code and get
 # a tunable "template". We can explain the modifications one by one.
@@ -183,13 +185,14 @@ def matmul_v1(N, L, M, dtype):
 # When the high level API cannot meet your requirement, you can always fall
 # back to use low level API.
 
+
 @autotvm.template("tutorial/matmul")
 def matmul(N, L, M, dtype):
-    A = te.placeholder((N, L), name='A', dtype=dtype)
-    B = te.placeholder((L, M), name='B', dtype=dtype)
+    A = te.placeholder((N, L), name="A", dtype=dtype)
+    B = te.placeholder((L, M), name="B", dtype=dtype)
 
-    k = te.reduce_axis((0, L), name='k')
-    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')
+    k = te.reduce_axis((0, L), name="k")
+    C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
     s = te.create_schedule(C.op)
 
     # schedule
@@ -210,6 +213,7 @@ def matmul(N, L, M, dtype):
 
     return s, [A, B, C]
 
+
 ######################################################################
 # .. note:: More Explanation on :code:`cfg.defile_split`
 #
@@ -273,7 +277,7 @@ def matmul(N, L, M, dtype):
 # In this case, for a 512x512 square matrix multiplication, the space size
 # is 10x10=100
 N, L, M = 512, 512, 512
-task = autotvm.task.create("tutorial/matmul", args=(N, L, M, 'float32'), target='llvm')
+task = autotvm.task.create("tutorial/matmul", args=(N, L, M, "float32"), target="llvm")
 print(task.config_space)
 
 ################################################################
@@ -286,22 +290,22 @@ print(task.config_space)
 # used to get the best config later.
 
 # logging config (for printing tuning log to the screen)
-logging.getLogger('autotvm').setLevel(logging.DEBUG)
-logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
+logging.getLogger("autotvm").setLevel(logging.DEBUG)
+logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
 
 # There are two steps for measuring a config: build and run.
 # By default, we use all CPU cores to compile program. Then measure them sequentially.
 # We measure 5 times and take average to reduce variance.
-measure_option = autotvm.measure_option(
-    builder='local',
-    runner=autotvm.LocalRunner(number=5))
+measure_option = autotvm.measure_option(builder="local", runner=autotvm.LocalRunner(number=5))
 
 # Begin tuning with RandomTuner, log records to file `matmul.log`
 # You can use alternatives like XGBTuner.
 tuner = autotvm.tuner.RandomTuner(task)
-tuner.tune(n_trial=10,
-           measure_option=measure_option,
-           callbacks=[autotvm.callback.log_to_file('matmul.log')])
+tuner.tune(
+    n_trial=10,
+    measure_option=measure_option,
+    callbacks=[autotvm.callback.log_to_file("matmul.log")],
+)
 
 #########################################################################
 # Finally we apply history best from the cache file and check its correctness.
@@ -311,9 +315,9 @@ tuner.tune(n_trial=10,
 # with the same argument.
 
 # apply history best from log file
-with autotvm.apply_history_best('matmul.log'):
+with autotvm.apply_history_best("matmul.log"):
     with tvm.target.Target("llvm"):
-        s, arg_bufs = matmul(N, L, M, 'float32')
+        s, arg_bufs = matmul(N, L, M, "float32")
         func = tvm.build(s, arg_bufs)
 
 # check correctness
index 17f864f..44fe59f 100644 (file)
@@ -50,12 +50,12 @@ import numpy as np
 #
 
 n = tvm.tir.const(128, "int32")
-a = te.placeholder((n, ), name="a")
-b = te.placeholder((n, ), name="b")
-c = te.compute((n, ), lambda i: a[i] + b[i], name='c')
+a = te.placeholder((n,), name="a")
+b = te.placeholder((n,), name="b")
+c = te.compute((n,), lambda i: a[i] + b[i], name="c")
 
 sch = te.create_schedule(c.op)
-ir  = tvm.lower(sch, [a, b, c])
+ir = tvm.lower(sch, [a, b, c])
 print(ir)
 
 ######################################################################
@@ -83,6 +83,8 @@ print(ir)
 #
 
 loops = []
+
+
 def find_width8(op):
     """ Find all the 'tir.For' nodes whose extent can be divided by 8. """
     if isinstance(op, tvm.tir.For):
@@ -90,6 +92,7 @@ def find_width8(op):
             if op.extent.value % 8 == 0:
                 loops.append(op)
 
+
 #####################################################################
 # IR Transformation
 # ~~~~~~~~~~~~~~~~~
@@ -105,18 +108,20 @@ def find_width8(op):
 #     function will be skipped.
 #
 
+
 def vectorize8(op):
     """ Split can vectorize the loops found in `find_width8`. """
     if op in loops:
         extent = op.extent.value
         name = op.loop_var.name
-        lo, li = te.var(name + '.outer'), te.var(name + '.inner')
+        lo, li = te.var(name + ".outer"), te.var(name + ".inner")
         body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
         body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body)
         body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body)
         return body
     return None
 
+
 @tvm.tir.transform.prim_func_pass(opt_level=0)
 def vectorize(f, mod, ctx):
     global loops
@@ -128,8 +133,7 @@ def vectorize(f, mod, ctx):
 
     # The last list arugment indicates what kinds of nodes will be transformed.
     # Thus, in this case only `For` nodes will call `vectorize8`
-    return f.with_body(
-        tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['tir.For']))
+    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))
 
 
 #####################################################################
index 0452801..b16eb93 100644 (file)
@@ -52,11 +52,12 @@ import tvm.relay as relay
 # will be used by various optimizations of the examples in this tutorial.
 # Similarly, users can write a tir primitive function and apply the tir passes.
 
+
 def example():
     shape = (1, 64, 54, 54)
     c_data = np.empty(shape).astype("float32")
     c = relay.const(c_data)
-    weight = relay.var('weight', shape=(64, 64, 3, 3))
+    weight = relay.var("weight", shape=(64, 64, 3, 3))
     x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
     conv = relay.nn.conv2d(x, weight)
     y = relay.add(c, c)
@@ -67,18 +68,21 @@ def example():
     z2 = relay.add(z, z1)
     return relay.Function([x, weight], z2)
 
+
 ###############################################################################
 # Let us register layout alteration for a conv2d op so that we can apply the
 # layout alteration pass on the example. How alter layout pass works is out
 # the scope of this tutorial.
 
+
 @relay.op.register_alter_op_layout("nn.conv2d", level=101)
 def alter_conv2d(attrs, inputs, tinfos, out_type):
     data, weight = inputs
     new_attrs = dict(attrs)
-    new_attrs['data_layout'] = 'NCHW16c'
+    new_attrs["data_layout"] = "NCHW16c"
     return relay.nn.conv2d(data, weight, **new_attrs)
 
+
 ###############################################################################
 # Optimize the Program
 # --------------------
@@ -148,9 +152,13 @@ print(mod)
 f = example()
 mod = tvm.IRModule.from_expr(f)
 # Glob the interested passes.
-seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
-                                  relay.transform.EliminateCommonSubexpr(),
-                                  relay.transform.FuseOps(fuse_opt_level=2)])
+seq = tvm.transform.Sequential(
+    [
+        relay.transform.FoldConstant(),
+        relay.transform.EliminateCommonSubexpr(),
+        relay.transform.FuseOps(fuse_opt_level=2),
+    ]
+)
 mod1 = seq(mod)
 print(mod1)
 
@@ -207,6 +215,7 @@ print(mod5)
 # visited and each constant in the function will be replaced when we invoke the
 # customized pass.
 
+
 @relay.transform.function_pass(opt_level=1)
 class CustomPipeline:
     """Simple test function to replace one argument to another."""
@@ -221,8 +230,10 @@ class CustomPipeline:
         class ReplaceConstant(tvm.relay.ExprMutator):
             def visit_constant(self, c):
                 return relay.multiply(obj.multiplier, c)
+
         return ReplaceConstant().visit(func)
 
+
 f = example()
 mod = tvm.IRModule.from_expr(f)
 custom_pass = CustomPipeline(multiplier=relay.const(3, "float32"))
@@ -240,16 +251,20 @@ print(mod3)
 
 f = example()
 mod = tvm.IRModule.from_expr(f)
-seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
-                                tvm.transform.PrintIR(),
-                                relay.transform.EliminateCommonSubexpr(),
-                                relay.transform.FuseOps(),
-                                relay.transform.AlterOpLayout()])
+seq = tvm.transform.Sequential(
+    [
+        relay.transform.FoldConstant(),
+        tvm.transform.PrintIR(),
+        relay.transform.EliminateCommonSubexpr(),
+        relay.transform.FuseOps(),
+        relay.transform.AlterOpLayout(),
+    ]
+)
 
 # By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
 # dump out the module IR when ``FoldConstant`` is done. Users can plug in this
 # pass after any pass they want to debug for viewing the optimization effect.
-# 
+#
 # There is a more flexible debugging mechanism also exposed by the build configuration
 # object. One can pass a tracing function which can be used to execute arbitrary code
 # before and/or after each pass. A tracing function will receive a :py::class:`tvm.IRModule`,
@@ -257,12 +272,14 @@ seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
 # and a boolean indicating whether you are executing before, or after a pass.
 # An example is below.
 
+
 def print_ir(mod, info, is_before):
     """Print the name of the pass, the IR, only before passes execute."""
     if is_before:
         print("Running pass: {}", info)
         print(mod)
 
+
 with tvm.transform.PassContext(opt_level=3, trace=print_ir):
     with tvm.target.Target("llvm"):
         # Perform the optimizations.
index 3f4efeb..b478694 100644 (file)
@@ -46,14 +46,9 @@ import dgl
 import networkx as nx
 from dgl.nn.pytorch import GraphConv
 
+
 class GCN(nn.Module):
-    def __init__(self,
-                 g,
-                 n_infeat,
-                 n_hidden,
-                 n_classes,
-                 n_layers,
-                 activation):
+    def __init__(self, g, n_infeat, n_hidden, n_classes, n_layers, activation):
         super(GCN, self).__init__()
         self.g = g
         self.layers = nn.ModuleList()
@@ -66,7 +61,7 @@ class GCN(nn.Module):
         h = features
         for i, layer in enumerate(self.layers):
             # handle api changes for differnt DGL version
-            if dgl.__version__ > '0.3':
+            if dgl.__version__ > "0.3":
                 h = layer(self.g, h)
             else:
                 h = layer(h, self.g)
@@ -80,6 +75,7 @@ class GCN(nn.Module):
 from dgl.data import load_data
 from collections import namedtuple
 
+
 def load_dataset(dataset="cora"):
     args = namedtuple("args", ["dataset"])
     data = load_data(args(dataset))
@@ -93,7 +89,7 @@ def load_dataset(dataset="cora"):
 
 
 def evaluate(data, logits):
-    test_mask = data.test_mask # the test set which isn't included in the training phase
+    test_mask = data.test_mask  # the test set which isn't included in the training phase
 
     pred = logits.argmax(axis=1)
     acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum()
@@ -142,16 +138,11 @@ from dgl import DGLGraph
 features = torch.FloatTensor(data.features)
 dgl_g = DGLGraph(g)
 
-torch_model = GCN(dgl_g,
-                  infeat_dim,
-                  num_hidden,
-                  num_classes,
-                  num_layers,
-                  F.relu)
+torch_model = GCN(dgl_g, infeat_dim, num_hidden, num_classes, num_layers, F.relu)
 
 # Download the pretrained weights
-model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch"%(dataset)
-model_path = download_testdata(model_url, "gcn_%s.pickle"%(dataset), module='gcn_model')
+model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch" % (dataset)
+model_path = download_testdata(model_url, "gcn_%s.pickle" % (dataset), module="gcn_model")
 
 # Load the weights into the model
 torch_model.load_state_dict(torch.load(model_path))
@@ -188,14 +179,8 @@ from tvm.contrib import graph_runtime
 import tvm
 from tvm import te
 
-def GraphConv(layer_name,
-              input_dim,
-              output_dim,
-              adj,
-              input,
-              norm=None,
-              bias=True,
-              activation=None):
+
+def GraphConv(layer_name, input_dim, output_dim, adj, input, norm=None, bias=True, activation=None):
     """
     Parameters
     ----------
@@ -246,6 +231,7 @@ def GraphConv(layer_name,
         output_t = activation(output_t)
     return output_t
 
+
 ######################################################################
 # Prepare the parameters needed in the GraphConv layers
 # -----------------------------------------------------
@@ -253,29 +239,33 @@ def GraphConv(layer_name,
 import numpy as np
 import networkx as nx
 
+
 def prepare_params(g, data):
     params = {}
-    params['infeats'] = data.features.astype('float32') # Only support float32 as feature for now
+    params["infeats"] = data.features.astype("float32")  # Only support float32 as feature for now
 
     # Generate adjacency matrix
     adjacency = nx.to_scipy_sparse_matrix(g)
-    params['g_data'] = adjacency.data.astype('float32')
-    params['indices'] = adjacency.indices.astype('int32')
-    params['indptr'] = adjacency.indptr.astype('int32')
+    params["g_data"] = adjacency.data.astype("float32")
+    params["indices"] = adjacency.indices.astype("int32")
+    params["indptr"] = adjacency.indptr.astype("int32")
 
     # Normalization w.r.t. node degrees
     degs = [g.in_degree[i] for i in range(g.number_of_nodes())]
-    params['norm'] = np.power(degs, -0.5).astype('float32')
-    params['norm'] = params['norm'].reshape((params['norm'].shape[0], 1))
+    params["norm"] = np.power(degs, -0.5).astype("float32")
+    params["norm"] = params["norm"].reshape((params["norm"].shape[0], 1))
 
     return params
 
+
 params = prepare_params(g, data)
 
 # Check shape of features and the validity of adjacency matrix
-assert len(params['infeats'].shape) == 2
-assert params['g_data'] is not None and params['indices'] is not None and params['indptr'] is not None
-assert params['infeats'].shape[0] == params['indptr'].shape[0] - 1
+assert len(params["infeats"].shape) == 2
+assert (
+    params["g_data"] is not None and params["indices"] is not None and params["indptr"] is not None
+)
+assert params["infeats"].shape[0] == params["indptr"].shape[0] - 1
 
 ######################################################################
 # Put layers together
@@ -283,34 +273,38 @@ assert params['infeats'].shape[0] == params['indptr'].shape[0] - 1
 
 # Define input features, norms, adjacency matrix in Relay
 infeats = relay.var("infeats", shape=data.features.shape)
-norm = relay.Constant(tvm.nd.array(params['norm']))
-g_data = relay.Constant(tvm.nd.array(params['g_data']))
-indices = relay.Constant(tvm.nd.array(params['indices']))
-indptr = relay.Constant(tvm.nd.array(params['indptr']))
+norm = relay.Constant(tvm.nd.array(params["norm"]))
+g_data = relay.Constant(tvm.nd.array(params["g_data"]))
+indices = relay.Constant(tvm.nd.array(params["indices"]))
+indptr = relay.Constant(tvm.nd.array(params["indptr"]))
 
-Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr'])
+Adjacency = namedtuple("Adjacency", ["data", "indices", "indptr"])
 adj = Adjacency(g_data, indices, indptr)
 
 # Construct the 2-layer GCN
 layers = []
-layers.append(GraphConv(
-    layer_name="layers.0",
-    input_dim=infeat_dim,
-    output_dim=num_hidden,
-    adj=adj,
-    input=infeats,
-    norm=norm,
-    activation=relay.nn.relu
-))
-layers.append(GraphConv(
-    layer_name="layers.1",
-    input_dim=num_hidden,
-    output_dim=num_classes,
-    adj=adj,
-    input=layers[-1],
-    norm=norm,
-    activation=None
-))
+layers.append(
+    GraphConv(
+        layer_name="layers.0",
+        input_dim=infeat_dim,
+        output_dim=num_hidden,
+        adj=adj,
+        input=infeats,
+        norm=norm,
+        activation=relay.nn.relu,
+    )
+)
+layers.append(
+    GraphConv(
+        layer_name="layers.1",
+        input_dim=num_hidden,
+        output_dim=num_classes,
+        adj=adj,
+        input=layers[-1],
+        norm=norm,
+        activation=None,
+    )
+)
 
 # Analyze free variables and generate Relay function
 output = layers[-1]
@@ -324,24 +318,24 @@ model_params = {}
 for param_tensor in torch_model.state_dict():
     model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy()
 
-for i in range(num_layers+1):
-    params["layers.%d.weight"%(i)] = model_params["layers.%d.weight"%(i)]
-    params["layers.%d.bias"%(i)] = model_params["layers.%d.bias"%(i)]
+for i in range(num_layers + 1):
+    params["layers.%d.weight" % (i)] = model_params["layers.%d.weight" % (i)]
+    params["layers.%d.bias" % (i)] = model_params["layers.%d.bias" % (i)]
 
 # Set the TVM build target
-target = 'llvm' # Currently only support `llvm` as target
+target = "llvm"  # Currently only support `llvm` as target
 
 func = relay.Function(relay.analysis.free_vars(output), output)
 func = relay.build_module.bind_params_by_name(func, params)
 mod = tvm.IRModule()
 mod["main"] = func
 # Build with Relay
-with tvm.transform.PassContext(opt_level=0): # Currently only support opt_level=0
+with tvm.transform.PassContext(opt_level=0):  # Currently only support opt_level=0
     lib = relay.build(mod, target, params=params)
 
 # Generate graph runtime
 ctx = tvm.context(target, 0)
-m = graph_runtime.GraphModule(lib['default'](ctx))
+m = graph_runtime.GraphModule(lib["default"](ctx))
 
 ######################################################################
 # Run the TVM model, test for accuracy and verify with DGL
index fe16dac..3bf55d9 100644 (file)
@@ -189,42 +189,53 @@ from tvm.contrib.download import download_testdata
 # ---------------------------
 # We load a pretrained MobileNetV2(alpha=0.5) classification model provided by keras.
 keras.backend.clear_session()  # Destroys the current TF graph and creates a new one.
-weights_url = ''.join(['https://github.com/JonathanCMitchell/',
-                       'mobilenet_v2_keras/releases/download/v1.1/',
-                       'mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5'])
-weights_file = 'mobilenet_v2_weights.h5'
-weights_path = download_testdata(weights_url, weights_file, module='keras')
-keras_mobilenet_v2 = MobileNetV2(alpha=0.5, include_top=True, weights=None,
-                                input_shape=(224, 224, 3), classes=1000)
+weights_url = "".join(
+    [
+        "https://github.com/JonathanCMitchell/",
+        "mobilenet_v2_keras/releases/download/v1.1/",
+        "mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5",
+    ]
+)
+weights_file = "mobilenet_v2_weights.h5"
+weights_path = download_testdata(weights_url, weights_file, module="keras")
+keras_mobilenet_v2 = MobileNetV2(
+    alpha=0.5, include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
+)
 keras_mobilenet_v2.load_weights(weights_path)
 
 ######################################################################
 # In order to test our model, here we download an image of cat and
 # transform its format.
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_name = 'cat.png'
-img_path = download_testdata(img_url, img_name, module='data')
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_name = "cat.png"
+img_path = download_testdata(img_url, img_name, module="data")
 image = Image.open(img_path).resize((224, 224))
-dtype = 'float32'
+dtype = "float32"
+
 
 def transform_image(image):
-    image = np.array(image) - np.array([123., 117., 104.])
+    image = np.array(image) - np.array([123.0, 117.0, 104.0])
     image /= np.array([58.395, 57.12, 57.375])
     image = image.transpose((2, 0, 1))
     image = image[np.newaxis, :]
     return image
 
+
 x = transform_image(image)
 
 ######################################################################
 # synset is used to transform the label from number of ImageNet class to
 # the word human can understand.
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
 
@@ -241,31 +252,30 @@ local_demo = True
 
 # by default on CPU target will execute.
 # select 'cpu', 'opencl' and 'vulkan'
-test_target = 'cpu'
+test_target = "cpu"
 
 # Change target configuration.
 # Run `adb shell cat /proc/cpuinfo` to find the arch.
-arch = 'arm64'
-target = 'llvm -mtriple=%s-linux-android' % arch
+arch = "arm64"
+target = "llvm -mtriple=%s-linux-android" % arch
 target_host = None
 
 if local_demo:
     target_host = None
-    target = 'llvm'
-elif test_target == 'opencl':
+    target = "llvm"
+elif test_target == "opencl":
     target_host = target
-    target = 'opencl'
-elif test_target == 'vulkan':
+    target = "opencl"
+elif test_target == "vulkan":
     target_host = target
-    target = 'vulkan'
+    target = "vulkan"
 
-input_name = 'input_1'
+input_name = "input_1"
 shape_dict = {input_name: x.shape}
 mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict)
 
 with tvm.transform.PassContext(opt_level=3):
-    lib = relay.build(mod, target=target,
-                      target_host=target_host, params=params)
+    lib = relay.build(mod, target=target, target_host=target_host, params=params)
 
 # After `relay.build`, you will get three return values: graph,
 # library and the new parameter, since we do some optimization that will
@@ -273,7 +283,7 @@ with tvm.transform.PassContext(opt_level=3):
 
 # Save the library at local temporary directory.
 tmp = util.tempdir()
-lib_fname = tmp.relpath('net.so')
+lib_fname = tmp.relpath("net.so")
 fcompile = ndk.create_shared if not local_demo else None
 lib.export_library(lib_fname, fcompile)
 
@@ -283,33 +293,32 @@ lib.export_library(lib_fname, fcompile)
 # With RPC, you can deploy the model remotely from your host machine
 # to the remote android device.
 
-tracker_host = os.environ.get('TVM_TRACKER_HOST', '0.0.0.0')
-tracker_port = int(os.environ.get('TVM_TRACKER_PORT', 9190))
-key = 'android'
+tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0")
+tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
+key = "android"
 
 if local_demo:
     remote = rpc.LocalSession()
 else:
     tracker = rpc.connect_tracker(tracker_host, tracker_port)
     # When running a heavy model, we should increase the `session_timeout`
-    remote = tracker.request(key, priority=0,
-                             session_timeout=60)
+    remote = tracker.request(key, priority=0, session_timeout=60)
 
 if local_demo:
     ctx = remote.cpu(0)
-elif test_target == 'opencl':
+elif test_target == "opencl":
     ctx = remote.cl(0)
-elif test_target == 'vulkan':
+elif test_target == "vulkan":
     ctx = remote.vulkan(0)
 else:
     ctx = remote.cpu(0)
 
 # upload the library to remote device and load it
 remote.upload(lib_fname)
-rlib = remote.load_module('net.so')
+rlib = remote.load_module("net.so")
 
 # create the remote runtime module
-module = runtime.GraphModule(rlib['default'](ctx))
+module = runtime.GraphModule(rlib["default"](ctx))
 
 ######################################################################
 # Execute on TVM
@@ -324,13 +333,12 @@ out = module.get_output(0)
 
 # get top1 result
 top1 = np.argmax(out.asnumpy())
-print('TVM prediction top-1: {}'.format(synset[top1]))
+print("TVM prediction top-1: {}".format(synset[top1]))
 
-print('Evaluate inference time cost...')
-ftimer = module.module.time_evaluator('run', ctx, number=1, repeat=10)
+print("Evaluate inference time cost...")
+ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
 prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
-print('Mean inference time (std dev): %.2f ms (%.2f ms)' % (np.mean(prof_res),
-                                                            np.std(prof_res)))
+print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))
 
 ######################################################################
 # Sample Output
index 4a88a74..c6e2d8f 100644 (file)
@@ -104,34 +104,40 @@ from PIL import Image
 import numpy as np
 
 # one line to get the model
-block = get_model('resnet18_v1', pretrained=True)
+block = get_model("resnet18_v1", pretrained=True)
 
 ######################################################################
 # In order to test our model, here we download an image of cat and
 # transform its format.
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_name = 'cat.png'
-img_path = download_testdata(img_url, img_name, module='data')
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_name = "cat.png"
+img_path = download_testdata(img_url, img_name, module="data")
 image = Image.open(img_path).resize((224, 224))
 
+
 def transform_image(image):
-    image = np.array(image) - np.array([123., 117., 104.])
+    image = np.array(image) - np.array([123.0, 117.0, 104.0])
     image /= np.array([58.395, 57.12, 57.375])
     image = image.transpose((2, 0, 1))
     image = image[np.newaxis, :]
     return image
 
+
 x = transform_image(image)
 
 ######################################################################
 # synset is used to transform the label from number of ImageNet class to
 # the word human can understand.
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
 
@@ -140,7 +146,7 @@ with open(synset_path) as f:
 # It's as easy as several lines.
 
 # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
-shape_dict = {'data': x.shape}
+shape_dict = {"data": x.shape}
 mod, params = relay.frontend.from_mxnet(block, shape_dict)
 # we want a probability so add a softmax operator
 func = mod["main"]
@@ -173,9 +179,9 @@ data_shape = (batch_size,) + image_shape
 local_demo = True
 
 if local_demo:
-    target = tvm.target.Target('llvm')
+    target = tvm.target.Target("llvm")
 else:
-    target = tvm.target.arm_cpu('rasp3b')
+    target = tvm.target.arm_cpu("rasp3b")
     # The above line is a simple form of
     # target = tvm.target.Target('llvm -device=arm_cpu -model=bcm2837 -mtriple=armv7l-linux-gnueabihf -mattr=+neon')
 
@@ -188,7 +194,7 @@ with tvm.transform.PassContext(opt_level=3):
 
 # Save the library at local temporary directory.
 tmp = util.tempdir()
-lib_fname = tmp.relpath('net.tar')
+lib_fname = tmp.relpath("net.tar")
 lib.export_library(lib_fname)
 
 ######################################################################
@@ -202,23 +208,23 @@ if local_demo:
     remote = rpc.LocalSession()
 else:
     # The following is my environment, change this to the IP address of your target device
-    host = '10.77.1.162'
+    host = "10.77.1.162"
     port = 9090
     remote = rpc.connect(host, port)
 
 # upload the library to remote device and load it
 remote.upload(lib_fname)
-rlib = remote.load_module('net.tar')
+rlib = remote.load_module("net.tar")
 
 # create the remote runtime module
 ctx = remote.cpu(0)
-module = runtime.GraphModule(rlib['default'](ctx))
+module = runtime.GraphModule(rlib["default"](ctx))
 # set input data
-module.set_input('data', tvm.nd.array(x.astype('float32')))
+module.set_input("data", tvm.nd.array(x.astype("float32")))
 # run
 module.run()
 # get output
 out = module.get_output(0)
 # get top1 result
 top1 = np.argmax(out.asnumpy())
-print('TVM prediction top-1: {}'.format(synset[top1]))
+print("TVM prediction top-1: {}".format(synset[top1]))
index ca741b3..81959db 100644 (file)
@@ -46,19 +46,21 @@ from tvm.contrib.download import download_testdata
 # Helper functions to run the demo
 def get_transform():
     import torchvision.transforms as transforms
-    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                                     std=[0.229, 0.224, 0.225])
-    return transforms.Compose([
+
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    return transforms.Compose(
+        [
             transforms.Resize(256),
             transforms.CenterCrop(224),
             transforms.ToTensor(),
             normalize,
-        ])
+        ]
+    )
 
 
 def get_real_image(im_height, im_width):
-    img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-    img_path = download_testdata(img_url, 'cat.png', module='data')
+    img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+    img_path = download_testdata(img_url, "cat.png", module="data")
     return Image.open(img_path).resize((im_height, im_width))
 
 
@@ -70,12 +72,16 @@ def get_imagenet_input():
 
 
 def get_synset():
-    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                          '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                          '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                          'imagenet1000_clsid_to_human.txt'])
-    synset_name = 'imagenet1000_clsid_to_human.txt'
-    synset_path = download_testdata(synset_url, synset_name, module='data')
+    synset_url = "".join(
+        [
+            "https://gist.githubusercontent.com/zhreshold/",
+            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+            "imagenet1000_clsid_to_human.txt",
+        ]
+    )
+    synset_name = "imagenet1000_clsid_to_human.txt"
+    synset_path = download_testdata(synset_url, synset_name, module="data")
     with open(synset_path) as f:
         return eval(f.read())
 
@@ -84,7 +90,7 @@ def run_tvm_model(mod, params, input_name, inp, target="llvm"):
     with tvm.transform.PassContext(opt_level=3):
         lib = relay.build(mod, target=target, params=params)
 
-    runtime = tvm.contrib.graph_runtime.GraphModule(lib['default'](tvm.context(target, 0)))
+    runtime = tvm.contrib.graph_runtime.GraphModule(lib["default"](tvm.context(target, 0)))
 
     runtime.set_input(input_name, inp)
     runtime.run()
@@ -114,9 +120,10 @@ inp = get_imagenet_input()
 # In short, this function takes a floating point model and converts it to uint8.
 # The model is per-channel quantized.
 
+
 def quantize_model(model, inp):
     model.fuse_model()
-    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
+    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
     torch.quantization.prepare(model, inplace=True)
     # Dummy calibration
     model(inp)
@@ -192,8 +199,7 @@ print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_
 # Here we give an example of how to measure performance of TVM compiled models.
 n_repeat = 100  # should be bigger to make the measurement more accurate
 ctx = tvm.cpu(0)
-ftimer = rt_mod.module.time_evaluator("run", ctx, number=1,
-                                      repeat=n_repeat)
+ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, repeat=n_repeat)
 prof_res = np.array(ftimer().results) * 1e3
 print("Elapsed average ms:", np.mean(prof_res))
 
index 0e5f9af..52321b1 100644 (file)
@@ -61,12 +61,15 @@ from tvm import relay
 # Download mobilenet V2 TFLite model provided by Google
 from tvm.contrib.download import download_testdata
 
-model_url = "https://storage.googleapis.com/download.tensorflow.org/models/" \
-             "tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz"
+model_url = (
+    "https://storage.googleapis.com/download.tensorflow.org/models/"
+    "tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz"
+)
 
 # Download model tar file and extract it to get mobilenet_v2_1.0_224.tflite
-model_path = download_testdata(model_url, "mobilenet_v2_1.0_224_quant.tgz",
-                               module=['tf', 'official'])
+model_path = download_testdata(
+    model_url, "mobilenet_v2_1.0_224_quant.tgz", module=["tf", "official"]
+)
 model_dir = os.path.dirname(model_path)
 
 
@@ -75,13 +78,15 @@ model_dir = os.path.dirname(model_path)
 # ----------------------------------------------
 def extract(path):
     import tarfile
+
     if path.endswith("tgz") or path.endswith("gz"):
         dir_path = os.path.dirname(path)
         tar = tarfile.open(path)
         tar.extractall(path=dir_path)
         tar.close()
     else:
-        raise RuntimeError('Could not decompress the file: ' + path)
+        raise RuntimeError("Could not decompress the file: " + path)
+
 
 extract(model_path)
 
@@ -95,15 +100,17 @@ extract(model_path)
 # --------------------------------
 def get_real_image(im_height, im_width):
     from PIL import Image
-    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-    img_name = 'elephant-299.jpg'
+
+    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+    img_name = "elephant-299.jpg"
     image_url = os.path.join(repo_base, img_name)
-    img_path = download_testdata(image_url, img_name, module='data')
+    img_path = download_testdata(image_url, img_name, module="data")
     image = Image.open(img_path).resize((im_height, im_width))
-    x = np.array(image).astype('uint8')
+    x = np.array(image).astype("uint8")
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+
 data = get_real_image(224, 224)
 
 ######################################################################
@@ -118,9 +125,11 @@ tflite_model_buf = open(tflite_model_file, "rb").read()
 # Get TFLite model from buffer
 try:
     import tflite
+
     tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
 except AttributeError:
     import tflite.Model
+
     tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
 
 ###############################################################################
@@ -143,7 +152,7 @@ def run_tflite_model(tflite_model_buf, input_data):
     # set input
     assert len(input_data) == len(input_details)
     for i in range(len(input_details)):
-        interpreter.set_tensor(input_details[i]['index'], input_data[i])
+        interpreter.set_tensor(input_details[i]["index"], input_data[i])
 
     # Run
     interpreter.invoke()
@@ -151,16 +160,18 @@ def run_tflite_model(tflite_model_buf, input_data):
     # get output
     tflite_output = list()
     for i in range(len(output_details)):
-        tflite_output.append(interpreter.get_tensor(output_details[i]['index']))
+        tflite_output.append(interpreter.get_tensor(output_details[i]["index"]))
 
     return tflite_output
 
+
 ###############################################################################
 # Lets run TVM compiled pre-quantized model inference and get the TVM prediction.
 def run_tvm(lib):
     from tvm.contrib import graph_runtime
-    rt_mod = graph_runtime.GraphModule(lib['default'](tvm.cpu(0)))
-    rt_mod.set_input('input', data)
+
+    rt_mod = graph_runtime.GraphModule(lib["default"](tvm.cpu(0)))
+    rt_mod.set_input("input", data)
     rt_mod.run()
     tvm_res = rt_mod.get_output(0).asnumpy()
     tvm_pred = np.squeeze(tvm_res).argsort()[-5:][::-1]
@@ -185,18 +196,16 @@ tflite_pred = np.squeeze(tflite_res).argsort()[-5:][::-1]
 # frontend parser call for a pre-quantized model is exactly same as frontend parser call for a FP32
 # model. We encourage you to remove the comment from print(mod) and inspect the Relay module. You
 # will see many QNN operators, like, Requantize, Quantize and QNN Conv2D.
-dtype_dict = {'input': data.dtype.name}
-shape_dict = {'input': data.shape}
+dtype_dict = {"input": data.dtype.name}
+shape_dict = {"input": data.shape}
 
-mod, params = relay.frontend.from_tflite(tflite_model,
-                                         shape_dict=shape_dict,
-                                         dtype_dict=dtype_dict)
+mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict)
 # print(mod)
 
 ###############################################################################
 # Lets now the compile the Relay module. We use the "llvm" target here. Please replace it with the
 # target platform that you are interested in.
-target = 'llvm'
+target = "llvm"
 with tvm.transform.PassContext(opt_level=3):
     lib = relay.build_module.build(mod, target=target, params=params)
 
index 2586318..093bd73 100644 (file)
@@ -38,7 +38,7 @@ import os
 
 batch_size = 1
 model_name = "resnet18_v1"
-target = 'cuda'
+target = "cuda"
 ctx = tvm.context(target)
 
 ###############################################################################
@@ -47,8 +47,10 @@ ctx = tvm.context(target)
 # We will demonstrate how to prepare the calibration dataset for quantization.
 # We first download the validation set of ImageNet and pre-process the dataset.
 calibration_rec = download_testdata(
-    'http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec',
-    'val_256_q90.rec')
+    "http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec",
+    "val_256_q90.rec",
+)
+
 
 def get_val_data(num_workers=4):
     mean_rgb = [123.68, 116.779, 103.939]
@@ -57,7 +59,7 @@ def get_val_data(num_workers=4):
     def batch_fn(batch):
         return batch.data[0].asnumpy(), batch.label[0].asnumpy()
 
-    img_size = 299 if model_name == 'inceptionv3' else 224
+    img_size = 299 if model_name == "inceptionv3" else 224
     val_data = mx.io.ImageRecordIter(
         path_imgrec=calibration_rec,
         preprocess_threads=num_workers,
@@ -82,6 +84,7 @@ def get_val_data(num_workers=4):
 
 calibration_samples = 10
 
+
 def calibrate_dataset():
     val_data, batch_fn = get_val_data()
     val_data.reset()
@@ -89,7 +92,7 @@ def calibrate_dataset():
         if i * batch_size >= calibration_samples:
             break
         data, _ = batch_fn(batch)
-        yield {'data': data}
+        yield {"data": data}
 
 
 ###############################################################################
@@ -98,7 +101,7 @@ def calibrate_dataset():
 # We use the Relay MxNet frontend to import a model from the Gluon model zoo.
 def get_model():
     gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
-    img_size = 299 if model_name == 'inceptionv3' else 224
+    img_size = 299 if model_name == "inceptionv3" else 224
     data_shape = (batch_size, 3, img_size, img_size)
     mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
     return mod, params
@@ -127,12 +130,13 @@ def get_model():
 # Alternatively, we can also use pre-defined global scales. This saves the time
 # for calibration. But the accuracy might be impacted.
 
+
 def quantize(mod, params, data_aware):
     if data_aware:
-        with relay.quantize.qconfig(calibrate_mode='kl_divergence', weight_scale='max'):
+        with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max"):
             mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset())
     else:
-        with relay.quantize.qconfig(calibrate_mode='global_scale', global_scale=8.0):
+        with relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0):
             mod = relay.quantize.quantize(mod, params)
     return mod
 
@@ -142,7 +146,7 @@ def quantize(mod, params, data_aware):
 # -------------
 # We create a Relay VM to build and execute the model.
 def run_inference(mod):
-    executor = relay.create_executor('vm', mod, ctx, target)
+    executor = relay.create_executor("vm", mod, ctx, target)
     val_data, batch_fn = get_val_data()
     for i, batch in enumerate(val_data):
         data, label = batch_fn(batch)
@@ -150,10 +154,12 @@ def run_inference(mod):
         if i > 10:  # only run inference on a few samples in this tutorial
             break
 
+
 def main():
     mod, params = get_model()
     mod = quantize(mod, params, data_aware=True)
     run_inference(mod)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     main()
index 11cd63c..dcf2fc4 100644 (file)
@@ -132,9 +132,9 @@ def load_keras_model(module, name, seq_len, batch_size, report_runtime=True):
     dummy_input = tf.keras.Input(shape=[seq_len], batch_size=batch_size, dtype="int32")
     dummy_out = model(dummy_input)  # Propagate shapes through the keras model.
     if report_runtime:
-        np_input = np.random.uniform(
-            size=[batch_size, seq_len], low=0, high=seq_len
-        ).astype("int32")
+        np_input = np.random.uniform(size=[batch_size, seq_len], low=0, high=seq_len).astype(
+            "int32"
+        )
         start = time.time()
         repeats = 50
         for i in range(repeats):
@@ -180,12 +180,8 @@ def import_graphdef(
 ):
     abs_path = os.path.dirname(os.path.abspath(__file__))
     shape_dict = {"input_1": (batch_size, seq_len)}
-    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace(
-        "/", "_"
-    )
-    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace(
-        "/", "_"
-    )
+    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace("/", "_")
+    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace("/", "_")
     if os.path.exists(os.path.join(abs_path, relay_file)) and os.path.exists(
         os.path.join(abs_path, relay_params)
     ):
@@ -218,11 +214,9 @@ def run_relay_graph(mod, params, shape_dict, target, ctx):
     with relay.build_config(opt_level=3):
         lib = relay.build(mod, target=target, params=params)
     input_shape = shape_dict["input_1"]
-    dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype(
-        "int32"
-    )
+    dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype("int32")
 
-    m = graph_runtime.GraphModule(lib['default'](ctx))
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     m.set_input(0, dummy_data)
     m.run()
     tvm_output = m.get_output(0)
@@ -252,7 +246,7 @@ def run_dense(mod, params, shape_dict, target, ctx):
 # into the parameters. This makes it easier to convert to matrix multiplies
 # to sparse versions. Next we apply `bsr_dense.convert` to identify all
 # weight matrices that can be sparse, and automatically replace them.
-# 
+#
 # The `bsr_dense.convert` call below is doing the heavy lifting of identifying
 # which weights in the model can be made sparse by checking if they are
 # at least `sparsity_threshold` percent sparse. If so, it converts those
@@ -269,9 +263,7 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"):
     assert N % BS_C == 0
     nnz = int(density * M * N)
     num_blocks = int(nnz / (BS_R * BS_C)) + 1
-    candidate_blocks = np.asarray(
-        list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))
-    )
+    candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
     assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
     chosen_blocks = candidate_blocks[
         np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
@@ -308,9 +300,7 @@ def random_sparse_bert_params(func, params, density, BS_R, BS_C):
 def run_sparse(mod, params, shape_dict, target, ctx, bs_r, sparsity, gen_weights):
     mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params)
     if gen_weights:
-        params = random_sparse_bert_params(
-            mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity
-        )
+        params = random_sparse_bert_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity)
     mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8)
     print("Block Sparse Model with {blocksize}x1 blocks:".format(blocksize=bs_r))
     return run_relay_graph(mod, params, shape_dict, target, ctx)
index 3643c8d..d874487 100644 (file)
@@ -58,13 +58,12 @@ from gluoncv import model_zoo, data, utils
 #   to your device.
 
 supported_model = [
-    'ssd_512_resnet50_v1_voc',
-    'ssd_512_resnet50_v1_coco',
-    'ssd_512_resnet101_v2_voc',
-    'ssd_512_mobilenet1.0_voc',
-    'ssd_512_mobilenet1.0_coco',
-    'ssd_300_vgg16_atrous_voc'
-    'ssd_512_vgg16_atrous_coco',
+    "ssd_512_resnet50_v1_voc",
+    "ssd_512_resnet50_v1_coco",
+    "ssd_512_resnet101_v2_voc",
+    "ssd_512_mobilenet1.0_voc",
+    "ssd_512_mobilenet1.0_coco",
+    "ssd_300_vgg16_atrous_voc" "ssd_512_vgg16_atrous_coco",
 ]
 
 model_name = supported_model[0]
@@ -73,9 +72,11 @@ dshape = (1, 3, 512, 512)
 ######################################################################
 # Download and pre-process demo image
 
-im_fname = download_testdata('https://github.com/dmlc/web-data/blob/master/' +
-                             'gluoncv/detection/street_small.jpg?raw=true',
-                             'street_small.jpg', module='data')
+im_fname = download_testdata(
+    "https://github.com/dmlc/web-data/blob/master/" + "gluoncv/detection/street_small.jpg?raw=true",
+    "street_small.jpg",
+    module="data",
+)
 x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
 
 ######################################################################
@@ -83,26 +84,30 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
 
 block = model_zoo.get_model(model_name, pretrained=True)
 
+
 def build(target):
     mod, params = relay.frontend.from_mxnet(block, {"data": dshape})
     with tvm.transform.PassContext(opt_level=3):
         lib = relay.build(mod, target, params=params)
     return lib
 
+
 ######################################################################
 # Create TVM runtime and do inference
 
+
 def run(lib, ctx):
     # Build TVM runtime
-    m = graph_runtime.GraphModule(lib['default'](ctx))
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx)
-    m.set_input('data', tvm_input)
+    m.set_input("data", tvm_input)
     # execute
     m.run()
     # get outputs
     class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2)
     return class_IDs, scores, bounding_boxs
 
+
 for target in ["llvm", "cuda"]:
     ctx = tvm.context(target, 0)
     if ctx.exist:
@@ -112,6 +117,11 @@ for target in ["llvm", "cuda"]:
 ######################################################################
 # Display result
 
-ax = utils.viz.plot_bbox(img, bounding_boxs.asnumpy()[0], scores.asnumpy()[0],
-                         class_IDs.asnumpy()[0], class_names=block.classes)
+ax = utils.viz.plot_bbox(
+    img,
+    bounding_boxs.asnumpy()[0],
+    scores.asnumpy()[0],
+    class_IDs.asnumpy()[0],
+    class_names=block.classes,
+)
 plt.show()
index 66ea0bb..4f6f647 100644 (file)
@@ -41,13 +41,16 @@ https://caffe2.ai/docs/getting-started.html
 # ----------------------------
 # We load a pretrained resnet50 classification model provided by Caffe2.
 from caffe2.python.models.download import ModelDownloader
+
 mf = ModelDownloader()
 
+
 class Model:
     def __init__(self, model_name):
         self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
 
-resnet50 = Model('resnet50')
+
+resnet50 = Model("resnet50")
 
 ######################################################################
 # Load a test image
@@ -57,19 +60,21 @@ from tvm.contrib.download import download_testdata
 from PIL import Image
 from matplotlib import pyplot as plt
 import numpy as np
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_path = download_testdata(img_url, 'cat.png', module='data')
+
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_path = download_testdata(img_url, "cat.png", module="data")
 img = Image.open(img_path).resize((224, 224))
 plt.imshow(img)
 plt.show()
 # input preprocess
 def transform_image(image):
-    image = np.array(image) - np.array([123., 117., 104.])
+    image = np.array(image) - np.array([123.0, 117.0, 104.0])
     image /= np.array([58.395, 57.12, 57.375])
     image = image.transpose((2, 0, 1))
-    image = image[np.newaxis, :].astype('float32')
+    image = image[np.newaxis, :].astype("float32")
     return image
 
+
 data = transform_image(img)
 
 ######################################################################
@@ -83,11 +88,14 @@ dtype_dict = {input_name: data.dtype}
 
 # parse Caffe2 model and convert into Relay computation graph
 from tvm import relay, transform
-mod, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict)
+
+mod, params = relay.frontend.from_caffe2(
+    resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict
+)
 
 # compile the model
 # target x86 CPU
-target = 'llvm'
+target = "llvm"
 with transform.PassContext(opt_level=3):
     lib = relay.build(mod, target, params=params)
 
@@ -98,12 +106,13 @@ with transform.PassContext(opt_level=3):
 import tvm
 from tvm import te
 from tvm.contrib import graph_runtime
+
 # context x86 CPU, use tvm.gpu(0) if you run on GPU
 ctx = tvm.cpu(0)
 # create a runtime executor module
-m = graph_runtime.GraphModule(lib['default'](ctx))
+m = graph_runtime.GraphModule(lib["default"](ctx))
 # set inputs
-m.set_input(input_name, tvm.nd.array(data.astype('float32')))
+m.set_input(input_name, tvm.nd.array(data.astype("float32")))
 # execute
 m.run()
 # get outputs
@@ -115,17 +124,22 @@ top1_tvm = np.argmax(tvm_out.asnumpy()[0])
 # -------------------
 # Look up prediction top 1 index in 1000 class synset.
 from caffe2.python import workspace
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
-print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm]))
+print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm]))
 # confirm correctness with caffe2 output
 p = workspace.Predictor(resnet50.init_net, resnet50.predict_net)
 caffe2_out = p.run({input_name: data})
 top1_caffe2 = np.argmax(caffe2_out)
-print('Caffe2 top-1 id: {}, class name: {}'.format(top1_caffe2, synset[top1_caffe2]))
+print("Caffe2 top-1 id: {}, class name: {}".format(top1_caffe2, synset[top1_caffe2]))
index f5db0f5..4e3f391 100644 (file)
@@ -47,9 +47,9 @@ from PIL import Image
 # ----------------------------
 # We will download and load a pretrained mobilenet classification network
 # provided by apple in this example
-model_url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
-model_file = 'mobilenet.mlmodel'
-model_path = download_testdata(model_url, model_file, module='coreml')
+model_url = "https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel"
+model_file = "mobilenet.mlmodel"
+model_path = download_testdata(model_url, model_file, module="coreml")
 # Now you have mobilenet.mlmodel on disk
 mlmodel = cm.models.MLModel(model_path)
 
@@ -57,19 +57,19 @@ mlmodel = cm.models.MLModel(model_path)
 # Load a test image
 # ------------------
 # A single cat dominates the examples!
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_path = download_testdata(img_url, 'cat.png', module='data')
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_path = download_testdata(img_url, "cat.png", module="data")
 img = Image.open(img_path).resize((224, 224))
 # Mobilenet.mlmodel's input is BGR format
-img_bgr = np.array(img)[:,:,::-1]
+img_bgr = np.array(img)[:, :, ::-1]
 x = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
 
 ######################################################################
 # Compile the model on Relay
 # ---------------------------
 # We should be familiar with the process right now.
-target = 'llvm'
-shape_dict = {'image': x.shape}
+target = "llvm"
+shape_dict = {"image": x.shape}
 
 # Parse CoreML model and convert into Relay computation graph
 mod, params = relay.frontend.from_coreml(mlmodel, shape_dict)
@@ -82,11 +82,12 @@ with tvm.transform.PassContext(opt_level=3):
 # -------------------
 # The process is no different from other example
 from tvm.contrib import graph_runtime
+
 ctx = tvm.cpu(0)
-dtype = 'float32'
-m = graph_runtime.GraphModule(lib['default'](ctx))
+dtype = "float32"
+m = graph_runtime.GraphModule(lib["default"](ctx))
 # set inputs
-m.set_input('image', tvm.nd.array(x.astype(dtype)))
+m.set_input("image", tvm.nd.array(x.astype(dtype)))
 # execute
 m.run()
 # get outputs
@@ -97,13 +98,17 @@ top1 = np.argmax(tvm_output.asnumpy()[0])
 # Look up synset name
 # -------------------
 # Look up prediction top 1 index in 1000 class synset.
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
 # You should see the following result: Top-1 id 282 class name tiger cat
-print('Top-1 id', top1, 'class name', synset[top1])
+print("Top-1 id", top1, "class name", synset[top1])
index c49fc8b..bbfb410 100644 (file)
@@ -52,28 +52,28 @@ import tvm.relay.testing.darknet
 # Models are: 'yolov2', 'yolov3' or 'yolov3-tiny'
 
 # Model name
-MODEL_NAME = 'yolov3'
+MODEL_NAME = "yolov3"
 
 ######################################################################
 # Download required files
 # -----------------------
 # Download cfg and weights file if first time.
-CFG_NAME = MODEL_NAME + '.cfg'
-WEIGHTS_NAME = MODEL_NAME + '.weights'
-REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
-CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
-WEIGHTS_URL = 'https://pjreddie.com/media/files/' + WEIGHTS_NAME
+CFG_NAME = MODEL_NAME + ".cfg"
+WEIGHTS_NAME = MODEL_NAME + ".weights"
+REPO_URL = "https://github.com/dmlc/web-data/blob/master/darknet/"
+CFG_URL = REPO_URL + "cfg/" + CFG_NAME + "?raw=true"
+WEIGHTS_URL = "https://pjreddie.com/media/files/" + WEIGHTS_NAME
 
 cfg_path = download_testdata(CFG_URL, CFG_NAME, module="darknet")
 weights_path = download_testdata(WEIGHTS_URL, WEIGHTS_NAME, module="darknet")
 
 # Download and Load darknet library
-if sys.platform in ['linux', 'linux2']:
-    DARKNET_LIB = 'libdarknet2.0.so'
-    DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
-elif sys.platform == 'darwin':
-    DARKNET_LIB = 'libdarknet_mac2.0.so'
-    DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true'
+if sys.platform in ["linux", "linux2"]:
+    DARKNET_LIB = "libdarknet2.0.so"
+    DARKNET_URL = REPO_URL + "lib/" + DARKNET_LIB + "?raw=true"
+elif sys.platform == "darwin":
+    DARKNET_LIB = "libdarknet_mac2.0.so"
+    DARKNET_URL = REPO_URL + "lib_osx/" + DARKNET_LIB + "?raw=true"
 else:
     err = "Darknet lib is not supported on {} platform".format(sys.platform)
     raise NotImplementedError(err)
@@ -81,12 +81,12 @@ else:
 lib_path = download_testdata(DARKNET_URL, DARKNET_LIB, module="darknet")
 
 DARKNET_LIB = __darknetffi__.dlopen(lib_path)
-net = DARKNET_LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
-dtype = 'float32'
+net = DARKNET_LIB.load_network(cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0)
+dtype = "float32"
 batch_size = 1
 
 data = np.empty([batch_size, net.c, net.h, net.w], dtype)
-shape_dict = {'data': data.shape}
+shape_dict = {"data": data.shape}
 print("Converting darknet to relay functions...")
 mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape)
 
@@ -94,22 +94,22 @@ mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape)
 # Import the graph to Relay
 # -------------------------
 # compile the model
-target = 'llvm'
-target_host = 'llvm'
+target = "llvm"
+target_host = "llvm"
 ctx = tvm.cpu(0)
 data = np.empty([batch_size, net.c, net.h, net.w], dtype)
-shape = {'data': data.shape}
+shape = {"data": data.shape}
 print("Compiling the model...")
 with tvm.transform.PassContext(opt_level=3):
     lib = relay.build(mod, target=target, target_host=target_host, params=params)
 
-[neth, netw] = shape['data'][2:] # Current image shape is 608x608
+[neth, netw] = shape["data"][2:]  # Current image shape is 608x608
 ######################################################################
 # Load a test image
 # -----------------
-test_image = 'dog.jpg'
+test_image = "dog.jpg"
 print("Loading the test image...")
-img_url = REPO_URL + 'data/' + test_image + '?raw=true'
+img_url = REPO_URL + "data/" + test_image + "?raw=true"
 img_path = download_testdata(img_url, test_image, "data")
 
 data = tvm.relay.testing.darknet.load_image(img_path, netw, neth)
@@ -119,10 +119,10 @@ data = tvm.relay.testing.darknet.load_image(img_path, netw, neth)
 # The process is no different from other examples.
 from tvm.contrib import graph_runtime
 
-m = graph_runtime.GraphModule(lib['default'](ctx))
+m = graph_runtime.GraphModule(lib["default"](ctx))
 
 # set inputs
-m.set_input('data', tvm.nd.array(data.astype(dtype)))
+m.set_input("data", tvm.nd.array(data.astype(dtype)))
 # execute
 print("Running the test image...")
 
@@ -134,69 +134,69 @@ nms_thresh = 0.45
 m.run()
 # get outputs
 tvm_out = []
-if MODEL_NAME == 'yolov2':
+if MODEL_NAME == "yolov2":
     layer_out = {}
-    layer_out['type'] = 'Region'
+    layer_out["type"] = "Region"
     # Get the region layer attributes (n, out_c, out_h, out_w, classes, coords, background)
     layer_attr = m.get_output(2).asnumpy()
-    layer_out['biases'] = m.get_output(1).asnumpy()
-    out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
-                 layer_attr[2], layer_attr[3])
-    layer_out['output'] = m.get_output(0).asnumpy().reshape(out_shape)
-    layer_out['classes'] = layer_attr[4]
-    layer_out['coords'] = layer_attr[5]
-    layer_out['background'] = layer_attr[6]
+    layer_out["biases"] = m.get_output(1).asnumpy()
+    out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3])
+    layer_out["output"] = m.get_output(0).asnumpy().reshape(out_shape)
+    layer_out["classes"] = layer_attr[4]
+    layer_out["coords"] = layer_attr[5]
+    layer_out["background"] = layer_attr[6]
     tvm_out.append(layer_out)
 
-elif MODEL_NAME == 'yolov3':
+elif MODEL_NAME == "yolov3":
     for i in range(3):
         layer_out = {}
-        layer_out['type'] = 'Yolo'
+        layer_out["type"] = "Yolo"
         # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
-        layer_attr = m.get_output(i*4+3).asnumpy()
-        layer_out['biases'] = m.get_output(i*4+2).asnumpy()
-        layer_out['mask'] = m.get_output(i*4+1).asnumpy()
-        out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
-                     layer_attr[2], layer_attr[3])
-        layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape)
-        layer_out['classes'] = layer_attr[4]
+        layer_attr = m.get_output(i * 4 + 3).asnumpy()
+        layer_out["biases"] = m.get_output(i * 4 + 2).asnumpy()
+        layer_out["mask"] = m.get_output(i * 4 + 1).asnumpy()
+        out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3])
+        layer_out["output"] = m.get_output(i * 4).asnumpy().reshape(out_shape)
+        layer_out["classes"] = layer_attr[4]
         tvm_out.append(layer_out)
 
-elif MODEL_NAME == 'yolov3-tiny':
+elif MODEL_NAME == "yolov3-tiny":
     for i in range(2):
         layer_out = {}
-        layer_out['type'] = 'Yolo'
+        layer_out["type"] = "Yolo"
         # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
-        layer_attr = m.get_output(i*4+3).asnumpy()
-        layer_out['biases'] = m.get_output(i*4+2).asnumpy()
-        layer_out['mask'] = m.get_output(i*4+1).asnumpy()
-        out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
-                     layer_attr[2], layer_attr[3])
-        layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape)
-        layer_out['classes'] = layer_attr[4]
+        layer_attr = m.get_output(i * 4 + 3).asnumpy()
+        layer_out["biases"] = m.get_output(i * 4 + 2).asnumpy()
+        layer_out["mask"] = m.get_output(i * 4 + 1).asnumpy()
+        out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3])
+        layer_out["output"] = m.get_output(i * 4).asnumpy().reshape(out_shape)
+        layer_out["classes"] = layer_attr[4]
         tvm_out.append(layer_out)
         thresh = 0.560
 
 # do the detection and bring up the bounding boxes
 img = tvm.relay.testing.darknet.load_image_color(img_path)
 _, im_h, im_w = img.shape
-dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh,
-                                                      1, tvm_out)
+dets = tvm.relay.testing.yolo_detection.fill_network_boxes(
+    (netw, neth), (im_w, im_h), thresh, 1, tvm_out
+)
 last_layer = net.layers[net.n - 1]
 tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
 
-coco_name = 'coco.names'
-coco_url = REPO_URL + 'data/' + coco_name + '?raw=true'
-font_name = 'arial.ttf'
-font_url = REPO_URL + 'data/' + font_name + '?raw=true'
-coco_path = download_testdata(coco_url, coco_name, module='data')
-font_path = download_testdata(font_url, font_name, module='data')
+coco_name = "coco.names"
+coco_url = REPO_URL + "data/" + coco_name + "?raw=true"
+font_name = "arial.ttf"
+font_url = REPO_URL + "data/" + font_name + "?raw=true"
+coco_path = download_testdata(coco_url, coco_name, module="data")
+font_path = download_testdata(font_url, font_name, module="data")
 
 with open(coco_path) as f:
     content = f.readlines()
 
 names = [x.strip() for x in content]
 
-tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
+tvm.relay.testing.yolo_detection.draw_detections(
+    font_path, img, dets, thresh, names, last_layer.classes
+)
 plt.imshow(img.transpose(1, 2, 0))
 plt.show()
index 7ece790..a68df55 100644 (file)
@@ -45,12 +45,17 @@ import numpy as np
 # Load pretrained keras model
 # ----------------------------
 # We load a pretrained resnet-50 classification model provided by keras.
-weights_url = ''.join(['https://github.com/fchollet/deep-learning-models/releases/',
-                       'download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'])
-weights_file = 'resnet50_weights.h5'
-weights_path = download_testdata(weights_url, weights_file, module='keras')
-keras_resnet50 = keras.applications.resnet50.ResNet50(include_top=True, weights=None,
-                                                      input_shape=(224, 224, 3), classes=1000)
+weights_url = "".join(
+    [
+        "https://github.com/fchollet/deep-learning-models/releases/",
+        "download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
+    ]
+)
+weights_file = "resnet50_weights.h5"
+weights_path = download_testdata(weights_url, weights_file, module="keras")
+keras_resnet50 = keras.applications.resnet50.ResNet50(
+    include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
+)
 keras_resnet50.load_weights(weights_path)
 
 ######################################################################
@@ -60,32 +65,33 @@ keras_resnet50.load_weights(weights_path)
 from PIL import Image
 from matplotlib import pyplot as plt
 from keras.applications.resnet50 import preprocess_input
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_path = download_testdata(img_url, 'cat.png', module='data')
+
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_path = download_testdata(img_url, "cat.png", module="data")
 img = Image.open(img_path).resize((224, 224))
 plt.imshow(img)
 plt.show()
 # input preprocess
-data = np.array(img)[np.newaxis, :].astype('float32')
+data = np.array(img)[np.newaxis, :].astype("float32")
 data = preprocess_input(data).transpose([0, 3, 1, 2])
-print('input_1', data.shape)
+print("input_1", data.shape)
 
 ######################################################################
 # Compile the model with Relay
 # ----------------------------
 # convert the keras model(NHWC layout) to Relay format(NCHW layout).
-shape_dict = {'input_1': data.shape}
+shape_dict = {"input_1": data.shape}
 mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
 # compile the model
-target = 'cuda'
+target = "cuda"
 ctx = tvm.gpu(0)
 with tvm.transform.PassContext(opt_level=3):
-    executor = relay.build_module.create_executor('graph', mod, ctx, target)
+    executor = relay.build_module.create_executor("graph", mod, ctx, target)
 
 ######################################################################
 # Execute on TVM
 # ---------------
-dtype = 'float32'
+dtype = "float32"
 tvm_out = executor.evaluate()(tvm.nd.array(data.astype(dtype)), **params)
 top1_tvm = np.argmax(tvm_out.asnumpy()[0])
 
@@ -93,16 +99,20 @@ top1_tvm = np.argmax(tvm_out.asnumpy()[0])
 # Look up synset name
 # -------------------
 # Look up prediction top 1 index in 1000 class synset.
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
-print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm]))
+print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm]))
 # confirm correctness with keras output
 keras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1]))
 top1_keras = np.argmax(keras_out)
-print('Keras top-1 id: {}, class name: {}'.format(top1_keras, synset[top1_keras]))
+print("Keras top-1 id: {}, class name: {}".format(top1_keras, synset[top1_keras]))
index d75ec00..d81b211 100644 (file)
@@ -49,31 +49,38 @@ from tvm.contrib.download import download_testdata
 from mxnet.gluon.model_zoo.vision import get_model
 from PIL import Image
 from matplotlib import pyplot as plt
-block = get_model('resnet18_v1', pretrained=True)
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_name = 'cat.png'
-synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
-                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
-                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
-                      'imagenet1000_clsid_to_human.txt'])
-synset_name = 'imagenet1000_clsid_to_human.txt'
-img_path = download_testdata(img_url, 'cat.png', module='data')
-synset_path = download_testdata(synset_url, synset_name, module='data')
+
+block = get_model("resnet18_v1", pretrained=True)
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_name = "cat.png"
+synset_url = "".join(
+    [
+        "https://gist.githubusercontent.com/zhreshold/",
+        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
+        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
+        "imagenet1000_clsid_to_human.txt",
+    ]
+)
+synset_name = "imagenet1000_clsid_to_human.txt"
+img_path = download_testdata(img_url, "cat.png", module="data")
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synset = eval(f.read())
 image = Image.open(img_path).resize((224, 224))
 plt.imshow(image)
 plt.show()
 
+
 def transform_image(image):
-    image = np.array(image) - np.array([123., 117., 104.])
+    image = np.array(image) - np.array([123.0, 117.0, 104.0])
     image /= np.array([58.395, 57.12, 57.375])
     image = image.transpose((2, 0, 1))
     image = image[np.newaxis, :]
     return image
 
+
 x = transform_image(image)
-print('x', x.shape)
+print("x", x.shape)
 
 ######################################################################
 # Compile the Graph
@@ -81,7 +88,7 @@ print('x', x.shape)
 # Now we would like to port the Gluon model to a portable computational graph.
 # It's as easy as several lines.
 # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
-shape_dict = {'data': x.shape}
+shape_dict = {"data": x.shape}
 mod, params = relay.frontend.from_mxnet(block, shape_dict)
 ## we want a probability so add a softmax operator
 func = mod["main"]
@@ -89,7 +96,7 @@ func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_
 
 ######################################################################
 # now compile the graph
-target = 'cuda'
+target = "cuda"
 with tvm.transform.PassContext(opt_level=3):
     lib = relay.build(func, target, params=params)
 
@@ -98,17 +105,18 @@ with tvm.transform.PassContext(opt_level=3):
 # ---------------------------------
 # Now, we would like to reproduce the same forward computation using TVM.
 from tvm.contrib import graph_runtime
+
 ctx = tvm.gpu(0)
-dtype = 'float32'
-m = graph_runtime.GraphModule(lib['default'](ctx))
+dtype = "float32"
+m = graph_runtime.GraphModule(lib["default"](ctx))
 # set inputs
-m.set_input('data', tvm.nd.array(x.astype(dtype)))
+m.set_input("data", tvm.nd.array(x.astype(dtype)))
 # execute
 m.run()
 # get outputs
 tvm_output = m.get_output(0)
 top1 = np.argmax(tvm_output.asnumpy()[0])
-print('TVM prediction top-1:', top1, synset[top1])
+print("TVM prediction top-1:", top1, synset[top1])
 
 ######################################################################
 # Use MXNet symbol with pretrained weights
@@ -116,22 +124,23 @@ print('TVM prediction top-1:', top1, synset[top1])
 # MXNet often use `arg_params` and `aux_params` to store network parameters
 # separately, here we show how to use these weights with existing API
 def block2symbol(block):
-    data = mx.sym.Variable('data')
+    data = mx.sym.Variable("data")
     sym = block(data)
     args = {}
     auxs = {}
     for k, v in block.collect_params().items():
         args[k] = mx.nd.array(v.data().asnumpy())
     return sym, args, auxs
+
+
 mx_sym, args, auxs = block2symbol(block)
 # usually we would save/load it as checkpoint
-mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
+mx.model.save_checkpoint("resnet18_v1", 0, mx_sym, args, auxs)
 # there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk
 
 ######################################################################
 # for a normal mxnet model, we start from here
-mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
+mx_sym, args, auxs = mx.model.load_checkpoint("resnet18_v1", 0)
 # now we use the same API to get Relay computation graph
-mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict,
-                                              arg_params=args, aux_params=auxs)
+mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, arg_params=args, aux_params=auxs)
 # repeat the same steps to run this model using TVM
index 9973a08..e68a398 100644 (file)
@@ -45,11 +45,15 @@ from tvm.contrib.download import download_testdata
 # The example super resolution model used here is exactly the same model in onnx tutorial
 # http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
 # we skip the pytorch model construction part, and download the saved onnx model
-model_url = ''.join(['https://gist.github.com/zhreshold/',
-                     'bcda4716699ac97ea44f791c24310193/raw/',
-                     '93672b029103648953c4e5ad3ac3aadf346a4cdc/',
-                     'super_resolution_0.2.onnx'])
-model_path = download_testdata(model_url, 'super_resolution.onnx', module='onnx')
+model_url = "".join(
+    [
+        "https://gist.github.com/zhreshold/",
+        "bcda4716699ac97ea44f791c24310193/raw/",
+        "93672b029103648953c4e5ad3ac3aadf346a4cdc/",
+        "super_resolution_0.2.onnx",
+    ]
+)
+model_path = download_testdata(model_url, "super_resolution.onnx", module="onnx")
 # now you have super_resolution.onnx on disk
 onnx_model = onnx.load(model_path)
 
@@ -58,8 +62,9 @@ onnx_model = onnx.load(model_path)
 # ---------------------------------------------
 # A single cat dominates the examples!
 from PIL import Image
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_path = download_testdata(img_url, 'cat.png', module='data')
+
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_path = download_testdata(img_url, "cat.png", module="data")
 img = Image.open(img_path).resize((224, 224))
 img_ycbcr = img.convert("YCbCr")  # convert to YCbCr
 img_y, img_cb, img_cr = img_ycbcr.split()
@@ -68,19 +73,19 @@ x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
 ######################################################################
 # Compile the model with relay
 # ---------------------------------------------
-target = 'llvm'
+target = "llvm"
 
-input_name = '1'
+input_name = "1"
 shape_dict = {input_name: x.shape}
 mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
 
 with tvm.transform.PassContext(opt_level=1):
-    intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), target)
+    intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target)
 
 ######################################################################
 # Execute on TVM
 # ---------------------------------------------
-dtype = 'float32'
+dtype = "float32"
 tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
 
 ######################################################################
@@ -88,11 +93,12 @@ tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
 # ---------------------------------------------
 # We put input and output image neck to neck
 from matplotlib import pyplot as plt
-out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
+
+out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode="L")
 out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
 out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
-result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
-canvas = np.full((672, 672*2, 3), 255)
+result = Image.merge("YCbCr", [out_y, out_cb, out_cr]).convert("RGB")
+canvas = np.full((672, 672 * 2, 3), 255)
 canvas[0:224, 0:224, :] = np.asarray(img)
 canvas[:, 672:, :] = np.asarray(result)
 plt.imshow(canvas.astype(np.uint8))
index b0639f5..2328651 100644 (file)
@@ -55,7 +55,7 @@ import torchvision
 ######################################################################
 # Load a pretrained PyTorch model
 # -------------------------------
-model_name = 'resnet18'
+model_name = "resnet18"
 model = getattr(torchvision.models, model_name)(pretrained=True)
 model = model.eval()
 
@@ -69,19 +69,22 @@ scripted_model = torch.jit.trace(model, input_data).eval()
 # -----------------
 # Classic cat example!
 from PIL import Image
-img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-img_path = download_testdata(img_url, 'cat.png', module='data')
+
+img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+img_path = download_testdata(img_url, "cat.png", module="data")
 img = Image.open(img_path).resize((224, 224))
 
 # Preprocess the image and convert to tensor
 from torchvision import transforms
-my_preprocess = transforms.Compose([
-    transforms.Resize(256),
-    transforms.CenterCrop(224),
-    transforms.ToTensor(),
-    transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                         std=[0.229, 0.224, 0.225])
-])
+
+my_preprocess = transforms.Compose(
+    [
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    ]
+)
 img = my_preprocess(img)
 img = np.expand_dims(img, 0)
 
@@ -89,17 +92,16 @@ img = np.expand_dims(img, 0)
 # Import the graph to Relay
 # -------------------------
 # Convert PyTorch graph to Relay graph. The input name can be arbitrary.
-input_name = 'input0'
+input_name = "input0"
 shape_list = [(input_name, img.shape)]
-mod, params = relay.frontend.from_pytorch(scripted_model,
-                                          shape_list)
+mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
 
 ######################################################################
 # Relay Build
 # -----------
 # Compile the graph to llvm target with given input specification.
-target = 'llvm'
-target_host = 'llvm'
+target = "llvm"
+target_host = "llvm"
 ctx = tvm.cpu(0)
 with tvm.transform.PassContext(opt_level=3):
     lib = relay.build(mod, target=target, target_host=target_host, params=params)
@@ -109,8 +111,9 @@ with tvm.transform.PassContext(opt_level=3):
 # ---------------------------------
 # Now we can try deploying the compiled model on target.
 from tvm.contrib import graph_runtime
-dtype = 'float32'
-m = graph_runtime.GraphModule(lib['default'](ctx))
+
+dtype = "float32"
+m = graph_runtime.GraphModule(lib["default"](ctx))
 # Set inputs
 m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
 # Execute
@@ -122,23 +125,31 @@ tvm_output = m.get_output(0)
 # Look up synset name
 # -------------------
 # Look up prediction top 1 index in 1000 class synset.
-synset_url = ''.join(['https://raw.githubusercontent.com/Cadene/',
-                      'pretrained-models.pytorch/master/data/',
-                      'imagenet_synsets.txt'])
-synset_name = 'imagenet_synsets.txt'
-synset_path = download_testdata(synset_url, synset_name, module='data')
+synset_url = "".join(
+    [
+        "https://raw.githubusercontent.com/Cadene/",
+        "pretrained-models.pytorch/master/data/",
+        "imagenet_synsets.txt",
+    ]
+)
+synset_name = "imagenet_synsets.txt"
+synset_path = download_testdata(synset_url, synset_name, module="data")
 with open(synset_path) as f:
     synsets = f.readlines()
 
 synsets = [x.strip() for x in synsets]
-splits = [line.split(' ') for line in synsets]
-key_to_classname = {spl[0]:' '.join(spl[1:]) for spl in splits}
-
-class_url = ''.join(['https://raw.githubusercontent.com/Cadene/',
-                      'pretrained-models.pytorch/master/data/',
-                      'imagenet_classes.txt'])
-class_name = 'imagenet_classes.txt'
-class_path = download_testdata(class_url, class_name, module='data')
+splits = [line.split(" ") for line in synsets]
+key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
+
+class_url = "".join(
+    [
+        "https://raw.githubusercontent.com/Cadene/",
+        "pretrained-models.pytorch/master/data/",
+        "imagenet_classes.txt",
+    ]
+)
+class_name = "imagenet_classes.txt"
+class_path = download_testdata(class_url, class_name, module="data")
 with open(class_path) as f:
     class_id_to_key = f.readlines()
 
@@ -157,5 +168,5 @@ with torch.no_grad():
     top1_torch = np.argmax(output.numpy())
     torch_class_key = class_id_to_key[top1_torch]
 
-print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key]))
-print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))
+print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
+print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
index 10d505c..a3e8173 100644 (file)
@@ -35,6 +35,7 @@ import os.path
 
 # Tensorflow imports
 import tensorflow as tf
+
 try:
     tf_compat_v1 = tf.compat.v1
 except ImportError:
@@ -44,10 +45,10 @@ except ImportError:
 import tvm.relay.testing.tf as tf_testing
 
 # Base location for model related files.
-repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
 
 # Test image
-img_name = 'elephant-299.jpg'
+img_name = "elephant-299.jpg"
 image_url = os.path.join(repo_base, img_name)
 
 ######################################################################
@@ -56,25 +57,25 @@ image_url = os.path.join(repo_base, img_name)
 # Please refer docs/frontend/tensorflow.md for more details for various models
 # from tensorflow.
 
-model_name = 'classify_image_graph_def-with_shapes.pb'
+model_name = "classify_image_graph_def-with_shapes.pb"
 model_url = os.path.join(repo_base, model_name)
 
 # Image label map
-map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
+map_proto = "imagenet_2012_challenge_label_map_proto.pbtxt"
 map_proto_url = os.path.join(repo_base, map_proto)
 
 # Human readable text for labels
-label_map = 'imagenet_synset_to_human_label_map.txt'
+label_map = "imagenet_synset_to_human_label_map.txt"
 label_map_url = os.path.join(repo_base, label_map)
 
 # Target settings
 # Use these commented settings to build for cuda.
-#target = 'cuda'
-#target_host = 'llvm'
-#layout = "NCHW"
-#ctx = tvm.gpu(0)
-target = 'llvm'
-target_host = 'llvm'
+# target = 'cuda'
+# target_host = 'llvm'
+# layout = "NCHW"
+# ctx = tvm.gpu(0)
+target = "llvm"
+target_host = "llvm"
 layout = None
 ctx = tvm.cpu(0)
 
@@ -84,25 +85,25 @@ ctx = tvm.cpu(0)
 # Download files listed above.
 from tvm.contrib.download import download_testdata
 
-img_path = download_testdata(image_url, img_name, module='data')
-model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
-map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
-label_path = download_testdata(label_map_url, label_map, module='data')
+img_path = download_testdata(image_url, img_name, module="data")
+model_path = download_testdata(model_url, model_name, module=["tf", "InceptionV1"])
+map_proto_path = download_testdata(map_proto_url, map_proto, module="data")
+label_path = download_testdata(label_map_url, label_map, module="data")
 
 ######################################################################
 # Import model
 # ------------
 # Creates tensorflow graph definition from protobuf file.
 
-with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
+with tf_compat_v1.gfile.GFile(model_path, "rb") as f:
     graph_def = tf_compat_v1.GraphDef()
     graph_def.ParseFromString(f.read())
-    graph = tf.import_graph_def(graph_def, name='')
+    graph = tf.import_graph_def(graph_def, name="")
     # Call the utility to import the graph definition into default graph.
     graph_def = tf_testing.ProcessGraphDefParam(graph_def)
     # Add shapes to the graph.
     with tf_compat_v1.Session() as sess:
-        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
+        graph_def = tf_testing.AddShapesToGraphDef(sess, "softmax")
 
 ######################################################################
 # Decode image
@@ -115,6 +116,7 @@ with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
 #
 
 from PIL import Image
+
 image = Image.open(img_path).resize((299, 299))
 
 x = np.array(image)
@@ -127,11 +129,9 @@ x = np.array(image)
 # Results:
 #   sym: relay expr for given tensorflow protobuf.
 #   params: params converted from tensorflow params (tensor protobuf).
-shape_dict = {'DecodeJpeg/contents': x.shape}
-dtype_dict = {'DecodeJpeg/contents': 'uint8'}
-mod, params = relay.frontend.from_tensorflow(graph_def,
-                                             layout=layout,
-                                             shape=shape_dict)
+shape_dict = {"DecodeJpeg/contents": x.shape}
+dtype_dict = {"DecodeJpeg/contents": "uint8"}
+mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
 
 print("Tensorflow protobuf imported to relay frontend.")
 ######################################################################
@@ -153,14 +153,15 @@ with tvm.transform.PassContext(opt_level=3):
 # Now we can try deploying the compiled model on target.
 
 from tvm.contrib import graph_runtime
-dtype = 'uint8'
-m = graph_runtime.GraphModule(lib['default'](ctx))
+
+dtype = "uint8"
+m = graph_runtime.GraphModule(lib["default"](ctx))
 # set inputs
-m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
+m.set_input("DecodeJpeg/contents", tvm.nd.array(x.astype(dtype)))
 # execute
 m.run()
 # get outputs
-tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
+tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), "float32"))
 
 ######################################################################
 # Process the output
@@ -170,31 +171,32 @@ predictions = tvm_output.asnumpy()
 predictions = np.squeeze(predictions)
 
 # Creates node ID --> English string lookup.
-node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
-                                    uid_lookup_path=label_path)
+node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path, uid_lookup_path=label_path)
 
 # Print top 5 predictions from TVM output.
 top_k = predictions.argsort()[-5:][::-1]
 for node_id in top_k:
     human_string = node_lookup.id_to_string(node_id)
     score = predictions[node_id]
-    print('%s (score = %.5f)' % (human_string, score))
+    print("%s (score = %.5f)" % (human_string, score))
 
 ######################################################################
 # Inference on tensorflow
 # -----------------------
 # Run the corresponding model on tensorflow
 
+
 def create_graph():
     """Creates a graph from saved GraphDef file and returns a saver."""
     # Creates graph from saved graph_def.pb.
-    with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
+    with tf_compat_v1.gfile.GFile(model_path, "rb") as f:
         graph_def = tf_compat_v1.GraphDef()
         graph_def.ParseFromString(f.read())
-        graph = tf.import_graph_def(graph_def, name='')
+        graph = tf.import_graph_def(graph_def, name="")
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
+
 def run_inference_on_image(image):
     """Runs inference on an image.
 
@@ -208,29 +210,30 @@ def run_inference_on_image(image):
         Nothing
     """
     if not tf_compat_v1.gfile.Exists(image):
-        tf.logging.fatal('File does not exist %s', image)
-    image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()
+        tf.logging.fatal("File does not exist %s", image)
+    image_data = tf_compat_v1.gfile.GFile(image, "rb").read()
 
     # Creates graph from saved GraphDef.
     create_graph()
 
     with tf_compat_v1.Session() as sess:
-        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
-        predictions = sess.run(softmax_tensor,
-                               {'DecodeJpeg/contents:0': image_data})
+        softmax_tensor = sess.graph.get_tensor_by_name("softmax:0")
+        predictions = sess.run(softmax_tensor, {"DecodeJpeg/contents:0": image_data})
 
         predictions = np.squeeze(predictions)
 
         # Creates node ID --> English string lookup.
-        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
-                                            uid_lookup_path=label_path)
+        node_lookup = tf_testing.NodeLookup(
+            label_lookup_path=map_proto_path, uid_lookup_path=label_path
+        )
 
         # Print top 5 predictions from tensorflow.
         top_k = predictions.argsort()[-5:][::-1]
-        print ("===== TENSORFLOW RESULTS =======")
+        print("===== TENSORFLOW RESULTS =======")
         for node_id in top_k:
             human_string = node_lookup.id_to_string(node_id)
             score = predictions[node_id]
-            print('%s (score = %.5f)' % (human_string, score))
+            print("%s (score = %.5f)" % (human_string, score))
+
 
 run_inference_on_image(img_path)
index c0b2a03..ee7da62 100644 (file)
@@ -57,15 +57,17 @@ Below you can find an example on how to compile TFLite model using TVM.
 # ----------------------------------------------
 import os
 
+
 def extract(path):
     import tarfile
+
     if path.endswith("tgz") or path.endswith("gz"):
         dir_path = os.path.dirname(path)
         tar = tarfile.open(path)
         tar.extractall(path=dir_path)
         tar.close()
     else:
-        raise RuntimeError('Could not decompress the file: ' + path)
+        raise RuntimeError("Could not decompress the file: " + path)
 
 
 ######################################################################
@@ -77,7 +79,7 @@ from tvm.contrib.download import download_testdata
 model_url = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"
 
 # Download model tar file and extract it to get mobilenet_v1_1.0_224.tflite
-model_path = download_testdata(model_url, "mobilenet_v1_1.0_224.tgz", module=['tf', 'official'])
+model_path = download_testdata(model_url, "mobilenet_v1_1.0_224.tgz", module=["tf", "official"])
 model_dir = os.path.dirname(model_path)
 extract(model_path)
 
@@ -88,9 +90,11 @@ tflite_model_buf = open(tflite_model_file, "rb").read()
 # Get TFLite model from buffer
 try:
     import tflite
+
     tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
 except AttributeError:
     import tflite.Model
+
     tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
 
 ######################################################################
@@ -101,8 +105,8 @@ from PIL import Image
 from matplotlib import pyplot as plt
 import numpy as np
 
-image_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
-image_path = download_testdata(image_url, 'cat.png', module='data')
+image_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
+image_path = download_testdata(image_url, "cat.png", module="data")
 resized_image = Image.open(image_path).resize((224, 224))
 plt.imshow(resized_image)
 plt.show()
@@ -116,7 +120,7 @@ image_data = np.expand_dims(image_data, axis=0)
 image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
 image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
 image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
-print('input', image_data.shape)
+print("input", image_data.shape)
 
 ######################################################################
 # Compile the model with relay
@@ -129,9 +133,10 @@ input_dtype = "float32"
 
 # Parse TFLite model and convert it to a Relay module
 from tvm import relay, transform
-mod, params = relay.frontend.from_tflite(tflite_model,
-                                         shape_dict={input_tensor: input_shape},
-                                         dtype_dict={input_tensor: input_dtype})
+
+mod, params = relay.frontend.from_tflite(
+    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
+)
 
 # Build the module against to x86 CPU
 target = "llvm"
@@ -146,7 +151,7 @@ from tvm import te
 from tvm.contrib import graph_runtime as runtime
 
 # Create a runtime executor module
-module = runtime.GraphModule(lib['default'](tvm.cpu()))
+module = runtime.GraphModule(lib["default"](tvm.cpu()))
 
 # Feed input data
 module.set_input(input_tensor, tvm.nd.array(image_data))
@@ -162,12 +167,16 @@ tvm_output = module.get_output(0).asnumpy()
 # ---------------
 
 # Load label file
-label_file_url = ''.join(['https://raw.githubusercontent.com/',
-                          'tensorflow/tensorflow/master/tensorflow/lite/java/demo/',
-                          'app/src/main/assets/',
-                          'labels_mobilenet_quant_v1_224.txt'])
+label_file_url = "".join(
+    [
+        "https://raw.githubusercontent.com/",
+        "tensorflow/tensorflow/master/tensorflow/lite/java/demo/",
+        "app/src/main/assets/",
+        "labels_mobilenet_quant_v1_224.txt",
+    ]
+)
 label_file = "labels_mobilenet_quant_v1_224.txt"
-label_path = download_testdata(label_file_url, label_file, module='data')
+label_path = download_testdata(label_file_url, label_file, module="data")
 
 # List of 1001 classes
 with open(label_path) as f:
index bc47023..a150b68 100644 (file)
@@ -54,7 +54,9 @@ bn_beta = relay.var("bn_beta")
 bn_mmean = relay.var("bn_mean")
 bn_mvar = relay.var("bn_var")
 
-simple_net = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3,3), channels=out_channels, padding=(1, 1))
+simple_net = relay.nn.conv2d(
+    data=data, weight=weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1)
+)
 simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
 simple_net = relay.nn.relu(simple_net)
 simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)
@@ -68,14 +70,15 @@ net, params = testing.create_workload(simple_net)
 # We build and run this network with cuda backend, as usual.
 # By setting the logging level to DEBUG, the result of Relay graph compilation will be dumped as pseudo code.
 import logging
-logging.basicConfig(level=logging.DEBUG) # to dump TVM IR after fusion
+
+logging.basicConfig(level=logging.DEBUG)  # to dump TVM IR after fusion
 
 target = "cuda"
 lib = relay.build_module.build(net, target, params=params)
 
 ctx = tvm.context(target, 0)
 data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
-module = runtime.GraphModule(lib['default'](ctx))
+module = runtime.GraphModule(lib["default"](ctx))
 module.set_input("data", data)
 module.run()
 out_shape = (batch_size, out_channels, 224, 224)
@@ -491,12 +494,12 @@ out_cuda = out.asnumpy()
 # We can use cuDNN to replace convolution kernels with cuDNN ones.
 # To do that, all we need to do is to append the option " -libs=cudnn" to the target string.
 net, params = testing.create_workload(simple_net)
-target = "cuda -libs=cudnn" # use cudnn for convolution
+target = "cuda -libs=cudnn"  # use cudnn for convolution
 lib = relay.build_module.build(net, target, params=params)
 
 ctx = tvm.context(target, 0)
 data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
-module = runtime.GraphModule(lib['default'](ctx))
+module = runtime.GraphModule(lib["default"](ctx))
 module.set_input("data", data)
 module.run()
 out_shape = (batch_size, out_channels, 224, 224)
index eaf6f03..572ebb8 100644 (file)
@@ -101,8 +101,8 @@ from tvm import rpc
 from tvm.contrib import util
 
 n = tvm.runtime.convert(1024)
-A = te.placeholder((n,), name='A')
-B = te.compute((n,), lambda i: A[i] + 1.0, name='B')
+A = te.placeholder((n,), name="A")
+B = te.compute((n,), lambda i: A[i] + 1.0, name="B")
 s = te.create_schedule(B.op)
 
 ######################################################################
@@ -114,14 +114,14 @@ s = te.create_schedule(B.op)
 local_demo = True
 
 if local_demo:
-    target = 'llvm'
+    target = "llvm"
 else:
-    target = 'llvm -mtriple=armv7l-linux-gnueabihf'
+    target = "llvm -mtriple=armv7l-linux-gnueabihf"
 
-func = tvm.build(s, [A, B], target=target, name='add_one')
+func = tvm.build(s, [A, B], target=target, name="add_one")
 # save the lib at a local temp folder
 temp = util.tempdir()
-path = temp.relpath('lib.tar')
+path = temp.relpath("lib.tar")
 func.export_library(path)
 
 ######################################################################
@@ -168,7 +168,7 @@ if local_demo:
     remote = rpc.LocalSession()
 else:
     # The following is my environment, change this to the IP address of your target device
-    host = '10.77.1.162'
+    host = "10.77.1.162"
     port = 9090
     remote = rpc.connect(host, port)
 
@@ -177,7 +177,7 @@ else:
 # compiler to relink them. Now `func` is a remote module object.
 
 remote.upload(path)
-func = remote.load_module('lib.tar')
+func = remote.load_module("lib.tar")
 
 # create arrays on the remote device
 ctx = remote.cpu()
@@ -196,7 +196,7 @@ np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
 time_f = func.time_evaluator(func.entry_name, ctx, number=10)
 cost = time_f(a, b).mean
-print('%g secs/op' % cost)
+print("%g secs/op" % cost)
 
 #########################################################################
 # Run OpenCL Kernel Remotely by RPC
@@ -221,11 +221,12 @@ print('%g secs/op' % cost)
 #
 # The following function shows how we run an OpenCL kernel remotely
 
+
 def run_opencl():
     # NOTE: This is the setting for my rk3399 board. You need to modify
     # them according to your environment.
     target_host = "llvm -mtriple=aarch64-linux-gnu"
-    opencl_device_host = '10.77.1.145'
+    opencl_device_host = "10.77.1.145"
     opencl_device_port = 9090
 
     # create schedule for the above "add one" compute declaration
@@ -238,10 +239,10 @@ def run_opencl():
     remote = rpc.connect(opencl_device_host, opencl_device_port)
 
     # export and upload
-    path = temp.relpath('lib_cl.tar')
+    path = temp.relpath("lib_cl.tar")
     func.export_library(path)
     remote.upload(path)
-    func = remote.load_module('lib_cl.tar')
+    func = remote.load_module("lib_cl.tar")
 
     # run
     ctx = remote.cl()
@@ -251,6 +252,7 @@ def run_opencl():
     np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
     print("OpenCL test passed!")
 
+
 ######################################################################
 # Summary
 # -------
index e52a99a..437a22f 100644 (file)
@@ -66,7 +66,8 @@ data_shape = (batch_size,) + image_shape
 out_shape = (batch_size, num_class)
 
 mod, params = relay.testing.resnet.get_workload(
-    num_layers=18, batch_size=batch_size, image_shape=image_shape)
+    num_layers=18, batch_size=batch_size, image_shape=image_shape
+)
 
 # set show_meta_data=True if you want to show meta data
 print(mod.astext(show_meta_data=False))
index d31dc1e..76e0262 100644 (file)
@@ -35,9 +35,9 @@ import numpy as np
 
 # Global declarations of environment.
 
-tgt_host="llvm"
+tgt_host = "llvm"
 # Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt="cuda"
+tgt = "cuda"
 
 ######################################################################
 # Vector Add Example
@@ -66,8 +66,8 @@ tgt="cuda"
 # the computation should be done.
 #
 n = te.var("n")
-A = te.placeholder((n,), name='A')
-B = te.placeholder((n,), name='B')
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
@@ -116,9 +116,9 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
 # compute grid. These are GPU specific constructs that allow us
 # to generate code that runs on GPU.
 #
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
-  s[C].bind(bx, te.thread_axis("blockIdx.x"))
-  s[C].bind(tx, te.thread_axis("threadIdx.x"))
+if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
 ######################################################################
 # Compilation
@@ -171,7 +171,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 #
 # The following code fetches the device module and prints the content code.
 #
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
+if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
     dev_module = fadd.imported_modules[0]
     print("-----GPU code-----")
     print(dev_module.get_source())
@@ -217,7 +217,7 @@ if tgt == "cuda":
     fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
 if tgt == "rocm":
     fadd.imported_modules[0].save(temp.relpath("myadd.hsaco"))
-if tgt.startswith('opencl'):
+if tgt.startswith("opencl"):
     fadd.imported_modules[0].save(temp.relpath("myadd.cl"))
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
 print(temp.listdir())
@@ -247,7 +247,7 @@ if tgt == "rocm":
     fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.hsaco"))
     fadd1.import_module(fadd1_dev)
 
-if tgt.startswith('opencl'):
+if tgt.startswith("opencl"):
     fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.cl"))
     fadd1.import_module(fadd1_dev)
 
@@ -289,7 +289,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 # The following code blocks generate OpenCL code, creates array on an OpenCL
 # device, and verifies the correctness of the code.
 #
-if tgt.startswith('opencl'):
+if tgt.startswith("opencl"):
     fadd_cl = tvm.build(s, [A, B, C], tgt, name="myadd")
     print("------opencl code------")
     print(fadd_cl.imported_modules[0].get_source())
index c20339d..454237a 100644 (file)
@@ -57,14 +57,18 @@ if not tvm.get_global_func("tvm.contrib.cblas.matmul", allow_missing=True):
 n = 1024
 l = 128
 m = 235
-bias = te.var('bias', dtype="float32")
-A = te.placeholder((n, l), name='A')
-B = te.placeholder((l, m), name='B')
-C = te.extern((n, m), [A, B],
-               lambda ins, outs: tvm.tir.call_packed(
-                   "tvm.contrib.cblas.matmul",
-                   ins[0], ins[1], outs[0], False, False), name="C")
-D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
+bias = te.var("bias", dtype="float32")
+A = te.placeholder((n, l), name="A")
+B = te.placeholder((l, m), name="B")
+C = te.extern(
+    (n, m),
+    [A, B],
+    lambda ins, outs: tvm.tir.call_packed(
+        "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], False, False
+    ),
+    name="C",
+)
+D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name="D")
 s = te.create_schedule(D.op)
 
 ######################################################################
@@ -79,8 +83,7 @@ b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
 d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
 bb = 10.0
 f(a, b, d, bb)
-tvm.testing.assert_allclose(
-    d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 10, rtol=1e-5)
+tvm.testing.assert_allclose(d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 10, rtol=1e-5)
 
 ######################################################################
 # Extern Contrib Wrappers
@@ -89,8 +92,9 @@ tvm.testing.assert_allclose(
 # the following line is equivalent to the previous example.
 #
 from tvm.contrib import cblas
+
 C = cblas.matmul(A, B)
-D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
+D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name="D")
 s = te.create_schedule(D.op)
 
 ######################################################################
@@ -110,9 +114,14 @@ def my_tvm_addone(x, y):
     print("my_tvm_addone signatures: %s, %s" % (type(x), type(y)))
     tvm.nd.array(x.asnumpy() + 1).copyto(y)
 
-A = te.placeholder((n,), name='A')
-B = te.extern(A.shape, [A], lambda ins, outs: tvm.tir.call_packed(
-    "tvm.contrib.my_tvm_addone", ins[0], outs[0]), name="C")
+
+A = te.placeholder((n,), name="A")
+B = te.extern(
+    A.shape,
+    [A],
+    lambda ins, outs: tvm.tir.call_packed("tvm.contrib.my_tvm_addone", ins[0], outs[0]),
+    name="C",
+)
 s = te.create_schedule(B.op)
 f = tvm.build(s, [A, B], "llvm")
 a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
index 4a4ff96..1453225 100644 (file)
@@ -43,10 +43,8 @@ import numpy as np
 # :code:`__expf` function, which is only available under CUDA.
 #
 n = te.var("n")
-A = te.placeholder((n,), name='A')
-B = te.compute(A.shape,
-                lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]),
-                name="B")
+A = te.placeholder((n,), name="A")
+B = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]), name="B")
 s = te.create_schedule(B.op)
 num_thread = 64
 bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
@@ -69,7 +67,7 @@ print(f.imported_modules[0].get_source())
 # :py::func:`tvm.te.exp` to do the exponential.
 #
 n = te.var("n")
-A = te.placeholder((n,), name='A')
+A = te.placeholder((n,), name="A")
 B = te.compute(A.shape, lambda i: te.exp(A[i]), name="B")
 s = te.create_schedule(B.op)
 num_thread = 64
@@ -147,12 +145,13 @@ def my_cuda_mylog_rule(op):
     else:
         return op
 
+
 # new op registration is triggered by registering an attribute of the op
 tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
 tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
 
 n = te.var("n")
-A = te.placeholder((n,), name='A')
+A = te.placeholder((n,), name="A")
 B = te.compute(A.shape, lambda i: mylog(A[i]), name="B")
 s = te.create_schedule(B.op)
 num_thread = 64
index cdfc94e..ecefc28 100644 (file)
@@ -56,7 +56,7 @@ import numpy as np
 #
 n = te.var("n")
 m = te.var("m")
-A = te.placeholder((n, m), name='A')
+A = te.placeholder((n, m), name="A")
 k = te.reduce_axis((0, m), "k")
 B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
 
@@ -136,12 +136,11 @@ print(fcuda.imported_modules[0].get_source())
 # Verify the correctness of result kernel by comparing it to numpy.
 #
 nn = 128
-ctx  = tvm.gpu(0)
+ctx = tvm.gpu(0)
 a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx)
 b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
 fcuda(a, b)
-tvm.testing.assert_allclose(
-    b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)
+tvm.testing.assert_allclose(b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
 
 ######################################################################
 # Describe Convolution via 2D Reduction
@@ -149,15 +148,16 @@ tvm.testing.assert_allclose(
 # In TVM, we can describe convolution via 2D reduction in a simple way.
 # Here is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1].
 #
-n = te.var('n')
-Input = te.placeholder((n, n), name='Input')
-Filter = te.placeholder((3, 3), name='Filter')
-di = te.reduce_axis((0, 3), name='di')
-dj = te.reduce_axis((0, 3), name='dj')
+n = te.var("n")
+Input = te.placeholder((n, n), name="Input")
+Filter = te.placeholder((3, 3), name="Filter")
+di = te.reduce_axis((0, 3), name="di")
+dj = te.reduce_axis((0, 3), name="dj")
 Output = te.compute(
     (n - 2, n - 2),
     lambda i, j: te.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),
-    name='Output')
+    name="Output",
+)
 s = te.create_schedule(Output.op)
 print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))
 
@@ -171,13 +171,12 @@ print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))
 # commutative reduction operation by :any:`te.comm_reducer`.
 #
 
-n = te.var('n')
-m = te.var('m')
-product = te.comm_reducer(lambda x, y: x*y,
-    lambda t: tvm.tir.const(1, dtype=t), name="product")
-A = te.placeholder((n, m), name='A')
-k = te.reduce_axis((0, m), name='k')
-B = te.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
+n = te.var("n")
+m = te.var("m")
+product = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name="product")
+A = te.placeholder((n, m), name="A")
+k = te.reduce_axis((0, m), name="k")
+B = te.compute((n,), lambda i: product(A[i, k], axis=k), name="B")
 
 ######################################################################
 # .. note::
index 73790da..fdb6ec9 100644 (file)
@@ -52,7 +52,7 @@ n = te.var("n")
 X = te.placeholder((m, n), name="X")
 s_state = te.placeholder((m, n))
 s_init = te.compute((1, n), lambda _, i: X[0, i])
-s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
 s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])
 
 ######################################################################
@@ -106,7 +106,7 @@ n = te.var("n")
 X = te.placeholder((m, n), name="X")
 s_state = te.placeholder((m, n))
 s_init = te.compute((1, n), lambda _, i: X[0, i])
-s_update_s1 = te.compute((m, n), lambda t, i: s_state[t-1, i] * 2, name="s1")
+s_update_s1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] * 2, name="s1")
 s_update_s2 = te.compute((m, n), lambda t, i: s_update_s1[t, i] + X[t, i], name="s2")
 s_scan = tvm.te.scan(s_init, s_update_s2, s_state, inputs=[X])
 
@@ -135,11 +135,11 @@ s_state1 = te.placeholder((m, n))
 s_state2 = te.placeholder((m, l))
 s_init1 = te.compute((1, n), lambda _, i: X[0, i])
 s_init2 = te.compute((1, l), lambda _, i: 0.0)
-s_update1 = te.compute((m, n), lambda t, i: s_state1[t-1, i] + X[t, i])
-s_update2 = te.compute((m, l), lambda t, i: s_state2[t-1, i] + s_state1[t-1, 0])
-s_scan1, s_scan2 = tvm.te.scan([s_init1, s_init2],
-                            [s_update1, s_update2],
-                            [s_state1, s_state2], inputs=[X])
+s_update1 = te.compute((m, n), lambda t, i: s_state1[t - 1, i] + X[t, i])
+s_update2 = te.compute((m, l), lambda t, i: s_state2[t - 1, i] + s_state1[t - 1, 0])
+s_scan1, s_scan2 = tvm.te.scan(
+    [s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2], inputs=[X]
+)
 s = te.create_schedule(s_scan1.op)
 print(tvm.lower(s, [X, s_scan1, s_scan2], simple_mode=True))
 
index 61bfcad..eb48dc2 100644 (file)
@@ -42,17 +42,17 @@ import numpy as np
 #
 
 # declare some variables for use later
-n = te.var('n')
-m = te.var('m')
+n = te.var("n")
+m = te.var("m")
 
 ######################################################################
 # A schedule can be created from a list of ops, by default the
 # schedule computes tensor in a serial manner in a row-major order.
 
 # declare a matrix element-wise multiply
-A = te.placeholder((m, n), name='A')
-B = te.placeholder((m, n), name='B')
-C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')
+A = te.placeholder((m, n), name="A")
+B = te.placeholder((m, n), name="B")
+C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
 
 s = te.create_schedule([C.op])
 # lower will transform the computation from definition to the real
@@ -71,8 +71,8 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # -----
 # :code:`split` can split a specified axis into two axises by
 # :code:`factor`.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i]*2, name='B')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i] * 2, name="B")
 
 s = te.create_schedule(B.op)
 xo, xi = s[B].split(B.op.axis[0], factor=32)
@@ -81,8 +81,8 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 ######################################################################
 # You can also split a axis by :code:`nparts`, which splits the axis
 # contrary with :code:`factor`.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i], name='B')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i], name="B")
 
 s = te.create_schedule(B.op)
 bx, tx = s[B].split(B.op.axis[0], nparts=32)
@@ -93,8 +93,8 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 # ----
 # :code:`tile` help you execute the computation tile by tile over two
 # axises.
-A = te.placeholder((m, n), name='A')
-B = te.compute((m, n), lambda i, j: A[i, j], name='B')
+A = te.placeholder((m, n), name="A")
+B = te.compute((m, n), lambda i, j: A[i, j], name="B")
 
 s = te.create_schedule(B.op)
 xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
@@ -104,8 +104,8 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 # fuse
 # ----
 # :code:`fuse` can fuse two consecutive axises of one computation.
-A = te.placeholder((m, n), name='A')
-B = te.compute((m, n), lambda i, j: A[i, j], name='B')
+A = te.placeholder((m, n), name="A")
+B = te.compute((m, n), lambda i, j: A[i, j], name="B")
 
 s = te.create_schedule(B.op)
 # tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
@@ -118,8 +118,8 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 # reorder
 # -------
 # :code:`reorder` can reorder the axises in the specified order.
-A = te.placeholder((m, n), name='A')
-B = te.compute((m, n), lambda i, j: A[i, j], name='B')
+A = te.placeholder((m, n), name="A")
+B = te.compute((m, n), lambda i, j: A[i, j], name="B")
 
 s = te.create_schedule(B.op)
 # tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
@@ -133,8 +133,8 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 # ----
 # :code:`bind` can bind a specified axis with a thread axis, often used
 # in gpu programming.
-A = te.placeholder((n,), name='A')
-B = te.compute(A.shape, lambda i: A[i] * 2, name='B')
+A = te.placeholder((n,), name="A")
+B = te.compute(A.shape, lambda i: A[i] * 2, name="B")
 
 s = te.create_schedule(B.op)
 bx, tx = s[B].split(B.op.axis[0], factor=64)
@@ -147,9 +147,9 @@ print(tvm.lower(s, [A, B], simple_mode=True))
 # ----------
 # For a schedule that consists of multiple operators, TVM will compute
 # tensors at the root separately by default.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i]+1, name='B')
-C = te.compute((m,), lambda i: B[i]*2, name='C')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i] + 1, name="B")
+C = te.compute((m,), lambda i: B[i] * 2, name="C")
 
 s = te.create_schedule(C.op)
 print(tvm.lower(s, [A, B, C], simple_mode=True))
@@ -157,9 +157,9 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 ######################################################################
 # :code:`compute_at` can move computation of `B` into the first axis
 # of computation of `C`.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i]+1, name='B')
-C = te.compute((m,), lambda i: B[i]*2, name='C')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i] + 1, name="B")
+C = te.compute((m,), lambda i: B[i] * 2, name="C")
 
 s = te.create_schedule(C.op)
 s[B].compute_at(s[C], C.op.axis[0])
@@ -171,9 +171,9 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # :code:`compute_inline` can mark one stage as inline, then the body of
 # computation will be expanded and inserted at the address where the
 # tensor is required.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i]+1, name='B')
-C = te.compute((m,), lambda i: B[i]*2, name='C')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i] + 1, name="B")
+C = te.compute((m,), lambda i: B[i] * 2, name="C")
 
 s = te.create_schedule(C.op)
 s[B].compute_inline()
@@ -183,9 +183,9 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # compute_root
 # ------------
 # :code:`compute_root` can move computation of one stage to the root.
-A = te.placeholder((m,), name='A')
-B = te.compute((m,), lambda i: A[i]+1, name='B')
-C = te.compute((m,), lambda i: B[i]*2, name='C')
+A = te.placeholder((m,), name="A")
+B = te.compute((m,), lambda i: A[i] + 1, name="B")
+C = te.compute((m,), lambda i: B[i] * 2, name="C")
 
 s = te.create_schedule(C.op)
 s[B].compute_at(s[C], C.op.axis[0])
index 6d22037..e0b8038 100644 (file)
@@ -56,11 +56,11 @@ num_filter = 256
 kernel = 3
 stride = 1
 padding = "SAME"
-dilation=1
+dilation = 1
 
-A = te.placeholder((in_size, in_size, in_channel, batch), name='A')
-W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-B = te.placeholder((1, num_filter, 1), name='bias')
+A = te.placeholder((in_size, in_size, in_channel, batch), name="A")
+W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+B = te.placeholder((1, num_filter, 1), name="bias")
 
 with tvm.target.Target("llvm"):
     t_conv = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
@@ -77,8 +77,8 @@ with tvm.target.Target("llvm"):
 # to render SVG figures showing in notebook directly.
 #
 
-tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
-#tedd.viz_dataflow_graph(s, show_svg = True)
+tedd.viz_dataflow_graph(s, dot_file_path="/tmp/dfg.dot")
+# tedd.viz_dataflow_graph(s, show_svg = True)
 
 ######################################################################
 # .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_dfg.png
@@ -89,8 +89,8 @@ tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
 # Edges show nodes' dependency.
 #
 
-tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
-#tedd.viz_schedule_tree(s, show_svg = True)
+tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree.dot")
+# tedd.viz_schedule_tree(s, show_svg = True)
 
 ######################################################################
 # We just rendered the schedule tree graph.  You may notice an warning about ranges not
@@ -101,8 +101,8 @@ tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
 #
 
 s = s.normalize()
-tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
-#tedd.viz_schedule_tree(s, show_svg = True)
+tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree2.dot")
+# tedd.viz_schedule_tree(s, show_svg = True)
 
 ######################################################################
 # .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_st.png
@@ -134,8 +134,8 @@ tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
 #   omitted, making every stage a block, for better readability.
 #
 
-tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
-#tedd.viz_itervar_relationship_graph(s, show_svg = True)
+tedd.viz_itervar_relationship_graph(s, dot_file_path="/tmp/itervar.dot")
+# tedd.viz_itervar_relationship_graph(s, show_svg = True)
 
 ######################################################################
 # .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_itervar_rel.png
index ac5b50f..601adb8 100644 (file)
@@ -47,11 +47,10 @@ import numpy as np
 # The following lines describe the computation :code:`A * B^T` in TVM.
 #
 N, M, L = 1024, 512, 64
-A = te.placeholder((N, L), name='A')
-B = te.placeholder((M, L), name='B')
-k = te.reduce_axis((0, L), name='k')
-C = te.compute((N, M), lambda i, j:
-                te.sum(A[i, k] * B[j, k], axis=k), name='C')
+A = te.placeholder((N, L), name="A")
+B = te.placeholder((M, L), name="B")
+k = te.reduce_axis((0, L), name="k")
+C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name="C")
 s = te.create_schedule(C.op)
 print(tvm.lower(s, [A, B, C], simple_mode=True))
 
@@ -66,7 +65,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 #
 factor = 16
 x, y = C.op.axis
-z, = C.op.reduce_axis
+(z,) = C.op.reduce_axis
 yo, yi = s[C].split(y, factor=factor)
 s[C].reorder(x, yo, yi, z)
 print(tvm.lower(s, [A, B, C], simple_mode=True))
@@ -89,34 +88,35 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # which is done in :code:`intrin_func` below.
 #
 def intrin_gemv(m, l):
-    a = te.placeholder((l,), name='a')
-    b = te.placeholder((m, l), name='b')
-    k = te.reduce_axis((0, l), name='k')
-    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name='c')
-    Ab = tvm.tir.decl_buffer(a.shape, a.dtype,
-                         name="A",
-                         offset_factor=1,
-                         strides=[1])
-    Bb = tvm.tir.decl_buffer(b.shape, b.dtype,
-                         name="B",
-                         offset_factor=1,
-                         strides=[te.var("s1"), 1])
-    Cb = tvm.tir.decl_buffer(c.shape, c.dtype,
-                         name="C",
-                         offset_factor=1,
-                         strides=[1])
+    a = te.placeholder((l,), name="a")
+    b = te.placeholder((m, l), name="b")
+    k = te.reduce_axis((0, l), name="k")
+    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
+    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
+    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
+    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
+
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
         aa, bb = ins
         cc = outs[0]
-        ib.emit(tvm.tir.call_extern("int32", "gemv_update",
-                                cc.access_ptr("w"),
-                                aa.access_ptr("r"),
-                                bb.access_ptr("r"),
-                                m, l, bb.strides[0]))
+        ib.emit(
+            tvm.tir.call_extern(
+                "int32",
+                "gemv_update",
+                cc.access_ptr("w"),
+                aa.access_ptr("r"),
+                bb.access_ptr("r"),
+                m,
+                l,
+                bb.strides[0],
+            )
+        )
         return ib.get()
+
     return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
 
+
 ######################################################################
 # Here :code:`te.decl_tensor_intrin` declares how to execute the computation :code:`c.op`.
 # Our implementation simply takes the inputs and outputs,
@@ -161,12 +161,14 @@ def gemv_impl():
       }
     """
     from tvm.contrib import util, clang
+
     temp = util.tempdir()
     ll_path = temp.relpath("temp.ll")
     # Create LLVM ir from c source code
     ll_code = clang.create_llvm(cc_code, output=ll_path)
     return ll_code
 
+
 ######################################################################
 # Now we leverage the pragma attribute :code:`import_llvm` to import llvm asm inline.
 # The importing needs to happen before the tensorized GEMV being executed.
@@ -181,6 +183,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
 
 from tvm.topi.util import get_const_tuple
+
 dtype = A.dtype
 ctx = tvm.context("cpu", 0)
 a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
@@ -226,50 +229,56 @@ def gemv_impl():
       }
     """
     from tvm.contrib import util, clang
+
     temp = util.tempdir()
     ll_path = temp.relpath("temp.ll")
     # Create LLVM ir from c source code
     ll_code = clang.create_llvm(cc_code, output=ll_path)
     return ll_code
 
+
 def intrin_gemv(m, l):
-    a = te.placeholder((l,), name='a')
-    b = te.placeholder((m, l), name='b')
-    k = te.reduce_axis((0, l), name='k')
-    c = te.compute((m,), lambda i:
-    te.sum(a[k] * b[i, k], axis=k), name='c')
-    Ab = tvm.tir.decl_buffer(a.shape, a.dtype,
-                         name="A",
-                         offset_factor=1,
-                         strides=[1])
-    Bb = tvm.tir.decl_buffer(b.shape, b.dtype,
-                         name="B",
-                         offset_factor=1,
-                         strides=[te.var("s1"), 1])
-    Cb = tvm.tir.decl_buffer(c.shape, c.dtype,
-                         name="C",
-                         offset_factor=1,
-                         strides=[1])
+    a = te.placeholder((l,), name="a")
+    b = te.placeholder((m, l), name="b")
+    k = te.reduce_axis((0, l), name="k")
+    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
+    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
+    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
+    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
+
     def intrin_func(ins, outs):
         aa, bb = ins
         cc = outs[0]
+
         def _body():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_extern("int32", "gemv_update",
-                                    cc.access_ptr("w"),
-                                    aa.access_ptr("r"),
-                                    bb.access_ptr("r"),
-                                    m, l, bb.strides[0]))
+            ib.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    "gemv_update",
+                    cc.access_ptr("w"),
+                    aa.access_ptr("r"),
+                    bb.access_ptr("r"),
+                    m,
+                    l,
+                    bb.strides[0],
+                )
+            )
             return ib.get()
+
         def _reduce_reset():
             ib = tvm.tir.ir_builder.create()
             ib.emit(tvm.tir.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m))
             return ib.get()
+
         def _reduce_update():
             return _body()
+
         return _body(), _reduce_reset(), _reduce_update()
+
     return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
 
+
 ######################################################################
 # Note that :code:`intrin_func` now returns a triplet:
 # :code:`(body, reduce_reset, reduce_update)`.
index 828797a..73db7b9 100644 (file)
@@ -40,9 +40,9 @@ import numpy as np
 #
 n = te.var("n")
 m = te.var("m")
-A0 = te.placeholder((m, n), name='A0')
-A1 = te.placeholder((m, n), name='A1')
-B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B')
+A0 = te.placeholder((m, n), name="A0")
+A1 = te.placeholder((m, n), name="A1")
+B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")
 
 # The generated IR code would be:
 s = te.create_schedule(B0.op)
@@ -66,20 +66,22 @@ def fcombine(x, y):
     rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
     return lhs, rhs
 
+
 # our identity element also need to be a tuple, so `fidentity` accepts
 # two types as inputs.
 def fidentity(t0, t1):
     return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
 
-argmax = te.comm_reducer(fcombine, fidentity, name='argmax')
+
+argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
 
 # describe the reduction computation
-m = te.var('m')
-n = te.var('n')
-idx = te.placeholder((m, n), name='idx', dtype='int32')
-val = te.placeholder((m, n), name='val', dtype='int32')
-k = te.reduce_axis((0, n), 'k')
-T0, T1 = te.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T')
+m = te.var("m")
+n = te.var("n")
+idx = te.placeholder((m, n), name="idx", dtype="int32")
+val = te.placeholder((m, n), name="val", dtype="int32")
+k = te.reduce_axis((0, n), "k")
+T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
 
 # the generated IR code would be:
 s = te.create_schedule(T0.op)
@@ -100,10 +102,10 @@ print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
 
 n = te.var("n")
 m = te.var("m")
-A0 = te.placeholder((m, n), name='A0')
-B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
-A1 = te.placeholder((m, n), name='A1')
-C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')
+A0 = te.placeholder((m, n), name="A0")
+B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name="B")
+A1 = te.placeholder((m, n), name="A1")
+C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name="C")
 
 s = te.create_schedule(C.op)
 s[B0].compute_at(s[C], C.op.axis[0])
index 0b0af57..ce30c0a 100644 (file)
@@ -108,9 +108,9 @@ from tvm import relay
 # Load the pretrained TFLite model from a file in your current
 # directory into a buffer
 
-model_url = 'https://people.linaro.org/~tom.gall/sine_model.tflite'
-model_file = 'sine_model.tflite'
-model_path = download_testdata(model_url, model_file, module='data')
+model_url = "https://people.linaro.org/~tom.gall/sine_model.tflite"
+model_file = "sine_model.tflite"
+model_path = download_testdata(model_url, model_file, module="data")
 
 tflite_model_buf = open(model_path, "rb").read()
 
@@ -118,15 +118,17 @@ tflite_model_buf = open(model_path, "rb").read()
 # Using the buffer, transform into a tflite model python object
 try:
     import tflite
+
     tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
 except AttributeError:
     import tflite.Model
+
     tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
 
 ######################################################################
 # Print out the version of the model
 version = tflite_model.Version()
-print ("Model Version: " + str(version))
+print("Model Version: " + str(version))
 
 ######################################################################
 # Parse the python model object to convert it into a relay module
@@ -137,14 +139,14 @@ print ("Model Version: " + str(version))
 # If you are unsure what that might be, this can be discovered by using
 # the visualize.py script within the Tensorflow project.
 # See : How do I inspect a .tflite file? `<https://www.tensorflow.org/lite/guide/faq>`_
+
 input_tensor = "dense_4_input"
 input_shape = (1,)
 input_dtype = "float32"
 
-mod, params = relay.frontend.from_tflite(tflite_model,
-                                         shape_dict={input_tensor: input_shape},
-                                         dtype_dict={input_tensor: input_dtype})
+mod, params = relay.frontend.from_tflite(
+    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
+)
 
 # %%
 # Running on device
@@ -152,7 +154,7 @@ mod, params = relay.frontend.from_tflite(tflite_model,
 #
 # Setup the device config which is what will be used to communicate
 # with the microcontroller (a STM32F746 Discovery board)
-TARGET = 'c --system-lib  --runtime=c'
+TARGET = "c --system-lib  --runtime=c"
 dev_config = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
 
 ######################################################################
index 025e53e..f50d302 100644 (file)
@@ -54,28 +54,31 @@ pad = 1
 stride = 1
 
 # Algorithm
-A = te.placeholder((in_size, in_size, in_channel, batch), name='A')
-W = te.placeholder((kernel, kernel, in_channel, out_channel), name='W')
-out_size = (in_size - kernel + 2*pad) // stride + 1
+A = te.placeholder((in_size, in_size, in_channel, batch), name="A")
+W = te.placeholder((kernel, kernel, in_channel, out_channel), name="W")
+out_size = (in_size - kernel + 2 * pad) // stride + 1
 # Pad input
 Apad = te.compute(
-    (in_size + 2*pad, in_size + 2*pad, in_channel, batch),
+    (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch),
     lambda yy, xx, cc, nn: tvm.tir.if_then_else(
-        tvm.tir.all(yy >= pad, yy - pad < in_size,
-                xx >= pad, xx - pad < in_size),
-        A[yy - pad, xx - pad, cc, nn], tvm.tir.const(0., "float32")),
-    name='Apad')
+        tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size),
+        A[yy - pad, xx - pad, cc, nn],
+        tvm.tir.const(0.0, "float32"),
+    ),
+    name="Apad",
+)
 # Create reduction variables
-rc = te.reduce_axis((0, in_channel), name='rc')
-ry = te.reduce_axis((0, kernel), name='ry')
-rx = te.reduce_axis((0, kernel), name='rx')
+rc = te.reduce_axis((0, in_channel), name="rc")
+ry = te.reduce_axis((0, kernel), name="ry")
+rx = te.reduce_axis((0, kernel), name="rx")
 # Compute the convolution
 B = te.compute(
     (out_size, out_size, out_channel, batch),
     lambda yy, xx, ff, nn: te.sum(
-        Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff],
-        axis=[ry, rx, rc]),
-    name='B')
+        Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]
+    ),
+    name="B",
+)
 
 
 ###############################################################################
@@ -103,8 +106,8 @@ B = te.compute(
 
 # Designate the memory hierarchy
 s = te.create_schedule(B.op)
-s[Apad].compute_inline() # compute Apad inline
-AA = s.cache_read(Apad, 'shared', [B])
+s[Apad].compute_inline()  # compute Apad inline
+AA = s.cache_read(Apad, "shared", [B])
 WW = s.cache_read(W, "shared", [B])
 AL = s.cache_read(AA, "local", [B])
 WL = s.cache_read(WW, "local", [B])
@@ -234,7 +237,7 @@ s[WW].vectorize(fi)  # vectorize memory load
 # latency of convolution.
 #
 
-func = tvm.build(s, [A, W, B], 'cuda')
+func = tvm.build(s, [A, W, B], "cuda")
 ctx = tvm.gpu(0)
 a_np = np.random.uniform(size=(in_size, in_size, in_channel, batch)).astype(A.dtype)
 w_np = np.random.uniform(size=(kernel, kernel, in_channel, out_channel)).astype(W.dtype)
@@ -243,4 +246,4 @@ w = tvm.nd.array(w_np, ctx)
 b = tvm.nd.array(np.zeros((out_size, out_size, out_channel, batch), dtype=B.dtype), ctx)
 func(a, w, b)
 evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
-print('Convolution: %f ms' % (evaluator(a, w, b).mean * 1e3))
+print("Convolution: %f ms" % (evaluator(a, w, b).mean * 1e3))
index 4b2823c..0cbcf7e 100644 (file)
@@ -72,55 +72,72 @@ stride_w = 1
 # TensorCore shape
 block_size = 16
 
-assert (batch_size % block_size == 0)
-assert (in_channels % block_size == 0)
-assert (out_channels % block_size == 0)
+assert batch_size % block_size == 0
+assert in_channels % block_size == 0
+assert out_channels % block_size == 0
 
 # Input feature map: (N, H, W, IC, n, ic)
-data_shape = (batch_size // block_size,
-              height,
-              width,
-              in_channels // block_size,
-              block_size,
-              block_size)
+data_shape = (
+    batch_size // block_size,
+    height,
+    width,
+    in_channels // block_size,
+    block_size,
+    block_size,
+)
 # Kernel: (H, W, IC, OC, ic, oc)
-kernel_shape = (kernel_h,
-                kernel_w,
-                in_channels // block_size,
-                out_channels // block_size,
-                block_size,
-                block_size)
+kernel_shape = (
+    kernel_h,
+    kernel_w,
+    in_channels // block_size,
+    out_channels // block_size,
+    block_size,
+    block_size,
+)
 # Output feature map: (N, H, W, OC, n, oc)
-output_shape = (batch_size // block_size,
-                height,
-                width,
-                out_channels // block_size,
-                block_size,
-                block_size)
+output_shape = (
+    batch_size // block_size,
+    height,
+    width,
+    out_channels // block_size,
+    block_size,
+    block_size,
+)
 
 # Reduction axes
-kh = te.reduce_axis((0, kernel_h), name='kh')
-kw = te.reduce_axis((0, kernel_w), name='kw')
-ic = te.reduce_axis((0, in_channels // block_size), name='ic')
-ii = te.reduce_axis((0, block_size), name='ii')
+kh = te.reduce_axis((0, kernel_h), name="kh")
+kw = te.reduce_axis((0, kernel_w), name="kw")
+ic = te.reduce_axis((0, in_channels // block_size), name="ic")
+ii = te.reduce_axis((0, block_size), name="ii")
 
 # Algorithm
-A = te.placeholder(data_shape, name='A', dtype="float16")
-W = te.placeholder(kernel_shape, name='W', dtype="float16")
+A = te.placeholder(data_shape, name="A", dtype="float16")
+W = te.placeholder(kernel_shape, name="W", dtype="float16")
 Apad = te.compute(
-    (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
-     block_size),
+    (
+        batch_size // block_size,
+        height + 2 * pad_h,
+        width + 2 * pad_w,
+        in_channels // block_size,
+        block_size,
+        block_size,
+    ),
     lambda n, h, w, i, nn, ii: tvm.tir.if_then_else(
-        tvm.tir.all(h >= pad_h, h - pad_h < height,
-                w >= pad_w, w - pad_w < width),
-        A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.tir.const(0., "float16")),
-    name='Apad')
-Conv = te.compute(output_shape,
-                   lambda n, h, w, o, nn, oo: te.sum(
-                       Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
-                       W[kh, kw, ic, o, ii, oo].astype("float32"),
-                       axis=[ic, kh, kw, ii]),
-                   name="Conv")
+        tvm.tir.all(h >= pad_h, h - pad_h < height, w >= pad_w, w - pad_w < width),
+        A[n, h - pad_h, w - pad_w, i, nn, ii],
+        tvm.tir.const(0.0, "float16"),
+    ),
+    name="Apad",
+)
+Conv = te.compute(
+    output_shape,
+    lambda n, h, w, o, nn, oo: te.sum(
+        Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32")
+        * W[kh, kw, ic, o, ii, oo].astype("float32"),
+        axis=[ic, kh, kw, ii],
+    ),
+    name="Conv",
+)
 
 s = te.create_schedule(Conv.op)
 s[Apad].compute_inline()
@@ -134,11 +151,11 @@ s[Apad].compute_inline()
 # stores at the on-chip registers level, the same place with local memory.
 
 # Designate the memory hierarchy
-AS = s.cache_read(Apad, 'shared', [Conv])
-WS = s.cache_read(W, 'shared', [Conv])
-AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
-WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
-ConvF = s.cache_write(Conv, 'wmma.accumulator')
+AS = s.cache_read(Apad, "shared", [Conv])
+WS = s.cache_read(W, "shared", [Conv])
+AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
+WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
+ConvF = s.cache_write(Conv, "wmma.accumulator")
 
 ###############################################################################
 # Define Tensor Intrinsic
@@ -151,11 +168,12 @@ ConvF = s.cache_write(Conv, 'wmma.accumulator')
 # :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync`
 # are both used in matrix multiplication, so we can just write following three intrinsics.
 
+
 def intrin_wmma_load_matrix(scope):
     n = 16
-    A = te.placeholder((n, n), name='A', dtype='float16')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
-    C = te.compute((n, n), lambda i, j: A[i, j], name='C')
+    A = te.placeholder((n, n), name="A", dtype="float16")
+    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=256)
+    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
     BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
 
     def intrin_func(ins, outs):
@@ -163,9 +181,20 @@ def intrin_wmma_load_matrix(scope):
 
         BA = ins[0]
         BC = outs[0]
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
-                                BC.data, n, n, n, BC.elem_offset // 256,
-                                BA.access_ptr('r'), n, 'row_major'))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_load_matrix_sync",
+                BC.data,
+                n,
+                n,
+                n,
+                BC.elem_offset // 256,
+                BA.access_ptr("r"),
+                n,
+                "row_major",
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
@@ -173,33 +202,53 @@ def intrin_wmma_load_matrix(scope):
 
 def intrin_wmma_gemm():
     n = 16
-    A = te.placeholder((n, n), name='A', dtype='float16')
-    B = te.placeholder((n, n), name='B', dtype='float16')
+    A = te.placeholder((n, n), name="A", dtype="float16")
+    B = te.placeholder((n, n), name="B", dtype="float16")
     k = te.reduce_axis((0, n), name="k")
-    C = te.compute((n, n),
-                    lambda ii, jj:
-                    te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
-                    name='C')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
-    BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)
+    C = te.compute(
+        (n, n),
+        lambda ii, jj: te.sum(A[ii, k].astype("float") * B[k, jj].astype("float"), axis=k),
+        name="C",
+    )
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, name="BA", scope="wmma.matrix_a", data_alignment=32, offset_factor=256
+    )
+    BB = tvm.tir.decl_buffer(
+        B.shape, B.dtype, name="BB", scope="wmma.matrix_b", data_alignment=32, offset_factor=256
+    )
+    BC = tvm.tir.decl_buffer(
+        C.shape, C.dtype, name="BC", scope="wmma.accumulator", data_alignment=32, offset_factor=256
+    )
 
     def intrin_func(ins, outs):
         BA, BB = ins
-        BC, = outs
+        (BC,) = outs
 
         def init():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
+            ib.emit(
+                tvm.tir.call_intrin(
+                    "handle", "tir.tvm_fill_fragment", BC.data, n, n, n, BC.elem_offset // 256, 0.0
+                )
+            )
             return ib.get()
 
         def update():
             ib = tvm.tir.ir_builder.create()
-            ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
-                                    BC.data, BC.elem_offset // 256,
-                                    BA.data, BA.elem_offset // 256,
-                                    BB.data, BB.elem_offset // 256,
-                                    BC.data, BC.elem_offset // 256))
+            ib.emit(
+                tvm.tir.call_intrin(
+                    "handle",
+                    "tir.tvm_mma_sync",
+                    BC.data,
+                    BC.elem_offset // 256,
+                    BA.data,
+                    BA.elem_offset // 256,
+                    BB.data,
+                    BB.elem_offset // 256,
+                    BC.data,
+                    BC.elem_offset // 256,
+                )
+            )
             return ib.get()
 
         return update(), init(), update()
@@ -209,22 +258,36 @@ def intrin_wmma_gemm():
 
 def intrin_wmma_store_matrix():
     n = 16
-    A = te.placeholder((n, n), name='A', dtype='float32')
-    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
-    C = te.compute((n, n), lambda i, j: A[i, j], name='C')
-    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)
+    A = te.placeholder((n, n), name="A", dtype="float32")
+    BA = tvm.tir.decl_buffer(
+        A.shape, A.dtype, scope="wmma.accumulator", data_alignment=32, offset_factor=256
+    )
+    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
+    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", data_alignment=32, offset_factor=256)
 
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
         BA = ins[0]
         BC = outs[0]
-        ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
-                                BA.data, n, n, n, BA.elem_offset // 256,
-                                BC.access_ptr('w'), n, 'row_major'))
+        ib.emit(
+            tvm.tir.call_intrin(
+                "handle",
+                "tir.tvm_store_matrix_sync",
+                BA.data,
+                n,
+                n,
+                n,
+                BA.elem_offset // 256,
+                BC.access_ptr("w"),
+                n,
+                "row_major",
+            )
+        )
         return ib.get()
 
     return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
 
+
 ###############################################################################
 # Scheduling the Computation
 # --------------------------
@@ -256,12 +319,12 @@ warp_col_tiles = 4
 warp_size = 32
 chunk = 2
 
-block_x = te.thread_axis('blockIdx.x')
-block_y = te.thread_axis('blockIdx.y')
-block_z = te.thread_axis('blockIdx.z')
-thread_x = te.thread_axis('threadIdx.x')
-thread_y = te.thread_axis('threadIdx.y')
-thread_z = te.thread_axis('threadIdx.z')
+block_x = te.thread_axis("blockIdx.x")
+block_y = te.thread_axis("blockIdx.y")
+block_z = te.thread_axis("blockIdx.z")
+thread_x = te.thread_axis("threadIdx.x")
+thread_y = te.thread_axis("threadIdx.y")
+thread_z = te.thread_axis("threadIdx.z")
 
 nc, hc, wc, oc, nnc, ooc = Conv.op.axis
 block_k = s[Conv].fuse(hc, wc)
@@ -316,8 +379,8 @@ print(tvm.lower(s, [A, W, Conv], simple_mode=True))
 # The last phase is to lower the computation loops down to TensorCore hardware intrinsics
 # by mapping the 2D convolution to tensor intrinsics
 
-s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
-s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
+s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_a"))
+s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_b"))
 s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
 s[ConvF].tensorize(nnf, intrin_wmma_gemm())
 print(tvm.lower(s, [A, W, Conv], simple_mode=True))
@@ -331,17 +394,15 @@ print(tvm.lower(s, [A, W, Conv], simple_mode=True))
 
 ctx = tvm.gpu(0)
 if nvcc.have_tensorcore(ctx.compute_version):
-    with tvm.transform.PassContext(config={"tir.UnrollLoop": {
-        "auto_max_step": 16
-    }}):
-        func = tvm.build(s, [A, W, Conv], 'cuda')
+    with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
+        func = tvm.build(s, [A, W, Conv], "cuda")
     a_np = np.random.uniform(size=data_shape).astype(A.dtype)
     w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
     a = tvm.nd.array(a_np, ctx)
     w = tvm.nd.array(w_np, ctx)
     c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
     evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-    print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
+    print("conv2d with tensor core: %f ms" % (evaluator(a, w, c).mean * 1e3))
 
 ###############################################################################
 # Summary
index daca89b..96a75ac 100644 (file)
@@ -73,7 +73,7 @@ dtype = "float32"
 # using Intel AVX2(Advanced Vector Extensions) ISA for SIMD
 # To get the best performance, please change the following line
 # to llvm -mcpu=core-avx2, or specific type of CPU you use
-target = 'llvm'
+target = "llvm"
 ctx = tvm.context(target, 0)
 
 # Random generated tensor for testing
@@ -81,31 +81,30 @@ a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), ctx)
 b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), ctx)
 
 np_repeat = 100
-np_runing_time = timeit.timeit(setup='import numpy\n'
-                                     'M = ' + str(M) + '\n'
-                                     'K = ' + str(K) + '\n'
-                                     'N = ' + str(N) + '\n'
-                                     'dtype = "float32"\n'
-                                     'a = numpy.random.rand(M, K).astype(dtype)\n'
-                                     'b = numpy.random.rand(K, N).astype(dtype)\n',
-                               stmt='answer = numpy.dot(a, b)',
-                               number=np_repeat)
+np_runing_time = timeit.timeit(
+    setup="import numpy\n"
+    "M = " + str(M) + "\n"
+    "K = " + str(K) + "\n"
+    "N = " + str(N) + "\n"
+    'dtype = "float32"\n'
+    "a = numpy.random.rand(M, K).astype(dtype)\n"
+    "b = numpy.random.rand(K, N).astype(dtype)\n",
+    stmt="answer = numpy.dot(a, b)",
+    number=np_repeat,
+)
 print("Numpy running time: %f" % (np_runing_time / np_repeat))
 
 answer = numpy.dot(a.asnumpy(), b.asnumpy())
 
 # Algorithm
-k = te.reduce_axis((0, K), 'k')
-A = te.placeholder((M, K), name='A')
-B = te.placeholder((K, N), name='B')
-C = te.compute(
-           (M, N),
-           lambda x, y: te.sum(A[x, k] * B[k, y], axis=k),
-           name='C')
+k = te.reduce_axis((0, K), "k")
+A = te.placeholder((M, K), name="A")
+B = te.placeholder((K, N), name="B")
+C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
 
 # Default schedule
 s = te.create_schedule(C.op)
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
 c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
@@ -113,7 +112,7 @@ func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
-print('Baseline: %f' % evaluator(a, b, c).mean)
+print("Baseline: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # In TVM, we can always inspect lower level IR to debug or optimize our schedule.
@@ -134,23 +133,23 @@ s = te.create_schedule(C.op)
 
 # Blocking by loop tiling
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
-k, = s[C].op.reduce_axis
+(k,) = s[C].op.reduce_axis
 ko, ki = s[C].split(k, factor=4)
 
 # Hoist reduction domain outside the blocking loop
 s[C].reorder(xo, yo, ko, ki, xi, yi)
 
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 # By simply tiling the loop 32x32, and hoisting ko, ki outside the blocking loops,
 # we can see big speedup compared with the baseline.
 evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-print('Opt1: %f' % evaluator(a, b, c).mean)
+print("Opt1: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # Here is the generated IR after blocking.
@@ -168,7 +167,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 
 s = te.create_schedule(C.op)
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
-k, = s[C].op.reduce_axis
+(k,) = s[C].op.reduce_axis
 ko, ki = s[C].split(k, factor=4)
 
 s[C].reorder(xo, yo, ko, ki, xi, yi)
@@ -176,15 +175,15 @@ s[C].reorder(xo, yo, ko, ki, xi, yi)
 # Vectorization
 s[C].vectorize(yi)
 
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-print('Opt2: %f' % evaluator(a, b, c).mean)
+print("Opt2: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # Here is the generated IR after vectorization.
@@ -202,22 +201,22 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 
 s = te.create_schedule(C.op)
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
-k, = s[C].op.reduce_axis
+(k,) = s[C].op.reduce_axis
 ko, ki = s[C].split(k, factor=4)
 
 # re-ordering
 s[C].reorder(xo, yo, ko, xi, ki, yi)
 s[C].vectorize(yi)
 
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-print('Opt3: %f' % evaluator(a, b, c).mean)
+print("Opt3: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # Here is the generated IR after loop permutation.
@@ -245,15 +244,17 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 #
 
 # We have to re-write the algorithm slightly.
-packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
-C = te.compute((M, N),
-                lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
-                name = 'C')
+packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB")
+C = te.compute(
+    (M, N),
+    lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
+    name="C",
+)
 
 s = te.create_schedule(C.op)
 
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
-k, = s[C].op.reduce_axis
+(k,) = s[C].op.reduce_axis
 ko, ki = s[C].split(k, factor=4)
 
 s[C].reorder(xo, yo, ko, xi, ki, yi)
@@ -263,15 +264,15 @@ x, y, z = s[packedB].op.axis
 s[packedB].vectorize(z)
 s[packedB].parallel(x)
 
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-print('Opt4: %f' % evaluator(a, b, c).mean)
+print("Opt4: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # Here is the generated IR after array packing.
@@ -289,7 +290,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 s = te.create_schedule(C.op)
 
 # Allocate write cache
-CC = s.cache_write(C, 'global')
+CC = s.cache_write(C, "global")
 
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
 
@@ -299,7 +300,7 @@ s[CC].compute_at(s[C], yo)
 # New inner axes
 xc, yc = s[CC].op.axis
 
-k, = s[CC].op.reduce_axis
+(k,) = s[CC].op.reduce_axis
 ko, ki = s[CC].split(k, factor=4)
 s[CC].reorder(ko, xc, ki, yc)
 s[CC].unroll(ki)
@@ -309,15 +310,15 @@ x, y, z = s[packedB].op.axis
 s[packedB].vectorize(z)
 s[packedB].parallel(x)
 
-func = tvm.build(s, [A, B, C], target=target, name='mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
-print('Opt5: %f' % evaluator(a, b, c).mean)
+print("Opt5: %f" % evaluator(a, b, c).mean)
 
 ################################################################################################
 # Here is the generated IR after blocking.
@@ -331,7 +332,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 
 s = te.create_schedule(C.op)
 
-CC = s.cache_write(C, 'global')
+CC = s.cache_write(C, "global")
 
 xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
 
@@ -339,7 +340,7 @@ s[CC].compute_at(s[C], yo)
 
 xc, yc = s[CC].op.axis
 
-k, = s[CC].op.reduce_axis
+(k,) = s[CC].op.reduce_axis
 ko, ki = s[CC].split(k, factor=4)
 s[CC].reorder(ko, xc, ki, yc)
 s[CC].unroll(ki)
@@ -352,16 +353,16 @@ x, y, z = s[packedB].op.axis
 s[packedB].vectorize(z)
 s[packedB].parallel(x)
 
-func = tvm.build(s, [A, B, C], target=target, name = 'mmult')
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
 assert func
 
-c = tvm.nd.array(numpy.zeros((M, N), dtype = dtype), ctx)
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
 func(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
 
 evaluator = func.time_evaluator(func.entry_name, ctx, number=50)
 opt6_time = evaluator(a, b, c).mean
-print('Opt6: %f' % opt6_time)
+print("Opt6: %f" % opt6_time)
 
 ################################################################################################
 # Here is the generated IR after parallelization.
index 7ca04a6..45a0d48 100644 (file)
@@ -51,22 +51,32 @@ from tvm import te
 from tvm import autotvm
 from tvm.contrib import nvcc
 
-def matmul_nn(A, B, L, dtype='float16', layout='NN'):
-    k = te.reduce_axis((0, L), name='k')
-    if dtype == 'float16':
-      out_type = 'float'
-    elif dtype == 'int8':
-      out_type = 'int'
-    elif dtype == 'int4' or dtype == 'int1':
-      out_type = 'int'
-    if (layout == 'NN'):
-      return te.compute((N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
-    if (layout == 'NT'):
-      return te.compute((N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k))
-    if (layout == 'TN'):
-      return te.compute((N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k))
-    if (layout == 'TT'):
-      return te.compute((N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k))
+
+def matmul_nn(A, B, L, dtype="float16", layout="NN"):
+    k = te.reduce_axis((0, L), name="k")
+    if dtype == "float16":
+        out_type = "float"
+    elif dtype == "int8":
+        out_type = "int"
+    elif dtype == "int4" or dtype == "int1":
+        out_type = "int"
+    if layout == "NN":
+        return te.compute(
+            (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k)
+        )
+    if layout == "NT":
+        return te.compute(
+            (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k)
+        )
+    if layout == "TN":
+        return te.compute(
+            (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k)
+        )
+    if layout == "TT":
+        return te.compute(
+            (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k)
+        )
+
 
 ###############################################################################
 # Scheduling the Computation
@@ -95,25 +105,26 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
 #
 # We use AutoTVM to search for best configurations in this schedule.
 
+
 @autotvm.template("tutorial/auto_tensorcore/test_gemm")
 def test_gemm(N, L, M, dtype, layout):
-    if (layout == "NN"):
-      shape_a = (N, L)
-      shape_b = (L, M)
-    elif (layout == "NT"):
-      shape_a = (L, N)
-      shape_b = (L, M)
-    elif (layout == "TN"):
-      shape_a = (N, L)
-      shape_b = (M, L)
-    elif (layout == "TT"):
-      shape_a = (L, N)
-      shape_b = (M, L)
+    if layout == "NN":
+        shape_a = (N, L)
+        shape_b = (L, M)
+    elif layout == "NT":
+        shape_a = (L, N)
+        shape_b = (L, M)
+    elif layout == "TN":
+        shape_a = (N, L)
+        shape_b = (M, L)
+    elif layout == "TT":
+        shape_a = (L, N)
+        shape_b = (M, L)
     else:
-      print ("Unsupported layout:", layout)
-      sys.exit(1);
-    A = te.placeholder(shape_a, name='A', dtype=dtype)
-    B = te.placeholder(shape_b, name='B', dtype=dtype)
+        print("Unsupported layout:", layout)
+        sys.exit(1)
+    A = te.placeholder(shape_a, name="A", dtype=dtype)
+    B = te.placeholder(shape_b, name="B", dtype=dtype)
     C = matmul_nn(A, B, L, dtype, layout)
 
     s = te.create_schedule(C.op)
@@ -123,53 +134,53 @@ def test_gemm(N, L, M, dtype, layout):
     # storage_align params
     factor = 16
     offset = 8
-    if dtype == 'int8':
-      factor = 32
-      offset = 16
-    elif dtype == 'int4':
-      factor = 64
-      offset = 32
-    elif dtype == 'int1':
-      factor = 256
-      offset = 128
+    if dtype == "int8":
+        factor = 32
+        offset = 16
+    elif dtype == "int4":
+        factor = 64
+        offset = 32
+    elif dtype == "int1":
+        factor = 256
+        offset = 128
 
     # create cache stages
     AA = s.cache_read(A, "shared", [C])
-    if (layout == "NN" or layout == "TN"):
-      s[AA].storage_align(AA.op.axis[0], factor, offset)
+    if layout == "NN" or layout == "TN":
+        s[AA].storage_align(AA.op.axis[0], factor, offset)
     AL = s.cache_read(AA, "local", [C])
     BB = s.cache_read(B, "shared", [C])
-    if (layout == "TT" or layout == "NT"):
-      s[BB].storage_align(BB.op.axis[0], factor, offset)
+    if layout == "TT" or layout == "NT":
+        s[BB].storage_align(BB.op.axis[0], factor, offset)
     BL = s.cache_read(BB, "local", [C])
     CL = s.cache_write(C, "local")
 
-    #autotvm search space definition
+    # autotvm search space definition
     cfg = autotvm.get_config()
 
     cfg.define_knob("bx", [2, 4, 8])
     cfg.define_knob("by", [8, 16, 32, 64])
     cfg.define_knob("step_k", [1, 2, 4, 8, 16, 32])
     cfg.define_knob("v", [4, 8, 16, 32])
-    by = cfg['by'].val
-    bx = cfg['bx'].val
-    step_k = cfg['step_k'].val
-    v = cfg['v'].val
+    by = cfg["by"].val
+    bx = cfg["bx"].val
+    step_k = cfg["step_k"].val
+    v = cfg["v"].val
 
     # thread tile
     TX = 8
     TY = 1
-    if dtype == 'int4' or dtype == 'int1':
-      TX = 2
+    if dtype == "int4" or dtype == "int1":
+        TX = 2
     # warp tile
-    warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
-    warp_tile_k = 16 # it must be 16 for fp16/int8 data type
-    if dtype == 'int4':
-      warp_tile_m = 8
-      warp_tile_k = 32
-    elif dtype == 'int1':
-      warp_tile_m = 8
-      warp_tile_k = 128
+    warp_tile_m = 16  # it could also be 8 or 32 on CUDA version >= 10.0
+    warp_tile_k = 16  # it must be 16 for fp16/int8 data type
+    if dtype == "int4":
+        warp_tile_m = 8
+        warp_tile_k = 32
+    elif dtype == "int1":
+        warp_tile_m = 8
+        warp_tile_k = 128
     # block tile
     tile_x = bx * TX
     tile_y = by * TY
@@ -198,8 +209,8 @@ def test_gemm(N, L, M, dtype, layout):
 
     # schedule for AA stage
     s[AA].compute_at(s[CL], ko)
-    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
-    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v)
+    tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
     tx, vec = s[AA].split(tx, factor=v)
     fused = s[AA].fuse(s[AA].op.axis[0], xo)
     _, ty = s[AA].split(fused, factor=by)
@@ -211,8 +222,8 @@ def test_gemm(N, L, M, dtype, layout):
 
     # schedule for BB stage
     s[BB].compute_at(s[CL], ko)
-    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
-    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v)
+    tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
     tx, vec = s[BB].split(tx, factor=v)
     fused = s[BB].fuse(s[BB].op.axis[0], xo)
     _, ty = s[BB].split(fused, factor=by)
@@ -225,10 +236,11 @@ def test_gemm(N, L, M, dtype, layout):
     s[BL].compute_at(s[CL], kl)
 
     # set the 'tensor_core' pragma for tensorcore codegen
-    s[CL].pragma(ko, 'tensor_core')
+    s[CL].pragma(ko, "tensor_core")
 
     return s, [A, B, C]
 
+
 ###############################################################################
 # AutoTune and Test
 # -----------------
@@ -237,148 +249,151 @@ def test_gemm(N, L, M, dtype, layout):
 
 # check whether the gpu has tensorcore
 if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-  raise Exception("skip building this tutorial because cuda is not enabled..")
+    raise Exception("skip building this tutorial because cuda is not enabled..")
 
 ctx = tvm.gpu()
 if not nvcc.have_tensorcore(ctx.compute_version):
-  raise Exception("the gpu has no tensorcore, skipping...")
+    raise Exception("the gpu has no tensorcore, skipping...")
 
 M, N, L = 512, 32, 512
-dtype = 'float16'
-layout = 'NN'
+dtype = "float16"
+layout = "NN"
 if len(sys.argv) >= 4:
-  M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
+    M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
 if len(sys.argv) >= 5:
-  dtype = sys.argv[4]
+    dtype = sys.argv[4]
 if len(sys.argv) >= 6:
-  layout = sys.argv[5]
+    layout = sys.argv[5]
 
 # check whether current gpu arch support support current dtype's wmma codegen
 cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
-major, minor= nvcc.parse_compute_version(cuda_compute_capability)
-if dtype == 'int8':
-  assert(major == 7 and minor >= 2)
-elif dtype == 'int4' or dtype == 'int1':
-  # int4/int1 only support layout TN
-  assert(major == 7 and minor == 5 and layout == 'TN')
+major, minor = nvcc.parse_compute_version(cuda_compute_capability)
+if dtype == "int8":
+    assert major == 7 and minor >= 2
+elif dtype == "int4" or dtype == "int1":
+    # int4/int1 only support layout TN
+    assert major == 7 and minor == 5 and layout == "TN"
+
 
 def tune_and_evaluate(M, N, L, dtype, layout):
-  task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout),
-                             target='cuda')
-  print(task.config_space)
-
-  logging.getLogger('autotvm').setLevel(logging.DEBUG)
-  logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
-
-  measure_option = autotvm.measure_option(
-    builder='local',
-    runner=autotvm.LocalRunner(number=5))
-
-  tuner = autotvm.tuner.XGBTuner(task)
-  tuner.tune(n_trial=1000,
-             measure_option=measure_option,
-             callbacks=[autotvm.callback.log_to_file('matmul.log')])
-
-  dispatch_context = autotvm.apply_history_best("matmul.log")
-  best_config = dispatch_context.query(task.target, task.workload)
-  print("\nBest config:")
-  print(best_config)
-  with autotvm.apply_history_best('matmul.log'):
-    with tvm.target.Target("cuda"):
-          s, arg_bufs = test_gemm(N, L, M, dtype, layout)
-          print(tvm.lower(s, arg_bufs, simple_mode=True))
-          func = tvm.build(s, arg_bufs)
-  dev_module = func.imported_modules[0]
-  print(dev_module.get_source())
-
-  # check correctness
-  if (layout == "NN"):
-    shape_a = (N, L)
-    shape_b = (L, M)
-  elif (layout == "NT"):
-    shape_a = (L, N)
-    shape_b = (L, M)
-  elif (layout == "TN"):
-    shape_a = (N, L)
-    shape_b = (M, L)
-  elif (layout == "TT"):
-    shape_a = (L, N)
-    shape_b = (M, L)
-
-  a_np = None
-  b_np = None
-  c_np = None
-  c_np_type = None
-  if dtype == 'float16':
-    c_np_type = np.float32
-    a_np = np.random.uniform(size=shape_a).astype(np.float16)
-    b_np = np.random.uniform(size=shape_b).astype(np.float16)
-    if (layout == "NN"):
-      c_np = np.dot(a_np, b_np)
-    elif (layout == "NT"):
-      c_np = np.dot(a_np.T, b_np)
-    elif (layout == "TN"):
-      c_np = np.dot(a_np, b_np.T)
-    elif (layout == "TT"):
-      c_np = np.dot(a_np.T, b_np.T)
-  elif dtype == 'int8':
-    c_np_type = np.int32
-    a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8)
-    b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8)
-    if (layout == "NN"):
-      c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
-    elif (layout == "NT"):
-      c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
-    elif (layout == "TN"):
-      c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
-    elif (layout == "TT"):
-      c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
-  elif dtype == 'int4':
-    c_np_type = np.int32
-    a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
-    b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
-    # "TN"
-    c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
-    a_np = np.zeros(shape=(N, int(L/8)), dtype = np.int32)
-    b_np = np.zeros(shape=(M, int(L/8)), dtype = np.int32)
-    # a_np --> col_major
-    for i in range(N):
-      for j in range(int(L/8)):
-        for k in range(8):
-          a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
-
-    # b_np --> row_major
-    for i in range(M):
-      for j in range(int(L/8)):
-        for k in range(8):
-          b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
-  elif dtype == 'int1':
-    c_np_type = np.int32
-    a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
-    b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
-    # "TN"
-    c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
-    a_np = np.zeros(shape=(N, int(L/32)), dtype = np.int32)
-    b_np = np.zeros(shape=(M, int(L/32)), dtype = np.int32)
-    for i in range(N):
-      for j in range(int(L/32)):
-        for k in range(32):
-          a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xf) << (31 - k))
-
-    for i in range(M):
-      for j in range(int(L/32)):
-        for k in range(32):
-          b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xf) << (31 - k))
-
-  c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
-  a_tvm = tvm.nd.array(a_np, ctx=ctx)
-  b_tvm = tvm.nd.array(b_np, ctx=ctx)
-  func(a_tvm, b_tvm, c_tvm)
-
-  tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)
-
-  evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
-  print('Time cost of this operator: %f' % evaluator(a_tvm, b_tvm, c_tvm).mean)
+    task = autotvm.task.create(
+        "tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout), target="cuda"
+    )
+    print(task.config_space)
+
+    logging.getLogger("autotvm").setLevel(logging.DEBUG)
+    logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
+
+    measure_option = autotvm.measure_option(builder="local", runner=autotvm.LocalRunner(number=5))
+
+    tuner = autotvm.tuner.XGBTuner(task)
+    tuner.tune(
+        n_trial=1000,
+        measure_option=measure_option,
+        callbacks=[autotvm.callback.log_to_file("matmul.log")],
+    )
+
+    dispatch_context = autotvm.apply_history_best("matmul.log")
+    best_config = dispatch_context.query(task.target, task.workload)
+    print("\nBest config:")
+    print(best_config)
+    with autotvm.apply_history_best("matmul.log"):
+        with tvm.target.Target("cuda"):
+            s, arg_bufs = test_gemm(N, L, M, dtype, layout)
+            print(tvm.lower(s, arg_bufs, simple_mode=True))
+            func = tvm.build(s, arg_bufs)
+    dev_module = func.imported_modules[0]
+    print(dev_module.get_source())
+
+    # check correctness
+    if layout == "NN":
+        shape_a = (N, L)
+        shape_b = (L, M)
+    elif layout == "NT":
+        shape_a = (L, N)
+        shape_b = (L, M)
+    elif layout == "TN":
+        shape_a = (N, L)
+        shape_b = (M, L)
+    elif layout == "TT":
+        shape_a = (L, N)
+        shape_b = (M, L)
+
+    a_np = None
+    b_np = None
+    c_np = None
+    c_np_type = None
+    if dtype == "float16":
+        c_np_type = np.float32
+        a_np = np.random.uniform(size=shape_a).astype(np.float16)
+        b_np = np.random.uniform(size=shape_b).astype(np.float16)
+        if layout == "NN":
+            c_np = np.dot(a_np, b_np)
+        elif layout == "NT":
+            c_np = np.dot(a_np.T, b_np)
+        elif layout == "TN":
+            c_np = np.dot(a_np, b_np.T)
+        elif layout == "TT":
+            c_np = np.dot(a_np.T, b_np.T)
+    elif dtype == "int8":
+        c_np_type = np.int32
+        a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8)
+        b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8)
+        if layout == "NN":
+            c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
+        elif layout == "NT":
+            c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
+        elif layout == "TN":
+            c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
+        elif layout == "TT":
+            c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
+    elif dtype == "int4":
+        c_np_type = np.int32
+        a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
+        b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
+        # "TN"
+        c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
+        a_np = np.zeros(shape=(N, int(L / 8)), dtype=np.int32)
+        b_np = np.zeros(shape=(M, int(L / 8)), dtype=np.int32)
+        # a_np --> col_major
+        for i in range(N):
+            for j in range(int(L / 8)):
+                for k in range(8):
+                    a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4))
+
+        # b_np --> row_major
+        for i in range(M):
+            for j in range(int(L / 8)):
+                for k in range(8):
+                    b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4))
+    elif dtype == "int1":
+        c_np_type = np.int32
+        a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
+        b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
+        # "TN"
+        c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
+        a_np = np.zeros(shape=(N, int(L / 32)), dtype=np.int32)
+        b_np = np.zeros(shape=(M, int(L / 32)), dtype=np.int32)
+        for i in range(N):
+            for j in range(int(L / 32)):
+                for k in range(32):
+                    a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xF) << (31 - k))
+
+        for i in range(M):
+            for j in range(int(L / 32)):
+                for k in range(32):
+                    b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xF) << (31 - k))
+
+    c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
+    a_tvm = tvm.nd.array(a_np, ctx=ctx)
+    b_tvm = tvm.nd.array(b_np, ctx=ctx)
+    func(a_tvm, b_tvm, c_tvm)
+
+    tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)
+
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
+    print("Time cost of this operator: %f" % evaluator(a_tvm, b_tvm, c_tvm).mean)
+
 
 # We do not run the tuning in our webpage server since it takes some time.
 # Uncomment the following line to run it by yourself.
index 82d7892..c9812ff 100644 (file)
@@ -39,7 +39,7 @@ import numpy as np
 #
 n = te.var("n")
 m = te.var("m")
-A = te.placeholder((n, m), name='A')
+A = te.placeholder((n, m), name="A")
 k = te.reduce_axis((0, m), "k")
 B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
 s = te.create_schedule(B.op)
@@ -97,7 +97,7 @@ print(sg.stages)
 ######################################################################
 # We can test the correctness by comparing with :code:`numpy` result as follows
 #
-func = tvm.build(sg, [a, b, g], 'cuda')
+func = tvm.build(sg, [a, b, g], "cuda")
 ctx = tvm.gpu(0)
 a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
 b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
index 015ee24..acef8ad 100644 (file)
@@ -27,6 +27,7 @@ List of affected files:
 """
 import os
 import re
+
 # current version
 # We use the version of the incoming release for code
 # that is under development
@@ -62,15 +63,25 @@ def update(file_name, pattern, repl):
 def main():
     proj_root = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     # python path
-    update(os.path.join(proj_root, "python", "tvm", "_ffi", "libinfo.py"),
-           r"(?<=__version__ = \")[.0-9a-z]+", __version__)
+    update(
+        os.path.join(proj_root, "python", "tvm", "_ffi", "libinfo.py"),
+        r"(?<=__version__ = \")[.0-9a-z]+",
+        __version__,
+    )
     # C++ header
-    update(os.path.join(proj_root, "include", "tvm", "runtime", "c_runtime_api.h"),
-           "(?<=TVM_VERSION \")[.0-9a-z]+", __version__)
+    update(
+        os.path.join(proj_root, "include", "tvm", "runtime", "c_runtime_api.h"),
+        '(?<=TVM_VERSION ")[.0-9a-z]+',
+        __version__,
+    )
     # conda
     for path in ["tvm", "tvm-libs"]:
-        update(os.path.join(proj_root, "conda", path, "meta.yaml"),
-               "(?<=version = \")[.0-9a-z]+", __version__)
+        update(
+            os.path.join(proj_root, "conda", path, "meta.yaml"),
+            '(?<=version = ")[.0-9a-z]+',
+            __version__,
+        )
+
 
 if __name__ == "__main__":
     main()
index 3ee39ca..254243d 100644 (file)
@@ -31,6 +31,7 @@ else:
 # bitstream repo
 BITSTREAM_URL = "https://github.com/uwsaml/vta-distro/raw/master/bitstreams/"
 
+
 def get_bitstream_path():
     """Returns the path to the cached bitstream corresponding to the current config
 
@@ -45,7 +46,7 @@ def get_bitstream_path():
     # Derive destination path
     cache_dir = os.getenv("VTA_CACHE_PATH", os.path.join(os.getenv("HOME"), ".vta_cache/"))
     cache_dir = os.path.join(cache_dir, env.TARGET)
-    cache_dir = os.path.join(cache_dir, env.HW_VER.replace('.', '_'))
+    cache_dir = os.path.join(cache_dir, env.HW_VER.replace(".", "_"))
     # Create the directory if it didn't exist
     if not os.path.exists(cache_dir):
         os.makedirs(cache_dir)
@@ -55,8 +56,7 @@ def get_bitstream_path():
 
 
 def download_bitstream():
-    """Downloads a cached bitstream corresponding to the current config
-    """
+    """Downloads a cached bitstream corresponding to the current config"""
 
     env = get_env()
 
@@ -77,12 +77,16 @@ def download_bitstream():
 bistream has not been cached. Please compile your own bitstream (see hardware \
 compilation guide to get Xilinx toolchains setup) and add it to your \
 $VTA_CACHE_PATH. Alternatively edit your config.json back to its default \
-settings. You can see the list of available bitstreams under {}"
-                .format(url, BITSTREAM_URL))
+settings. You can see the list of available bitstreams under {}".format(
+                    url, BITSTREAM_URL
+                )
+            )
         raise RuntimeError(
             # This could happen when trying to access the URL behind a proxy
             "Something went wrong when trying to access {}. Check your \
-internet connection or proxy settings."
-            .format(url))
+internet connection or proxy settings.".format(
+                url
+            )
+        )
 
     return success
index 5becc5e..3b62edd 100644 (file)
@@ -23,13 +23,14 @@ from .environment import get_env
 
 def EarlyRewrite():
     """Try to do storage rewrite in early pass."""
+
     def _transform(mod, ctx):
         try:
             return tvm.tir.transform.StorageRewrite()(mod)
         except tvm.error.TVMError:
             return mod
-    return tvm.transform.module_pass(
-        _transform, opt_level=0, name="tir.vta.EarlyRewrite")
+
+    return tvm.transform.module_pass(_transform, opt_level=0, name="tir.vta.EarlyRewrite")
 
 
 def build_config(debug_flag=0, **kwargs):
@@ -60,32 +61,28 @@ def build_config(debug_flag=0, **kwargs):
 
     @tvm.tir.transform.prim_func_pass(opt_level=0)
     def add_debug(f, *_):
-        debug = tvm.tir.call_extern(
-            "int32", "VTASetDebugMode",
-            env.dev.command_handle,
-            debug_flag)
+        debug = tvm.tir.call_extern("int32", "VTASetDebugMode", env.dev.command_handle, debug_flag)
 
         return f.with_body(tvm.tir.stmt_seq(debug, f.body))
 
-
-    pass_list = [(0, transform.InjectConv2DTransposeSkip()),
-                 (1, transform.InjectDMAIntrin()),
-                 (1, transform.InjectSkipCopy()),
-                 (1, transform.AnnotateALUCoProcScope()),
-                 (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
-                 (1, transform.LiftAllocToScopeBegin()),
-                 (1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
-                 (1, transform.InjectCoProcSync()),
-                 (1, EarlyRewrite())]
+    pass_list = [
+        (0, transform.InjectConv2DTransposeSkip()),
+        (1, transform.InjectDMAIntrin()),
+        (1, transform.InjectSkipCopy()),
+        (1, transform.AnnotateALUCoProcScope()),
+        (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
+        (1, transform.LiftAllocToScopeBegin()),
+        (1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
+        (1, transform.InjectCoProcSync()),
+        (1, EarlyRewrite()),
+    ]
     if debug_flag:
         pass_list.append((1, add_debug))
     pass_list.append((2, transform.InjectALUIntrin()))
     pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
     pass_list.append((3, transform.FoldUopLoop()))
     pass_list.append((3, transform.CPUAccessRewrite()))
-    config = {
-        "tir.add_lower_pass": pass_list
-    }
+    config = {"tir.add_lower_pass": pass_list}
     if kwargs.get("config"):
         config.update(kwargs[config])
         del kwargs["config"]
index 3a18cf7..b334e01 100644 (file)
@@ -25,13 +25,15 @@ import tvm
 from tvm import te
 from . import intrin
 
+
 def get_vta_hw_path():
     """Get the VTA HW path."""
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     vta_hw_default = os.path.abspath(os.path.join(curr_path, "../../../3rdparty/vta-hw"))
-    VTA_HW_PATH = os.getenv('VTA_HW_PATH', vta_hw_default)
+    VTA_HW_PATH = os.getenv("VTA_HW_PATH", vta_hw_default)
     return os.path.abspath(VTA_HW_PATH)
 
+
 def pkg_config(cfg):
     """Returns PkgConfig pkg config object."""
     pkg_config_py = os.path.join(get_vta_hw_path(), "config/pkg_config.py")
@@ -40,6 +42,7 @@ def pkg_config(cfg):
     PkgConfig = libpkg["PkgConfig"]
     return PkgConfig(cfg)
 
+
 class DevContext(object):
     """Internal development context
 
@@ -56,6 +59,7 @@ class DevContext(object):
     This class is introduced so we have a clear separation
     of developer related, and user facing attributes.
     """
+
     # Memory id for DMA
     MEM_ID_UOP = 0
     MEM_ID_WGT = 1
@@ -78,8 +82,7 @@ class DevContext(object):
         self.vta_axis = te.thread_axis("vta")
         self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
         ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle")
-        self.command_handle = tvm.tir.Call(
-            "handle", "tir.tvm_thread_context", [ctx])
+        self.command_handle = tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
         self.DEBUG_NO_SYNC = False
         env._dev_ctx = self
         self.gemm = intrin.gemm(env, env.mock_mode)
@@ -111,14 +114,15 @@ class Environment(object):
           # env works on the new environment
           env = vta.get_env()
     """
+
     current = None
     # constants
     MAX_XFER = 1 << 22
     # debug flags
-    DEBUG_DUMP_INSN = (1 << 1)
-    DEBUG_DUMP_UOP = (1 << 2)
-    DEBUG_SKIP_READ_BARRIER = (1 << 3)
-    DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
+    DEBUG_DUMP_INSN = 1 << 1
+    DEBUG_DUMP_UOP = 1 << 2
+    DEBUG_SKIP_READ_BARRIER = 1 << 3
+    DEBUG_SKIP_WRITE_BARRIER = 1 << 4
     # memory scopes
     inp_scope = "local.inp_buffer"
     wgt_scope = "local.wgt_buffer"
@@ -145,18 +149,10 @@ class Environment(object):
         self.ACC_BUFF_SIZE = 1 << self.LOG_ACC_BUFF_SIZE
         self.OUT_BUFF_SIZE = 1 << self.LOG_OUT_BUFF_SIZE
         # bytes per buffer
-        self.INP_ELEM_BITS = (self.BATCH *
-                              self.BLOCK_IN *
-                              self.INP_WIDTH)
-        self.WGT_ELEM_BITS = (self.BLOCK_OUT *
-                              self.BLOCK_IN *
-                              self.WGT_WIDTH)
-        self.ACC_ELEM_BITS = (self.BATCH *
-                              self.BLOCK_OUT *
-                              self.ACC_WIDTH)
-        self.OUT_ELEM_BITS = (self.BATCH *
-                              self.BLOCK_OUT *
-                              self.OUT_WIDTH)
+        self.INP_ELEM_BITS = self.BATCH * self.BLOCK_IN * self.INP_WIDTH
+        self.WGT_ELEM_BITS = self.BLOCK_OUT * self.BLOCK_IN * self.WGT_WIDTH
+        self.ACC_ELEM_BITS = self.BATCH * self.BLOCK_OUT * self.ACC_WIDTH
+        self.OUT_ELEM_BITS = self.BATCH * self.BLOCK_OUT * self.OUT_WIDTH
         self.INP_ELEM_BYTES = self.INP_ELEM_BITS // 8
         self.WGT_ELEM_BYTES = self.WGT_ELEM_BITS // 8
         self.ACC_ELEM_BYTES = self.ACC_ELEM_BITS // 8
@@ -213,16 +209,12 @@ class Environment(object):
     @property
     def dma_copy(self):
         """DMA copy pragma"""
-        return ("dma_copy"
-                if not self.mock_mode
-                else "skip_dma_copy")
+        return "dma_copy" if not self.mock_mode else "skip_dma_copy"
 
     @property
     def alu(self):
         """ALU pragma"""
-        return ("alu"
-                if not self.mock_mode
-                else "skip_alu")
+        return "alu" if not self.mock_mode else "skip_alu"
 
     @property
     def gemm(self):
@@ -248,6 +240,7 @@ class Environment(object):
     def target_vta_cpu(self):
         return tvm.target.arm_cpu(model=self.TARGET)
 
+
 def get_env():
     """Get the current VTA Environment.
 
@@ -263,55 +256,63 @@ def get_env():
 @tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
 def mem_info_inp_buffer():
     spec = get_env()
-    return tvm.ir.make_node("MemoryInfo",
-                            unit_bits=spec.INP_ELEM_BITS,
-                            max_simd_bits=spec.INP_ELEM_BITS,
-                            max_num_bits=spec.INP_BUFF_SIZE * 8,
-                            head_address=None)
+    return tvm.ir.make_node(
+        "MemoryInfo",
+        unit_bits=spec.INP_ELEM_BITS,
+        max_simd_bits=spec.INP_ELEM_BITS,
+        max_num_bits=spec.INP_BUFF_SIZE * 8,
+        head_address=None,
+    )
+
 
 @tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
 def mem_info_wgt_buffer():
     spec = get_env()
-    return tvm.ir.make_node("MemoryInfo",
-                            unit_bits=spec.WGT_ELEM_BITS,
-                            max_simd_bits=spec.WGT_ELEM_BITS,
-                            max_num_bits=spec.WGT_BUFF_SIZE * 8,
-                            head_address=None)
+    return tvm.ir.make_node(
+        "MemoryInfo",
+        unit_bits=spec.WGT_ELEM_BITS,
+        max_simd_bits=spec.WGT_ELEM_BITS,
+        max_num_bits=spec.WGT_BUFF_SIZE * 8,
+        head_address=None,
+    )
+
 
 @tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
 def mem_info_acc_buffer():
     spec = get_env()
-    return tvm.ir.make_node("MemoryInfo",
-                            unit_bits=spec.ACC_ELEM_BITS,
-                            max_simd_bits=spec.ACC_ELEM_BITS,
-                            max_num_bits=spec.ACC_BUFF_SIZE * 8,
-                            head_address=None)
+    return tvm.ir.make_node(
+        "MemoryInfo",
+        unit_bits=spec.ACC_ELEM_BITS,
+        max_simd_bits=spec.ACC_ELEM_BITS,
+        max_num_bits=spec.ACC_BUFF_SIZE * 8,
+        head_address=None,
+    )
+
 
 # TVM related registration
 @tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
 def coproc_sync(op):
     _ = op
     return tvm.tir.call_extern(
-        "int32", "VTASynchronize",
+        "int32",
+        "VTASynchronize",
         get_env().dev.command_handle,
-        tvm.runtime.const(1<<31, dtype="uint32"))
-
+        tvm.runtime.const(1 << 31, dtype="uint32"),
+    )
 
 
 @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
 def coproc_dep_push(op):
     return tvm.tir.call_extern(
-        "int32", "VTADepPush",
-        get_env().dev.command_handle,
-        op.args[0], op.args[1])
+        "int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1]
+    )
 
 
 @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
 def coproc_dep_pop(op):
     return tvm.tir.call_extern(
-        "int32", "VTADepPop",
-        get_env().dev.command_handle,
-        op.args[0], op.args[1])
+        "int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
+    )
 
 
 def _init_env():
@@ -322,4 +323,5 @@ def _init_env():
     cfg = json.load(open(config_path))
     return Environment(cfg)
 
+
 Environment.current = _init_env()
index 220da43..cd51913 100644 (file)
@@ -38,8 +38,7 @@ from ..libinfo import find_libvta
 def server_start():
     """VTA RPC server extension."""
     # pylint: disable=unused-variable
-    curr_path = os.path.dirname(
-        os.path.abspath(os.path.expanduser(__file__)))
+    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
     dll_path = find_libvta("libvta")[0]
     cfg_path = os.path.abspath(os.path.join(proj_root, "3rdparty/vta-hw/config/vta_config.json"))
@@ -69,6 +68,7 @@ def server_start():
         env = get_env()
         if env.TARGET == "pynq":
             from pynq import xlnk
+
             # Reset xilinx driver
             xlnk.Xlnk().xlnk_reset()
         elif env.TARGET == "de10nano":
@@ -112,8 +112,13 @@ def server_start():
         ldflags = pkg.ldflags
         lib_name = dll_path
         source = pkg.lib_source
-        logging.info("Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
-                     dll_path, '\n\t'.join(cflags), '\n\t'.join(source), '\n\t'.join(ldflags))
+        logging.info(
+            "Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
+            dll_path,
+            "\n\t".join(cflags),
+            "\n\t".join(source),
+            "\n\t".join(ldflags),
+        )
         cc.create_shared(lib_name, source, cflags + ldflags)
         with open(cfg_path, "w") as outputfile:
             outputfile.write(pkg.cfg_json)
@@ -122,16 +127,13 @@ def server_start():
 def main():
     """Main funciton"""
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default="0.0.0.0",
-                        help='the hostname of the server')
-    parser.add_argument('--port', type=int, default=9091,
-                        help='The port of the RPC')
-    parser.add_argument('--port-end', type=int, default=9199,
-                        help='The end search port of the RPC')
-    parser.add_argument('--key', type=str, default="",
-                        help="RPC key used to identify the connection type.")
-    parser.add_argument('--tracker', type=str, default="",
-                        help="Report to RPC tracker")
+    parser.add_argument("--host", type=str, default="0.0.0.0", help="the hostname of the server")
+    parser.add_argument("--port", type=int, default=9091, help="The port of the RPC")
+    parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
+    parser.add_argument(
+        "--key", type=str, default="", help="RPC key used to identify the connection type."
+    )
+    parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker")
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
 
@@ -140,17 +142,15 @@ def main():
         port = int(port)
         tracker_addr = (url, port)
         if not args.key:
-            raise RuntimeError(
-                "Need key to present type of resource when tracker is available")
+            raise RuntimeError("Need key to present type of resource when tracker is available")
     else:
         tracker_addr = None
 
-    server = rpc.Server(args.host,
-                        args.port,
-                        args.port_end,
-                        key=args.key,
-                        tracker_addr=tracker_addr)
+    server = rpc.Server(
+        args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr
+    )
     server.proc.join()
 
+
 if __name__ == "__main__":
     main()
index 897bbcb..52bf586 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import te
 
+
 def gemm(env, mock=False):
     """Matrix-matrix multiply intrinsic
 
@@ -46,66 +47,93 @@ def gemm(env, mock=False):
     out_shape = (env.BATCH, env.BLOCK_OUT)
     assert out_shape[0] * out_shape[1] == out_lanes
 
-    wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
-                         dtype="int%d" % env.WGT_WIDTH,
-                         name=env.wgt_scope)
-    inp = te.placeholder((inp_shape[0], inp_shape[1]),
-                         dtype="int%d" % env.INP_WIDTH,
-                         name=env.inp_scope)
+    wgt = te.placeholder(
+        (wgt_shape[0], wgt_shape[1]), dtype="int%d" % env.WGT_WIDTH, name=env.wgt_scope
+    )
+    inp = te.placeholder(
+        (inp_shape[0], inp_shape[1]), dtype="int%d" % env.INP_WIDTH, name=env.inp_scope
+    )
     k = te.reduce_axis((0, wgt_shape[1]), name="k")
     out_dtype = "int%d" % env.ACC_WIDTH
-    out = te.compute((out_shape[0], out_shape[1]),
-                     lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
-                                         wgt[j, k].astype(out_dtype),
-                                         axis=[k]),
-                     name="out")
+    out = te.compute(
+        (out_shape[0], out_shape[1]),
+        lambda i, j: te.sum(inp[i, k].astype(out_dtype) * wgt[j, k].astype(out_dtype), axis=[k]),
+        name="out",
+    )
     wgt_layout = tvm.tir.decl_buffer(
-        wgt.shape, wgt.dtype, env.wgt_scope,
-        scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
+        wgt.shape,
+        wgt.dtype,
+        env.wgt_scope,
+        scope=env.wgt_scope,
+        offset_factor=wgt_lanes,
+        data_alignment=wgt_lanes,
+    )
     inp_layout = tvm.tir.decl_buffer(
-        inp.shape, inp.dtype, env.inp_scope,
-        scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
+        inp.shape,
+        inp.dtype,
+        env.inp_scope,
+        scope=env.inp_scope,
+        offset_factor=inp_lanes,
+        data_alignment=inp_lanes,
+    )
     out_layout = tvm.tir.decl_buffer(
-        out.shape, out.dtype, env.acc_scope,
-        scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
+        out.shape,
+        out.dtype,
+        env.acc_scope,
+        scope=env.acc_scope,
+        offset_factor=out_lanes,
+        data_alignment=out_lanes,
+    )
 
     def intrin_func(ins, outs):
         """Matrix-matrix multiply intrinsic function"""
         dinp, dwgt = ins
         dout = outs[0]
+
         def instr(index):
             """Generate matrix-matrix multiply VTA instruction"""
             irb = tvm.tir.ir_builder.create()
             dev = env.dev
-            irb.scope_attr(dev.vta_axis, "coproc_scope",
-                           dev.get_task_qid(dev.QID_COMPUTE))
-            irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
-                           dev.vta_push_uop)
+            irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+            irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
             if index in (0, 2):
-                irb.emit(tvm.tir.call_intrin(
-                    "int32", "tir.vta.uop_push",
-                    0, 0,
-                    dout.access_ptr("rw", "int32"),
-                    dinp.access_ptr("r", "int32"),
-                    dwgt.access_ptr("r", "int32"),
-                    0, 0, 0))
+                irb.emit(
+                    tvm.tir.call_intrin(
+                        "int32",
+                        "tir.vta.uop_push",
+                        0,
+                        0,
+                        dout.access_ptr("rw", "int32"),
+                        dinp.access_ptr("r", "int32"),
+                        dwgt.access_ptr("r", "int32"),
+                        0,
+                        0,
+                        0,
+                    )
+                )
             else:
-                irb.emit(tvm.tir.call_intrin(
-                    "int32", "tir.vta.uop_push",
-                    0, 1,
-                    dout.access_ptr("rw", "int32"),
-                    0,
-                    0,
-                    0, 0, 0))
+                irb.emit(
+                    tvm.tir.call_intrin(
+                        "int32",
+                        "tir.vta.uop_push",
+                        0,
+                        1,
+                        dout.access_ptr("rw", "int32"),
+                        0,
+                        0,
+                        0,
+                        0,
+                        0,
+                    )
+                )
             return irb.get()
+
         # return a triple of normal-set, reset, update
         nop = tvm.tir.Evaluate(0)
         if mock:
             return (nop, nop, nop)
         return (instr(0), instr(1), instr(2))
 
-    return te.decl_tensor_intrin(out.op, intrin_func,
-                                 name="GEMM",
-                                 binds={inp: inp_layout,
-                                        wgt: wgt_layout,
-                                        out: out_layout})
+    return te.decl_tensor_intrin(
+        out.op, intrin_func, name="GEMM", binds={inp: inp_layout, wgt: wgt_layout, out: out_layout}
+    )
index 3816e47..8b300c3 100644 (file)
@@ -21,6 +21,7 @@ import os
 
 from .environment import get_vta_hw_path
 
+
 def _get_lib_name(lib_name):
     """Get lib name with extension
 
@@ -34,9 +35,9 @@ def _get_lib_name(lib_name):
     lib_name : str
         Name of VTA shared library
     """
-    if sys.platform.startswith('win32'):
+    if sys.platform.startswith("win32"):
         return lib_name + ".dll"
-    if sys.platform.startswith('darwin'):
+    if sys.platform.startswith("darwin"):
         return lib_name + ".dylib"
     return lib_name + ".so"
 
@@ -58,12 +59,21 @@ def find_libvta(lib_vta, optional=False):
         Enable error check
     """
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-    lib_search = [os.path.join(curr_path, "..", "..", "..", "build",)]
+    lib_search = [
+        os.path.join(
+            curr_path,
+            "..",
+            "..",
+            "..",
+            "build",
+        )
+    ]
     lib_search += [os.path.join(get_vta_hw_path(), "build")]
     lib_name = _get_lib_name(lib_vta)
     lib_path = [os.path.join(x, lib_name) for x in lib_search]
     lib_found = [x for x in lib_path if os.path.exists(x)]
     if not lib_found and not optional:
-        raise RuntimeError('Cannot find the files.\n' +
-                           'List of candidates:\n' + str('\n'.join(lib_path)))
+        raise RuntimeError(
+            "Cannot find the files.\n" + "List of candidates:\n" + str("\n".join(lib_path))
+        )
     return lib_found
index 62cb5f2..556933a 100644 (file)
 import os
 import argparse
 
+
 def main():
     """Main function"""
     parser = argparse.ArgumentParser()
-    parser.add_argument("target", type=str, default="",
-                        help="target")
-    parser.add_argument("bitstream", type=str, default="",
-                        help="bitstream path")
+    parser.add_argument("target", type=str, default="", help="target")
+    parser.add_argument("bitstream", type=str, default="", help="bitstream path")
     args = parser.parse_args()
 
-    if args.target not in ('pynq', 'ultra96', 'de10nano', 'sim', 'tsim'):
+    if args.target not in ("pynq", "ultra96", "de10nano", "sim", "tsim"):
         raise RuntimeError("Unknown target {}".format(args.target))
 
-    curr_path = os.path.dirname(
-        os.path.abspath(os.path.expanduser(__file__)))
+    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
     path_list = [
         os.path.join(curr_path, "/{}".format(args.bitstream)),
-        os.path.join('./', "{}".format(args.bitstream))
+        os.path.join("./", "{}".format(args.bitstream)),
     ]
     ok_path_list = [p for p in path_list if os.path.exists(p)]
     if not ok_path_list:
@@ -42,28 +40,34 @@ def main():
 
     bitstream_program(args.target, args.bitstream)
 
+
 def pynq_bitstream_program(bitstream_path):
     # pylint: disable=import-outside-toplevel
     from pynq import Bitstream
+
     bitstream = Bitstream(bitstream_path)
     bitstream.download()
 
+
 def de10nano_bitstream_program(bitstream_path):
     # pylint: disable=import-outside-toplevel
     from tvm import get_global_func
+
     program = get_global_func("vta.de10nano.program")
     program(bitstream_path)
 
+
 def bitstream_program(target, bitstream):
-    if target in ['pynq', 'ultra96']:
+    if target in ["pynq", "ultra96"]:
         pynq_bitstream_program(bitstream)
-    elif target in ['de10nano']:
+    elif target in ["de10nano"]:
         de10nano_bitstream_program(bitstream)
-    elif target in ['sim', 'tsim']:
+    elif target in ["sim", "tsim"]:
         # In simulation, bit stream programming is a no-op
         return
     else:
         raise RuntimeError("Unknown target {}".format(target))
 
+
 if __name__ == "__main__":
     main()
index 097ea8e..02ff8be 100644 (file)
@@ -20,6 +20,7 @@ import os
 from .environment import get_env
 from .bitstream import download_bitstream, get_bitstream_path
 
+
 def reconfig_runtime(remote):
     """Reconfigure remote runtime based on current hardware spec.
 
@@ -50,7 +51,7 @@ def program_fpga(remote, bitstream=None):
         bitstream = get_bitstream_path()
         if not os.path.isfile(bitstream):
             env = get_env()
-            if env.TARGET == 'de10nano':
+            if env.TARGET == "de10nano":
                 return
             download_bitstream()
 
index 8703b1c..fbc50b0 100644 (file)
@@ -17,4 +17,4 @@
 
 """Testing utilities, this namespace is not imported by default."""
 
-from . util import run
+from .util import run
index 16827c4..7f2471c 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from ..environment import get_env
 from ..libinfo import find_libvta
 
+
 def _load_sw():
     """Load hardware library for simulator."""
 
@@ -37,7 +38,7 @@ def _load_sw():
 
     if env.TARGET == "tsim":
         lib_hw = find_libvta("libvta_hw", optional=True)
-        assert lib_hw # make sure to make in ${VTA_HW_PATH}/hardware/chisel
+        assert lib_hw  # make sure to make in ${VTA_HW_PATH}/hardware/chisel
         try:
             f = tvm.get_global_func("vta.tsim.init")
             m = tvm.runtime.load_module(lib_hw[0], "vta-tsim")
@@ -85,6 +86,7 @@ def stats():
 # debug flag to skip execution.
 DEBUG_SKIP_EXEC = 1
 
+
 def debug_mode(flag):
     """Set debug mode
     Paramaters
index afbf00d..99d8d40 100644 (file)
@@ -63,10 +63,9 @@ def run(run_func):
         pynq_port = os.environ.get("VTA_RPC_PORT", None)
         # Run device from fleet node if env variables are defined
         if tracker_host and tracker_port:
-            remote = autotvm.measure.request_remote(env.TARGET,
-                                                    tracker_host,
-                                                    int(tracker_port),
-                                                    timeout=10000)
+            remote = autotvm.measure.request_remote(
+                env.TARGET, tracker_host, int(tracker_port), timeout=10000
+            )
             run_func(env, remote)
         else:
             # Next, run on PYNQ if env variables are defined
@@ -75,7 +74,8 @@ def run(run_func):
                 run_func(env, remote)
             else:
                 raise RuntimeError(
-                    "Please set the VTA_RPC_HOST and VTA_RPC_PORT environment variables")
+                    "Please set the VTA_RPC_HOST and VTA_RPC_PORT environment variables"
+                )
 
     else:
         raise RuntimeError("Unknown target %s" % env.TARGET)
index 48a5c1c..52bd13d 100644 (file)
@@ -26,6 +26,7 @@ from tvm.topi import util
 from tvm.relay.op.op import register_compute, register_injective_schedule
 from tvm.relay.op.op import register_pattern, OpPattern
 
+
 def bitpack(data, bits, pack_type="int8", name="bitpack"):
     """Packs lowest dimension into format needed by VTA
 
@@ -42,11 +43,11 @@ def bitpack(data, bits, pack_type="int8", name="bitpack"):
         The packed tensor.
     """
     shape_vec = list(data.shape)
-    if pack_type == 'int8':
+    if pack_type == "int8":
         data_width = 8
-    elif pack_type == 'int16':
+    elif pack_type == "int16":
         data_width = 16
-    elif pack_type == 'int32':
+    elif pack_type == "int32":
         data_width = 32
     else:
         raise RuntimeError("Unknown pack type %s" % pack_type)
@@ -72,8 +73,7 @@ def bitpack(data, bits, pack_type="int8", name="bitpack"):
                 ret = ret | val
         return ret
 
-    return te.compute(
-        oshape, _bitpack, name=name, tag='bitpack')
+    return te.compute(oshape, _bitpack, name=name, tag="bitpack")
 
 
 @register_compute("bitpack", level=15)
@@ -86,5 +86,6 @@ def compute_bitpack(attrs, inputs):
     bits = 8 // lanes
     return bitpack(inputs[0], bits, dtype)
 
+
 register_injective_schedule("bitpack")
 register_pattern("bitpack", OpPattern.INJECTIVE)
index 633ef3f..15421c9 100644 (file)
@@ -22,6 +22,7 @@ from tvm import relay
 from tvm.relay import op, transform
 from tvm.relay import ExprMutator
 
+
 def run_opt_pass(expr, opt_pass):
     """Exectue a relay pass."""
     assert isinstance(opt_pass, tvm.transform.Pass)
@@ -30,47 +31,51 @@ def run_opt_pass(expr, opt_pass):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
+
 def _to_shape(shape):
-    """ convert shape into tuple.
-    """
+    """convert shape into tuple."""
     return tuple(int(sh) for sh in shape)
 
+
 def _pack_batch_channel(data, dshape, bfactor, cfactor):
-    """Pack the data channel dimension.
-    """
+    """Pack the data channel dimension."""
     assert int(dshape[0]) % bfactor == 0
     assert int(dshape[1]) % cfactor == 0
-    data = op.reshape(data,
-                      newshape=(int(dshape[0]) // bfactor, bfactor,
-                                int(dshape[1]) // cfactor, cfactor,
-                                int(dshape[2]), int(dshape[3])))
-    data = op.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
+    data = op.reshape(
+        data,
+        newshape=(
+            int(dshape[0]) // bfactor,
+            bfactor,
+            int(dshape[1]) // cfactor,
+            cfactor,
+            int(dshape[2]),
+            int(dshape[3]),
+        ),
+    )
+    data = op.transpose(data, axes=(0, 2, 4, 5, 1, 3))
     return data
 
 
 def _unpack_batch_channel(data, old_shape):
-    """Unpack the data channel dimension.
-    """
+    """Unpack the data channel dimension."""
     data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
     data = op.reshape(data, newshape=old_shape)
     return data
 
 
 def _const_shape_match(data, dshape, cfactor_out):
-    """ Pad the constant if the shape[0] not divisible by cfactor_out.
-    """
+    """Pad the constant if the shape[0] not divisible by cfactor_out."""
     assert len(dshape) == 3
     pad_width = int(dshape[0]) % cfactor_out
     if pad_width != 0:
-        pad_width = cfactor_out -pad_width
+        pad_width = cfactor_out - pad_width
         data = op.nn.pad(data, [[0, pad_width], [0, 0], [0, 0]])
         dshape = tuple([dshape[0] + pad_width, dshape[1], dshape[2]])
     return data, dshape
 
+
 def _weight_shape_match(data, dshape, channels, cfactor_out, transpose=False):
-    """ Pad the weight if the shape[0] not divisible by cfactor_out.
-    """
+    """Pad the weight if the shape[0] not divisible by cfactor_out."""
     assert len(dshape) == 4
     pad_width = int(dshape[0]) % cfactor_out
     channels_pad = int(channels) % cfactor_out
@@ -84,9 +89,9 @@ def _weight_shape_match(data, dshape, channels, cfactor_out, transpose=False):
 
     return data, dshape, channels
 
+
 def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
-    """ Pad the weight if the shape[1] not divisible by cfactor_out.
-    """
+    """Pad the weight if the shape[1] not divisible by cfactor_out."""
     assert len(dshape) == 4
     pad_width = int(dshape[1]) % cfactor_out
     channels_pad = int(channels) % cfactor_out
@@ -100,91 +105,97 @@ def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
 
     return data, dshape, channels
 
+
 def _pack_weight(data, dshape, cfactor):
-    """Pack the weight into packed format.
-    """
+    """Pack the weight into packed format."""
     assert len(dshape) == 4
     assert int(dshape[0]) % cfactor == 0
     assert int(dshape[1]) % cfactor == 0
-    data = op.reshape(data,
-                      newshape=(int(dshape[0]) // cfactor, cfactor,
-                                int(dshape[1]) // cfactor, cfactor,
-                                int(dshape[2]), int(dshape[3])))
-    data = op.transpose(
-        data, axes=(0, 2, 4, 5, 1, 3))
+    data = op.reshape(
+        data,
+        newshape=(
+            int(dshape[0]) // cfactor,
+            cfactor,
+            int(dshape[1]) // cfactor,
+            cfactor,
+            int(dshape[2]),
+            int(dshape[3]),
+        ),
+    )
+    data = op.transpose(data, axes=(0, 2, 4, 5, 1, 3))
     return data
 
 
 def _pack_weight_conv2d_transpose(data, dshape, cfactor):
-    """Pack the weight into packed format.
-    """
+    """Pack the weight into packed format."""
     dshape = _to_shape(dshape)
     assert len(dshape) == 4
     assert dshape[0] % cfactor == 0
     assert dshape[1] % cfactor == 0
-    data = op.reshape(data,
-                      newshape=(dshape[0] // cfactor, cfactor,
-                                dshape[1] // cfactor, cfactor,
-                                dshape[2], dshape[3]))
-    data = op.transpose(
-        data, axes=(2, 0, 4, 5, 3, 1))
+    data = op.reshape(
+        data,
+        newshape=(
+            dshape[0] // cfactor,
+            cfactor,
+            dshape[1] // cfactor,
+            cfactor,
+            dshape[2],
+            dshape[3],
+        ),
+    )
+    data = op.transpose(data, axes=(2, 0, 4, 5, 3, 1))
     return data
 
 
 def _pack_const(data, dshape, dtype, bfactor, cfactor):
-    """Pack a constant parameter.
-    """
+    """Pack a constant parameter."""
     dshape = _to_shape(dshape)
     assert len(dshape) == 3
     assert dshape[0] % cfactor == 0
-    data = op.reshape(data,
-                      newshape=(dshape[0] // cfactor,
-                                cfactor, dshape[1],
-                                dshape[2], 1))
-    data = op.transpose(
-        data, axes=(0, 2, 3, 4, 1))
+    data = op.reshape(data, newshape=(dshape[0] // cfactor, cfactor, dshape[1], dshape[2], 1))
+    data = op.transpose(data, axes=(0, 2, 3, 4, 1))
 
     # broadcast batch dimension to bfactor
     data = op.broadcast_to(
-        data,
-        shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
+        data, shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)
+    )
     return data
 
 
 def _get_tensor_shape(node):
-    """Get node shape.
-    """
+    """Get node shape."""
     if isinstance(node.checked_type, relay.ty.TensorType):
         return _to_shape(node.checked_type.shape)
     return []
 
+
 def _get_tensor_type(node):
-    """Get node type.
-    """
+    """Get node type."""
     if isinstance(node.checked_type, relay.ty.TensorType):
         return node.checked_type.dtype
     return "float32"
 
+
 def _operator_idx_inc(expr, count_meta, operator_current_idx):
-    """Increase operator index
-    """
+    """Increase operator index"""
     if isinstance(expr, relay.expr.Constant):
         operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
     else:
         operator_current_idx = operator_current_idx + 1
     return operator_current_idx
 
+
 class ExprPack(ExprMutator):
-    """Visitor to perform graph packing on an AST.
-    """
+    """Visitor to perform graph packing on an AST."""
+
     def __init__(self, bfactor, cfactor, weight_bits):
         self.bfactor = bfactor
         self.cfactor = cfactor
         self.weight_bits = weight_bits
         self.start_pack = False
         # Cache Operator the algorithm matches against.
-        self.bitpack_start = op.op.get('annotation.bitpack_start')
-        self.bitpack_end = op.op.get('annotation.bitpack_end')
+        self.bitpack_start = op.op.get("annotation.bitpack_start")
+        self.bitpack_end = op.op.get("annotation.bitpack_end")
         self.conv2d = op.op.get("nn.conv2d")
         self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
         self.add = op.op.get("add")
@@ -217,7 +228,7 @@ class ExprPack(ExprMutator):
                 return _unpack_batch_channel(data, data_shape)
         if self.start_pack:
             # Operator cases
-            if call.op == self.conv2d and odtype == 'int32':
+            if call.op == self.conv2d and odtype == "int32":
                 self.number_of_conv2d += 1
                 assert 8 % self.weight_bits == 0
                 w_lanes = 8 // self.weight_bits
@@ -227,10 +238,9 @@ class ExprPack(ExprMutator):
                 data_shape = _to_shape(input_types[0].shape)
                 kernel_shape = _to_shape(input_types[1].shape)
                 channels = call.attrs.channels
-                weight, kernel_shape, channels = _weight_shape_match(weight,
-                                                                     kernel_shape,
-                                                                     channels,
-                                                                     self.cfactor)
+                weight, kernel_shape, channels = _weight_shape_match(
+                    weight, kernel_shape, channels, self.cfactor
+                )
                 kernel = _pack_weight(weight, kernel_shape, self.cfactor)
                 # insert bit packing when necessary
                 if w_lanes != 1:
@@ -248,10 +258,11 @@ class ExprPack(ExprMutator):
                     kernel_size=call.attrs.kernel_size,
                     data_layout=data_layout,
                     kernel_layout=kernel_layout,
-                    out_dtype=call.attrs.out_dtype)
+                    out_dtype=call.attrs.out_dtype,
+                )
                 return conv2d
 
-            if call.op == self.conv2d_transpose and odtype == 'int32':
+            if call.op == self.conv2d_transpose and odtype == "int32":
                 self.number_of_conv2d += 1
                 assert 8 % self.weight_bits == 0
                 w_lanes = 8 // self.weight_bits
@@ -262,10 +273,9 @@ class ExprPack(ExprMutator):
                     data_shape = _to_shape(input_types[0].shape)
                     kernel_shape = _to_shape(input_types[1].shape)
                     channels = call.attrs.channels
-                    weight, kernel_shape, channels = _weight_shape_match_transpose(weight,
-                                                                                   kernel_shape,
-                                                                                   channels,
-                                                                                   self.cfactor)
+                    weight, kernel_shape, channels = _weight_shape_match_transpose(
+                        weight, kernel_shape, channels, self.cfactor
+                    )
                     kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
                     conv2d = op.nn.conv2d_transpose(
                         data,
@@ -279,100 +289,98 @@ class ExprPack(ExprMutator):
                         data_layout=data_layout,
                         kernel_layout=kernel_layout,
                         output_padding=call.attrs.output_padding,
-                        out_dtype=call.attrs.out_dtype)
+                        out_dtype=call.attrs.out_dtype,
+                    )
                 return conv2d
-            if call.op == self.add and \
-                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
+            if call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
                 pass
             elif call.op == self.add and len(input_types[1].shape) == 3:
                 data, const = args
-                const, input_shape = _const_shape_match(const,
-                                                        input_types[1].shape,
-                                                        self.cfactor)
-                const = _pack_const(const,
-                                    _to_shape(input_shape),
-                                    input_types[1].dtype,
-                                    self.bfactor,
-                                    self.cfactor)
+                const, input_shape = _const_shape_match(const, input_types[1].shape, self.cfactor)
+                const = _pack_const(
+                    const, _to_shape(input_shape), input_types[1].dtype, self.bfactor, self.cfactor
+                )
                 return relay.Call(self.add, [data, const])
-            elif call.op == self.multiply and \
-                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
+            elif call.op == self.multiply and tuple(input_types[0].shape) == tuple(
+                input_types[1].shape
+            ):
                 pass
             elif call.op == self.multiply and len(input_types[1].shape) == 3:
                 data, const = args
-                const = _pack_const(const,
-                                    _to_shape(input_types[1].shape),
-                                    input_types[1].dtype,
-                                    self.bfactor,
-                                    self.cfactor)
+                const = _pack_const(
+                    const,
+                    _to_shape(input_types[1].shape),
+                    input_types[1].dtype,
+                    self.bfactor,
+                    self.cfactor,
+                )
                 return relay.Call(self.multiply, [data, const])
             elif self.start_pack and call.op == self.bias_add:
                 data, bias = args
-                bias = _pack_const(bias,
-                                   _to_shape(input_types[1].shape),
-                                   input_types[1].dtype,
-                                   self.bfactor,
-                                   self.cfactor)
+                bias = _pack_const(
+                    bias,
+                    _to_shape(input_types[1].shape),
+                    input_types[1].dtype,
+                    self.bfactor,
+                    self.cfactor,
+                )
                 return relay.Call(self.add, [data, bias])
-            elif self.start_pack and call.op == op.op.get('cast') and \
-                    input_types[0].dtype == 'int32':
-                cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
-                return relay.Call(op.op.get('copy'), [cast])
+            elif (
+                self.start_pack and call.op == op.op.get("cast") and input_types[0].dtype == "int32"
+            ):
+                cast = relay.Call(op.op.get("cast"), [args[0]], call.attrs)
+                return relay.Call(op.op.get("copy"), [cast])
             elif call.op == self.pad:
                 pad_width = call.attrs.pad_width
                 if len(pad_width) == 6:
                     pass
                 elif len(pad_width) == 4:
-                    data, = args
+                    (data,) = args
                     new_pad_width = []
                     new_pad_width.extend(pad_width)
                     for _ in range(2):
                         new_pad_width.append([0, 0])
-                    return op.nn.pad(data,
-                                     pad_value=call.attrs.pad_value,
-                                     pad_width=new_pad_width)
+                    return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width)
             elif call.op == self.upsampling:
-                data, = args
+                (data,) = args
                 scale_h = call.attrs.scale_h
                 scale_w = call.attrs.scale_w
                 data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                 method = call.attrs.method
                 align_corners = call.attrs.align_corners
-                return op.nn.upsampling(data,
-                                        scale_h,
-                                        scale_w,
-                                        data_layout,
-                                        method,
-                                        align_corners)
+                return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners)
             elif call.op == self.reshape and len(input_types[0].shape) == 4:
-                data, = args
+                (data,) = args
                 data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
                 return op.reshape(data, [int(x) for x in input_types[0].shape])
 
-        return relay.Call(
-            self.visit(call.op),
-            args,
-            call.attrs)
+        return relay.Call(self.visit(call.op), args, call.attrs)
+
 
 class BT(Exception):
     pass
+
+
 def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
-    """ We assume stop_name only appears once for simplicity.
-        This constraint will be lifted in the future.
-        bitpack_start and bitpack_end are both inclusive.
+    """We assume stop_name only appears once for simplicity.
+    This constraint will be lifted in the future.
+    bitpack_start and bitpack_end are both inclusive.
     """
-    bitpack_start = op.op.get('annotation.bitpack_start')
-    bitpack_end = op.op.get('annotation.bitpack_end')
+    bitpack_start = op.op.get("annotation.bitpack_start")
+    bitpack_end = op.op.get("annotation.bitpack_end")
     anf = run_opt_pass(expr, transform.ToANormalForm())
     operator_current_idx = 0
+
     def _recursion(anf, start_found, stop_found, operator_current_idx):
-        """ Helper to obtain the subgraph.
-        """
+        """Helper to obtain the subgraph."""
         if isinstance(anf, relay.Function):
-            return relay.Function(anf.params,
-                                  _recursion(anf.body, start_found, stop_found,
-                                             operator_current_idx),
-                                  anf.ret_type, anf.type_params, anf.attrs)
+            return relay.Function(
+                anf.params,
+                _recursion(anf.body, start_found, stop_found, operator_current_idx),
+                anf.ret_type,
+                anf.type_params,
+                anf.attrs,
+            )
         if isinstance(anf, relay.expr.Let):
             value = anf.value
             if isinstance(value, relay.expr.Call):
@@ -388,8 +396,11 @@ def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, cou
             operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
 
             try:
-                return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
-                                                                 operator_current_idx))
+                return relay.expr.Let(
+                    anf.var,
+                    value,
+                    _recursion(anf.body, start_found, stop_found, operator_current_idx),
+                )
             except BT:
                 assert start_found
                 assert not stop_found
@@ -401,18 +412,22 @@ def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, cou
             assert start_found
             assert stop_found
             return anf
+
     annotated = _recursion(anf, False, False, operator_current_idx)
     return run_opt_pass(annotated, transform.ToGraphNormalForm())
 
-def graph_pack(expr,
-               bfactor,
-               cfactor,
-               weight_bits,
-               start_name="nn.max_pool2d",
-               stop_name="nn.global_avg_pool2d",
-               start_name_idx=None,
-               stop_name_idx=None,
-               count_meta=False):
+
+def graph_pack(
+    expr,
+    bfactor,
+    cfactor,
+    weight_bits,
+    start_name="nn.max_pool2d",
+    stop_name="nn.global_avg_pool2d",
+    start_name_idx=None,
+    stop_name_idx=None,
+    count_meta=False,
+):
     """Pack the graph into batch&channel packed format.
 
     Parameters
@@ -455,12 +470,10 @@ def graph_pack(expr,
         The transformed expression.
     """
     assert isinstance(expr, relay.Function)
-    assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
+    assert (start_name != stop_name) or (start_name_idx < stop_name_idx)
     expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
     expr = run_opt_pass(expr, transform.InferType())
-    packer = ExprPack(
-        bfactor, cfactor,
-        weight_bits)
+    packer = ExprPack(bfactor, cfactor, weight_bits)
     expr = packer.visit(expr)
     assert not packer.start_pack
     return run_opt_pass(expr, transform.InferType())
index 8280798..2710557 100644 (file)
@@ -46,22 +46,24 @@ def compute_clip_vta(attrs, inputs, output_type):
     const_min = tvm.tir.const(a_min, x.dtype)
     const_max = tvm.tir.const(a_max, x.dtype)
     with tvm.te.tag_scope(topi.tag.ELEMWISE):
-        x = te.compute(
-            x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
-        x = te.compute(
-            x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
+        x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
+        x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return [x]
 
+
 def clip_strategy_vta(attrs, inputs, out_type, target):
     strategy = OpStrategy()
     strategy.add_implementation(
         compute_clip_vta,
         _strategy.wrap_topi_schedule(topi.generic.schedule_injective),
-        name="clip.vta")
+        name="clip.vta",
+    )
     return strategy
 
+
 reg.get("clip").get_attr("FTVMStrategy").register(clip_strategy_vta, "vta")
 
+
 @_strategy.conv2d_strategy.register("vta")
 def conv2d_strategy_vta(attrs, inputs, out_type, target):
     """conv2d vta strategy"""
@@ -82,12 +84,14 @@ def conv2d_strategy_vta(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 _strategy.wrap_compute_conv2d(conv2d_packed, True),
                 _strategy.wrap_topi_schedule(schedule_conv2d_packed),
-                name="conv2d_packed.vta")
-        else: # group_conv2d
+                name="conv2d_packed.vta",
+            )
+        else:  # group_conv2d
             strategy.add_implementation(
                 _strategy.wrap_compute_conv2d(group_conv2d_packed, has_groups=True),
                 _strategy.wrap_topi_schedule(schedule_group_conv2d_packed),
-                name="group_conv2d_packed.vta")
+                name="group_conv2d_packed.vta",
+            )
         return strategy
 
     # If it's not packed, run on ARM CPU
@@ -107,7 +111,8 @@ def conv2d_transpose_strategy_vta(attrs, inputs, out_type, target):
         strategy.add_implementation(
             _strategy.wrap_compute_conv2d_transpose(conv2d_transpose_packed),
             _strategy.wrap_topi_schedule(schedule_conv2d_transpose_packed),
-            name="conv2d_transpose_packed.vta")
+            name="conv2d_transpose_packed.vta",
+        )
         return strategy
 
     # If it's not packed, run on ARM CPU
@@ -118,12 +123,13 @@ def conv2d_transpose_strategy_vta(attrs, inputs, out_type, target):
 @_strategy.dense_strategy.register("vta")
 def dense_strategy_vta(attrs, inputs, out_type, target):
     """dense vta strategy"""
-    if inputs[0].shape == 4: # this implies the layout is packed
+    if inputs[0].shape == 4:  # this implies the layout is packed
         strategy = OpStrategy()
         strategy.add_implementation(
             _strategy.wrap_compute_dense(dense_packed),
             _strategy.wrap_topi_schedule(schedule_dense_packed),
-            name="dense_packed.vta")
+            name="dense_packed.vta",
+        )
         return strategy
     # If it's not packed, run on ARM CPU
     arm_tgt = tvm.target.arm_cpu(target.model)
index 0fbdb2f..46a3a88 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 """VTA TOPI Utils."""
 
+
 def is_packed_layout(layout):
     """Check if layout is packed layout"""
     if layout == "NCHW":
index 799b105..e155565 100644 (file)
@@ -26,6 +26,7 @@ from tvm import topi
 from .util import is_packed_layout
 from ..environment import get_env
 
+
 @autotvm.register_topi_compute("conv2d_packed.vta")
 def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """ Packed conv2d function."""
@@ -45,24 +46,34 @@ def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dty
 
     ishape = topi.util.get_const_tuple(data.shape)
     kshape = topi.util.get_const_tuple(kernel.shape)
-    d_i = te.reduce_axis((0, kshape[2]), name='d_i')
-    d_j = te.reduce_axis((0, kshape[3]), name='d_j')
-    k_o = te.reduce_axis((0, ishape[1]), name='k_o')
-    k_i = te.reduce_axis((0, ishape[-1]), name='k_i')
+    d_i = te.reduce_axis((0, kshape[2]), name="d_i")
+    d_j = te.reduce_axis((0, kshape[3]), name="d_j")
+    k_o = te.reduce_axis((0, ishape[1]), name="k_o")
+    k_i = te.reduce_axis((0, ishape[-1]), name="k_i")
     hstride, wstride = strides
     res = te.compute(
         oshape,
         lambda b_o, c_o, i, j, b_i, c_i: te.sum(
-            pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) *
-            kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
-            axis=[k_o, d_i, d_j, k_i]),
-        name="res", tag="conv2d_dense")
-
-    cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
-                 kshape[2] * kshape[3] * ishape[1] * ishape[-1])
+            pad_data[b_o, k_o, i * hstride + d_i, j * wstride + d_j, b_i, k_i].astype(out_dtype)
+            * kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
+            axis=[k_o, d_i, d_j, k_i],
+        ),
+        name="res",
+        tag="conv2d_dense",
+    )
+
+    cfg.add_flop(
+        2
+        * np.prod(topi.util.get_const_tuple(oshape))
+        * kshape[2]
+        * kshape[3]
+        * ishape[1]
+        * ishape[-1]
+    )
 
     return res
 
+
 @autotvm.register_topi_schedule("conv2d_packed.vta")
 def schedule_conv2d_packed(cfg, outs):
     """Schedule packed conv2d"""
@@ -98,13 +109,13 @@ def schedule_conv2d_packed(cfg, outs):
     ##### space definition begin #####
     b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis
     c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
-    cfg.define_split('tile_b', b, num_outputs=2)
-    cfg.define_split('tile_h', x_i, num_outputs=2)
-    cfg.define_split('tile_w', x_j, num_outputs=2)
-    cfg.define_split('tile_ci', c_i, num_outputs=2)
-    cfg.define_split('tile_co', c_o, num_outputs=2)
-    cfg.define_knob('oc_nthread', [1, 2])
-    cfg.define_knob('h_nthread', [1, 2])
+    cfg.define_split("tile_b", b, num_outputs=2)
+    cfg.define_split("tile_h", x_i, num_outputs=2)
+    cfg.define_split("tile_w", x_j, num_outputs=2)
+    cfg.define_split("tile_ci", c_i, num_outputs=2)
+    cfg.define_split("tile_co", c_o, num_outputs=2)
+    cfg.define_knob("oc_nthread", [1, 2])
+    cfg.define_knob("h_nthread", [1, 2])
     ###### space definition end ######
 
     data, kernel = conv2d_stage.op.input_tensors
@@ -129,8 +140,7 @@ def schedule_conv2d_packed(cfg, outs):
     # cache read input
     cache_read_ewise = []
     for consumer, tensor in ewise_inputs:
-        cache_read_ewise.append(
-            s.cache_read(tensor, env.acc_scope, [consumer]))
+        cache_read_ewise.append(s.cache_read(tensor, env.acc_scope, [consumer]))
 
     # set ewise scope
     for op in ewise_ops:
@@ -142,9 +152,9 @@ def schedule_conv2d_packed(cfg, outs):
 
     # tile
     x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
-    x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
-    x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
-    x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
+    x_co0, x_co1 = cfg["tile_co"].apply(s, output, x_co)
+    x_i0, x_i1 = cfg["tile_h"].apply(s, output, x_i)
+    x_j0, x_j1 = cfg["tile_w"].apply(s, output, x_j)
     s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
     store_pt = x_j0
 
@@ -158,14 +168,14 @@ def schedule_conv2d_packed(cfg, outs):
         s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
 
     # virtual threading along output channel axes
-    if cfg['oc_nthread'].val > 1:
-        _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
+    if cfg["oc_nthread"].val > 1:
+        _, v_t = s[output].split(x_co0, factor=cfg["oc_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
     # virtual threading along spatial rows
-    if cfg['h_nthread'].val > 1:
-        _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
+    if cfg["h_nthread"].val > 1:
+        _, v_t = s[output].split(x_i0, factor=cfg["h_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
@@ -173,7 +183,7 @@ def schedule_conv2d_packed(cfg, outs):
     k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
     s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
 
-    k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
+    k_o, _ = cfg["tile_ci"].apply(s, conv2d_stage, k_o)
     s[cdata].compute_at(s[conv2d_stage], k_o)
     s[ckernel].compute_at(s[conv2d_stage], k_o)
 
index ea0dfce..c020747 100644 (file)
@@ -27,9 +27,9 @@ from tvm.topi.nn.util import get_pad_tuple
 
 from ..environment import get_env
 
+
 @autotvm.register_topi_compute("conv2d_transpose_packed.vta")
-def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype,
-                            output_padding=(0, 0)):
+def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype, output_padding=(0, 0)):
     """Packed conv2d_transpose compute"""
     ishape = get_const_tuple(data.shape)
     kshape = get_const_tuple(kernel.shape)
@@ -49,33 +49,42 @@ def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype,
 
     # padding stage
     dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1])
-    data_pad = topi.nn.pad(dilated_input,
-                           [0, 0, bpad_top, bpad_left, 0, 0],
-                           [0, 0, bpad_bottom, bpad_right, 0, 0])
+    data_pad = topi.nn.pad(
+        dilated_input, [0, 0, bpad_top, bpad_left, 0, 0], [0, 0, bpad_bottom, bpad_right, 0, 0]
+    )
 
     # convolution transpose stage
     out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h
     out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w
     oshape = (b, c_o, out_h, out_w, t_b, t_co)
-    d_c = te.reduce_axis((0, c_i), name='d_c')
-    d_h = te.reduce_axis((0, k_h), name='d_h')
-    d_w = te.reduce_axis((0, k_w), name='d_w')
-    d_ci = te.reduce_axis((0, t_ci), name='d_ci')
+    d_c = te.reduce_axis((0, c_i), name="d_c")
+    d_h = te.reduce_axis((0, k_h), name="d_h")
+    d_w = te.reduce_axis((0, k_w), name="d_w")
+    d_ci = te.reduce_axis((0, t_ci), name="d_ci")
 
     out = te.compute(
         oshape,
         lambda i_n, i_c, i_h, i_w, j_n, j_c: te.sum(
-            data_pad(i_n, d_c, i_h + d_h, i_w + d_w, j_n, d_ci).astype(out_dtype) *
-            kernel[i_c, d_c, d_h, d_w, j_c, d_ci].astype(out_dtype),
-            axis=[d_c, d_h, d_w, d_ci]),
+            data_pad(i_n, d_c, i_h + d_h, i_w + d_w, j_n, d_ci).astype(out_dtype)
+            * kernel[i_c, d_c, d_h, d_w, j_c, d_ci].astype(out_dtype),
+            axis=[d_c, d_h, d_w, d_ci],
+        ),
         tag="packed_conv2d_transpose",
-        name='res')
-
-    cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
-                 kshape[2] * kshape[3] * ishape[1] * ishape[-1])
+        name="res",
+    )
+
+    cfg.add_flop(
+        2
+        * np.prod(topi.util.get_const_tuple(oshape))
+        * kshape[2]
+        * kshape[3]
+        * ishape[1]
+        * ishape[-1]
+    )
 
     return out
 
+
 @autotvm.register_topi_schedule("conv2d_transpose_packed.vta")
 def schedule_conv2d_transpose_packed(cfg, outs):
     """Schedule packed conv2d_transpose"""
@@ -108,13 +117,13 @@ def schedule_conv2d_transpose_packed(cfg, outs):
     ##### space definition begin #####
     b, c_o, x_i, x_j, _, c_i = s[conv2d_stage].op.axis
     c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
-    cfg.define_split('tile_b', b, num_outputs=2)
-    cfg.define_split('tile_h', x_i, num_outputs=2)
-    cfg.define_split('tile_w', x_j, num_outputs=2)
-    cfg.define_split('tile_ci', c_i, num_outputs=2)
-    cfg.define_split('tile_co', c_o, num_outputs=2)
-    cfg.define_knob('oc_nthread', [1, 2])
-    cfg.define_knob('h_nthread', [1, 2])
+    cfg.define_split("tile_b", b, num_outputs=2)
+    cfg.define_split("tile_h", x_i, num_outputs=2)
+    cfg.define_split("tile_w", x_j, num_outputs=2)
+    cfg.define_split("tile_ci", c_i, num_outputs=2)
+    cfg.define_split("tile_co", c_o, num_outputs=2)
+    cfg.define_knob("oc_nthread", [1, 2])
+    cfg.define_knob("h_nthread", [1, 2])
     ###### space definition end ######
 
     data, kernel = conv2d_stage.op.input_tensors
@@ -139,8 +148,7 @@ def schedule_conv2d_transpose_packed(cfg, outs):
     # cache read input
     cache_read_ewise = []
     for consumer, tensor in ewise_inputs:
-        cache_read_ewise.append(
-            s.cache_read(tensor, env.acc_scope, [consumer]))
+        cache_read_ewise.append(s.cache_read(tensor, env.acc_scope, [consumer]))
     # set ewise scope
     for op in ewise_ops:
         s[op].set_scope(env.acc_scope)
@@ -148,9 +156,9 @@ def schedule_conv2d_transpose_packed(cfg, outs):
 
     # tile
     x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
-    x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
-    x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
-    x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
+    x_co0, x_co1 = cfg["tile_co"].apply(s, output, x_co)
+    x_i0, x_i1 = cfg["tile_h"].apply(s, output, x_i)
+    x_j0, x_j1 = cfg["tile_w"].apply(s, output, x_j)
     s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
     store_pt = x_j0
 
@@ -164,14 +172,14 @@ def schedule_conv2d_transpose_packed(cfg, outs):
         s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
 
     # virtual threading along output channel axes
-    if cfg['oc_nthread'].val > 1:
-        _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
+    if cfg["oc_nthread"].val > 1:
+        _, v_t = s[output].split(x_co0, factor=cfg["oc_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
     # virtual threading along spatial rows
-    if cfg['h_nthread'].val > 1:
-        _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
+    if cfg["h_nthread"].val > 1:
+        _, v_t = s[output].split(x_i0, factor=cfg["h_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
@@ -184,7 +192,7 @@ def schedule_conv2d_transpose_packed(cfg, outs):
     for axis in [d_j, d_i, x_ii, x_jj]:
         s[conv2d_stage].unroll(axis)
 
-    k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
+    k_o, _ = cfg["tile_ci"].apply(s, conv2d_stage, k_o)
     s[cdata].compute_at(s[conv2d_stage], k_o)
     s[ckernel].compute_at(s[conv2d_stage], k_o)
 
index 0b98261..4a618f1 100644 (file)
@@ -25,6 +25,7 @@ from tvm import topi
 
 from ..environment import get_env
 
+
 def is_packed_layout(layout):
     """Check if layout is packed layout"""
     if layout == "NCHW":
@@ -33,6 +34,7 @@ def is_packed_layout(layout):
         return True
     return False
 
+
 @autotvm.register_topi_compute("dense_packed.vta")
 def dense_packed(cfg, data, weight, bias=None, out_dtype=None):
     """Dense function declaration."""
@@ -49,21 +51,24 @@ def dense_packed(cfg, data, weight, bias=None, out_dtype=None):
     # Reduction axes (input channel)
     assert ishape[1] == wshape[1]
     assert ishape[3] == wshape[3]
-    k_o = te.reduce_axis((0, ishape[1]), name='k_o')
-    k_i = te.reduce_axis((0, ishape[3]), name='k_i')
+    k_o = te.reduce_axis((0, ishape[1]), name="k_o")
+    k_i = te.reduce_axis((0, ishape[3]), name="k_i")
     res = te.compute(
         oshape,
         lambda b_o, c_o, b_i, c_i: te.sum(
-            data[b_o, k_o, b_i, k_i].astype(out_dtype) *
-            weight[c_o, k_o, c_i, k_i].astype(out_dtype),
-            axis=[k_o, k_i]),
-        name="res", tag="dense_pack")
+            data[b_o, k_o, b_i, k_i].astype(out_dtype)
+            * weight[c_o, k_o, c_i, k_i].astype(out_dtype),
+            axis=[k_o, k_i],
+        ),
+        name="res",
+        tag="dense_pack",
+    )
 
-    cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
-                 ishape[1] * ishape[3])
+    cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * ishape[1] * ishape[3])
 
     return res
 
+
 @autotvm.register_topi_schedule("dense_packed.vta")
 def schedule_dense_packed(cfg, outs):
     """Packed dense schedule."""
@@ -100,10 +105,10 @@ def schedule_dense_packed(cfg, outs):
     ##### space definition begin #####
     b, c_o, _, _ = s[dense_stage].op.axis
     c_i, _ = s[dense_stage].op.reduce_axis
-    cfg.define_split('tile_b', b, num_outputs=2)
-    cfg.define_split('tile_ci', c_i, num_outputs=2)
-    cfg.define_split('tile_co', c_o, num_outputs=2)
-    cfg.define_knob('oc_nthread', [1, 2])
+    cfg.define_split("tile_b", b, num_outputs=2)
+    cfg.define_split("tile_ci", c_i, num_outputs=2)
+    cfg.define_split("tile_co", c_o, num_outputs=2)
+    cfg.define_knob("oc_nthread", [1, 2])
     ###### space definition end ######
 
     data, weight = dense_stage.op.input_tensors
@@ -117,8 +122,7 @@ def schedule_dense_packed(cfg, outs):
     # cache read input
     cache_read_ewise = []
     for consumer, tensor in ewise_inputs:
-        cache_read_ewise.append(
-            s.cache_read(tensor, env.acc_scope, [consumer]))
+        cache_read_ewise.append(s.cache_read(tensor, env.acc_scope, [consumer]))
 
     # set ewise scope
     for op in ewise_ops:
@@ -130,8 +134,8 @@ def schedule_dense_packed(cfg, outs):
 
     # apply tiling for SRAM reuse
     x_b, x_c, _, _ = s[output].op.axis
-    x_bo, x_bi = cfg['tile_b'].apply(s, output, x_b)
-    x_co, x_ci = cfg['tile_co'].apply(s, output, x_c)
+    x_bo, x_bi = cfg["tile_b"].apply(s, output, x_b)
+    x_co, x_ci = cfg["tile_co"].apply(s, output, x_c)
     s[output].reorder(x_bo, x_co, x_bi, x_ci)
     store_pt = x_co
 
@@ -145,8 +149,8 @@ def schedule_dense_packed(cfg, outs):
         s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
 
     # virtual threading along output channel axes
-    if cfg['oc_nthread'].val > 1:
-        _, v_t = s[output].split(x_co, factor=cfg['oc_nthread'].val)
+    if cfg["oc_nthread"].val > 1:
+        _, v_t = s[output].split(x_co, factor=cfg["oc_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
@@ -154,7 +158,7 @@ def schedule_dense_packed(cfg, outs):
     k_o, _ = s[dense_stage].op.reduce_axis
     s[dense_stage].reorder(x_bo, k_o, x_co)
 
-    k_o, _ = cfg['tile_ci'].apply(s, dense_stage, k_o)
+    k_o, _ = cfg["tile_ci"].apply(s, dense_stage, k_o)
     s[cdata].compute_at(s[dense_stage], k_o)
     s[cweight].compute_at(s[dense_stage], k_o)
 
index 36768c3..b2661b3 100644 (file)
@@ -25,15 +25,9 @@ from tvm import topi
 
 from ..environment import get_env
 
+
 @autotvm.register_topi_compute("group_conv2d_packed.vta")
-def group_conv2d_packed(cfg,
-                        data,
-                        kernel,
-                        strides,
-                        padding,
-                        dilation,
-                        group,
-                        out_dtype):
+def group_conv2d_packed(cfg, data, kernel, strides, padding, dilation, group, out_dtype):
     """ Packed group conv2d nchw function."""
     assert dilation == (1, 1)
 
@@ -55,30 +49,44 @@ def group_conv2d_packed(cfg,
     kshape = topi.util.get_const_tuple(kernel.shape)
     assert group * kshape[1] == ishape[1]
     assert kshape[0] % group == 0
-    d_i = te.reduce_axis((0, kshape[2]), name='d_i')
-    d_j = te.reduce_axis((0, kshape[3]), name='d_j')
-    k_o = te.reduce_axis((0, kshape[1]), name='k_o')
-    k_i = te.reduce_axis((0, kshape[-1]), name='k_i')
+    d_i = te.reduce_axis((0, kshape[2]), name="d_i")
+    d_j = te.reduce_axis((0, kshape[3]), name="d_j")
+    k_o = te.reduce_axis((0, kshape[1]), name="k_o")
+    k_i = te.reduce_axis((0, kshape[-1]), name="k_i")
     hstride, wstride = strides
     out = te.compute(
         oshape,
         lambda b_o, c_o, i, j, b_i, c_i: te.sum(
-            pad_data[b_o, c_o // (kshape[0] // group) * kshape[1] + k_o, i * hstride + d_i,
-                     j * wstride + d_j, b_i, k_i].astype(out_dtype) *
-            kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
-            axis=[k_o, d_i, d_j, k_i]),
-        name="res", tag="packed_group_conv2d")
-
-    cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
-                 kshape[2] * kshape[3] * ishape[1] * kshape[-1])
+            pad_data[
+                b_o,
+                c_o // (kshape[0] // group) * kshape[1] + k_o,
+                i * hstride + d_i,
+                j * wstride + d_j,
+                b_i,
+                k_i,
+            ].astype(out_dtype)
+            * kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
+            axis=[k_o, d_i, d_j, k_i],
+        ),
+        name="res",
+        tag="packed_group_conv2d",
+    )
+
+    cfg.add_flop(
+        2
+        * np.prod(topi.util.get_const_tuple(oshape))
+        * kshape[2]
+        * kshape[3]
+        * ishape[1]
+        * kshape[-1]
+    )
 
     return out
 
 
 @autotvm.register_topi_schedule("group_conv2d_packed.vta")
 def schedule_group_conv2d_packed(cfg, outs):
-    """ Schedule the packed conv2d.
-    """
+    """Schedule the packed conv2d."""
     assert len(outs) == 1
     output = outs[0]
     const_ops = []
@@ -112,13 +120,13 @@ def schedule_group_conv2d_packed(cfg, outs):
     ##### space definition begin #####
     b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis
     c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
-    cfg.define_split('tile_b', b, num_outputs=2)
-    cfg.define_split('tile_h', x_i, num_outputs=2)
-    cfg.define_split('tile_w', x_j, num_outputs=2)
-    cfg.define_split('tile_ci', c_i, num_outputs=2)
-    cfg.define_split('tile_co', c_o, num_outputs=2)
-    cfg.define_knob('oc_nthread', [1, 2])
-    cfg.define_knob('h_nthread', [1, 2])
+    cfg.define_split("tile_b", b, num_outputs=2)
+    cfg.define_split("tile_h", x_i, num_outputs=2)
+    cfg.define_split("tile_w", x_j, num_outputs=2)
+    cfg.define_split("tile_ci", c_i, num_outputs=2)
+    cfg.define_split("tile_co", c_o, num_outputs=2)
+    cfg.define_knob("oc_nthread", [1, 2])
+    cfg.define_knob("h_nthread", [1, 2])
     ###### space definition end ######
 
     data, kernel = conv2d_stage.op.input_tensors
@@ -143,8 +151,7 @@ def schedule_group_conv2d_packed(cfg, outs):
     # cache read input
     cache_read_ewise = []
     for consumer, tensor in ewise_inputs:
-        cache_read_ewise.append(
-            s.cache_read(tensor, env.acc_scope, [consumer]))
+        cache_read_ewise.append(s.cache_read(tensor, env.acc_scope, [consumer]))
 
     # set ewise scope
     for op in ewise_ops:
@@ -156,9 +163,9 @@ def schedule_group_conv2d_packed(cfg, outs):
 
     # tile
     x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
-    x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
-    x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
-    x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
+    x_co0, x_co1 = cfg["tile_co"].apply(s, output, x_co)
+    x_i0, x_i1 = cfg["tile_h"].apply(s, output, x_i)
+    x_j0, x_j1 = cfg["tile_w"].apply(s, output, x_j)
     s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
     store_pt = x_j0
 
@@ -172,14 +179,14 @@ def schedule_group_conv2d_packed(cfg, outs):
         s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
 
     # virtual threading along output channel axes
-    if cfg['oc_nthread'].val > 1:
-        _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
+    if cfg["oc_nthread"].val > 1:
+        _, v_t = s[output].split(x_co0, factor=cfg["oc_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
     # virtual threading along spatial rows
-    if cfg['h_nthread'].val > 1:
-        _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
+    if cfg["h_nthread"].val > 1:
+        _, v_t = s[output].split(x_i0, factor=cfg["h_nthread"].val)
         s[output].reorder(v_t, x_bo)
         s[output].bind(v_t, te.thread_axis("cthread"))
 
@@ -187,7 +194,7 @@ def schedule_group_conv2d_packed(cfg, outs):
     k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
     s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
 
-    k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
+    k_o, _ = cfg["tile_ci"].apply(s, conv2d_stage, k_o)
     s[cdata].compute_at(s[conv2d_stage], k_o)
     s[ckernel].compute_at(s[conv2d_stage], k_o)
 
index eb051f5..ed64ba3 100644 (file)
@@ -34,8 +34,9 @@ def _match_pragma(stmt, key):
     key : str
         The pragma key
     """
-    return ((stmt.attr_key == "pragma_" + key) or
-            (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
+    return (stmt.attr_key == "pragma_" + key) or (
+        stmt.attr_key == "pragma_scope" and stmt.value.value == key
+    )
 
 
 def FoldUopLoop():
@@ -51,6 +52,7 @@ def FoldUopLoop():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _fold_outermost_loop(body):
         stmt = body
         if not isinstance(stmt, tvm.tir.For):
@@ -68,8 +70,7 @@ def FoldUopLoop():
                 args = []
                 args += op.args[:base_args]
                 for i in range(3):
-                    m = tvm.arith.detect_linear_equation(
-                        op.args[i + base_args], [loop_var])
+                    m = tvm.arith.detect_linear_equation(op.args[i + base_args], [loop_var])
                     if not m:
                         fail[0] = True
                         return op
@@ -81,32 +82,34 @@ def FoldUopLoop():
                     else:
                         gemm_offsets[i] = m[0]
                         args.append(m[1])
-                args += op.args[base_args+3:]
+                args += op.args[base_args + 3 :]
                 return tvm.tir.call_intrin("int32", builtin_uop_push, *args)
             if op.op.name not in ("tir.vta.command_handle", "tir.tvm_thread_context"):
                 raise RuntimeError("unexpected op %s" % op)
             return op
 
-        ret = tvm.tir.stmt_functor.ir_transform(
-            stmt.body, None, _post_order, ["tir.Call"])
+        ret = tvm.tir.stmt_functor.ir_transform(stmt.body, None, _post_order, ["tir.Call"])
 
         if not fail[0] and all(x is not None for x in gemm_offsets):
+
             def _visit(op):
                 if op.same_as(loop_var):
                     fail[0] = True
+
             tvm.tir.stmt_functor.post_order_visit(ret, _visit)
             if not fail[0]:
-                begin = tvm.tir.call_extern(
-                    "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
+                begin = tvm.tir.call_extern("int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
                 end = tvm.tir.call_extern("int32", "VTAUopLoopEnd")
                 return [begin, ret, end]
         raise ValueError("Failed to fold the GEMM instructions..")
 
     def _do_fold(stmt):
         env = get_env()
-        if (stmt.attr_key == "coproc_uop_scope" and
-                isinstance(stmt.value, tvm.tir.StringImm) and
-                stmt.value.value == env.dev.vta_push_uop.value):
+        if (
+            stmt.attr_key == "coproc_uop_scope"
+            and isinstance(stmt.value, tvm.tir.StringImm)
+            and stmt.value.value == env.dev.vta_push_uop.value
+        ):
             body = stmt.body
             begins = []
             ends = []
@@ -127,16 +130,15 @@ def FoldUopLoop():
                 return stmt
             ends = list(reversed(ends))
             body = tvm.tir.stmt_seq(*(begins + [body] + ends))
-            return tvm.tir.AttrStmt(
-                stmt.node, stmt.attr_key, stmt.value, body)
+            return tvm.tir.AttrStmt(stmt.node, stmt.attr_key, stmt.value, body)
         return None
 
     def _ftransform(f, mod, ctx):
-        return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, _do_fold, None, ["tir.AttrStmt"]))
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(f.body, _do_fold, None, ["tir.AttrStmt"])
+        )
 
-    return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
+    return tvm.tir.transform.prim_func_pass(_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
 
 
 def CPUAccessRewrite():
@@ -152,9 +154,11 @@ def CPUAccessRewrite():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(f, mod, ctx):
         rw_info = {}
         env = get_env()
+
         def _post_order(op):
             if isinstance(op, tvm.tir.Allocate):
                 buffer_var = op.buffer_var
@@ -162,44 +166,47 @@ def CPUAccessRewrite():
                     return None
                 new_var = rw_info[buffer_var]
                 let_stmt = tvm.tir.LetStmt(
-                    new_var, tvm.tir.call_extern(
-                        "handle", "VTABufferCPUPtr",
-                        env.dev.command_handle,
-                        buffer_var), op.body)
-                alloc = tvm.tir.Allocate(
-                    buffer_var, op.dtype, op.extents,
-                    op.condition, let_stmt)
+                    new_var,
+                    tvm.tir.call_extern(
+                        "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var
+                    ),
+                    op.body,
+                )
+                alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt)
                 del rw_info[buffer_var]
                 return alloc
             if isinstance(op, tvm.tir.Load):
                 buffer_var = op.buffer_var
                 if not buffer_var in rw_info:
-                    rw_info[buffer_var] = te.var(
-                        buffer_var.name + "_ptr", "handle")
+                    rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle")
                 new_var = rw_info[buffer_var]
                 return tvm.tir.Load(op.dtype, new_var, op.index)
             if isinstance(op, tvm.tir.Store):
                 buffer_var = op.buffer_var
                 if not buffer_var in rw_info:
-                    rw_info[buffer_var] = te.var(
-                        buffer_var.name + "_ptr", "handle")
+                    rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle")
                 new_var = rw_info[buffer_var]
                 return tvm.tir.Store(new_var, op.value, op.index)
             raise RuntimeError("not reached")
 
         stmt_in = f.body
         stmt = tvm.tir.stmt_functor.ir_transform(
-            stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"])
+            stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"]
+        )
 
         for buffer_var, new_var in rw_info.items():
             stmt = tvm.tir.LetStmt(
-                new_var, tvm.tir.call_extern(
-                    "handle", "VTABufferCPUPtr",
-                    env.dev.command_handle,
-                    buffer_var), stmt)
+                new_var,
+                tvm.tir.call_extern(
+                    "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var
+                ),
+                stmt,
+            )
         return f.with_body(stmt)
+
     return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
+        _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite"
+    )
 
 
 def LiftAllocToScopeBegin():
@@ -210,23 +217,22 @@ def LiftAllocToScopeBegin():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(f, mod, ctx):
         lift_stmt = [[]]
+
         def _merge_block(slist, body):
             for op in slist:
                 if op.body == body:
                     body = op
                 elif isinstance(op, tvm.tir.Allocate):
-                    body = tvm.tir.Allocate(
-                        op.buffer_var, op.dtype,
-                        op.extents, op.condition, body)
+                    body = tvm.tir.Allocate(op.buffer_var, op.dtype, op.extents, op.condition, body)
                 elif isinstance(op, tvm.tir.AttrStmt):
-                    body = tvm.tir.AttrStmt(
-                        op.node, op.attr_key, op.value, body)
+                    body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value, body)
                 elif isinstance(op, tvm.tir.For):
                     body = tvm.tir.For(
-                        op.loop_var, op.min, op.extent, op.for_type,
-                        op.device_api, body)
+                        op.loop_var, op.min, op.extent, op.for_type, op.device_api, body
+                    )
                 else:
                     raise RuntimeError("unexpected op")
             del slist[:]
@@ -253,14 +259,17 @@ def LiftAllocToScopeBegin():
             if isinstance(op, tvm.tir.For):
                 return _merge_block(lift_stmt.pop() + [op], op.body)
             raise RuntimeError("not reached")
+
         stmt_in = f.body
         stmt = tvm.tir.stmt_functor.ir_transform(
-            stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"])
+            stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"]
+        )
         assert len(lift_stmt) == 1
         return f.with_body(_merge_block(lift_stmt[0], stmt))
 
     return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
+        _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin"
+    )
 
 
 def InjectSkipCopy():
@@ -271,17 +280,18 @@ def InjectSkipCopy():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _do_fold(stmt):
         if _match_pragma(stmt, "skip_dma_copy"):
             return tvm.tir.Evaluate(0)
         return None
 
     def _ftransform(f, mod, ctx):
-        return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, _do_fold, None, ["tir.AttrStmt"]))
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(f.body, _do_fold, None, ["tir.AttrStmt"])
+        )
 
-    return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
+    return tvm.tir.transform.prim_func_pass(_ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
 
 
 def InjectCoProcSync():
@@ -292,27 +302,33 @@ def InjectCoProcSync():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(f, *_):
         success = [False]
+
         def _do_fold(stmt):
             if _match_pragma(stmt, "coproc_sync"):
                 success[0] = True
-                sync = tvm.tir.Call(
-                    "int32", "vta.coproc_sync", [])
+                sync = tvm.tir.Call("int32", "vta.coproc_sync", [])
                 return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
             if _match_pragma(stmt, "trim_loop"):
                 op = stmt.body
                 assert isinstance(op, tvm.tir.For)
-                return tvm.tir.For(
-                    op.loop_var, op.min, 2, op.for_type,
-                    op.device_api, op.body)
+                return tvm.tir.For(op.loop_var, op.min, 2, op.for_type, op.device_api, op.body)
             return None
-        return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, None, _do_fold, ["tir.AttrStmt"]))
+
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(f.body, None, _do_fold, ["tir.AttrStmt"])
+        )
+
     return tvm.transform.Sequential(
-        [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
-         tvm.tir.transform.CoProcSync()],
-        opt_level=0, name="tir.vta.InjectCoProcSync")
+        [
+            tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
+            tvm.tir.transform.CoProcSync(),
+        ],
+        opt_level=0,
+        name="tir.vta.InjectCoProcSync",
+    )
 
 
 def InjectDMAIntrin():
@@ -332,7 +348,8 @@ def InjectDMAIntrin():
         for i in reversed(range(ndim)):
             if not util.equal_const_int(size - buf.strides[i], 0):
                 raise RuntimeError(
-                    "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
+                    "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)
+                )
             size = size * buf.shape[i]
 
     def _fold_buffer_dim(buf, scope, elem_block):
@@ -347,8 +364,9 @@ def InjectDMAIntrin():
                 base = i + 1
                 break
         if base == 0:
-            raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
-                scope, elem_block, buf.shape))
+            raise RuntimeError(
+                "scope %s need to have block=%d, shape=%s" % (scope, elem_block, buf.shape)
+            )
         shape = [elem_block]
         strides = [1]
 
@@ -363,8 +381,9 @@ def InjectDMAIntrin():
             next_base = base
             if not util.equal_const_int(idxm(x_stride, elem_block), 0):
                 raise RuntimeError(
-                    "scope %s need to have block=%d, shape=%s, strides=%s" % (
-                        scope, elem_block, buf.shape, buf.strides))
+                    "scope %s need to have block=%d, shape=%s, strides=%s"
+                    % (scope, elem_block, buf.shape, buf.strides)
+                )
             for i in range(base, ndim + 1):
                 k = ndim - i
                 if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
@@ -383,8 +402,7 @@ def InjectDMAIntrin():
     def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
         elem_block = elem_bytes * 8 // elem_width
         if buf.dtype != dtype:
-            raise RuntimeError("Expect buffer type to be %s instead of %s" %
-                               (dtype, buf.dtype))
+            raise RuntimeError("Expect buffer type to be %s instead of %s" % (dtype, buf.dtype))
         shape, strides = buf.shape, buf.strides
         if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
             raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
@@ -397,8 +415,12 @@ def InjectDMAIntrin():
         def raise_error():
             """Internal function to raise error """
             raise RuntimeError(
-                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
-                 " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
+                (
+                    "Scope[%s]: cannot detect 2d pattern with elem_block=%d:"
+                    + " shape=%s, strides=%s"
+                )
+                % (scope, elem_block, buf.shape, buf.strides)
+            )
 
         ndim = len(shape)
 
@@ -463,7 +485,6 @@ def InjectDMAIntrin():
 
         raise_error()
 
-
     def _inject_copy(src, dst, pad_before, pad_after, pad_value):
         # FIXME: pad_value is ignored...
         env = get_env()
@@ -482,15 +503,24 @@ def InjectDMAIntrin():
                 raise RuntimeError("Do not support copy %s->dram" % (src.scope))
             _check_compact(src)
             x_size, y_size, x_stride, offset = _get_2d_pattern(
-                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
+                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
+            )
             irb = tvm.tir.ir_builder.create()
-            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                           env.dev.get_task_qid(task_qid))
-            irb.emit(tvm.tir.call_extern(
-                "int32", "VTAStoreBuffer2D",
-                env.dev.command_handle,
-                src.access_ptr("r", "int32"),
-                mem_type, dst.data, offset, x_size, y_size, x_stride))
+            irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
+            irb.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    "VTAStoreBuffer2D",
+                    env.dev.command_handle,
+                    src.access_ptr("r", "int32"),
+                    mem_type,
+                    dst.data,
+                    offset,
+                    x_size,
+                    y_size,
+                    x_stride,
+                )
+            )
             return irb.get()
         elif src.scope == "global":
             if dst.scope == env.acc_scope:
@@ -550,20 +580,30 @@ def InjectDMAIntrin():
 
             _check_compact(dst)
             x_size, y_size, x_stride, offset = _get_2d_pattern(
-                src, elem_width, elem_bytes, data_type,
-                dst.scope, allow_fold=allow_fold)
+                src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
+            )
 
             irb = tvm.tir.ir_builder.create()
-            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                           env.dev.get_task_qid(task_qid))
-
-            irb.emit(tvm.tir.call_extern(
-                "int32", "VTALoadBuffer2D",
-                env.dev.command_handle,
-                src.data, offset, x_size, y_size, x_stride,
-                x_pad_before, y_pad_before,
-                x_pad_after, y_pad_after,
-                dst.access_ptr("r", "int32"), mem_type))
+            irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
+
+            irb.emit(
+                tvm.tir.call_extern(
+                    "int32",
+                    "VTALoadBuffer2D",
+                    env.dev.command_handle,
+                    src.data,
+                    offset,
+                    x_size,
+                    y_size,
+                    x_stride,
+                    x_pad_before,
+                    y_pad_before,
+                    x_pad_after,
+                    y_pad_after,
+                    dst.access_ptr("r", "int32"),
+                    mem_type,
+                )
+            )
             return irb.get()
 
         else:
@@ -586,28 +626,43 @@ def _get_gemm_intrin_buffer():
     assert out_lanes == env.BATCH * env.BLOCK_OUT
     out_shape = (env.BATCH, env.BLOCK_OUT)
     assert out_shape[0] * out_shape[1] == out_lanes
-    wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
-                         dtype="int%d" % env.WGT_WIDTH,
-                         name=env.wgt_scope)
-    inp = te.placeholder((inp_shape[0], inp_shape[1]),
-                         dtype="int%d" % env.INP_WIDTH,
-                         name=env.inp_scope)
+    wgt = te.placeholder(
+        (wgt_shape[0], wgt_shape[1]), dtype="int%d" % env.WGT_WIDTH, name=env.wgt_scope
+    )
+    inp = te.placeholder(
+        (inp_shape[0], inp_shape[1]), dtype="int%d" % env.INP_WIDTH, name=env.inp_scope
+    )
     k = te.reduce_axis((0, wgt_shape[1]), name="k")
     out_dtype = "int%d" % env.ACC_WIDTH
-    out = te.compute((out_shape[0], out_shape[1]),
-                     lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
-                                         wgt[j, k].astype(out_dtype),
-                                         axis=[k]),
-                     name="out")
+    out = te.compute(
+        (out_shape[0], out_shape[1]),
+        lambda i, j: te.sum(inp[i, k].astype(out_dtype) * wgt[j, k].astype(out_dtype), axis=[k]),
+        name="out",
+    )
     wgt_layout = tvm.tir.decl_buffer(
-        wgt.shape, wgt.dtype, env.wgt_scope,
-        scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
+        wgt.shape,
+        wgt.dtype,
+        env.wgt_scope,
+        scope=env.wgt_scope,
+        offset_factor=wgt_lanes,
+        data_alignment=wgt_lanes,
+    )
     inp_layout = tvm.tir.decl_buffer(
-        inp.shape, inp.dtype, env.inp_scope,
-        scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
+        inp.shape,
+        inp.dtype,
+        env.inp_scope,
+        scope=env.inp_scope,
+        offset_factor=inp_lanes,
+        data_alignment=inp_lanes,
+    )
     out_layout = tvm.tir.decl_buffer(
-        out.shape, out.dtype, env.acc_scope,
-        scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
+        out.shape,
+        out.dtype,
+        env.acc_scope,
+        scope=env.acc_scope,
+        offset_factor=out_lanes,
+        data_alignment=out_lanes,
+    )
 
     return wgt_layout, inp_layout, out_layout
 
@@ -620,6 +675,7 @@ def InjectConv2DTransposeSkip():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(func, mod, ctx):
         env = get_env()
         dwgt, dinp, dout = _get_gemm_intrin_buffer()
@@ -644,11 +700,20 @@ def InjectConv2DTransposeSkip():
                     dev = env.dev
                     irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
                     irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
-                    irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push",
-                                                 0, 1,
-                                                 dout.access_ptr("rw", "int32"),
-                                                 0, 0,
-                                                 0, 0, 0))
+                    irb.emit(
+                        tvm.tir.call_intrin(
+                            "int32",
+                            "tir.vta.uop_push",
+                            0,
+                            1,
+                            dout.access_ptr("rw", "int32"),
+                            0,
+                            0,
+                            0,
+                            0,
+                            0,
+                        )
+                    )
                     inner = irb.get()
                     # TODO(@tmoreau89): This is only a temporary fix, please take a look.
                     body = op.body.body
@@ -658,8 +723,11 @@ def InjectConv2DTransposeSkip():
                     res_buffer = body.buffer
                     tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
                     inner = tvm.tir.AttrStmt(
-                        [dout, res_buffer], 'buffer_bind_scope',
-                        tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
+                        [dout, res_buffer],
+                        "buffer_bind_scope",
+                        tvm.tir.call_intrin("handle", "tir.tvm_tuple", *tpl),
+                        inner,
+                    )
                     return inner
                 else:
                     conv_call, data_call, kernel_call = calls[-3:]
@@ -670,48 +738,79 @@ def InjectConv2DTransposeSkip():
                     if selects:
                         condition = selects[0].condition
                     else:
-                        condition = tvm.tir.const(1, 'int')
+                        condition = tvm.tir.const(1, "int")
 
                     # create inner most block
                     irb = tvm.tir.ir_builder.create()
                     with irb.if_scope(condition):
                         dev = env.dev
                         irb.scope_attr(
-                            dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+                            dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)
+                        )
                         irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
-                        irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push",
-                                                     0, 0,
-                                                     dout.access_ptr("rw", "int32"),
-                                                     dinp.access_ptr("r", "int32"),
-                                                     dwgt.access_ptr("r", "int32"),
-                                                     0, 0, 0))
+                        irb.emit(
+                            tvm.tir.call_intrin(
+                                "int32",
+                                "tir.vta.uop_push",
+                                0,
+                                0,
+                                dout.access_ptr("rw", "int32"),
+                                dinp.access_ptr("r", "int32"),
+                                dwgt.access_ptr("r", "int32"),
+                                0,
+                                0,
+                                0,
+                            )
+                        )
                     inner = irb.get()
 
                     args = conv_call.indices
-                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                           1, 0, 1, 0, env.BLOCK_OUT)
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
                     inner = tvm.tir.AttrStmt(
-                        [dout, res_tensor], 'buffer_bind_scope',
-                        tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
+                        [dout, res_tensor],
+                        "buffer_bind_scope",
+                        tvm.tir.call_intrin("handle", "tir.tvm_tuple", *tpl),
+                        inner,
+                    )
                     args = kernel_call.indices
-                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                           1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
+                    tpl = (
+                        args[0],
+                        1,
+                        args[1],
+                        1,
+                        args[2],
+                        1,
+                        args[3],
+                        1,
+                        0,
+                        env.BLOCK_OUT,
+                        0,
+                        env.BLOCK_IN,
+                    )
                     inner = tvm.tir.AttrStmt(
-                        [dwgt, kernel_tensor], 'buffer_bind_scope',
-                        tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
+                        [dwgt, kernel_tensor],
+                        "buffer_bind_scope",
+                        tvm.tir.call_intrin("handle", "tir.tvm_tuple", *tpl),
+                        inner,
+                    )
                     args = data_call.indices
-                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                           1, 0, 1, 0, env.BLOCK_IN)
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN)
                     inner = tvm.tir.AttrStmt(
-                        [dinp, pad_data_tensor], 'buffer_bind_scope',
-                        tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
+                        [dinp, pad_data_tensor],
+                        "buffer_bind_scope",
+                        tvm.tir.call_intrin("handle", "tir.tvm_tuple", *tpl),
+                        inner,
+                    )
                     return inner
             return None
 
-        return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, _do_fold, None, ["tir.AttrStmt"]))
+        return func.with_body(
+            tvm.tir.stmt_functor.ir_transform(func.body, _do_fold, None, ["tir.AttrStmt"])
+        )
+
     return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
+        _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip"
+    )
 
 
 def AnnotateALUCoProcScope():
@@ -722,25 +821,32 @@ def AnnotateALUCoProcScope():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(func, mod, ctx):
         env = get_env()
+
         def _do_fold(stmt):
             if _match_pragma(stmt, "alu"):
                 irb = tvm.tir.ir_builder.create()
-                irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                               env.dev.get_task_qid(env.dev.QID_COMPUTE))
-                irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
-                               tvm.tir.StringImm("VTAPushALUOp"))
+                irb.scope_attr(
+                    env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(env.dev.QID_COMPUTE)
+                )
+                irb.scope_attr(
+                    env.dev.vta_axis, "coproc_uop_scope", tvm.tir.StringImm("VTAPushALUOp")
+                )
                 irb.emit(stmt)
                 return irb.get()
             if _match_pragma(stmt, "skip_alu"):
                 return tvm.tir.Evaluate(0)
             return stmt
 
-        return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, None, _do_fold, ["tir.AttrStmt"]))
+        return func.with_body(
+            tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"])
+        )
+
     return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
+        _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope"
+    )
 
 
 def InjectALUIntrin():
@@ -751,6 +857,7 @@ def InjectALUIntrin():
     fpass : tvm.transform.Pass
         The pass
     """
+
     def _ftransform(func, mod, ctx):
         env = get_env()
         idxm = tvm.tir.indexmod
@@ -834,25 +941,27 @@ def InjectALUIntrin():
                     lhs = loop_body.value.a
                     rhs = loop_body.value.b
                 elif isinstance(loop_body.value, tvm.tir.Call):
-                    if loop_body.value.op.name == 'tir.shift_left':
+                    if loop_body.value.op.name == "tir.shift_left":
                         alu_opcode = env.dev.ALU_OPCODE_SHR
                         lhs = loop_body.value.args[0]
                         rhs = analyzer.simplify(-loop_body.value.args[1])
-                    elif loop_body.value.op.name == 'tir.shift_right':
+                    elif loop_body.value.op.name == "tir.shift_right":
                         alu_opcode = env.dev.ALU_OPCODE_SHR
                         lhs = loop_body.value.args[0]
                         rhs = loop_body.value.args[1]
                     else:
                         raise RuntimeError(
-                            "Function call not recognized %s" % (loop_body.value.name))
+                            "Function call not recognized %s" % (loop_body.value.name)
+                        )
                 elif isinstance(loop_body.value, tvm.tir.Load):
                     alu_opcode = env.dev.ALU_OPCODE_SHR
                     lhs = loop_body.value
                     rhs = tvm.tir.const(0, "int32")
                 else:
                     raise RuntimeError(
-                        "Expression not recognized %s, %s, %s" % (
-                            type(loop_body.value), str(loop_body.value), str(stmt)))
+                        "Expression not recognized %s, %s, %s"
+                        % (type(loop_body.value), str(loop_body.value), str(stmt))
+                    )
 
                 # Derive array index coefficients
                 dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
@@ -900,11 +1009,11 @@ def InjectALUIntrin():
                 assert len(dst_coeff) > 1
                 assert len(extents) != 0
                 assert tvm.ir.structural_equal(
-                    analyzer.simplify(
-                        idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+                    analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
+                )
                 assert tvm.ir.structural_equal(
-                    analyzer.simplify(
-                        idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+                    analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
+                )
                 assert tvm.ir.structural_equal(src_coeff[-2], 1)
                 assert tvm.ir.structural_equal(dst_coeff[-2], 1)
                 if env.BATCH > 1:
@@ -927,10 +1036,8 @@ def InjectALUIntrin():
                     extents = extents[:-2]
                 src_coeff.append(src_offset)
                 dst_coeff.append(dst_offset)
-                src_coeff = [
-                    analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
-                dst_coeff = [
-                    analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
+                src_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
+                dst_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
 
                 # Flatten the outer loops
                 if extents:
@@ -939,25 +1046,35 @@ def InjectALUIntrin():
                 # Insert ALU micro-ops
                 irb = tvm.tir.ir_builder.create()
                 for idx, extent in enumerate(extents):
-                    irb.emit(tvm.tir.call_extern(
-                        "int32", "VTAUopLoopBegin",
-                        extent, dst_coeff[idx], src_coeff[idx], 0))
+                    irb.emit(
+                        tvm.tir.call_extern(
+                            "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0
+                        )
+                    )
                 use_imm = int(use_imm)
-                irb.emit(tvm.tir.call_intrin(
-                    "int32", "tir.vta.uop_push",
-                    1, 0,
-                    dst_coeff[len(dst_coeff)-1],
-                    src_coeff[len(src_coeff)-1],
-                    0,
-                    alu_opcode, use_imm, imm_val))
+                irb.emit(
+                    tvm.tir.call_intrin(
+                        "int32",
+                        "tir.vta.uop_push",
+                        1,
+                        0,
+                        dst_coeff[len(dst_coeff) - 1],
+                        src_coeff[len(src_coeff) - 1],
+                        0,
+                        alu_opcode,
+                        use_imm,
+                        imm_val,
+                    )
+                )
                 for extent in extents:
-                    irb.emit(tvm.tir.call_extern(
-                        "int32", "VTAUopLoopEnd"))
+                    irb.emit(tvm.tir.call_extern("int32", "VTAUopLoopEnd"))
                 return irb.get()
             return stmt
 
-        return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, None, _do_fold, ["tir.AttrStmt"]))
+        return func.with_body(
+            tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"])
+        )
 
     return tvm.tir.transform.prim_func_pass(
-        _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")
+        _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin"
+    )
index 6095d96..2a1331f 100644 (file)
@@ -30,25 +30,39 @@ import vta.testing
 
 env = vta.get_env()
 
-Workload = namedtuple("Conv2DWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+Workload = namedtuple(
+    "Conv2DWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
 
 resnet_wkls = [
     # Workloads of resnet18 on imagenet
     # ('resnet-18.C1',  Workload(env.BATCH, 224, 224, 3,   64,  7, 7, 3, 3, 2, 2)),
-    ('resnet-18.C2',  Workload(env.BATCH,  56,  56, 64,  64,  3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C3',  Workload(env.BATCH,  56,  56, 64,  128, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C4',  Workload(env.BATCH,  56,  56, 64,  128, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C5',  Workload(env.BATCH,  28,  28, 128, 128, 3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C6',  Workload(env.BATCH,  28,  28, 128, 256, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C7',  Workload(env.BATCH,  28,  28, 128, 256, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C8',  Workload(env.BATCH,  14,  14, 256, 256, 3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C9',  Workload(env.BATCH,  14,  14, 256, 512, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C10', Workload(env.BATCH,  14,  14, 256, 512, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C11', Workload(env.BATCH,   7,   7, 512, 512, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C2", Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C3", Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C4", Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C5", Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C6", Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C7", Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C8", Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C9", Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C10", Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C11", Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
 ]
 
+
 @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
 def my_clip(x, a_min, a_max):
     """Unlike topi's current clip, put min and max into two stages."""
@@ -58,10 +72,11 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
+
 def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
-    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
-    kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
-    bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
+    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
+    kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
+    bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
 
     data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
@@ -74,21 +89,23 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
             padding=padding,
             strides=strides,
             dilation=dilation,
-            layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN),
-            out_dtype=env.acc_dtype)
+            layout="NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN),
+            out_dtype=env.acc_dtype,
+        )
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.Target.current().device_name == 'vta':
+    if tvm.target.Target.current().device_name == "vta":
         s = topi.generic.schedule_conv2d_nchw([res])
     else:
         s = te.create_schedule([res.op])
 
     return s, [data, kernel, bias, res]
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
 
     # Logging config (for printing tuning log to the screen)
     logging.basicConfig()
@@ -125,20 +142,26 @@ if __name__ == '__main__':
 
         # Create task
         task = autotvm.task.create(
-                conv2d,
-                args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation),
-                target=tvm.target.vta(),
-                target_host=env.target_host,
-                template_key='direct')
+            conv2d,
+            args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation),
+            target=tvm.target.vta(),
+            target_host=env.target_host,
+            template_key="direct",
+        )
         print(task.config_space)
 
         # Tune
         measure_option = autotvm.measure_option(
-                builder=autotvm.LocalBuilder(),
-                runner=autotvm.RPCRunner(
-                    env.TARGET, host=tracker_host, port=int(tracker_port),
-                    number=5, timeout=60,
-                    check_correctness=True))
+            builder=autotvm.LocalBuilder(),
+            runner=autotvm.RPCRunner(
+                env.TARGET,
+                host=tracker_host,
+                port=int(tracker_port),
+                number=5,
+                timeout=60,
+                check_correctness=True,
+            ),
+        )
 
         # Run Tuner
         tuner = autotvm.tuner.RandomTuner(task)
@@ -147,8 +170,10 @@ if __name__ == '__main__':
             early_stopping=None,
             measure_option=measure_option,
             callbacks=[
-                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
-                    autotvm.callback.log_to_file(tmp_log_file)])
+                autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # Pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_file)
index 551e6f9..ebfe7eb 100644 (file)
@@ -31,19 +31,34 @@ import vta.testing
 # Get batch info from env
 env = vta.get_env()
 
-Workload = namedtuple("Conv2DTransposeWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride',
-                       'o_hpad', 'o_wpad'])
+Workload = namedtuple(
+    "Conv2DTransposeWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+        "o_hpad",
+        "o_wpad",
+    ],
+)
 
 # DCGAN workloads
 dcgan_wkls = [
     # dcgan
-    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
-    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
-    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT1", Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT2", Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT3", Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
 ]
 
+
 @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
 def my_clip(x, a_min, a_max):
     """Unlike topi's current clip, put min and max into two stages."""
@@ -53,9 +68,10 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
+
 def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding):
-    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
-    kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
+    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
+    kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
 
     data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
@@ -67,20 +83,21 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding):
             strides=strides,
             padding=padding,
             out_dtype=env.acc_dtype,
-            output_padding=opadding
+            output_padding=opadding,
         )
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.Target.current().device_name == 'vta':
+    if tvm.target.Target.current().device_name == "vta":
         s = topi.generic.schedule_conv2d_transpose_nchw([res])
     else:
         s = te.create_schedule([res.op])
 
     return s, [data, kernel, res]
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
 
     # Logging config (for printing tuning log to the screen)
     logging.basicConfig()
@@ -117,20 +134,26 @@ if __name__ == '__main__':
 
         # Create task
         task = autotvm.task.create(
-                conv2d_transpose,
-                args=(N, CI, H, W, CO, KH, KW, strides, padding, opadding),
-                target=tvm.target.vta(),
-                target_host=env.target_host,
-                template_key='direct')
+            conv2d_transpose,
+            args=(N, CI, H, W, CO, KH, KW, strides, padding, opadding),
+            target=tvm.target.vta(),
+            target_host=env.target_host,
+            template_key="direct",
+        )
         print(task.config_space)
 
         # Tune
         measure_option = autotvm.measure_option(
-                builder=autotvm.LocalBuilder(),
-                runner=autotvm.RPCRunner(
-                    env.TARGET, host=tracker_host, port=int(tracker_port),
-                    number=5, timeout=60,
-                    check_correctness=True))
+            builder=autotvm.LocalBuilder(),
+            runner=autotvm.RPCRunner(
+                env.TARGET,
+                host=tracker_host,
+                port=int(tracker_port),
+                number=5,
+                timeout=60,
+                check_correctness=True,
+            ),
+        )
 
         # Run Tuner
         tuner = autotvm.tuner.RandomTuner(task)
@@ -139,8 +162,10 @@ if __name__ == '__main__':
             early_stopping=None,
             measure_option=measure_option,
             callbacks=[
-                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
-                    autotvm.callback.log_to_file(tmp_log_file)])
+                autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # Pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_file)
index b1711fa..7e3aec8 100644 (file)
@@ -30,14 +30,14 @@ import vta.testing
 
 env = vta.get_env()
 
-Workload = namedtuple("DenseWorkload",
-                      ['batch', 'in_filter', 'out_filter'])
+Workload = namedtuple("DenseWorkload", ["batch", "in_filter", "out_filter"])
 
 dense_wkls = [
-    ('lstm.dense.1',  Workload(1, 256, 128)),
-    ('lstm.dense.4',  Workload(4, 256, 128)),
+    ("lstm.dense.1", Workload(1, 256, 128)),
+    ("lstm.dense.4", Workload(4, 256, 128)),
 ]
 
+
 @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
 def my_clip(x, a_min, a_max):
     """Unlike topi's current clip, put min and max into two stages."""
@@ -47,27 +47,29 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
+
 def dense(N, CI, CO):
-    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
-    kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
+    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
+    kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
 
     data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
 
     with tvm.target.vta():
-        res = topi.nn.dense(data, kernel, None, 'int32')
+        res = topi.nn.dense(data, kernel, None, "int32")
         res = topi.right_shift(res, 8)
         res = my_clip(res, 0, 127)
         res = topi.cast(res, "int8")
 
-    if tvm.target.Target.current().device_name == 'vta':
+    if tvm.target.Target.current().device_name == "vta":
         s = topi.generic.schedule_dense([res])
     else:
         s = te.create_schedule([res.op])
 
     return s, [data, kernel, res]
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
 
     # Logging config (for printing tuning log to the screen)
     logging.basicConfig()
@@ -96,28 +98,40 @@ if __name__ == '__main__':
         CI = wl.in_filter
         CO = wl.out_filter
 
-        task = autotvm.task.create(dense, args=(N, CI, CO),
-                target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
+        task = autotvm.task.create(
+            dense,
+            args=(N, CI, CO),
+            target=tvm.target.vta(),
+            target_host=env.target_host,
+            template_key="direct",
+        )
         print(task.config_space)
 
         # Tune
         measure_option = autotvm.measure_option(
-                builder=autotvm.LocalBuilder(),
-                runner=autotvm.RPCRunner(
-                        env.TARGET, host=tracket_host, port=int(tracket_port),
-                        number=5, timeout=60,
-                        check_correctness=True))
+            builder=autotvm.LocalBuilder(),
+            runner=autotvm.RPCRunner(
+                env.TARGET,
+                host=tracket_host,
+                port=int(tracket_port),
+                number=5,
+                timeout=60,
+                check_correctness=True,
+            ),
+        )
 
         # Run Tuner
         tuner = autotvm.tuner.RandomTuner(task)
         tuner.tune(
-                n_trial=len(task.config_space),
-                early_stopping=None,
-                measure_option=measure_option,
-                callbacks=[
-                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
-                    autotvm.callback.log_to_file(tmp_log_file)])
+            n_trial=len(task.config_space),
+            early_stopping=None,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # Pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_file)
-    os.remove(tmp_log_file)
\ No newline at end of file
+    os.remove(tmp_log_file)
index d8dcc02..bfac499 100644 (file)
@@ -30,23 +30,38 @@ import vta.testing
 
 env = vta.get_env()
 
-Workload = namedtuple("GroupConv2DWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter', 'groups',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+Workload = namedtuple(
+    "GroupConv2DWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "groups",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
 
 # Mobilenet (grouped variant) workloads
 mobilenet_wkls = [
-    ('mobilenet.D1', Workload(env.BATCH, 112, 112,   32,   32,  2, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D2', Workload(env.BATCH, 112, 112,   64,   64,  4, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D3', Workload(env.BATCH,  56,  56,  128,  128,  8, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D4', Workload(env.BATCH,  56,  56,  128,  128,  8, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D5', Workload(env.BATCH,  28,  28,  256,  256, 16, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D6', Workload(env.BATCH,  28,  28,  256,  256, 16, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D7', Workload(env.BATCH,  14,  14,  512,  512, 32, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D8', Workload(env.BATCH,  14,  14,  512,  512, 32, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D9', Workload(env.BATCH,   7,  7,  1024, 1024, 64, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D1", Workload(env.BATCH, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D2", Workload(env.BATCH, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D3", Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D4", Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D5", Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D6", Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D7", Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D8", Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D9", Workload(env.BATCH, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)),
 ]
 
+
 @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
 def my_clip(x, a_min, a_max):
     """Unlike topi's current clip, put min and max into two stages."""
@@ -56,12 +71,13 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
+
 def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):
 
     CI_G = CI // groups
-    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
-    kernel_shape = (CO//env.BLOCK_OUT, CI_G//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
-    bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
+    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
+    kernel_shape = (CO // env.BLOCK_OUT, CI_G // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
+    bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
 
     data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
@@ -69,26 +85,22 @@ def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):
 
     with tvm.target.vta():
         res = topi.nn.group_conv2d_nchw(
-            data,
-            kernel,
-            strides,
-            padding,
-            dilation,
-            groups,
-            env.acc_dtype)
+            data, kernel, strides, padding, dilation, groups, env.acc_dtype
+        )
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.Target.current().device_name == 'vta':
+    if tvm.target.Target.current().device_name == "vta":
         s = topi.generic.schedule_group_conv2d_nchw([res])
     else:
         s = te.create_schedule([res.op])
 
     return s, [data, kernel, bias, res]
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
 
     # Logging config (for printing tuning log to the screen)
     logging.basicConfig()
@@ -125,20 +137,26 @@ if __name__ == '__main__':
 
         # Create task
         task = autotvm.task.create(
-                group_conv2d,
-                args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, groups),
-                target=tvm.target.vta(),
-                target_host=env.target_host,
-                template_key='direct')
+            group_conv2d,
+            args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, groups),
+            target=tvm.target.vta(),
+            target_host=env.target_host,
+            template_key="direct",
+        )
         print(task.config_space)
 
         # Tune
         measure_option = autotvm.measure_option(
-                builder=autotvm.LocalBuilder(),
-                runner=autotvm.RPCRunner(
-                    env.TARGET, host=tracker_host, port=int(tracker_port),
-                    number=5, timeout=60,
-                    check_correctness=True))
+            builder=autotvm.LocalBuilder(),
+            runner=autotvm.RPCRunner(
+                env.TARGET,
+                host=tracker_host,
+                port=int(tracker_port),
+                number=5,
+                timeout=60,
+                check_correctness=True,
+            ),
+        )
 
         # Run Tuner
         tuner = autotvm.tuner.RandomTuner(task)
@@ -147,8 +165,10 @@ if __name__ == '__main__':
             early_stopping=None,
             measure_option=measure_option,
             callbacks=[
-                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
-                    autotvm.callback.log_to_file(tmp_log_file)])
+                autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # Pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_file)
index 2d15335..6d64904 100644 (file)
@@ -35,25 +35,42 @@ from vta.testing import simulator
 from vta.top import graph_pack
 from tvm.autotvm.task import extract_from_program
 
+
 def parse_arguments():
 
-    parser = argparse.ArgumentParser(description='Train a model for image classification.')
-    parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'],
-                        help='Input model name.')
-    parser.add_argument('--start-name', type=str, default='nn.max_pool2d',
-                        help='The name of the node where packing starts')
-    parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d',
-                        help='The name of the node where packing stops')
-    parser.add_argument('--debug-profile', action='store_true',
-                        help='Show layer-wise time cost profiling results')
-    parser.add_argument('--device', default='vta',  choices=['vta', 'arm_cpu'],
-                        help='Select device target')
-    parser.add_argument('--measurements', type=int, default=1,
-                        help='Number of measurements during AutoTVM search')
-    parser.add_argument('--tuner', type=str, default="random",
-                        help='AutoTVM search strategy')
-    parser.add_argument('--log-filename', type=str, default="resnet-18.log",
-                        help='AutoTVM log file name')
+    parser = argparse.ArgumentParser(description="Train a model for image classification.")
+    parser.add_argument(
+        "--model",
+        type=str,
+        default="resnet18_v1",
+        choices=["resnet18_v1"],
+        help="Input model name.",
+    )
+    parser.add_argument(
+        "--start-name",
+        type=str,
+        default="nn.max_pool2d",
+        help="The name of the node where packing starts",
+    )
+    parser.add_argument(
+        "--stop-name",
+        type=str,
+        default="nn.global_avg_pool2d",
+        help="The name of the node where packing stops",
+    )
+    parser.add_argument(
+        "--debug-profile", action="store_true", help="Show layer-wise time cost profiling results"
+    )
+    parser.add_argument(
+        "--device", default="vta", choices=["vta", "arm_cpu"], help="Select device target"
+    )
+    parser.add_argument(
+        "--measurements", type=int, default=1, help="Number of measurements during AutoTVM search"
+    )
+    parser.add_argument("--tuner", type=str, default="random", help="AutoTVM search strategy")
+    parser.add_argument(
+        "--log-filename", type=str, default="resnet-18.log", help="AutoTVM log file name"
+    )
 
     return parser.parse_args()
 
@@ -85,7 +102,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.Target.current().device_name == 'vta':
+        if tvm.target.Target.current().device_name == "vta":
             s = topi.generic.schedule_conv2d_nchw([res])
         else:
             s = te.create_schedule([res.op])
@@ -103,7 +120,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.Target.current().device_name == 'vta':
+        if tvm.target.Target.current().device_name == "vta":
             s = topi.generic.schedule_dense([res])
         else:
             s = te.create_schedule([res.op])
@@ -114,7 +131,7 @@ def register_vta_tuning_tasks():
 def compile_network(opt, env, target):
 
     # Populate the shape and data type dictionary
-    dtype_dict = {"data": 'float32'}
+    dtype_dict = {"data": "float32"}
     shape_dict = {"data": (env.BATCH, 3, 224, 224)}
 
     # Get off the shelf gluon model, and convert to relay
@@ -128,8 +145,7 @@ def compile_network(opt, env, target):
     # Perform quantization in Relay
     # Note: We set opt_level to 3 in order to fold batch norm
     with tvm.transform.PassContext(opt_level=3):
-        with relay.quantize.qconfig(global_scale=8.0,
-                                    skip_conv_layers=[0]):
+        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
             relay_prog = relay.quantize.quantize(mod["main"], params=params)
 
     # Perform graph packing and constant folding for VTA target
@@ -141,19 +157,22 @@ def compile_network(opt, env, target):
             env.BLOCK_OUT,
             env.WGT_WIDTH,
             start_name=opt.start_name,
-            stop_name=opt.stop_name)
+            stop_name=opt.stop_name,
+        )
 
     return relay_prog, params
 
 
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True,
-               try_winograd=True):
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+    try_winograd=True,
+):
 
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
@@ -161,16 +180,16 @@ def tune_tasks(tasks,
         os.remove(tmp_log_file)
 
     for i, tsk in enumerate(reversed(tasks)):
-        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=50)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
@@ -181,18 +200,22 @@ def tune_tasks(tasks,
 
         # do tuning
         n_trial_ = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial_,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(n_trial_, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)])
+        tuner_obj.tune(
+            n_trial_,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(n_trial_, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
     os.remove(tmp_log_file)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
 
     opt = parse_arguments()
 
@@ -216,7 +239,9 @@ if __name__ == '__main__':
         reconfig_start = time.time()
 
         # Get remote from fleet node
-        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000)
+        remote = autotvm.measure.request_remote(
+            env.TARGET, tracker_host, int(tracker_port), timeout=10000
+        )
 
         # Reconfigure the JIT runtime and FPGA.
         # You can program the FPGA with your own custom bitstream
@@ -245,24 +270,34 @@ if __name__ == '__main__':
 
     # Perform task extraction on Relay program
     print("Extracting tasks...")
-    tasks = extract_from_program(func=relay_prog,
-                                 params=params,
-                                 ops=(relay.op.get("nn.conv2d"),),
-                                 target=target,
-                                 target_host=env.target_host)
+    tasks = extract_from_program(
+        func=relay_prog,
+        params=params,
+        ops=(relay.op.get("nn.conv2d"),),
+        target=target,
+        target_host=env.target_host,
+    )
 
     # Perform Autotuning
     print("Tuning...")
     tuning_opt = {
-        'log_filename': opt.log_filename,
-        'tuner': opt.tuner,
-        'n_trial': 1e9,
-        'early_stopping': None,
-        'measure_option': autotvm.measure_option(
-                builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
-                runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
-                    number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60,
-                    check_correctness=True))
+        "log_filename": opt.log_filename,
+        "tuner": opt.tuner,
+        "n_trial": 1e9,
+        "early_stopping": None,
+        "measure_option": autotvm.measure_option(
+            builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
+            runner=autotvm.RPCRunner(
+                env.TARGET,
+                tracker_host,
+                tracker_port,
+                number=4,
+                min_repeat_ms=150,
+                repeat=opt.measurements,
+                timeout=60,
+                check_correctness=True,
+            ),
+        ),
     }
     tune_tasks(tasks, **tuning_opt)
 
@@ -274,13 +309,13 @@ if __name__ == '__main__':
         if target.device_name != "vta":
             with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
                 graph, lib, params = relay.build(
-                    relay_prog, target=target,
-                    params=params, target_host=env.target_host)
+                    relay_prog, target=target, params=params, target_host=env.target_host
+                )
         else:
             with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
                 graph, lib, params = relay.build(
-                    relay_prog, target=target,
-                    params=params, target_host=env.target_host)
+                    relay_prog, target=target, params=params, target_host=env.target_host
+                )
 
         # Export library
         temp = util.tempdir()
@@ -295,17 +330,18 @@ if __name__ == '__main__':
             m = graph_runtime.create(graph, lib, ctx)
 
         # Set the network parameters and synthetic input
-        image = tvm.nd.array(
-            (np.random.uniform(size=(1, 3, 224, 224))).astype('float32'))
+        image = tvm.nd.array((np.random.uniform(size=(1, 3, 224, 224))).astype("float32"))
         m.set_input(**params)
-        m.set_input('data', image)
+        m.set_input("data", image)
 
         # Perform inference
         timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements)
         tcost = timer()
         prof_res = np.array(tcost.results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
 
         # Display profile information
         if opt.debug_profile:
index 6dbd457..ea62f3f 100644 (file)
@@ -22,6 +22,7 @@ from vta import get_bitstream_path, download_bitstream, program_fpga, reconfig_r
 host = os.environ.get("VTA_RPC_HOST", "de10nano")
 port = int(os.environ.get("VTA_RPC_PORT", "9091"))
 
+
 def program_rpc_bitstream(path=None):
     """Program the FPGA on the RPC server
 
@@ -33,13 +34,14 @@ def program_rpc_bitstream(path=None):
     remote = rpc.connect(host, port)
     program_fpga(remote, path)
 
+
 def reconfig_rpc_runtime():
-    """Reconfig the RPC server runtime
-    """
+    """Reconfig the RPC server runtime"""
     assert tvm.runtime.enabled("rpc")
     remote = rpc.connect(host, port)
     reconfig_runtime(remote)
 
+
 bitstream = sys.argv[1] if len(sys.argv) == 2 else None
 program_rpc_bitstream(bitstream)
 reconfig_rpc_runtime()
index e023c20..f37b418 100644 (file)
@@ -24,58 +24,42 @@ from vta.testing import simulator
 
 def test_gemm():
     def run_gemm_packed(env, remote, batch_size, channel, block):
-        data_shape = (batch_size // env.BATCH,
-                      channel // env.BLOCK_IN,
-                      env.BATCH,
-                      env.BLOCK_IN)
-        weight_shape = (channel // env.BLOCK_OUT,
-                        channel // env.BLOCK_IN,
-                        env.BLOCK_OUT,
-                        env.BLOCK_IN)
-        res_shape = (batch_size // env.BATCH,
-                     channel // env.BLOCK_OUT,
-                     env.BATCH,
-                     env.BLOCK_OUT)
+        data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
+        weight_shape = (
+            channel // env.BLOCK_OUT,
+            channel // env.BLOCK_IN,
+            env.BLOCK_OUT,
+            env.BLOCK_IN,
+        )
+        res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
         # To compute number of ops, use a x2 factor for FMA
         num_ops = 2 * channel * channel * batch_size
 
-        ko = te.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
-        ki = te.reduce_axis((0, env.BLOCK_IN), name='ki')
-
-        data = te.placeholder(data_shape,
-                               name="data",
-                               dtype=env.inp_dtype)
-        weight = te.placeholder(weight_shape,
-                                 name="weight",
-                                 dtype=env.wgt_dtype)
-        data_buf = te.compute(data_shape,
-                               lambda *i: data(*i),
-                               "data_buf")
-        weight_buf = te.compute(weight_shape,
-                                 lambda *i: weight(*i),
-                                 "weight_buf")
-        res_gem = te.compute(res_shape,
-                              lambda bo, co, bi, ci: te.sum(
-                                  data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
-                                  weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
-                                  axis=[ko, ki]),
-                              name="res_gem")
-        res_shf = te.compute(res_shape,
-                              lambda *i: res_gem(*i)>>8,
-                            name="res_shf")
-        res_max = te.compute(res_shape,
-                              lambda *i: tvm.te.max(res_shf(*i), 0),
-                              "res_max") #relu
-        res_min = te.compute(res_shape,
-                              lambda *i: tvm.te.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
-                              "res_min") #relu
-        res = te.compute(res_shape,
-                          lambda *i: res_min(*i).astype(env.inp_dtype),
-                          name="res")
+        ko = te.reduce_axis((0, channel // env.BLOCK_IN), name="ko")
+        ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
+
+        data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
+        weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
+        data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
+        weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
+        res_gem = te.compute(
+            res_shape,
+            lambda bo, co, bi, ci: te.sum(
+                data_buf[bo, ko, bi, ki].astype(env.acc_dtype)
+                * weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
+                axis=[ko, ki],
+            ),
+            name="res_gem",
+        )
+        res_shf = te.compute(res_shape, lambda *i: res_gem(*i) >> 8, name="res_shf")
+        res_max = te.compute(res_shape, lambda *i: tvm.te.max(res_shf(*i), 0), "res_max")  # relu
+        res_min = te.compute(
+            res_shape, lambda *i: tvm.te.min(res_max(*i), (1 << (env.INP_WIDTH - 1)) - 1), "res_min"
+        )  # relu
+        res = te.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
 
         def verify(s, check_correctness=True):
-            mod = vta.build(s, [data, weight, res],
-                            "ext_dev", env.target_host, name="gemm")
+            mod = vta.build(s, [data, weight, res], "ext_dev", env.target_host, name="gemm")
             temp = util.tempdir()
             mod.save(temp.relpath("gemm.o"))
             remote.upload(temp.relpath("gemm.o"))
@@ -83,16 +67,14 @@ def test_gemm():
             # verify
             ctx = remote.ext_dev(0)
             # Data in original format
-            data_orig = np.random.randint(
-                -128, 128, size=(batch_size, channel)).astype(data.dtype)
-            weight_orig = np.random.randint(
-                -128, 128, size=(channel, channel)).astype(weight.dtype)
+            data_orig = np.random.randint(-128, 128, size=(batch_size, channel)).astype(data.dtype)
+            weight_orig = np.random.randint(-128, 128, size=(channel, channel)).astype(weight.dtype)
             data_packed = data_orig.reshape(
-                batch_size // env.BATCH, env.BATCH,
-                channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
+                batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN
+            ).transpose((0, 2, 1, 3))
             weight_packed = weight_orig.reshape(
-                channel // env.BLOCK_OUT, env.BLOCK_OUT,
-                channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
+                channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN
+            ).transpose((0, 2, 1, 3))
             res_np = np.zeros(res_shape).astype(res.dtype)
             data_arr = tvm.nd.array(data_packed, ctx)
             weight_arr = tvm.nd.array(weight_packed, ctx)
@@ -101,10 +83,12 @@ def test_gemm():
             for b in range(batch_size // env.BATCH):
                 for i in range(channel // env.BLOCK_OUT):
                     for j in range(channel // env.BLOCK_IN):
-                        res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
-                                                 weight_packed[i,j].T.astype(env.acc_dtype))
+                        res_ref[b, i, :] += np.dot(
+                            data_packed[b, j, :].astype(env.acc_dtype),
+                            weight_packed[i, j].T.astype(env.acc_dtype),
+                        )
             res_ref = np.right_shift(res_ref, 8)
-            res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
+            res_ref = np.clip(res_ref, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
             time_f = f.time_evaluator("gemm", ctx, number=20)
             if env.TARGET in ["sim", "tsim"]:
                 simulator.clear_stats()
@@ -114,21 +98,14 @@ def test_gemm():
                 print("Execution statistics:")
                 for k, v in stats.items():
                     print("\t{:<16}: {:>16}".format(k, v))
-            res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
-                                                   channel // env.BLOCK_OUT,
-                                                   env.BATCH,
-                                                   env.BLOCK_OUT)
+            res_unpack = res_arr.asnumpy().reshape(
+                batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT
+            )
             if check_correctness:
                 tvm.testing.assert_allclose(res_unpack, res_ref)
             return cost
 
-        def run_schedule(load_inp,
-                         load_wgt,
-                         gemm,
-                         alu,
-                         store_out,
-                         print_ir,
-                         check_correctness):
+        def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir, check_correctness):
             s = te.create_schedule(res.op)
             s[data_buf].set_scope(env.inp_scope)
             s[weight_buf].set_scope(env.wgt_scope)
@@ -176,7 +153,6 @@ def test_gemm():
                 s[res_max].pragma(s[res_max].op.axis[0], alu)
                 s[res].pragma(s[res].op.axis[0], store_out)
 
-
             if print_ir:
                 print(tvm.lower(s, [data, weight, res], simple_mode=True))
             return verify(s, check_correctness)
@@ -184,39 +160,51 @@ def test_gemm():
         def gemm_normal(print_ir):
             mock = env.mock
             print("----- GEMM GOPS End-to-End Test-------")
+
             def run_test(header, print_ir, check_correctness):
                 cost = run_schedule(
-                    env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
-                    print_ir, check_correctness)
+                    env.dma_copy,
+                    env.dma_copy,
+                    env.gemm,
+                    env.alu,
+                    env.dma_copy,
+                    print_ir,
+                    check_correctness,
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 print(header)
                 print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
+
             with vta.build_config():
                 run_test("NORMAL", print_ir, True)
 
         def gemm_unittest(print_ir):
             mock = env.mock
             print("----- GEMM Unit Test-------")
+
             def run_test(header, print_ir):
                 cost = run_schedule(
-                    mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
-                    print_ir, False)
+                    mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy, print_ir, False
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 print(header)
                 print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
+
             with vta.build_config():
                 run_test("NORMAL", print_ir)
 
         def alu_unittest(print_ir):
             mock = env.mock
             print("----- ALU Unit Test-------")
+
             def run_test(header, print_ir):
                 cost = run_schedule(
-                    mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
-                    print_ir, False)
+                    mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy, print_ir, False
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 print(header)
                 print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
+
             with vta.build_config():
                 run_test("NORMAL", print_ir)
             print("")
@@ -224,14 +212,19 @@ def test_gemm():
         def load_inp_unittest(print_ir):
             mock = env.mock
             print("----- LoadInp Unit Test-------")
+
             def run_test(header, print_ir):
                 cost = run_schedule(
-                    env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
+                    env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
                 print(header)
-                print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
-                    cost.mean, gops, bandwith))
+                print(
+                    "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
+                    % (cost.mean, gops, bandwith)
+                )
+
             with vta.build_config():
                 run_test("NORMAL", print_ir)
             print("")
@@ -239,14 +232,19 @@ def test_gemm():
         def load_wgt_unittest(print_ir):
             mock = env.mock
             print("----- LoadWgt Unit Test-------")
+
             def run_test(header, print_ir):
                 cost = run_schedule(
-                    mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
+                    mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
                 print(header)
-                print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
-                    cost.mean, gops, bandwith))
+                print(
+                    "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
+                    % (cost.mean, gops, bandwith)
+                )
+
             with vta.build_config():
                 run_test("NORMAL", print_ir)
             print("")
@@ -254,21 +252,23 @@ def test_gemm():
         def store_out_unittest(print_ir):
             mock = env.mock
             print("----- StoreOut Unit Test-------")
+
             def run_test(header, print_ir):
                 cost = run_schedule(
-                    mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
-                    print_ir, False)
+                    mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy, print_ir, False
+                )
                 gops = (num_ops / cost.mean) / float(10 ** 9)
                 bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
                 print(header)
-                print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
-                    cost.mean, gops, bandwith))
+                print(
+                    "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
+                    % (cost.mean, gops, bandwith)
+                )
+
             with vta.build_config():
                 run_test("NORMAL", print_ir)
             print("")
 
-
-
         gemm_normal(False)
         gemm_unittest(False)
         alu_unittest(False)
@@ -279,5 +279,6 @@ def test_gemm():
 
     vta.testing.run(_run)
 
+
 if __name__ == "__main__":
     test_gemm()
index 3affbac..004cc6b 100644 (file)
@@ -38,9 +38,22 @@ import vta.testing
 from vta.testing import simulator
 
 
-Workload = namedtuple("Conv2DWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+Workload = namedtuple(
+    "Conv2DWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
 
 # Get batch info from env
 env = vta.get_env()
@@ -49,16 +62,16 @@ env = vta.get_env()
 resnet_wkls = [
     # Workloads of resnet18 on imagenet
     # ('resnet-18.C1',  Workload(env.BATCH, 224, 224, 3,   64,  7, 7, 3, 3, 2, 2)),
-    ('resnet-18.C2',  Workload(env.BATCH,  56,  56, 64,  64,  3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C3',  Workload(env.BATCH,  56,  56, 64,  128, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C4',  Workload(env.BATCH,  56,  56, 64,  128, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C5',  Workload(env.BATCH,  28,  28, 128, 128, 3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C6',  Workload(env.BATCH,  28,  28, 128, 256, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C7',  Workload(env.BATCH,  28,  28, 128, 256, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C8',  Workload(env.BATCH,  14,  14, 256, 256, 3, 3, 1, 1, 1, 1)),
-    ('resnet-18.C9',  Workload(env.BATCH,  14,  14, 256, 512, 3, 3, 1, 1, 2, 2)),
-    ('resnet-18.C10', Workload(env.BATCH,  14,  14, 256, 512, 1, 1, 0, 0, 2, 2)),
-    ('resnet-18.C11', Workload(env.BATCH,   7,   7, 512, 512, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C2", Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C3", Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C4", Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C5", Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C6", Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C7", Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C8", Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
+    ("resnet-18.C9", Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
+    ("resnet-18.C10", Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
+    ("resnet-18.C11", Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
 ]
 
 # FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@@ -71,9 +84,8 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
-def run_conv2d(env, remote, wl, target,
-               check_correctness=True, print_ir=False,
-               samples=4):
+
+def run_conv2d(env, remote, wl, target, check_correctness=True, print_ir=False, samples=4):
 
     # Workload assertions
     assert wl.hpad == wl.wpad
@@ -95,12 +107,30 @@ def run_conv2d(env, remote, wl, target,
     w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
     b_shape = (wl.batch, wl.out_filter, 1, 1)
     if data_pack:
-        data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
-                      wl.height, wl.width, env.BATCH, env.BLOCK_IN)
-        kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
-                        wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
-        bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
-                      1, 1, env.BATCH, env.BLOCK_OUT)
+        data_shape = (
+            wl.batch // env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            wl.height,
+            wl.width,
+            env.BATCH,
+            env.BLOCK_IN,
+        )
+        kernel_shape = (
+            wl.out_filter // env.BLOCK_OUT,
+            wl.in_filter // env.BLOCK_IN,
+            wl.hkernel,
+            wl.wkernel,
+            env.BLOCK_OUT,
+            env.BLOCK_IN,
+        )
+        bias_shape = (
+            wl.batch // env.BATCH,
+            wl.out_filter // env.BLOCK_OUT,
+            1,
+            1,
+            env.BATCH,
+            env.BLOCK_OUT,
+        )
     else:
         data_shape = a_shape
         kernel_shape = w_shape
@@ -114,12 +144,12 @@ def run_conv2d(env, remote, wl, target,
     with target:
         if data_pack:
             res = conv2d_fcompute(
-                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
-                layout, env.acc_dtype)
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1), layout, env.acc_dtype
+            )
         else:
             res = conv2d_fcompute(
-                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
-                env.acc_dtype)
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1), env.acc_dtype
+            )
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
@@ -132,47 +162,68 @@ def run_conv2d(env, remote, wl, target,
     # Derive number of ops
     fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
     fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
-    num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
+    num_ops = (
+        2
+        * wl.batch
+        * fout_height
+        * fout_width
+        * wl.hkernel
+        * wl.wkernel
+        * wl.out_filter
+        * wl.in_filter
+    )
 
     # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
     def get_ref_data():
         # derive min max for act, wgt, and bias types (max non inclusive)
         a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
         w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
-        b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2)
+        b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (
+            env.INP_WIDTH + env.WGT_WIDTH - 2
+        )
         a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
         w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
         b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype)
         r_np = tvm.topi.testing.conv2d_nchw_python(
-            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad).astype(env.acc_dtype)
+            a_np.astype(env.acc_dtype),
+            w_np.astype(env.acc_dtype),
+            (wl.hstride, wl.wstride),
+            wl.hpad,
+        ).astype(env.acc_dtype)
         return a_np, w_np, b_np, r_np
 
     # Data in original format
     data_np, kernel_np, bias_np, res_ref = get_ref_data()
     if data_pack:
         data_np = data_np.reshape(
-            wl.batch//env.BATCH, env.BATCH,
-            wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
-            wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
+            wl.batch // env.BATCH,
+            env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.height,
+            wl.width,
+        ).transpose((0, 2, 4, 5, 1, 3))
         kernel_np = kernel_np.reshape(
-            wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
-            wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
-            wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
+            wl.out_filter // env.BLOCK_OUT,
+            env.BLOCK_OUT,
+            wl.in_filter // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.hkernel,
+            wl.wkernel,
+        ).transpose((0, 2, 4, 5, 1, 3))
         bias_np = bias_np.reshape(
-            wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
-            1, 1, env.BATCH, env.BLOCK_OUT)
+            wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT
+        )
 
     # Build
     if "vta" in target.keys:
-        mod = vta.build(s, [data, kernel, bias, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d")
+        mod = vta.build(
+            s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d"
+        )
     else:
-        mod = tvm.build(s, [data, kernel, bias, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d")
+        mod = tvm.build(
+            s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d"
+        )
     temp = util.tempdir()
     mod.save(temp.relpath("conv2d.o"))
     remote.upload(temp.relpath("conv2d.o"))
@@ -215,10 +266,10 @@ def run_conv2d(env, remote, wl, target,
     if check_correctness:
         res_orig = res_arr.asnumpy()
         if data_pack:
-            res_orig = res_orig.transpose(
-                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
-            bias_np = bias_np.transpose(
-                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
+            res_orig = res_orig.transpose((0, 4, 1, 5, 2, 3)).reshape(
+                wl.batch, wl.out_filter, fout_height, fout_width
+            )
+            bias_np = bias_np.transpose((0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
         res_ref = res_ref >> env.WGT_WIDTH
         res_ref += bias_np
         res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
@@ -235,6 +286,7 @@ def run_conv2d(env, remote, wl, target,
 
     return correct, cost, stats
 
+
 @pytest.mark.parametrize("device", ["vta", "arm_cpu"])
 def test_conv2d(device):
     def _run(env, remote):
@@ -246,12 +298,14 @@ def test_conv2d(device):
                 reconfig_runtime(remote)
         elif device == "arm_cpu":
             target = env.target_vta_cpu
-        with autotvm.tophub.context(target): # load pre-tuned schedule parameters
+        with autotvm.tophub.context(target):  # load pre-tuned schedule parameters
             for _, wl in resnet_wkls:
                 print(wl)
                 run_conv2d(env, remote, wl, target)
+
     vta.testing.run(_run)
 
+
 if __name__ == "__main__":
     test_conv2d(device="arm_cpu")
     test_conv2d(device="vta")
index 80a6848..23c4a5c 100644 (file)
@@ -38,10 +38,24 @@ import vta.testing
 from vta.testing import simulator
 
 
-Workload = namedtuple("Conv2DTransposeWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride',
-                       'o_hpad', 'o_wpad'])
+Workload = namedtuple(
+    "Conv2DTransposeWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+        "o_hpad",
+        "o_wpad",
+    ],
+)
 
 # Get batch info from env
 env = vta.get_env()
@@ -49,9 +63,9 @@ env = vta.get_env()
 # DCGAN workloads
 dcgan_wklds = [
     # dcgan
-    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
-    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
-    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT1", Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT2", Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ("DCGAN.CT3", Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
 ]
 
 # FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@@ -64,6 +78,7 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
+
 # Helper function to get factors
 def _find_factors(n):
     factors = []
@@ -73,9 +88,9 @@ def _find_factors(n):
     return factors
 
 
-def run_conv2d_transpose(env, remote, wl, target,
-               check_correctness=True, print_ir=False,
-               samples=4):
+def run_conv2d_transpose(
+    env, remote, wl, target, check_correctness=True, print_ir=False, samples=4
+):
 
     # Workload assertions
     assert wl.hpad == wl.wpad
@@ -97,10 +112,22 @@ def run_conv2d_transpose(env, remote, wl, target,
     a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
     w_shape = (wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
     if data_pack:
-        data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
-                      wl.height, wl.width, env.BATCH, env.BLOCK_IN)
-        kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
-                        wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
+        data_shape = (
+            wl.batch // env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            wl.height,
+            wl.width,
+            env.BATCH,
+            env.BLOCK_IN,
+        )
+        kernel_shape = (
+            wl.out_filter // env.BLOCK_OUT,
+            wl.in_filter // env.BLOCK_IN,
+            wl.hkernel,
+            wl.wkernel,
+            env.BLOCK_OUT,
+            env.BLOCK_IN,
+        )
     else:
         data_shape = a_shape
         kernel_shape = w_shape
@@ -112,8 +139,8 @@ def run_conv2d_transpose(env, remote, wl, target,
     with target:
 
         res = fcompute(
-            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype,
-            (wl.o_hpad, wl.o_wpad))
+            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype, (wl.o_hpad, wl.o_wpad)
+        )
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
@@ -125,7 +152,16 @@ def run_conv2d_transpose(env, remote, wl, target,
     # Derive number of ops
     fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + wl.o_hpad
     fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + wl.o_wpad
-    num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
+    num_ops = (
+        2
+        * wl.batch
+        * fout_height
+        * fout_width
+        * wl.hkernel
+        * wl.wkernel
+        * wl.out_filter
+        * wl.in_filter
+    )
 
     # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
     def get_ref_data():
@@ -133,36 +169,57 @@ def run_conv2d_transpose(env, remote, wl, target,
         a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
         w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
         a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
-        w_np = np.random.randint(w_min, w_max, size=(wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype)
+        w_np = np.random.randint(
+            w_min, w_max, size=(wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
+        ).astype(kernel.dtype)
         r_np = tvm.topi.testing.conv2d_transpose_nchw_python(
-            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad, (wl.o_hpad, wl.o_wpad)).astype(env.acc_dtype)
+            a_np.astype(env.acc_dtype),
+            w_np.astype(env.acc_dtype),
+            (wl.hstride, wl.wstride),
+            wl.hpad,
+            (wl.o_hpad, wl.o_wpad),
+        ).astype(env.acc_dtype)
         return a_np, w_np, r_np
 
     # Data in original format
     data_np, kernel_np, res_ref = get_ref_data()
     if data_pack:
         data_np = data_np.reshape(
-            wl.batch//env.BATCH, env.BATCH,
-            wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
-            wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
+            wl.batch // env.BATCH,
+            env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.height,
+            wl.width,
+        ).transpose((0, 2, 4, 5, 1, 3))
         kernel_np = kernel_np.reshape(
-            wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
-            wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
-            wl.hkernel, wl.wkernel).transpose((2, 0, 4, 5, 3, 1))
+            wl.in_filter // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.out_filter // env.BLOCK_OUT,
+            env.BLOCK_OUT,
+            wl.hkernel,
+            wl.wkernel,
+        ).transpose((2, 0, 4, 5, 3, 1))
         kernel_np = np.flip(kernel_np, 2)
         kernel_np = np.flip(kernel_np, 3)
 
     # Build
     if "vta" in target.keys:
-        mod = vta.build(s, [data, kernel, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d_transpose")
+        mod = vta.build(
+            s,
+            [data, kernel, res],
+            target=target,
+            target_host=env.target_host,
+            name="conv2d_transpose",
+        )
     else:
-        mod = tvm.build(s, [data, kernel, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d_transpose")
+        mod = tvm.build(
+            s,
+            [data, kernel, res],
+            target=target,
+            target_host=env.target_host,
+            name="conv2d_transpose",
+        )
     temp = util.tempdir()
     mod.save(temp.relpath("conv2d_transpose.o"))
     remote.upload(temp.relpath("conv2d_transpose.o"))
@@ -204,8 +261,9 @@ def run_conv2d_transpose(env, remote, wl, target,
     if check_correctness:
         res_orig = res_arr.asnumpy()
         if data_pack:
-            res_orig = res_orig.transpose(
-                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
+            res_orig = res_orig.transpose((0, 4, 1, 5, 2, 3)).reshape(
+                wl.batch, wl.out_filter, fout_height, fout_width
+            )
         res_ref = res_ref >> env.WGT_WIDTH
         res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res_ref = res_ref.astype(env.out_dtype)
@@ -221,6 +279,7 @@ def run_conv2d_transpose(env, remote, wl, target,
 
     return correct, cost, stats
 
+
 @pytest.mark.parametrize("device", ["vta", "arm_cpu"])
 def test_conv2d_transpose(device):
     def _run(env, remote):
@@ -232,12 +291,14 @@ def test_conv2d_transpose(device):
                 reconfig_runtime(remote)
         elif device == "arm_cpu":
             target = env.target_vta_cpu
-        with autotvm.tophub.context(target): # load pre-tuned schedule parameters
+        with autotvm.tophub.context(target):  # load pre-tuned schedule parameters
             for _, wl in dcgan_wklds:
                 print(wl)
                 run_conv2d_transpose(env, remote, wl, target)
+
     vta.testing.run(_run)
 
+
 if __name__ == "__main__":
     test_conv2d_transpose(device="arm_cpu")
     test_conv2d_transpose(device="vta")
index 3affb36..37cfac1 100644 (file)
@@ -45,10 +45,18 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
-def run_gemm(env, remote, target,
-             batch_size, in_feat, out_feat,
-             check_correctness=True, print_ir=True,
-             samples=4):
+
+def run_gemm(
+    env,
+    remote,
+    target,
+    batch_size,
+    in_feat,
+    out_feat,
+    check_correctness=True,
+    print_ir=True,
+    samples=4,
+):
 
     # Perform packing only if we are targeting the accelerator
     if "arm_cpu" in target.keys:
@@ -60,10 +68,13 @@ def run_gemm(env, remote, target,
     a_shape = (batch_size, in_feat)
     w_shape = (out_feat, in_feat)
     if data_pack:
-        data_shape = (batch_size//env.BATCH, in_feat//env.BLOCK_IN,
-                      env.BATCH, env.BLOCK_IN)
-        kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN,
-                        env.BLOCK_OUT, env.BLOCK_IN)
+        data_shape = (batch_size // env.BATCH, in_feat // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
+        kernel_shape = (
+            out_feat // env.BLOCK_OUT,
+            in_feat // env.BLOCK_IN,
+            env.BLOCK_OUT,
+            env.BLOCK_IN,
+        )
         fcompute = vta.top.dense_packed
         fschedule = vta.top.schedule_dense_packed
     else:
@@ -76,8 +87,7 @@ def run_gemm(env, remote, target,
 
     # Define base computation schedule
     with target:
-        res = fcompute(
-            data, kernel, None, env.acc_dtype)
+        res = fcompute(data, kernel, None, env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
@@ -97,30 +107,30 @@ def run_gemm(env, remote, target,
         a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
         w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
 
-        r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(env.acc_dtype)
+        r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(
+            env.acc_dtype
+        )
         return a_np, w_np, r_np
 
     # Data in original format
     data_np, kernel_np, res_ref = get_ref_data()
     if data_pack:
         data_np = data_np.reshape(
-            batch_size//env.BATCH, env.BATCH,
-            in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
+            batch_size // env.BATCH, env.BATCH, in_feat // env.BLOCK_IN, env.BLOCK_IN
+        ).transpose((0, 2, 1, 3))
         kernel_np = kernel_np.reshape(
-            out_feat//env.BLOCK_OUT, env.BLOCK_OUT,
-            in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
+            out_feat // env.BLOCK_OUT, env.BLOCK_OUT, in_feat // env.BLOCK_IN, env.BLOCK_IN
+        ).transpose((0, 2, 1, 3))
 
     # Build
     if "vta" in target.keys:
-        mod = vta.build(s, [data, kernel, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="dense")
+        mod = vta.build(
+            s, [data, kernel, res], target=target, target_host=env.target_host, name="dense"
+        )
     else:
-        mod = tvm.build(s, [data, kernel, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="dense")
+        mod = tvm.build(
+            s, [data, kernel, res], target=target, target_host=env.target_host, name="dense"
+        )
     temp = util.tempdir()
     mod.save(temp.relpath("dense.o"))
     remote.upload(temp.relpath("dense.o"))
@@ -178,6 +188,7 @@ def run_gemm(env, remote, target,
 
     return correct, cost, stats
 
+
 def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
     def _run(env, remote):
         if device == "vta":
@@ -188,9 +199,11 @@ def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
                 reconfig_runtime(remote)
         elif device == "arm_cpu":
             target = env.target_vta_cpu
-        with autotvm.tophub.context(target): # load pre-tuned schedule parameters
+        with autotvm.tophub.context(target):  # load pre-tuned schedule parameters
             run_gemm(env, remote, target, batch, in_feat, out_feat)
+
     vta.testing.run(_run)
 
+
 if __name__ == "__main__":
     test_gemm("vta", 16, 512, 1008)
index 1fed5a0..08d5e4b 100644 (file)
@@ -37,24 +37,38 @@ import vta.testing
 from vta.testing import simulator
 
 
-Workload = namedtuple("GroupConv2DWorkload",
-                      ['batch', 'height', 'width', 'in_filter', 'out_filter', 'groups',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+Workload = namedtuple(
+    "GroupConv2DWorkload",
+    [
+        "batch",
+        "height",
+        "width",
+        "in_filter",
+        "out_filter",
+        "groups",
+        "hkernel",
+        "wkernel",
+        "hpad",
+        "wpad",
+        "hstride",
+        "wstride",
+    ],
+)
 
 # Get batch info from env
 env = vta.get_env()
 
 # Mobilenet (grouped variant) workloads
 mobilenet_wkls = [
-    ('mobilenet.D1', Workload(env.BATCH, 112, 112,   32,   32,  2, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D2', Workload(env.BATCH, 112, 112,   64,   64,  4, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D3', Workload(env.BATCH,  56,  56,  128,  128,  8, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D4', Workload(env.BATCH,  56,  56,  128,  128,  8, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D5', Workload(env.BATCH,  28,  28,  256,  256, 16, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D6', Workload(env.BATCH,  28,  28,  256,  256, 16, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D7', Workload(env.BATCH,  14,  14,  512,  512, 32, 3, 3, 1, 1, 1, 1)),
-    ('mobilenet.D8', Workload(env.BATCH,  14,  14,  512,  512, 32, 3, 3, 1, 1, 2, 2)),
-    ('mobilenet.D9', Workload(env.BATCH,   7,  7,  1024, 1024, 64, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D1", Workload(env.BATCH, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D2", Workload(env.BATCH, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D3", Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D4", Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D5", Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D6", Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D7", Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)),
+    ("mobilenet.D8", Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)),
+    ("mobilenet.D9", Workload(env.BATCH, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)),
 ]
 
 # FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@@ -67,9 +81,8 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
-def run_group_conv2d(env, remote, wl, target,
-                     check_correctness=True, print_ir=False,
-                     samples=4):
+
+def run_group_conv2d(env, remote, wl, target, check_correctness=True, print_ir=False, samples=4):
 
     # Workload assertions
     assert wl.hpad == wl.wpad
@@ -92,12 +105,30 @@ def run_group_conv2d(env, remote, wl, target,
     w_shape = (wl.out_filter, CI_G, wl.hkernel, wl.wkernel)
     b_shape = (wl.batch, wl.out_filter, 1, 1)
     if data_pack:
-        data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
-                      wl.height, wl.width, env.BATCH, env.BLOCK_IN)
-        kernel_shape = (wl.out_filter//env.BLOCK_OUT, CI_G//env.BLOCK_IN,
-                        wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
-        bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
-                      1, 1, env.BATCH, env.BLOCK_OUT)
+        data_shape = (
+            wl.batch // env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            wl.height,
+            wl.width,
+            env.BATCH,
+            env.BLOCK_IN,
+        )
+        kernel_shape = (
+            wl.out_filter // env.BLOCK_OUT,
+            CI_G // env.BLOCK_IN,
+            wl.hkernel,
+            wl.wkernel,
+            env.BLOCK_OUT,
+            env.BLOCK_IN,
+        )
+        bias_shape = (
+            wl.batch // env.BATCH,
+            wl.out_filter // env.BLOCK_OUT,
+            1,
+            1,
+            env.BATCH,
+            env.BLOCK_OUT,
+        )
     else:
         data_shape = a_shape
         kernel_shape = w_shape
@@ -110,8 +141,8 @@ def run_group_conv2d(env, remote, wl, target,
     # Define base computation schedule
     with target:
         res = fcompute(
-            data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
-            wl.groups, env.acc_dtype)
+            data, kernel, (wl.hstride, wl.wstride), padding, (1, 1), wl.groups, env.acc_dtype
+        )
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
@@ -124,48 +155,69 @@ def run_group_conv2d(env, remote, wl, target,
     # Derive number of ops
     fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
     fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
-    num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * \
-        wl.out_filter * wl.in_filter // wl.groups
+    num_ops = (
+        2
+        * wl.batch
+        * fout_height
+        * fout_width
+        * wl.hkernel
+        * wl.wkernel
+        * wl.out_filter
+        * wl.in_filter
+        // wl.groups
+    )
 
     def get_ref_data():
         # derive min max for act, wgt, and bias types (max non inclusive)
         a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
         w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
-        b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2)
+        b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (
+            env.INP_WIDTH + env.WGT_WIDTH - 2
+        )
         a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
         w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
         b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype)
         r_np = tvm.topi.testing.conv2d_nchw_python(
-            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype),
-            (wl.hstride, wl.wstride), wl.hpad, wl.groups).astype(env.acc_dtype)
+            a_np.astype(env.acc_dtype),
+            w_np.astype(env.acc_dtype),
+            (wl.hstride, wl.wstride),
+            wl.hpad,
+            wl.groups,
+        ).astype(env.acc_dtype)
         return a_np, w_np, b_np, r_np
 
     # Data in original format
     data_np, kernel_np, bias_np, res_ref = get_ref_data()
     if data_pack:
         data_np = data_np.reshape(
-            wl.batch//env.BATCH, env.BATCH,
-            wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
-            wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
+            wl.batch // env.BATCH,
+            env.BATCH,
+            wl.in_filter // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.height,
+            wl.width,
+        ).transpose((0, 2, 4, 5, 1, 3))
         kernel_np = kernel_np.reshape(
-            wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
-            CI_G//env.BLOCK_IN, env.BLOCK_IN,
-            wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
+            wl.out_filter // env.BLOCK_OUT,
+            env.BLOCK_OUT,
+            CI_G // env.BLOCK_IN,
+            env.BLOCK_IN,
+            wl.hkernel,
+            wl.wkernel,
+        ).transpose((0, 2, 4, 5, 1, 3))
         bias_np = bias_np.reshape(
-            wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
-            1, 1, env.BATCH, env.BLOCK_OUT)
+            wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT
+        )
 
     # Build
     if "vta" in target.keys:
-        mod = vta.build(s, [data, kernel, bias, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d")
+        mod = vta.build(
+            s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d"
+        )
     else:
-        mod = tvm.build(s, [data, kernel, bias, res],
-                        target=target,
-                        target_host=env.target_host,
-                        name="conv2d")
+        mod = tvm.build(
+            s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d"
+        )
     temp = util.tempdir()
     mod.save(temp.relpath("conv2d.o"))
     remote.upload(temp.relpath("conv2d.o"))
@@ -208,10 +260,10 @@ def run_group_conv2d(env, remote, wl, target,
     if check_correctness:
         res_orig = res_arr.asnumpy()
         if data_pack:
-            res_orig = res_orig.transpose(
-                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
-            bias_np = bias_np.transpose(
-                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
+            res_orig = res_orig.transpose((0, 4, 1, 5, 2, 3)).reshape(
+                wl.batch, wl.out_filter, fout_height, fout_width
+            )
+            bias_np = bias_np.transpose((0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
         res_ref = res_ref >> env.WGT_WIDTH
         res_ref += bias_np
         res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
@@ -224,10 +276,14 @@ def run_group_conv2d(env, remote, wl, target,
         device = "CPU"
     elif "vta" in target.keys:
         device = "VTA"
-    print("%s GROUP CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops))
+    print(
+        "%s GROUP CONV2D TEST %s: Time cost = %g sec/op, %g GOPS"
+        % (device, status, cost.mean, gops)
+    )
 
     return correct, cost, stats
 
+
 @pytest.mark.parametrize("device", ["vta", "arm_cpu"])
 def test_conv2d(device):
     def _run(env, remote):
@@ -239,12 +295,14 @@ def test_conv2d(device):
                 reconfig_runtime(remote)
         elif device == "arm_cpu":
             target = env.target_vta_cpu
-        with autotvm.tophub.context(target): # load pre-tuned schedule parameters
+        with autotvm.tophub.context(target):  # load pre-tuned schedule parameters
             for _, wl in mobilenet_wkls:
                 print(wl)
                 run_group_conv2d(env, remote, wl, target)
+
     vta.testing.run(_run)
 
+
 if __name__ == "__main__":
     test_conv2d(device="arm_cpu")
     test_conv2d(device="vta")
index ad6e43e..5e47153 100644 (file)
@@ -23,6 +23,7 @@ from vta import get_bitstream_path, download_bitstream, program_fpga, reconfig_r
 host = os.environ.get("VTA_RPC_HOST", "pynq")
 port = int(os.environ.get("VTA_RPC_PORT", "9091"))
 
+
 def program_rpc_bitstream(path=None):
     """Program the FPGA on the RPC server
 
@@ -34,12 +35,13 @@ def program_rpc_bitstream(path=None):
     remote = rpc.connect(host, port)
     program_fpga(remote, path)
 
+
 def reconfig_rpc_runtime():
-    """Reconfig the RPC server runtime
-    """
+    """Reconfig the RPC server runtime"""
     assert tvm.runtime.enabled("rpc")
     remote = rpc.connect(host, port)
     reconfig_runtime(remote)
 
+
 program_rpc_bitstream()
 reconfig_rpc_runtime()
index 605a9e0..61219b6 100644 (file)
@@ -22,6 +22,7 @@ def test_env():
     mock = env.mock
     assert mock.alu == "skip_alu"
 
+
 def test_env_scope():
     env = vta.get_env()
     cfg = env.cfg_dict
index be347a0..fb0acf1 100644 (file)
@@ -25,26 +25,21 @@ import vta
 import vta.testing
 from vta.testing import simulator
 
-np.random.seed(0xdeadb)
+np.random.seed(0xDEADB)
+
 
 def test_save_load_out():
     """Test save/store output command"""
+
     def _run(env, remote):
         n = 6
-        x = te.placeholder(
-            (n, n, env.BATCH, env.BLOCK_OUT),
-            name="x",
-            dtype=env.acc_dtype)
-        x_buf = te.compute(
-            (n, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: x(*i), "x_buf")
+        x = te.placeholder((n, n, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
+        x_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x(*i), "x_buf")
         # insert no-op that won't be optimized away
-        y_buf = te.compute(
-            (n, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: x_buf(*i)>>0, "y_buf")
+        y_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x_buf(*i) >> 0, "y_buf")
         y = te.compute(
-            (n, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+            (n, n, env.BATCH, env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y"
+        )
         # schedule
         s = te.create_schedule(y.op)
         s[x_buf].set_scope(env.acc_scope)
@@ -65,8 +60,7 @@ def test_save_load_out():
         f = remote.load_module("load_act.o")
         # verify
         ctx = remote.ext_dev(0)
-        x_np = np.random.randint(
-            1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+        x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
         y_np = x_np.astype(y.dtype)
         x_nd = tvm.nd.array(x_np, ctx)
         y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
@@ -89,25 +83,35 @@ def test_save_load_out():
 
 def test_padded_load():
     """Test padded load."""
+
     def _run(env, remote):
         def check_padded_load(pad_before, pad_after, test_name=None):
             # declare
             n = 3
             m = 5
-            x = te.placeholder(
-                (n, m, env.BATCH, env.BLOCK_OUT),
-                name="x",
-                dtype=env.acc_dtype)
+            x = te.placeholder((n, m, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
             x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
             # insert no-op that won't be optimized away
-            y_buf = te.compute((n + pad_before[0] + pad_after[0],
-                                 m + pad_before[1] + pad_after[1],
-                                 env.BATCH,
-                                 env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
-            y = te.compute((n + pad_before[0] + pad_after[0],
-                             m + pad_before[1] + pad_after[1],
-                             env.BATCH,
-                             env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+            y_buf = te.compute(
+                (
+                    n + pad_before[0] + pad_after[0],
+                    m + pad_before[1] + pad_after[1],
+                    env.BATCH,
+                    env.BLOCK_OUT,
+                ),
+                lambda *i: x_buf(*i) >> 0,
+                "y_buf",
+            )
+            y = te.compute(
+                (
+                    n + pad_before[0] + pad_after[0],
+                    m + pad_before[1] + pad_after[1],
+                    env.BATCH,
+                    env.BLOCK_OUT,
+                ),
+                lambda *i: y_buf(*i).astype(env.inp_dtype),
+                "y",
+            )
             # schedule
             s = te.create_schedule(y.op)
             s[x_buf].set_scope(env.acc_scope)
@@ -127,15 +131,16 @@ def test_padded_load():
             f = remote.load_module("padded_load.o")
             # verify
             ctx = remote.ext_dev(0)
-            x_np = np.random.randint(0, 10, size=(
-                n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
-            y_np = np.zeros((n + pad_before[0] + pad_after[0],
-                             m + pad_before[1] + pad_after[1],
-                             env.BATCH,
-                             env.BLOCK_OUT)).astype(y.dtype)
-            y_np[pad_before[0]:pad_before[0] + n,
-                 pad_before[1]:pad_before[1] + m,
-                 :] = x_np
+            x_np = np.random.randint(0, 10, size=(n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+            y_np = np.zeros(
+                (
+                    n + pad_before[0] + pad_after[0],
+                    m + pad_before[1] + pad_after[1],
+                    env.BATCH,
+                    env.BLOCK_OUT,
+                )
+            ).astype(y.dtype)
+            y_np[pad_before[0] : pad_before[0] + n, pad_before[1] : pad_before[1] + m, :] = x_np
             x_nd = tvm.nd.array(x_np, ctx)
             y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
 
@@ -163,6 +168,7 @@ def test_padded_load():
 
 def test_gemm():
     """Test GEMM."""
+
     def _run(env, remote):
         # declare
         o = 4
@@ -176,27 +182,27 @@ def test_gemm():
         ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
         y_gem = te.compute(
             (o, m, env.BATCH, env.BLOCK_OUT),
-            lambda bo, co, bi, ci:
-            te.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
-                    w_buf[co, ko, ci, ki].astype(env.acc_dtype),
-                    axis=[ko, ki]),
-            name="y_gem")
+            lambda bo, co, bi, ci: te.sum(
+                x_buf[bo, ko, bi, ki].astype(env.acc_dtype)
+                * w_buf[co, ko, ci, ki].astype(env.acc_dtype),
+                axis=[ko, ki],
+            ),
+            name="y_gem",
+        )
         y_shf = te.compute(
-            (o, m, env.BATCH, env.BLOCK_OUT),
-            lambda *i: y_gem(*i)>>8,
-            name="y_shf")
+            (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_gem(*i) >> 8, name="y_shf"
+        )
         y_max = te.compute(
-            (o, m, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm.te.max(y_shf(*i), 0),
-            "y_max") #relu
+            (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(y_shf(*i), 0), "y_max"
+        )  # relu
         y_min = te.compute(
             (o, m, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm.te.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
-            "y_min") #relu
+            lambda *i: tvm.te.min(y_max(*i), (1 << (env.INP_WIDTH - 1)) - 1),
+            "y_min",
+        )  # relu
         y = te.compute(
-            (o, m, env.BATCH, env.BLOCK_OUT),
-            lambda *i: y_min(*i).astype(env.inp_dtype),
-            name="y")
+            (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_min(*i).astype(env.inp_dtype), name="y"
+        )
 
         if not remote:
             return
@@ -209,10 +215,12 @@ def test_gemm():
             f = remote.load_module("gemm.o")
             # verify
             ctx = remote.ext_dev(0)
-            x_np = np.random.randint(
-                -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
-            w_np = np.random.randint(
-                -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
+            x_np = np.random.randint(-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(
+                x.dtype
+            )
+            w_np = np.random.randint(-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(
+                w.dtype
+            )
             y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
             x_nd = tvm.nd.array(x_np, ctx)
             w_nd = tvm.nd.array(w_np, ctx)
@@ -221,10 +229,11 @@ def test_gemm():
             for b in range(o):
                 for i in range(m):
                     for j in range(n):
-                        y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
-                                              w_np[i,j].T.astype(env.acc_dtype))
+                        y_np[b, i, :] += np.dot(
+                            x_np[b, j, :].astype(env.acc_dtype), w_np[i, j].T.astype(env.acc_dtype)
+                        )
             y_np = np.right_shift(y_np, 8)
-            y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
+            y_np = np.clip(y_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(y.dtype)
 
             if env.TARGET in ["sim", "tsim"]:
                 simulator.clear_stats()
@@ -265,7 +274,8 @@ def test_gemm():
                 s[y_gem].op.axis[1],
                 s[y_gem].op.axis[2],
                 s[y_gem].op.axis[3],
-                ki)
+                ki,
+            )
             s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
             verify(s, name="default")
 
@@ -291,7 +301,8 @@ def test_gemm():
                 s[y_gem].op.axis[1],
                 s[y_gem].op.axis[2],
                 s[y_gem].op.axis[3],
-                ki)
+                ki,
+            )
             s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
             s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
             s[y_max].pragma(s[y_max].op.axis[0], env.alu)
@@ -305,6 +316,7 @@ def test_gemm():
 
         test_schedule1()
         test_smt()
+
     vta.testing.run(_run)
 
 
@@ -314,48 +326,41 @@ def test_alu():
             """Test ALU"""
             m = 8
             n = 8
-            imm = np.random.randint(1,5)
+            imm = np.random.randint(1, 5)
             # compute
-            a = te.placeholder(
-                (m, n, env.BATCH, env.BLOCK_OUT),
-                name="a",
-                dtype=env.acc_dtype)
+            a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
             a_buf = te.compute(
-                (m, n, env.BATCH, env.BLOCK_OUT),
-                lambda *i: a(*i),
-                "a_buf") #DRAM->SRAM
+                (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
+            )  # DRAM->SRAM
             if use_imm:
                 res_buf = te.compute(
-                    (m, n, env.BATCH, env.BLOCK_OUT),
-                    lambda *i: tvm_op(a_buf(*i), imm),
-                    "res_buf") #compute
+                    (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm_op(a_buf(*i), imm), "res_buf"
+                )  # compute
             else:
-                b = te.placeholder(
-                    (m, n, env.BATCH, env.BLOCK_OUT),
-                    name="b",
-                    dtype=env.acc_dtype)
+                b = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="b", dtype=env.acc_dtype)
                 b_buf = te.compute(
-                    (m, n, env.BATCH, env.BLOCK_OUT),
-                    lambda *i: b(*i),
-                    "b_buf") #DRAM->SRAM
+                    (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: b(*i), "b_buf"
+                )  # DRAM->SRAM
                 res_buf = te.compute(
                     (m, n, env.BATCH, env.BLOCK_OUT),
                     lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
-                    "res_buf") #compute5B
+                    "res_buf",
+                )  # compute5B
             res = te.compute(
                 (m, n, env.BATCH, env.BLOCK_OUT),
                 lambda *i: res_buf(*i).astype(env.inp_dtype),
-                "res") #SRAM->DRAM
+                "res",
+            )  # SRAM->DRAM
             # schedule
             s = te.create_schedule(res.op)
-            s[a_buf].set_scope(env.acc_scope) # SRAM
-            s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-            s[res_buf].set_scope(env.acc_scope) # SRAM
-            s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
-            s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+            s[a_buf].set_scope(env.acc_scope)  # SRAM
+            s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
+            s[res_buf].set_scope(env.acc_scope)  # SRAM
+            s[res_buf].pragma(res_buf.op.axis[0], env.alu)  # compute
+            s[res].pragma(res.op.axis[0], env.dma_copy)  # SRAM->DRAM
             if not use_imm:
-                s[b_buf].set_scope(env.acc_scope) # SRAM
-                s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
+                s[b_buf].set_scope(env.acc_scope)  # SRAM
+                s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
 
             if not remote:
                 return
@@ -372,18 +377,17 @@ def test_alu():
             f = remote.load_module("load_act.o")
             # verify
             ctx = remote.ext_dev(0)
-            a_np = np.random.randint(
-                -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+            a_np = np.random.randint(-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
             if use_imm:
                 res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
             else:
-                b_np = np.random.randint(
-                    -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
+                b_np = np.random.randint(-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(
+                    b.dtype
+                )
                 res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
             res_np = res_np.astype(res.dtype)
             a_nd = tvm.nd.array(a_np, ctx)
-            res_nd = tvm.nd.array(
-                np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+            res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
 
             if env.TARGET in ["sim", "tsim"]:
                 simulator.clear_stats()
@@ -414,39 +418,37 @@ def test_alu():
 
 def test_relu():
     """Test RELU on ALU"""
+
     def _run(env, remote):
         m = 8
         n = 10
         # compute
-        a = te.placeholder(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            name="a",
-            dtype=env.acc_dtype)
+        a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
         a_buf = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: a(*i),
-            "a_buf") # DRAM->SRAM
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
+        )  # DRAM->SRAM
         max_buf = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm.te.max(a_buf(*i), 0),
-            "res_buf") # relu
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(a_buf(*i), 0), "res_buf"
+        )  # relu
         min_buf = te.compute(
             (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm.te.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
-            "max_buf") # relu
+            lambda *i: tvm.te.min(max_buf(*i), (1 << (env.INP_WIDTH - 1)) - 1),
+            "max_buf",
+        )  # relu
         res = te.compute(
             (m, n, env.BATCH, env.BLOCK_OUT),
             lambda *i: min_buf(*i).astype(env.inp_dtype),
-            "min_buf") # SRAM->DRAM
+            "min_buf",
+        )  # SRAM->DRAM
         # schedule
         s = te.create_schedule(res.op)
-        s[a_buf].set_scope(env.acc_scope) # SRAM
-        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-        s[max_buf].set_scope(env.acc_scope) # SRAM
-        s[min_buf].set_scope(env.acc_scope) # SRAM
-        s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
-        s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
-        s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+        s[a_buf].set_scope(env.acc_scope)  # SRAM
+        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
+        s[max_buf].set_scope(env.acc_scope)  # SRAM
+        s[min_buf].set_scope(env.acc_scope)  # SRAM
+        s[max_buf].pragma(max_buf.op.axis[0], env.alu)  # compute
+        s[min_buf].pragma(min_buf.op.axis[0], env.alu)  # compute
+        s[res].pragma(res.op.axis[0], env.dma_copy)  # SRAM->DRAM
         # build
         with vta.build_config():
             mod = vta.build(s, [a, res], "ext_dev", env.target_host)
@@ -458,12 +460,10 @@ def test_relu():
         f = remote.load_module("load_act.o")
         # verify
         ctx = remote.ext_dev(0)
-        a_np = np.random.randint(
-            -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
-        res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
+        a_np = np.random.randint(-256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+        res_np = np.clip(a_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
         a_nd = tvm.nd.array(a_np, ctx)
-        res_nd = tvm.nd.array(
-            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+        res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
 
         if env.TARGET in ["sim", "tsim"]:
             simulator.clear_stats()
@@ -483,40 +483,35 @@ def test_relu():
 
 def test_shift_and_scale():
     """Test shift and scale on ALU"""
+
     def _run(env, remote):
         m = 2
         n = 8
-        imm_shift = np.random.randint(0,8)
-        imm_scale = np.random.randint(1,5)
+        imm_shift = np.random.randint(0, 8)
+        imm_scale = np.random.randint(1, 5)
         # compute
-        a = te.placeholder(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            name="a", dtype=env.acc_dtype)
+        a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
         a_buf = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: a(*i),
-            "a_buf") # DRAM->SRAM
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
+        )  # DRAM->SRAM
         res_shift = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: a_buf(*i)+imm_shift,
-            "res_shift") # compute
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a_buf(*i) + imm_shift, "res_shift"
+        )  # compute
         res_scale = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: res_shift(*i)>>imm_scale,
-            "res_scale") # compute
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: res_shift(*i) >> imm_scale, "res_scale"
+        )  # compute
         res = te.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: res_scale(*i).astype(env.inp_dtype),
-            "res") # SRAM->DRAM
+            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: res_scale(*i).astype(env.inp_dtype), "res"
+        )  # SRAM->DRAM
         # schedule
         s = te.create_schedule(res.op)
-        s[a_buf].set_scope(env.acc_scope) # SRAM
-        s[res_shift].set_scope(env.acc_scope) # SRAM
-        s[res_scale].set_scope(env.acc_scope) # SRAM
-        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-        s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
-        s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
-        s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+        s[a_buf].set_scope(env.acc_scope)  # SRAM
+        s[res_shift].set_scope(env.acc_scope)  # SRAM
+        s[res_scale].set_scope(env.acc_scope)  # SRAM
+        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
+        s[res_shift].pragma(res_shift.op.axis[0], env.alu)  # compute
+        s[res_scale].pragma(res_scale.op.axis[0], env.alu)  # compute
+        s[res].pragma(res.op.axis[0], env.dma_copy)  # SRAM->DRAM
         # build
         mod = vta.build(s, [a, res], "ext_dev", env.target_host)
         if not remote:
@@ -527,13 +522,11 @@ def test_shift_and_scale():
         f = remote.load_module("load_act.o")
         # verify
         ctx = remote.ext_dev(0)
-        a_np = np.random.randint(
-            -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+        a_np = np.random.randint(-10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
         res_np = np.right_shift((a_np + imm_shift), imm_scale)
         res_np = res_np.astype(res.dtype)
         a_nd = tvm.nd.array(a_np, ctx)
-        res_nd = tvm.nd.array(
-            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+        res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
 
         if env.TARGET in ["sim", "tsim"]:
             simulator.clear_stats()
@@ -555,8 +548,7 @@ def test_runtime_array():
     def _run(env, remote):
         n = 100
         ctx = remote.ext_dev(0)
-        x_np = np.random.randint(
-            1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8")
+        x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8")
         x_nd = tvm.nd.array(x_np, ctx)
         np.testing.assert_equal(x_np, x_nd.asnumpy())
 
index 1184006..d1a2e85 100644 (file)
@@ -79,7 +79,7 @@ from vta.top import graph_pack
 def compile_network(env, target, model, start_pack, stop_pack):
 
     # Populate the shape and data type dictionary
-    dtype_dict = {"data": 'float32'}
+    dtype_dict = {"data": "float32"}
     shape_dict = {"data": (env.BATCH, 3, 224, 224)}
 
     # Get off the shelf gluon model, and convert to relay
@@ -99,12 +99,14 @@ def compile_network(env, target, model, start_pack, stop_pack):
     # Perform graph packing and constant folding for VTA target
     if target.device_name == "vta":
         assert env.BLOCK_IN == env.BLOCK_OUT
-        relay_prog = graph_pack(mod["main"],
-                                env.BATCH,
-                                env.BLOCK_OUT,
-                                env.WGT_WIDTH,
-                                start_name=start_pack,
-                                stop_name=stop_pack)
+        relay_prog = graph_pack(
+            mod["main"],
+            env.BATCH,
+            env.BLOCK_OUT,
+            env.WGT_WIDTH,
+            start_name=start_pack,
+            stop_name=stop_pack,
+        )
 
     return relay_prog, params
 
@@ -178,7 +180,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
 # Here we use an Pynq-Z1 board as an example.
 
 # Tracker host and port can be set by your environment
-tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0')
+tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0")
 tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
 
 # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
@@ -201,20 +203,20 @@ stop_pack = "nn.global_avg_pool2d"
 # Tuning option
 log_file = "%s.%s.log" % (device, network)
 tuning_option = {
-    'log_filename': log_file,
-
-    'tuner': 'random',
-    'n_trial': 1000,
-    'early_stopping': None,
-
-    'measure_option': autotvm.measure_option(
+    "log_filename": log_file,
+    "tuner": "random",
+    "n_trial": 1000,
+    "early_stopping": None,
+    "measure_option": autotvm.measure_option(
         builder=autotvm.LocalBuilder(),
-        runner=autotvm.RPCRunner(env.TARGET,
-                                 host=tracker_host,
-                                 port=tracker_port,
-                                 number=5,
-                                 timeout=60,
-                                 check_correctness=True),
+        runner=autotvm.RPCRunner(
+            env.TARGET,
+            host=tracker_host,
+            port=tracker_port,
+            number=5,
+            timeout=60,
+            check_correctness=True,
+        ),
     ),
 }
 
@@ -242,13 +244,15 @@ tuning_option = {
 
 
 # You can skip the implementation of this function for this tutorial.
-def tune_tasks(tasks,
-               measure_option,
-               tuner='xgb',
-               n_trial=1000,
-               early_stopping=None,
-               log_filename='tuning.log',
-               use_transfer_learning=True):
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
 
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
@@ -259,15 +263,15 @@ def tune_tasks(tasks,
         prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
 
         # create tuner
-        if tuner == 'xgb' or tuner == 'xgb-rank':
-            tuner_obj = XGBTuner(tsk, loss_type='rank')
-        elif tuner == 'xgb_knob':
-            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
-        elif tuner == 'ga':
+        if tuner == "xgb" or tuner == "xgb-rank":
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "xgb_knob":
+            tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob")
+        elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=50)
-        elif tuner == 'random':
+        elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
-        elif tuner == 'gridsearch':
+        elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
@@ -278,13 +282,15 @@ def tune_tasks(tasks,
 
         # do tuning
         tsk_trial = min(n_trial, len(tsk.config_space))
-        tuner_obj.tune(n_trial=tsk_trial,
-                       early_stopping=early_stopping,
-                       measure_option=measure_option,
-                       callbacks=[
-                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
-                           autotvm.callback.log_to_file(tmp_log_file)
-                       ])
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
 
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
@@ -321,7 +327,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.Target.current().device_name == 'vta':
+        if tvm.target.Target.current().device_name == "vta":
             s = vta.top.schedule_conv2d_packed([res])
         else:
             s = te.create_schedule([res.op])
@@ -336,10 +342,9 @@ def tune_and_evaluate(tuning_opt):
 
     if env.TARGET != "sim":
         # Get remote from fleet node
-        remote = autotvm.measure.request_remote(env.TARGET,
-                                                tracker_host,
-                                                tracker_port,
-                                                timeout=10000)
+        remote = autotvm.measure.request_remote(
+            env.TARGET, tracker_host, tracker_port, timeout=10000
+        )
         # Reconfigure the JIT runtime and FPGA.
         vta.reconfig_runtime(remote)
         vta.program_fpga(remote, bitstream=None)
@@ -354,11 +359,13 @@ def tune_and_evaluate(tuning_opt):
     print("Extract tasks...")
     relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)
     mod = tvm.IRModule.from_expr(relay_prog)
-    tasks = autotvm.task.extract_from_program(mod,
-                                              params=params,
-                                              ops=(relay.op.get("nn.conv2d"),),
-                                              target=target,
-                                              target_host=env.target_host)
+    tasks = autotvm.task.extract_from_program(
+        mod,
+        params=params,
+        ops=(relay.op.get("nn.conv2d"),),
+        target=target,
+        target_host=env.target_host,
+    )
 
     # filter out non-packed conv2d task
     tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
@@ -376,9 +383,21 @@ def tune_and_evaluate(tuning_opt):
         hkernel, wkernel = wgt[2], wgt[3]
         hstride, wstride = tsk.args[2][0], tsk.args[2][1]
         hpad, wpad = tsk.args[3][0], tsk.args[3][1]
-        print("({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
-            batch, height, width, in_filter, out_filter, hkernel, wkernel,
-            hpad, wpad, hstride, wstride))
+        print(
+            "({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
+                batch,
+                height,
+                width,
+                in_filter,
+                out_filter,
+                hkernel,
+                wkernel,
+                hpad,
+                wpad,
+                hstride,
+                wstride,
+            )
+        )
 
     # We do not run the tuning in our webpage server since it takes too long.
     # Comment the following line to run it by yourself.
@@ -394,17 +413,14 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         if target.device_name != "vta":
             with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
-                graph, lib, params = relay.build(relay_prog,
-                                                target=target,
-                                                params=params,
-                                                target_host=env.target_host)
+                graph, lib, params = relay.build(
+                    relay_prog, target=target, params=params, target_host=env.target_host
+                )
         else:
             with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
                 graph, lib, params = relay.build(
-                    relay_prog,
-                    target=target,
-                    params=params,
-                    target_host=env.target_host)
+                    relay_prog, target=target, params=params, target_host=env.target_host
+                )
 
         # Export library
         print("Upload...")
@@ -418,18 +434,19 @@ def tune_and_evaluate(tuning_opt):
         m = graph_runtime.create(graph, lib, ctx)
 
         # upload parameters to device
-        image = tvm.nd.array(
-            (np.random.uniform(size=(1, 3, 224, 224))).astype('float32'))
+        image = tvm.nd.array((np.random.uniform(size=(1, 3, 224, 224))).astype("float32"))
         m.set_input(**params)
-        m.set_input('data', image)
+        m.set_input("data", image)
 
         # evaluate
         print("Evaluate inference time cost...")
         timer = m.module.time_evaluator("run", ctx, number=1, repeat=10)
         tcost = timer()
         prof_res = np.array(tcost.results) * 1000  # convert to millisecond
-        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
-              (np.mean(prof_res), np.std(prof_res)))
+        print(
+            "Mean inference time (std dev): %.2f ms (%.2f ms)"
+            % (np.mean(prof_res), np.std(prof_res))
+        )
 
 
 # Run the tuning and evaluate the results
index 3a36785..74c7412 100644 (file)
@@ -114,7 +114,9 @@ if env.TARGET not in ["sim", "tsim"]:
     if not tracker_host or not tracker_port:
         remote = rpc.connect(device_host, int(device_port))
     else:
-        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000)
+        remote = autotvm.measure.request_remote(
+            env.TARGET, tracker_host, int(tracker_port), timeout=10000
+        )
 
     # Reconfigure the JIT runtime and FPGA.
     # You can program the FPGA with your own custom bitstream
@@ -152,7 +154,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
 with autotvm.tophub.context(target):
 
     # Populate the shape and data type dictionary for ImageNet classifier input
-    dtype_dict = {"data": 'float32'}
+    dtype_dict = {"data": "float32"}
     shape_dict = {"data": (env.BATCH, 3, 224, 224)}
 
     # Get off the shelf gluon model, and convert to relay
@@ -172,8 +174,7 @@ with autotvm.tophub.context(target):
         # Perform quantization in Relay
         # Note: We set opt_level to 3 in order to fold batch norm
         with tvm.transform.PassContext(opt_level=3):
-            with relay.quantize.qconfig(global_scale=8.0,
-                                        skip_conv_layers=[0]):
+            with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
                 mod = relay.quantize.quantize(mod, params=params)
             # Perform graph packing and constant folding for VTA target
             assert env.BLOCK_IN == env.BLOCK_OUT
@@ -183,7 +184,8 @@ with autotvm.tophub.context(target):
                 env.BLOCK_OUT,
                 env.WGT_WIDTH,
                 start_name=pack_dict[model][0],
-                stop_name=pack_dict[model][1])
+                stop_name=pack_dict[model][1],
+            )
     else:
         relay_prog = mod["main"]
 
@@ -191,13 +193,13 @@ with autotvm.tophub.context(target):
     if target.device_name != "vta":
         with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
             graph, lib, params = relay.build(
-                relay_prog, target=target,
-                params=params, target_host=env.target_host)
+                relay_prog, target=target, params=params, target_host=env.target_host
+            )
     else:
         with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
             graph, lib, params = relay.build(
-                relay_prog, target=target,
-                params=params, target_host=env.target_host)
+                relay_prog, target=target, params=params, target_host=env.target_host
+            )
 
     # Measure Relay build time
     build_time = time.time() - build_start
@@ -226,15 +228,15 @@ download.download(join(categ_url, categ_fn), categ_fn)
 synset = eval(open(categ_fn).read())
 
 # Download test image
-image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
-image_fn = 'cat.png'
+image_url = "https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg"
+image_fn = "cat.png"
 download.download(image_url, image_fn)
 
 # Prepare test image for inference
 image = Image.open(image_fn).resize((224, 224))
 plt.imshow(image)
 plt.show()
-image = np.array(image) - np.array([123., 117., 104.])
+image = np.array(image) - np.array([123.0, 117.0, 104.0])
 image /= np.array([58.395, 57.12, 57.375])
 image = image.transpose((2, 0, 1))
 image = image[np.newaxis, :]
@@ -242,12 +244,12 @@ image = np.repeat(image, env.BATCH, axis=0)
 
 # Set the network parameters and inputs
 m.set_input(**params)
-m.set_input('data', image)
+m.set_input("data", image)
 
 # Perform inference and gather execution statistics
 # More on: :py:method:`tvm.runtime.Module.time_evaluator`
-num = 4 # number of times we run module for a single measurement
-rep = 3 # number of measurements (we derive std dev from this)
+num = 4  # number of times we run module for a single measurement
+rep = 3  # number of measurements (we derive std dev from this)
 timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
 
 if env.TARGET in ["sim", "tsim"]:
@@ -265,7 +267,7 @@ else:
     std = np.std(tcost.results) * 1000
     mean = tcost.mean * 1000
     print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
-    print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
+    print("Average per sample inference time: %.2fms" % (mean / env.BATCH))
 
 # Get classification results
 tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
@@ -287,4 +289,4 @@ for b in range(env.BATCH):
     for k in top_categories[-5:]:
         if "cat" in synset[k]:
             cat_detected = True
-    assert(cat_detected)
+    assert cat_detected
index 5039488..f6fd462 100644 (file)
@@ -62,6 +62,7 @@ from tvm.contrib import graph_runtime, graph_runtime, util
 from tvm.contrib.download import download_testdata
 from vta.testing import simulator
 from vta.top import graph_pack
+
 # Make sure that TVM was compiled with RPC=1
 assert tvm.runtime.enabled("rpc")
 
@@ -69,39 +70,42 @@ assert tvm.runtime.enabled("rpc")
 # Download yolo net configure file, weight file, darknet library file based on
 # Model Name
 # ----------------------------------------------------------------------------
-MODEL_NAME = 'yolov3-tiny'
-REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
-
-cfg_path = download_testdata('https://github.com/pjreddie/darknet/blob/master/cfg/'
-                             + MODEL_NAME + '.cfg' + '?raw=true',
-                             MODEL_NAME + '.cfg',
-                             module="darknet")
-weights_path = download_testdata('https://pjreddie.com/media/files/'
-                                 + MODEL_NAME + '.weights' + '?raw=true',
-                                 MODEL_NAME + '.weights',
-                                 module="darknet")
-
-if sys.platform in ['linux', 'linux2']:
-    darknet_lib_path = download_testdata(REPO_URL + 'lib/' + 'libdarknet2.0.so' + '?raw=true',
-                                         'libdarknet2.0.so',
-                                         module="darknet")
-elif sys.platform == 'darwin':
-    darknet_lib_path = download_testdata(REPO_URL+'lib_osx/'+'libdarknet_mac2.0.so'+'?raw=true',
-                                         'libdarknet_mac2.0.so',
-                                         module="darknet")
+MODEL_NAME = "yolov3-tiny"
+REPO_URL = "https://github.com/dmlc/web-data/blob/master/darknet/"
+
+cfg_path = download_testdata(
+    "https://github.com/pjreddie/darknet/blob/master/cfg/" + MODEL_NAME + ".cfg" + "?raw=true",
+    MODEL_NAME + ".cfg",
+    module="darknet",
+)
+weights_path = download_testdata(
+    "https://pjreddie.com/media/files/" + MODEL_NAME + ".weights" + "?raw=true",
+    MODEL_NAME + ".weights",
+    module="darknet",
+)
+
+if sys.platform in ["linux", "linux2"]:
+    darknet_lib_path = download_testdata(
+        REPO_URL + "lib/" + "libdarknet2.0.so" + "?raw=true", "libdarknet2.0.so", module="darknet"
+    )
+elif sys.platform == "darwin":
+    darknet_lib_path = download_testdata(
+        REPO_URL + "lib_osx/" + "libdarknet_mac2.0.so" + "?raw=true",
+        "libdarknet_mac2.0.so",
+        module="darknet",
+    )
 else:
-    raise NotImplementedError("Darknet lib is not supported on {} platform"
-                              .format(sys.platform))
+    raise NotImplementedError("Darknet lib is not supported on {} platform".format(sys.platform))
 
 ##################################################
 # Download yolo categories and illustration front.
 # ------------------------------------------------
-coco_path = download_testdata(REPO_URL + 'data/' + 'coco.names' + '?raw=true',
-                              'coco.names',
-                              module='data')
-font_path = download_testdata(REPO_URL + 'data/' + 'arial.ttf' + '?raw=true',
-                              'arial.ttf',
-                              module='data')
+coco_path = download_testdata(
+    REPO_URL + "data/" + "coco.names" + "?raw=true", "coco.names", module="data"
+)
+font_path = download_testdata(
+    REPO_URL + "data/" + "arial.ttf" + "?raw=true", "arial.ttf", module="data"
+)
 with open(coco_path) as f:
     content = f.readlines()
 names = [x.strip() for x in content]
@@ -154,10 +158,9 @@ if env.TARGET not in ["sim", "tsim"]:
     if not tracker_host or not tracker_port:
         remote = rpc.connect(device_host, int(device_port))
     else:
-        remote = autotvm.measure.request_remote(env.TARGET,
-                                                tracker_host,
-                                                int(tracker_port),
-                                                timeout=10000)
+        remote = autotvm.measure.request_remote(
+            env.TARGET, tracker_host, int(tracker_port), timeout=10000
+        )
     # Reconfigure the JIT runtime and FPGA.
     # You can program the FPGA with your own custom bitstream
     # by passing the path to the bitstream file instead of None.
@@ -192,11 +195,11 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
 
 # Load pre-configured AutoTVM schedules
 with autotvm.tophub.context(target):
-    net = __darknetffi__.dlopen(darknet_lib_path).load_network(cfg_path.encode('utf-8'),
-                                                               weights_path.encode('utf-8'),
-                                                               0)
+    net = __darknetffi__.dlopen(darknet_lib_path).load_network(
+        cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0
+    )
     dshape = (env.BATCH, net.c, net.h, net.w)
-    dtype = 'float32'
+    dtype = "float32"
 
     # Measure build start time
     build_start = time.time()
@@ -205,13 +208,15 @@ with autotvm.tophub.context(target):
     mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=dshape)
 
     if target.device_name == "vta":
-    # Perform quantization in Relay
-    # Note: We set opt_level to 3 in order to fold batch norm
+        # Perform quantization in Relay
+        # Note: We set opt_level to 3 in order to fold batch norm
         with tvm.transform.PassContext(opt_level=3):
-            with relay.quantize.qconfig(global_scale=33.0,
-                                        skip_conv_layers=[0],
-                                        store_lowbit_output=True,
-                                        round_for_shift=True):
+            with relay.quantize.qconfig(
+                global_scale=33.0,
+                skip_conv_layers=[0],
+                store_lowbit_output=True,
+                round_for_shift=True,
+            ):
                 mod = relay.quantize.quantize(mod, params=params)
             # Perform graph packing and constant folding for VTA target
             mod = graph_pack(
@@ -222,17 +227,16 @@ with autotvm.tophub.context(target):
                 start_name=pack_dict[MODEL_NAME][0],
                 stop_name=pack_dict[MODEL_NAME][1],
                 start_name_idx=pack_dict[MODEL_NAME][2],
-                stop_name_idx=pack_dict[MODEL_NAME][3])
+                stop_name_idx=pack_dict[MODEL_NAME][3],
+            )
     else:
         mod = mod["main"]
 
     # Compile Relay program with AlterOpLayout disabled
     with vta.build_config(disabled_pass={"AlterOpLayout"}):
         graph, lib, params = relay.build(
-            mod,
-            target=target,
-            params=params,
-            target_host=env.target_host)
+            mod, target=target, params=params, target_host=env.target_host
+        )
 
     # Measure Relay build time
     build_time = time.time() - build_start
@@ -253,8 +257,8 @@ with autotvm.tophub.context(target):
 # We run detect on an downloaded image
 # Download test image
 [neth, netw] = dshape[2:]
-test_image = 'person.jpg'
-img_url = REPO_URL + 'data/' + test_image + '?raw=true'
+test_image = "person.jpg"
+img_url = REPO_URL + "data/" + test_image + "?raw=true"
 img_path = download_testdata(img_url, test_image, "data")
 data = darknet.load_image(img_path, neth, netw).transpose(1, 2, 0)
 
@@ -266,13 +270,13 @@ data = data[np.newaxis, :]
 data = np.repeat(data, env.BATCH, axis=0)
 
 # Set the network parameters and inputs
-m.set_input('data', data)
+m.set_input("data", data)
 m.set_input(**params)
 
 # Perform inference and gather execution statistics
 # More on: :py:method:`tvm.runtime.Module.time_evaluator`
-num = 4 # number of times we run module for a single measurement
-rep = 3 # number of measurements (we derive std dev from this)
+num = 4  # number of times we run module for a single measurement
+rep = 3  # number of measurements (we derive std dev from this)
 timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
 
 if env.TARGET in ["sim", "tsim"]:
@@ -290,7 +294,7 @@ else:
     std = np.std(tcost.results) * 1000
     mean = tcost.mean * 1000
     print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
-    print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
+    print("Average per sample inference time: %.2fms" % (mean / env.BATCH))
 
 # Get detection results from out
 thresh = 0.5
@@ -298,33 +302,23 @@ nms_thresh = 0.45
 tvm_out = []
 for i in range(2):
     layer_out = {}
-    layer_out['type'] = 'Yolo'
+    layer_out["type"] = "Yolo"
     # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
-    layer_attr = m.get_output(i*4+3).asnumpy()
-    layer_out['biases'] = m.get_output(i*4+2).asnumpy()
-    layer_out['mask'] = m.get_output(i*4+1).asnumpy()
-    out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
-                 layer_attr[2], layer_attr[3])
-    layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape)
-    layer_out['classes'] = layer_attr[4]
+    layer_attr = m.get_output(i * 4 + 3).asnumpy()
+    layer_out["biases"] = m.get_output(i * 4 + 2).asnumpy()
+    layer_out["mask"] = m.get_output(i * 4 + 1).asnumpy()
+    out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3])
+    layer_out["output"] = m.get_output(i * 4).asnumpy().reshape(out_shape)
+    layer_out["classes"] = layer_attr[4]
     tvm_out.append(layer_out)
     thresh = 0.560
 
 # Show detection results
 img = darknet.load_image_color(img_path)
 _, im_h, im_w = img.shape
-dets = yolo_detection.fill_network_boxes((netw, neth),
-                                         (im_w, im_h),
-                                         thresh,
-                                         1,
-                                         tvm_out)
+dets = yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out)
 last_layer = net.layers[net.n - 1]
 yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
-yolo_detection.draw_detections(font_path,
-                               img,
-                               dets,
-                               thresh,
-                               names,
-                               last_layer.classes)
+yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
 plt.imshow(img.transpose(1, 2, 0))
 plt.show()
index 024e179..77fc805 100644 (file)
@@ -205,11 +205,12 @@ ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
 # Describe the in-VTA matrix multiplication
 C_buf = te.compute(
     (o, m, env.BATCH, env.BLOCK_OUT),
-    lambda bo, co, bi, ci:
-        te.sum(A_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
-                B_buf[co, ko, ci, ki].astype(env.acc_dtype),
-                axis=[ko, ki]),
-    name="C_buf")
+    lambda bo, co, bi, ci: te.sum(
+        A_buf[bo, ko, bi, ki].astype(env.acc_dtype) * B_buf[co, ko, ci, ki].astype(env.acc_dtype),
+        axis=[ko, ki],
+    ),
+    name="C_buf",
+)
 
 ######################################################################
 # Casting the Results
@@ -236,9 +237,8 @@ C_buf = te.compute(
 
 # Cast to output type, and send to main memory
 C = te.compute(
-    (o, m, env.BATCH, env.BLOCK_OUT),
-    lambda *i: C_buf(*i).astype(env.inp_dtype),
-    name="C")
+    (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: C_buf(*i).astype(env.inp_dtype), name="C"
+)
 
 ######################################################################
 # This concludes the computation declaration part of this tutorial.
@@ -369,12 +369,8 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # by the VTA runtime JIT compiler.
 
 s[C_buf].reorder(
-    ko,
-    s[C_buf].op.axis[0],
-    s[C_buf].op.axis[1],
-    s[C_buf].op.axis[2],
-    s[C_buf].op.axis[3],
-    ki)
+    ko, s[C_buf].op.axis[0], s[C_buf].op.axis[1], s[C_buf].op.axis[2], s[C_buf].op.axis[3], ki
+)
 s[C_buf].tensorize(s[C_buf].op.axis[2], env.gemm)
 
 # Let's take a look at the finalized schedule
@@ -422,16 +418,12 @@ f = remote.load_module("gemm.o")
 ctx = remote.ext_dev(0)
 
 # Initialize the A and B arrays randomly in the int range of (-128, 128]
-A_orig = np.random.randint(
-    -128, 128, size=(o * env.BATCH, n * env.BLOCK_IN)).astype(A.dtype)
-B_orig = np.random.randint(
-    -128, 128, size=(m * env.BLOCK_OUT, n * env.BLOCK_IN)).astype(B.dtype)
+A_orig = np.random.randint(-128, 128, size=(o * env.BATCH, n * env.BLOCK_IN)).astype(A.dtype)
+B_orig = np.random.randint(-128, 128, size=(m * env.BLOCK_OUT, n * env.BLOCK_IN)).astype(B.dtype)
 
 # Apply packing to the A and B arrays from a 2D to a 4D packed layout
-A_packed = A_orig.reshape(
-    o, env.BATCH, n, env.BLOCK_IN).transpose((0, 2, 1, 3))
-B_packed = B_orig.reshape(
-    m, env.BLOCK_OUT, n, env.BLOCK_IN).transpose((0, 2, 1, 3))
+A_packed = A_orig.reshape(o, env.BATCH, n, env.BLOCK_IN).transpose((0, 2, 1, 3))
+B_packed = B_orig.reshape(m, env.BLOCK_OUT, n, env.BLOCK_IN).transpose((0, 2, 1, 3))
 
 # Format the input/output arrays with tvm.nd.array to the DLPack standard
 A_nd = tvm.nd.array(A_packed, ctx)
@@ -452,10 +444,8 @@ f(A_nd, B_nd, C_nd)
 # matrix multiplication indeed is correct
 
 # Compute reference result with numpy
-C_ref = np.dot(A_orig.astype(env.acc_dtype),
-               B_orig.T.astype(env.acc_dtype)).astype(C.dtype)
-C_ref = C_ref.reshape(
-    o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
+C_ref = np.dot(A_orig.astype(env.acc_dtype), B_orig.T.astype(env.acc_dtype)).astype(C.dtype)
+C_ref = C_ref.reshape(o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
 np.testing.assert_equal(C_ref, C_nd.asnumpy())
 
 # Print stats
index d364fef..05479c3 100644 (file)
@@ -143,78 +143,72 @@ assert in_channels % env.BLOCK_IN == 0
 assert out_channels % env.BLOCK_OUT == 0
 
 # Input feature map: (N, IC, H, W, n, ic)
-data_shape = (batch_size // env.BATCH,
-              in_channels // env.BLOCK_IN,
-              height,
-              width,
-              env.BATCH,
-              env.BLOCK_IN)
+data_shape = (
+    batch_size // env.BATCH,
+    in_channels // env.BLOCK_IN,
+    height,
+    width,
+    env.BATCH,
+    env.BLOCK_IN,
+)
 # Kernel: (OC, IC, H, W, oc, ic)
-kernel_shape = (out_channels // env.BLOCK_OUT,
-                in_channels // env.BLOCK_IN,
-                kernel_h,
-                kernel_w,
-                env.BLOCK_OUT,
-                env.BLOCK_IN)
+kernel_shape = (
+    out_channels // env.BLOCK_OUT,
+    in_channels // env.BLOCK_IN,
+    kernel_h,
+    kernel_w,
+    env.BLOCK_OUT,
+    env.BLOCK_IN,
+)
 # Derive output feature map dimensions
 fout_height = (height + 2 * pad_h - kernel_h) // stride_h + 1
 fout_width = (width + 2 * pad_w - kernel_w) // stride_w + 1
 # Output feature map: (N, OC, H, W, n, oc)
-output_shape = (batch_size // env.BATCH,
-                out_channels // env.BLOCK_OUT,
-                fout_height,
-                fout_width,
-                env.BATCH,
-                env.BLOCK_OUT)
+output_shape = (
+    batch_size // env.BATCH,
+    out_channels // env.BLOCK_OUT,
+    fout_height,
+    fout_width,
+    env.BATCH,
+    env.BLOCK_OUT,
+)
 
 # Convolution reduction axes
-dy = te.reduce_axis((0, kernel_h), name='dy')
-dx = te.reduce_axis((0, kernel_w), name='dx')
-ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name='ic')
-ic_tns = te.reduce_axis((0, env.BLOCK_IN), name='ic_tns')
+dy = te.reduce_axis((0, kernel_h), name="dy")
+dx = te.reduce_axis((0, kernel_w), name="dx")
+ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
+ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")
 
 # Input placeholder tensors
-data = te.placeholder(data_shape,
-                       name="data",
-                       dtype=env.inp_dtype)
-kernel = te.placeholder(kernel_shape,
-                         name="kernel",
-                         dtype=env.wgt_dtype)
+data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
+kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
 
 # Copy buffers:
 #   Apply spatial padding to input feature map
-data_buf = topi.nn.pad(data,
-                       [0, 0, pad_h, pad_w, 0, 0],
-                       name="data_buf")
+data_buf = topi.nn.pad(data, [0, 0, pad_h, pad_w, 0, 0], name="data_buf")
 kernel_buf = te.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf")
 
 # Declare 2D convolution
 res_conv = te.compute(
     output_shape,
     lambda bo, co, i, j, bi, ci: te.sum(
-      data_buf[bo, ic, i*stride_h+dy, j*stride_w+dx, bi, ic_tns].astype(env.acc_dtype) *
-      kernel_buf[co, ic, dy, dx, ci, ic_tns].astype(env.acc_dtype),
-    axis=[ic, dy, dx, ic_tns]),
-    name="res_conv")
+        data_buf[bo, ic, i * stride_h + dy, j * stride_w + dx, bi, ic_tns].astype(env.acc_dtype)
+        * kernel_buf[co, ic, dy, dx, ci, ic_tns].astype(env.acc_dtype),
+        axis=[ic, dy, dx, ic_tns],
+    ),
+    name="res_conv",
+)
 
 # Add shift stage for fix-point normalization
-res_shr = te.compute(output_shape,
-                      lambda *i: res_conv(*i) >> 8,
-                      name="res_shr")
+res_shr = te.compute(output_shape, lambda *i: res_conv(*i) >> 8, name="res_shr")
 
 # Apply clipping between (0, input max value)
 inp_max = (1 << (env.INP_WIDTH - 1)) - 1
-res_max = te.compute(output_shape,
-                      lambda *i: tvm.te.max(res_shr(*i), 0),
-                      "res_max")
-res_min = te.compute(output_shape,
-                      lambda *i: tvm.te.min(res_max(*i), inp_max),
-                      "res_min")
+res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
+res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")
 
 # Result Tensor
-res = te.compute(output_shape,
-                  lambda *i: res_min(*i).astype(env.inp_dtype),
-                  name="res")
+res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
 
 
 ######################################################################
@@ -385,28 +379,27 @@ ctx = remote.ext_dev(0)
 
 # Initialize the data and kernel arrays randomly in the int range
 # of (-128, 128] in NCHW layout
-data_np = np.random.randint(
-    -128, 128,
-    size=(batch_size, in_channels, height, width)).astype(data.dtype)
+data_np = np.random.randint(-128, 128, size=(batch_size, in_channels, height, width)).astype(
+    data.dtype
+)
 kernel_np = np.random.randint(
-    -128, 128,
-    size=(out_channels, in_channels, kernel_h, kernel_w)).astype(kernel.dtype)
+    -128, 128, size=(out_channels, in_channels, kernel_h, kernel_w)
+).astype(kernel.dtype)
 
 # Apply packing to the data and kernel arrays from a 2D NCHW
 # to a 4D NCHWnc packed layout
-data_packed = data_np.reshape(batch_size // env.BATCH,
-                              env.BATCH,
-                              in_channels // env.BLOCK_IN,
-                              env.BLOCK_IN,
-                              height,
-                              width).transpose((0, 2, 4, 5, 1, 3))
-
-kernel_packed = kernel_np.reshape(out_channels // env.BLOCK_OUT,
-                                  env.BLOCK_OUT,
-                                  in_channels // env.BLOCK_IN,
-                                  env.BLOCK_IN,
-                                  kernel_h,
-                                  kernel_w).transpose((0, 2, 4, 5, 1, 3))
+data_packed = data_np.reshape(
+    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN, height, width
+).transpose((0, 2, 4, 5, 1, 3))
+
+kernel_packed = kernel_np.reshape(
+    out_channels // env.BLOCK_OUT,
+    env.BLOCK_OUT,
+    in_channels // env.BLOCK_IN,
+    env.BLOCK_IN,
+    kernel_h,
+    kernel_w,
+).transpose((0, 2, 4, 5, 1, 3))
 
 # Format the input/output arrays with tvm.nd.array to the DLPack standard
 data_nd = tvm.nd.array(data_packed, ctx)
@@ -421,19 +414,25 @@ if env.TARGET in ["sim", "tsim"]:
 f(data_nd, kernel_nd, res_nd)
 
 # Verify against numpy implementation
-res_ref = conv2d_nchw_python(data_np.astype(env.acc_dtype),
-                            kernel_np.astype(env.acc_dtype),
-                            (stride_h, stride_w),
-                            (pad_h, pad_w)).astype(env.acc_dtype)
+res_ref = conv2d_nchw_python(
+    data_np.astype(env.acc_dtype),
+    kernel_np.astype(env.acc_dtype),
+    (stride_h, stride_w),
+    (pad_h, pad_w),
+).astype(env.acc_dtype)
 res_ref = res_ref >> env.INP_WIDTH
 res_ref = np.clip(res_ref, 0, inp_max)
 res_ref = res_ref.astype(res.dtype)
-res_ref = res_ref.reshape((batch_size // env.BATCH,
-                           env.BATCH,
-                           out_channels // env.BLOCK_OUT,
-                           env.BLOCK_OUT,
-                           fout_height,
-                           fout_width)).transpose((0, 2, 4, 5, 1, 3))
+res_ref = res_ref.reshape(
+    (
+        batch_size // env.BATCH,
+        env.BATCH,
+        out_channels // env.BLOCK_OUT,
+        env.BLOCK_OUT,
+        fout_height,
+        fout_width,
+    )
+).transpose((0, 2, 4, 5, 1, 3))
 tvm.testing.assert_allclose(res_ref, res_nd.asnumpy())
 
 # Print stats
index 77b0381..28600d4 100644 (file)
@@ -105,62 +105,49 @@ assert in_channels % env.BLOCK_IN == 0
 assert out_channels % env.BLOCK_OUT == 0
 
 # Let's derive the tiled input tensor shapes
-data_shape = (batch_size // env.BATCH,
-              in_channels // env.BLOCK_IN,
-              env.BATCH,
-              env.BLOCK_IN)
-weight_shape = (out_channels // env.BLOCK_OUT,
-                in_channels // env.BLOCK_IN,
-                env.BLOCK_OUT,
-                env.BLOCK_IN)
-output_shape = (batch_size // env.BATCH,
-                out_channels // env.BLOCK_OUT,
-                env.BATCH,
-                env.BLOCK_OUT)
+data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
+weight_shape = (
+    out_channels // env.BLOCK_OUT,
+    in_channels // env.BLOCK_IN,
+    env.BLOCK_OUT,
+    env.BLOCK_IN,
+)
+output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
 num_ops = in_channels * out_channels * batch_size * 2
 
 # Reduction axes
-ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name='ic')
-ic_tns = te.reduce_axis((0, env.BLOCK_IN), name='ic_tns')
+ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
+ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")
 
 # Input placeholder tensors
 data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
 weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
 
 # Copy buffers
-data_buf = te.compute(data_shape,
-                       lambda *i: data(*i),
-                       "data_buf")
-weight_buf = te.compute(weight_shape,
-                         lambda *i: weight(*i),
-                         "weight_buf")
+data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
+weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
 
 # Declare matrix multiply computation
-res_gemm = te.compute(output_shape,
-                       lambda bo, co, bi, ci: te.sum(
-                            data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype) *
-                            weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
-                            axis=[ic, ic_tns]),
-                       name="res_gem")
+res_gemm = te.compute(
+    output_shape,
+    lambda bo, co, bi, ci: te.sum(
+        data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)
+        * weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
+        axis=[ic, ic_tns],
+    ),
+    name="res_gem",
+)
 
 # Add shift stage for fix-point normalization
-res_shr = te.compute(output_shape,
-                      lambda *i: res_gemm(*i) >> env.INP_WIDTH,
-                      name="res_shr")
+res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name="res_shr")
 
 # Apply clipping between (0, input max value)
-inp_max = (1<<(env.INP_WIDTH-1))-1
-res_max = te.compute(output_shape,
-                      lambda *i: tvm.te.max(res_shr(*i), 0),
-                      "res_max")
-res_min = te.compute(output_shape,
-                      lambda *i: tvm.te.min(res_max(*i), inp_max),
-                      "res_min")
+inp_max = (1 << (env.INP_WIDTH - 1)) - 1
+res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
+res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")
 
 # Apply typecast to input data type before sending results back
-res = te.compute(output_shape,
-                  lambda *i: res_min(*i).astype(env.inp_dtype),
-                  name="res")
+res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
 
 ######################################################################
 # Scheduling the Computation
@@ -333,20 +320,16 @@ f = remote.load_module("gemm.o")
 ctx = remote.ext_dev(0)
 
 # Initialize the data and weight arrays randomly in the int range of (-128, 128]
-data_np = np.random.randint(
-    -128, 128, size=(batch_size, in_channels)).astype(data.dtype)
-weight_np = np.random.randint(
-    -128, 128, size=(out_channels, in_channels)).astype(weight.dtype)
+data_np = np.random.randint(-128, 128, size=(batch_size, in_channels)).astype(data.dtype)
+weight_np = np.random.randint(-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)
 
 # Apply packing to the data and weight arrays from a 2D to a 4D packed layout
-data_packed = data_np.reshape(batch_size // env.BATCH,
-                              env.BATCH,
-                              in_channels // env.BLOCK_IN,
-                              env.BLOCK_IN).transpose((0, 2, 1, 3))
-weight_packed = weight_np.reshape(out_channels // env.BLOCK_OUT,
-                                  env.BLOCK_OUT,
-                                  in_channels // env.BLOCK_IN,
-                                  env.BLOCK_IN).transpose((0, 2, 1, 3))
+data_packed = data_np.reshape(
+    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN
+).transpose((0, 2, 1, 3))
+weight_packed = weight_np.reshape(
+    out_channels // env.BLOCK_OUT, env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_IN
+).transpose((0, 2, 1, 3))
 
 # Format the input/output arrays with tvm.nd.array to the DLPack standard
 data_nd = tvm.nd.array(data_packed, ctx)
@@ -361,15 +344,13 @@ if env.TARGET in ["sim", "tsim"]:
 f(data_nd, weight_nd, res_nd)
 
 # Verify against numpy implementation
-res_ref = np.dot(data_np.astype(env.acc_dtype),
-                 weight_np.T.astype(env.acc_dtype))
+res_ref = np.dot(data_np.astype(env.acc_dtype), weight_np.T.astype(env.acc_dtype))
 res_ref = res_ref >> env.INP_WIDTH
 res_ref = np.clip(res_ref, 0, inp_max)
 res_ref = res_ref.astype(res.dtype)
-res_ref = res_ref.reshape(batch_size // env.BATCH,
-                          env.BATCH,
-                          out_channels // env.BLOCK_OUT,
-                          env.BLOCK_OUT).transpose((0, 2, 1, 3))
+res_ref = res_ref.reshape(
+    batch_size // env.BATCH, env.BATCH, out_channels // env.BLOCK_OUT, env.BLOCK_OUT
+).transpose((0, 2, 1, 3))
 np.testing.assert_equal(res_ref, res_nd.asnumpy())
 
 # Print stats
index ab41687..46b050f 100644 (file)
@@ -178,7 +178,8 @@ B_buf = te.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: B(*i), "B_buf")
 C_buf = te.compute(
     (o, m, env.BATCH, env.BLOCK_OUT),
     lambda *i: A_buf(*i).astype(env.acc_dtype) + B_buf(*i).astype(env.acc_dtype),
-    name="C_buf")
+    name="C_buf",
+)
 
 ######################################################################
 # Casting the Results
@@ -201,9 +202,8 @@ C_buf = te.compute(
 
 # Cast to output type, and send to main memory
 C = te.compute(
-    (o, m, env.BATCH, env.BLOCK_OUT),
-    lambda *i: C_buf(*i).astype(env.inp_dtype),
-    name="C")
+    (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: C_buf(*i).astype(env.inp_dtype), name="C"
+)
 
 ######################################################################
 # This concludes the computation declaration part of this tutorial.
@@ -353,16 +353,12 @@ f = remote.load_module("vadd.o")
 ctx = remote.ext_dev(0)
 
 # Initialize the A and B arrays randomly in the int range of (-128, 128]
-A_orig = np.random.randint(
-    -128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(A.dtype)
-B_orig = np.random.randint(
-    -128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(B.dtype)
+A_orig = np.random.randint(-128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(A.dtype)
+B_orig = np.random.randint(-128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(B.dtype)
 
 # Apply packing to the A and B arrays from a 2D to a 4D packed layout
-A_packed = A_orig.reshape(
-    o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
-B_packed = B_orig.reshape(
-    o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
+A_packed = A_orig.reshape(o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
+B_packed = B_orig.reshape(o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
 
 # Format the input/output arrays with tvm.nd.array to the DLPack standard
 A_nd = tvm.nd.array(A_packed, ctx)
@@ -380,8 +376,7 @@ f(A_nd, B_nd, C_nd)
 
 # Compute reference result with numpy
 C_ref = (A_orig.astype(env.acc_dtype) + B_orig.astype(env.acc_dtype)).astype(C.dtype)
-C_ref = C_ref.reshape(
-    o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
+C_ref = C_ref.reshape(o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
 np.testing.assert_equal(C_ref, C_nd.asnumpy())
 print("Successful vector add test!")
 
index f529e04..fa086e6 100644 (file)
@@ -27,8 +27,8 @@ def prepare_test_libs(base_path):
     if not tvm.runtime.enabled(target):
         raise RuntimeError("Target %s is not enbaled" % target)
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
     fadd = tvm.build(s, [A, B], target, name="add_one")
 
index 6bef33a..5de8cf8 100644 (file)
@@ -40,8 +40,8 @@ def test_rpc():
         raise RuntimeError("Target %s is not enbaled" % target_host)
 
     n = 2048
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     num_thread = 2
@@ -49,7 +49,6 @@ def test_rpc():
     s[B].bind(xi, te.thread_axis("threadIdx.x"))
     s[B].bind(xo, te.thread_axis("blockIdx.x"))
 
-
     fadd = tvm.build(s, [A, B], target_device, target_host=target_host, name="addone")
     temp = util.tempdir()
 
@@ -57,8 +56,12 @@ def test_rpc():
     fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
 
     wasm_binary = open(wasm_path, "rb").read()
-    remote = rpc.connect(proxy_host, proxy_port, key="wasm",
-                         session_constructor_args=["rpc.WasmSession", wasm_binary])
+    remote = rpc.connect(
+        proxy_host,
+        proxy_port,
+        key="wasm",
+        session_constructor_args=["rpc.WasmSession", wasm_binary],
+    )
 
     def check(remote):
         # basic function checks.
@@ -76,4 +79,5 @@ def test_rpc():
 
     check(remote)
 
+
 test_rpc()
index 80f39a3..6729964 100644 (file)
@@ -29,6 +29,7 @@ import numpy as np
 proxy_host = "localhost"
 proxy_port = 9090
 
+
 def test_rpc():
     if not tvm.runtime.enabled("rpc"):
         return
@@ -37,8 +38,8 @@ def test_rpc():
     if not tvm.runtime.enabled(target):
         raise RuntimeError("Target %s is not enbaled" % target)
     n = te.var("n")
-    A = te.placeholder((n,), name='A')
-    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
     s = te.create_schedule(B.op)
 
     fadd = tvm.build(s, [A, B], target, name="addone")
@@ -49,19 +50,23 @@ def test_rpc():
 
     wasm_binary = open(wasm_path, "rb").read()
 
-    remote = rpc.connect(proxy_host, proxy_port, key="wasm",
-                         session_constructor_args=["rpc.WasmSession", wasm_binary])
+    remote = rpc.connect(
+        proxy_host,
+        proxy_port,
+        key="wasm",
+        session_constructor_args=["rpc.WasmSession", wasm_binary],
+    )
 
     def check(remote):
         # basic function checks.
         faddone = remote.get_function("testing.asyncAddOne")
         fecho = remote.get_function("testing.echo")
-        assert(faddone(100) == 101)
-        assert(fecho(1, 2, 3) == 1)
-        assert(fecho(1, 2, 3) == 1)
-        assert(fecho(100, 2, 3) == 100)
-        assert(fecho("xyz") == "xyz")
-        assert(bytes(fecho(bytearray(b"123"))) == b"123")
+        assert faddone(100) == 101
+        assert fecho(1, 2, 3) == 1
+        assert fecho(1, 2, 3) == 1
+        assert fecho(100, 2, 3) == 100
+        assert fecho("xyz") == "xyz"
+        assert bytes(fecho(bytearray(b"123"))) == b"123"
 
         # run the generated library.
         f1 = remote.system_lib()
@@ -76,9 +81,10 @@ def test_rpc():
         time_f = f1.time_evaluator("addone", ctx, number=100, repeat=10)
         time_f(a, b)
         cost = time_f(a, b).mean
-        print('%g secs/op' % cost)
+        print("%g secs/op" % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
     check(remote)
 
+
 test_rpc()